当前位置: 代码迷 >> 综合 >> torch.unsqueeze和 torch.squeeze() 详解
  详细解决方案

torch.unsqueeze和 torch.squeeze() 详解

热度:80   发布时间:2024-01-19 11:16:06.0

1. torch.unsqueeze 详解

torch.unsqueeze(input, dim, out=None)

  • 作用:扩展维度

返回一个新的张量,对输入的既定位置插入维度 1

  • 注意: 返回张量与输入张量共享内存,所以改变其中一个的内容会改变另一个。
如果dim为负,则将会被转化dim+input.dim()+1
  • 参数:
  • tensor (Tensor) – 输入张量
  • dim (int) – 插入维度的索引
  • out (Tensor, optional) – 结果张量

A dim value within the range [-input.dim() - 1, input.dim() + 1) (左闭右开)can be used.

 为何取值范围要如此设计呢?
 原因:方便操作
 0(-2)-行扩展
 1(-1)-列扩展
 正向:我们在0,1位置上扩展
 逆向:我们在-2,-1位置上扩展
 维度扩展:1维->2维,2维->3维,...,n维->n+1维
 维度降低:n维->n-1维,n-1维->n-2维,...,2维->1维

 以 1维->2维 为例,

 从【正向】的角度思考:

 torch.Size([4])
 最初的 tensor([1., 2., 3., 4.]) 是 1维,我们想让它扩展成 2维,那么,可以有两种扩展方式:

 一种是:扩展成 1行4列 ,即 tensor([[1., 2., 3., 4.]])
 针对第一种,扩展成 [1, 4]的形式,那么,在 dim=0 的位置上添加 1

 另一种是:扩展成 4行1列,即
 tensor([[1.],
         [2.],
         [3.],
         [4.]])
 针对第二种,扩展成 [4, 1]的形式,那么,在dim=1的位置上添加 1

 从【逆向】的角度思考:
 原则:一般情况下, "-1" 是代表的是【最后一个元素】
 在上述的原则下,
 扩展成[1, 4]的形式,就变成了,在 dim=-2 的的位置上添加 1
 扩展成[4, 1]的形式,就变成了,在 dim=-1 的的位置上添加 1

torch.squeeze 详解

  • 作用:降维

torch.squeeze(input, dim=None, out=None)

将输入张量形状中的1 去除并返回。 如果输入是形如(A×1×B×1×C×1×D),那么输出形状就为: (A×B×C×D)

当给定dim时,那么挤压操作只在给定维度上。例如,输入形状为: (A×1×B), squeeze(input, 0) 将会保持张量不变,只有用 squeeze(input, 1),形状会变成 (A×B)。

  • 注意: 返回张量与输入张量共享内存,所以改变其中一个的内容会改变另一个。
  • 参数:
  • input (Tensor) – 输入张量
  • dim (int, optional) – 如果给定,则input只会在给定维度挤压
  • out (Tensor, optional) – 输出张量
为何只去掉 1 呢?

多维张量本质上就是一个变换,如果维度是 1 ,那么,1 仅仅起到扩充维度的作用,而没有其他用途,因而,在进行降维操作时,为了加快计算,是可以去掉这些 1 的维度。

参考资料:https://pytorch.org/docs/stable/generated/torch.squeeze.html?highlight=torch%20squeeze#torch.squeeze

https://zhuanlan.zhihu.com/p/86763381

  相关解决方案