squeeze()函数用于压缩维度,unsqueeze()用于扩充维度。
1. unsqueeze()用法介绍
unsqueeze()函数用于扩充维度,它有一个参数unsqueeze(dim),表示在第dim维上扩充维度。
下面的代码中arr维度是(2,3,4),在第0维进行扩充,代码第三行维度是(1,2,3,4),通过第四行代码的输出结果可以看出,输出张量与原张量不共享内存,可通过第五行arr = arr.unsqueeze(0)
来改变arr维度。
arr = torch.arange(1, 25).view(2,3,4)
print(arr)
print(arr.unsqueeze(0))
print(arr)
arr = arr.unsqueeze(0)
print(arr)
三次输出结果依次是:
tensor([[[ 1, 2, 3, 4],[ 5, 6, 7, 8],[ 9, 10, 11, 12]],[[13, 14, 15, 16],[17, 18, 19, 20],[21, 22, 23, 24]]])
tensor([[[[ 1, 2, 3, 4],[ 5, 6, 7, 8],[ 9, 10, 11, 12]],[[13, 14, 15, 16],[17, 18, 19, 20],[21, 22, 23, 24]]]])
tensor([[[ 1, 2, 3, 4],[ 5, 6, 7, 8],[ 9, 10, 11, 12]],[[13, 14, 15, 16],[17, 18, 19, 20],[21, 22, 23, 24]]])
tensor([[[[ 1, 2, 3, 4],[ 5, 6, 7, 8],[ 9, 10, 11, 12]],[[13, 14, 15, 16],[17, 18, 19, 20],[21, 22, 23, 24]]]])
2. squeeze()函数的用法
对维度为1的维进行压缩,对其他大小的维不起作用。代码第一行a的维度是(1,3,4),对第一维进行压缩,维度变成(3, 4),由于a和a.squeeze(0)对应的tensor不同享内存,所以要真正改变a的话,需要通过a = a.squeeze(0)
。
a = torch.tensor([[[1,5,62,54], [2,6,2,6], [2,65,2,6]]])
print(a.shape)
print(a.squeeze(0).shape)
print(a.shape)
输出结果如下:
torch.Size([1, 3, 4])
torch.Size([3, 4])
torch.Size([1, 3, 4])
3. unsqueeze()和expand()组合使用,经常后边接一个操作——拼接(cat)
a和b的维度可以看作(batch_size, channel, height, width),相当于a是batch size为3的3张图像,b是batch size为1的一张图像,对a和b按照channel维度进行拼接。首先需要扩充维度,然后再拼接。
import torch
a = torch.arange(1, 25).view(3,2,2,2)
b = torch.arange(25, 33).view(1,2,2,2)
a = a.unsqueeze(0).expand(1,-1,-1,-1,-1)
print(a)
b = b.unsqueeze(1).expand(-1, 3, -1, -1, -1)
print(b)
print(torch.cat([a,b], dim=2))
结果如下:
torch.Size([1, 3, 2, 2, 2])
torch.Size([1, 3, 2, 2, 2])
torch.Size([1, 3, 4, 2, 2])