当前位置: 代码迷 >> 综合 >> Tensor——拼接与拆分
  详细解决方案

Tensor——拼接与拆分

热度:21   发布时间:2023-11-09 07:53:03.0

文章目录

    • 1. 拼接
      • (1). cat
      • (2). stack
    • 2. 拆分
      • (1). split
      • (2). chunk

1. 拼接

(1). cat

注意要指出在哪个维度上进行拼接:

>>> import torch
>>> a = torch.rand(4,32,8)
>>> b = torch.rand(5,32,8)
>>> torch.cat([a,b],dim=0).shape
torch.Size([9, 32, 8])

且除了要拼接的维度外,其他维度数值必须保持一致,否则会报错:

>>> import torch
>>> a = torch.rand(4,3,32,32)
>>> b = torch.rand(4,1,32,32)
>>> torch.cat([a,b],dim=0).shape
Traceback (most recent call last):File "<stdin>", line 1, in <module>
RuntimeError: Sizes of tensors must match except in dimension 0. Got 3 and 1 in dimension 1

(2). stack

会创建新的维度,所以在旧维度上必须完全一摸一样:

>>> import torch
>>> a = torch.rand(32,8)
>>> b = torch.rand(32,8)
>>> torch.stack([a,b],dim=0).shape
torch.Size([2, 32, 8])

2. 拆分

(1). split

根据长度拆分

>>> import torch
>>> a = torch.rand(3,32,8)
>>> aa, bb = a.split([2,1],dim=0)
>>> aa.shape, bb.shape
(torch.Size([2, 32, 8]), torch.Size([1, 32, 8]))
>>> import torch
>>> a = torch.rand(2,32,8)
>>> aa,bb = a.split(1,dim=0)
>>> aa.shape,bb.shape
(torch.Size([1, 32, 8]), torch.Size([1, 32, 8]))

如果把2拆分成N块,每块的长度是2,则会报错。
在理论上就是不拆分,也就是一个拆分成一块,但在pytorch中不可以这样做。

>>> import torch
>>> a = torch.rand(2,32,8)
>>> aa,bb = a.split(2,dim=0)
Traceback (most recent call last):File "<stdin>", line 1, in <module>
ValueError: not enough values to unpack (expected 2, got 1)

(2). chunk

按数量拆分:
就比较好理解,算除法就行。

>>> import torch
>>> a = torch.rand(8,32,8)
>>> aa,bb = a.chunk(2,dim=0)
>>> aa.shape,bb.shape
(torch.Size([4, 32, 8]), torch.Size([4, 32, 8]))
  相关解决方案