当前位置: 代码迷 >> 综合 >> Pytorch学习(十五)squeeze()和unsqueeze()的用法
  详细解决方案

Pytorch学习(十五)squeeze()和unsqueeze()的用法

热度:52   发布时间:2023-11-04 16:04:28.0

squeeze()函数用于压缩维度,unsqueeze()用于扩充维度。

1. unsqueeze()用法介绍

unsqueeze()函数用于扩充维度,它有一个参数unsqueeze(dim),表示在第dim维上扩充维度。
下面的代码中arr维度是(2,3,4),在第0维进行扩充,代码第三行维度是(1,2,3,4),通过第四行代码的输出结果可以看出,输出张量与原张量不共享内存,可通过第五行arr = arr.unsqueeze(0)来改变arr维度。

arr = torch.arange(1, 25).view(2,3,4)
print(arr)
print(arr.unsqueeze(0))
print(arr)
arr = arr.unsqueeze(0)
print(arr)
三次输出结果依次是:
tensor([[[ 1,  2,  3,  4],[ 5,  6,  7,  8],[ 9, 10, 11, 12]],[[13, 14, 15, 16],[17, 18, 19, 20],[21, 22, 23, 24]]])
tensor([[[[ 1,  2,  3,  4],[ 5,  6,  7,  8],[ 9, 10, 11, 12]],[[13, 14, 15, 16],[17, 18, 19, 20],[21, 22, 23, 24]]]])
tensor([[[ 1,  2,  3,  4],[ 5,  6,  7,  8],[ 9, 10, 11, 12]],[[13, 14, 15, 16],[17, 18, 19, 20],[21, 22, 23, 24]]])
tensor([[[[ 1,  2,  3,  4],[ 5,  6,  7,  8],[ 9, 10, 11, 12]],[[13, 14, 15, 16],[17, 18, 19, 20],[21, 22, 23, 24]]]])

2. squeeze()函数的用法

对维度为1的维进行压缩,对其他大小的维不起作用。代码第一行a的维度是(1,3,4),对第一维进行压缩,维度变成(3, 4),由于a和a.squeeze(0)对应的tensor不同享内存,所以要真正改变a的话,需要通过a = a.squeeze(0)

a = torch.tensor([[[1,5,62,54], [2,6,2,6], [2,65,2,6]]])
print(a.shape)
print(a.squeeze(0).shape)
print(a.shape)

输出结果如下:

torch.Size([1, 3, 4])
torch.Size([3, 4])
torch.Size([1, 3, 4])

3. unsqueeze()和expand()组合使用,经常后边接一个操作——拼接(cat)

a和b的维度可以看作(batch_size, channel, height, width),相当于a是batch size为3的3张图像,b是batch size为1的一张图像,对a和b按照channel维度进行拼接。首先需要扩充维度,然后再拼接。

import torch 
a = torch.arange(1, 25).view(3,2,2,2)
b = torch.arange(25, 33).view(1,2,2,2)
a = a.unsqueeze(0).expand(1,-1,-1,-1,-1)
print(a)
b = b.unsqueeze(1).expand(-1, 3, -1, -1, -1)
print(b)
print(torch.cat([a,b], dim=2))

结果如下:

torch.Size([1, 3, 2, 2, 2])
torch.Size([1, 3, 2, 2, 2])
torch.Size([1, 3, 4, 2, 2])