当前位置: 代码迷 >> 综合 >> PyTorch:torch.Tensor.repeat()、expand()
  详细解决方案

PyTorch:torch.Tensor.repeat()、expand()

热度:50   发布时间:2024-03-05 21:28:26.0

目录

1、torch.Tensor.repeat()

2、torch.Tensor.expand()


1、torch.Tensor.repeat()

函数定义:

repeat(*sizes) → Tensor

作用:

在指定的维度上重复这个张量,即把这个维度的张量复制*sizes次。同时可以通过复制的形式扩展维度的数量。

注意:torch.Tensor.repeat方法与numpy.tile方法作用相似,而不是numpy.repeat!torch中与numpy.repeat类似的方法是torch.repeat_interleave!

区别:与expand的不同之处在于,repeat函数传入的参数直接就是对应维度要扩充的倍数,而不是最后的shape。

举例分析:

例1——对应(已存在的)维度的拓展。

import torcha = torch.tensor([[1], [2], [3]])  # 3 * 1
b = a.repeat(3, 2)  # torch.linspace(0, 10, 5)
print('a:\n', a)
print('shape of a', a.size())  # 原始shape = (3,1)
print('b:\n', b)
print('shape of b', b.size())  # 新的shape = (3*3,1*2),新增加的数据通过复制得到'''   运行结果   '''
a:tensor([[1],[2],[3]])
shape of a torch.Size([3, 1])  注: 原始shape = (3,1)
b:tensor([[1, 1],[2, 2],[3, 3],[1, 1],[2, 2],[3, 3],[1, 1],[2, 2],[3, 3]])
shape of b torch.Size([9, 2])  注: 新的shape = (3*3,1*2),新增加的数据通过复制得到

例2——带有(原始不存在的)维度数量拓展的用法:

import torch
a = torch.tensor([[1, 2], [3, 4], [5, 6]])  # 3 * 2
b = a.repeat(3, 2, 1)   # 在原始tensor的0维前拓展一个维度,并把原始tensor的第1维扩张2倍,都是通过复制来完成的
print('a:\n', a)
print('shape of a', a.size())  # 原始维度为 (3,2)
print('b:\n', b)
print('shape of b', b.size())  # 新的维度为 (3,2*2,2*1)=(3,4,2)'''   运行结果   '''
a:tensor([[1, 2],[3, 4],[5, 6]])
shape of a torch.Size([3, 2])   注:原始维度为 (3,2)
b:tensor([[[1, 2],[3, 4],[5, 6],[1, 2],[3, 4],[5, 6]],[[1, 2],[3, 4],[5, 6],[1, 2],[3, 4],[5, 6]],[[1, 2],[3, 4],[5, 6],[1, 2],[3, 4],[5, 6]]])
shape of b torch.Size([3, 6, 2])   注:新的维度为 (3,2*2,2*1)=(3,4,2)

2、torch.Tensor.expand()

函数定义:

expand(*sizes) → Tensor

作用:不仅可以对tensor指定的(已存在的)维度进行扩大(复制型扩大),扩大后的shape为*(size)。而且还有类似于unsqueeze的维度扩充功能,新增加的维度将会加在前面。

区别:与repeat不同之处在于,expand传入的参数直接就是将tensor扩大后的shape。

举例说明

a = torch.ones(3, 1)   # 创建3*1的全为1的tensor
b = a.expand(3, 2)     # 对a的维度1进行扩充
c = a.expand(2, 3, 2)  # 对a的维度1进行扩充,并在维度0前加一个维度 
print('a:', a)
print('shape of a:', a.shape)
print('b:', b)
print('shape of b:', b.shape)
print('c:', c)
print('shape of c:', c.shape)'''   运行结果   '''
a: tensor([[1.],[1.],[1.]])
shape of a: torch.Size([3, 1])b: tensor([[1., 1.],[1., 1.],[1., 1.]])
shape of b: torch.Size([3, 2])c: tensor([[[1., 1.],[1., 1.],[1., 1.]],[[1., 1.],[1., 1.],[1., 1.]]])
shape of c: torch.Size([2, 3, 2])