当前位置: 代码迷 >> 综合 >> 深度学习框架_PyTorch_torch.squeeze()函数和torch.unsqueeze()函数的用法
  详细解决方案

深度学习框架_PyTorch_torch.squeeze()函数和torch.unsqueeze()函数的用法

热度:56   发布时间:2023-12-16 01:23:26.0

torch.squeeze()函数的用法主要是对数据的维度进行压缩,去掉维数为1的维度。

torch.squeeze(x)是去掉x中所有维数为1的维度;x.squeeze(n)是去掉x中指定的维数为1的维度。

接下来我们在具体代码中了解:

>>> c
tensor([[[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.]]])
>>> c.size()
torch.Size([1, 6, 3])>>> torch.squeeze(c)
tensor([[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.]])
>>> c.size()
torch.Size([6, 3])>>> c =c.squeeze(0)
>>> c
tensor([[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.]])
>>> c.size()
torch.Size([6, 3])

torch.unsqueeze()函数主要对数据进行扩充。给指定位置加上维数为1的维度。

x.squeeze(n)就是在x中指定位置n加上维数为1的维度。

我们继续看代码:

>>> c
tensor([[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.]])
>>> c.size()
torch.Size([6, 3])>>> c = c.unsqueeze(0)
>>> c
tensor([[[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.]]])
>>> c.size()
torch.Size([1, 6, 3])>>> c = c.unsqueeze(1)
>>> c
tensor([[[[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.],[1., 1., 1.]]]])
>>> c.size()
torch.Size([1, 1, 6, 3])>>> c = c.unsqueeze(3)
>>> c
tensor([[[[[1., 1., 1.]],[[1., 1., 1.]],[[1., 1., 1.]],[[1., 1., 1.]],[[1., 1., 1.]],[[1., 1., 1.]]]]])
>>> c.size()
torch.Size([1, 1, 6, 1, 3])
  相关解决方案