当前位置: 代码迷 >> 综合 >> Pytorch中的 flatten() ,squeeze() 和 unsqueeze() 的区分
  详细解决方案

Pytorch中的 flatten() ,squeeze() 和 unsqueeze() 的区分

热度:84   发布时间:2023-12-08 07:14:43.0

Pytorch中的 flatten,squeeze 和 unsqueeze 的区分

  • 解释
  • 举例:
    • 原数据 T 的输出:
    • 原数据 T 的 flatten() 输出
    • 原数据 T 的 squeeze() 输出
    • 原数据 T 的 unsqueeze() 输出
  • 参考链接

解释

flatten() 用于将数据展开。

squeeze() 用于将数据进行压缩,移除某个维度

  • Compute torch.squeeze(input). It squeezes (removes) the size 1 and returns a tensor with all other dimensions of the input tensor.

unsqueeze() 用于将数据解压缩,扩充一个维度

  • Compute torch.unsqueeze(input, dim). It inserts a new dimension of size 1 at the given dim and returns the tensor.

举例:

# Python program to squeeze and unsqueeze a tensor
# import necessary library
import torch# Create a tensor of all one
T = torch.ones(2,1,2) # size 2x1x2
print("Original Tensor T:\n", T )
print("Size of T:", T.size())

原数据 T 的输出:

在这里插入图片描述

原数据 T 的 flatten() 输出

T.flatten()
torch.flatten() 

在这里插入图片描述

原数据 T 的 squeeze() 输出

T.squeeze(0)
troch.squeeze(T)

注意观察,两者的维度,(中括号个数)
在这里插入图片描述

原数据 T 的 unsqueeze() 输出

T.unsqueeze(dim=0)
torch.unsqueeze(T, dim=0)

在这里插入图片描述

参考链接

https://www.tutorialspoint.com/how-to-squeeze-and-unsqueeze-a-tensor-in-pytorch

https://pytorch.org/docs/stable/generated/torch.squeeze.html