当前位置: 代码迷 >> 综合 >> 【pytorch系列】torch.unsqueeze()和torch.squeeze()
  详细解决方案

【pytorch系列】torch.unsqueeze()和torch.squeeze()

热度:54   发布时间:2024-01-04 03:26:47.0

torch.unsqueeze()

原型:

torch.unsqueeze(input, dim, out=None)	

参数:
tensor (Tensor) – 输入张量
dim (int) – 插入维度的索引
out (Tensor, optional) – 结果张量

或者:

Tensor2 = Tensor1.torch.unsqueeze( dim)	

参数:
dim (int) – 插入维度的索引

功能
返回一个新的tensor,这个tensor 在指定的位置被插入了一个大小为1的新维度
这个返回的tensor 和之前的 tensor 有着相同的数据

example

>>> x = torch.tensor([1,2,3,4])
>>> torch.unsqueeze(x,0,x1)
>>>> print(x1)
tensor([[1,2,3,4]])
>>> torch.unsqueeze(x,1,x2)
>>> print(x2)
tensor([[1],[2],[3],[4]])

torch.squeeze()

原型:

torch.squeeze(input, dim=None, out=None)

参数:
input (Tensor) – 输入张量
dim (int, optional) – 如果给定,则input只会在给定维度挤压
out (Tensor, optional) – 输出张量

或者:

Tensor2 = Tensor1.torch.squeeze(dim)	

参数:
dim (int, optional) – 如果给定,则input只会在给定维度挤压

功能
squeeze(dim)表示第dim维的维度值为1,则去掉该维度。否则tensor不变。