当前位置: 代码迷 >> 综合 >> torch.unsqueeze与torch.squeeze
  详细解决方案

torch.unsqueeze与torch.squeeze

热度:25   发布时间:2023-11-25 15:48:21.0

unsqueeze

扩充数据维度,在从0开始的指定位置上增加一维(维度为1)

x = torch.rand(2,3)
y = torch.unsqueeze(x, 1)
y = torch.unsqueeze(y, 0)print(x.shape)
print(y.shape)
>>torch.Size([2, 3])
>>torch.Size([1, 2, 1, 3])

也可以倒着数,比如torch.unsqueeze(x,-1),就是在最后添加一维

squeeze

维度压缩,在从0开始的指定位置上,去掉维数为1的的维度

  • 若不指定参数,删除所有为 1 的维度
  • 若指定参数 N
    • 如果第 N 个位置的维度为 1 ,则删除该维度
    • 否则,不受影响
x = torch.rand(2,3)#增加两个维度
y = torch.unsqueeze(x, 1)
y = torch.unsqueeze(y, 0)#若第二个位置的维度为 1,则删除。否则,不受影响
z = torch.squeeze(y, 2)#删除所有为 1 的维度
m = torch.squeeze(y)
print(x.shape)
>>torch.Size([2, 3])print(y.shape)
>>torch.Size([1, 2, 1, 3])print(z.shape)
>>torch.Size([1, 2, 3])print(m.shape)
>>torch.Size([2, 3])
  相关解决方案