torch.squeeze()函数的用法主要是对数据的维度进行压缩,去掉维数为1的维度。
torch.squeeze(x)是去掉x中所有维数为1的维度;x.squeeze(n)是去掉x中指定的维数为1的维度。
接下来我们在具体代码中了解:
>>> c
tensor([[[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.]]])
>>> c.size()
torch.Size([1, 6, 3])>>> torch.squeeze(c)
tensor([[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.]])
>>> c.size()
torch.Size([6, 3])>>> c =c.squeeze(0)
>>> c
tensor([[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.]])
>>> c.size()
torch.Size([6, 3])
torch.unsqueeze()函数主要对数据进行扩充。给指定位置加上维数为1的维度。
x.squeeze(n)就是在x中指定位置n加上维数为1的维度。
我们继续看代码:
>>> c
tensor([[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.]])
>>> c.size()
torch.Size([6, 3])>>> c = c.unsqueeze(0)
>>> c
tensor([[[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.]]])
>>> c.size()
torch.Size([1, 6, 3])>>> c = c.unsqueeze(1)
>>> c
tensor([[[[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.]]]])
>>> c.size()
torch.Size([1, 1, 6, 3])>>> c = c.unsqueeze(3)
>>> c
tensor([[[[[1., 1., 1.]],[[1., 1., 1.]],[[1., 1., 1.]],[[1., 1., 1.]],[[1., 1., 1.]],[[1., 1., 1.]]]]])
>>> c.size()
torch.Size([1, 1, 6, 1, 3])