unsqueeze
扩充数据维度,在从0开始的指定位置上增加一维(维度为1)
x = torch.rand(2,3)
y = torch.unsqueeze(x, 1)
y = torch.unsqueeze(y, 0)print(x.shape)
print(y.shape)
>>torch.Size([2, 3])
>>torch.Size([1, 2, 1, 3])
也可以倒着数,比如torch.unsqueeze(x,-1)
,就是在最后添加一维
squeeze
维度压缩,在从0开始的指定位置上,去掉维数为1的的维度
- 若不指定参数,删除所有为 1 的维度
- 若指定参数 N
- 如果第 N 个位置的维度为 1 ,则删除该维度
- 否则,不受影响
x = torch.rand(2,3)#增加两个维度
y = torch.unsqueeze(x, 1)
y = torch.unsqueeze(y, 0)#若第二个位置的维度为 1,则删除。否则,不受影响
z = torch.squeeze(y, 2)#删除所有为 1 的维度
m = torch.squeeze(y)
print(x.shape)
>>torch.Size([2, 3])print(y.shape)
>>torch.Size([1, 2, 1, 3])print(z.shape)
>>torch.Size([1, 2, 3])print(m.shape)
>>torch.Size([2, 3])