当前位置: 代码迷 >> 综合 >> pytorch中的torch.cat()和torch.chunk()
  详细解决方案

pytorch中的torch.cat()和torch.chunk()

热度:60   发布时间:2023-12-17 22:36:55.0

用法介绍

?pytorch中张量进行拼接和分割的函数分别是torch.cat()torch.chunk()torch.cat()是将多个张量组成的元组按照指定的维度进行拼接。torch.chunk()是对一个张量按照某个维度分割成多个子张量块。它们具体的用法如下所示

torch.cat(tensors, dim=0, *, out=None)?\longrightarrow?Tensor

  • tensors (tuple of tensor):张量组成的元组
  • dim (int):按照某个维度对多个张量进行拼接

注意: 如果多个张量按照某个维度进行拼接,那么其它的维度要一致。

torch.chunk(input, chunks, dim=0)?\longrightarrow?List of Tensors

  • input (Tensor):要被分割的张量
  • chunks (int):被分割的张量数
  • dim (int):按照某个维度对张量进行分割

代码示例

?torch.cat()的代码示例如下所示

>>> import torch
>>> x = torch.randn(2,3)
>>> x
tensor([[-0.5654,  0.7048,  0.5851],[-1.3871,  0.5481,  0.3028]])
>>> torch.cat((x,x,x),0)
tensor([[-0.5654,  0.7048,  0.5851],[-1.3871,  0.5481,  0.3028],[-0.5654,  0.7048,  0.5851],[-1.3871,  0.5481,  0.3028],[-0.5654,  0.7048,  0.5851],[-1.3871,  0.5481,  0.3028]])
>>> torch.cat((x,x,x),1)
tensor([[-0.5654,  0.7048,  0.5851, -0.5654,  0.7048,  0.5851, -0.5654,  0.7048,0.5851],[-1.3871,  0.5481,  0.3028, -1.3871,  0.5481,  0.3028, -1.3871,  0.5481,0.3028]])

?torch.chunk()的代码示例如下所示

>>> import torch
>>> x = torch.randn(8,8)
>>> x
tensor([[ 1.0272,  1.5964,  0.1502,  1.3435, -0.1774,  0.7908,  0.6920,  1.0908],[ 0.8614, -0.3212,  0.4715,  0.1476,  1.7950,  1.8308, -0.1419, -0.1448],[-0.7407,  0.5510,  0.1284,  0.1485,  0.2997, -0.8133,  1.5608,  0.0682],[ 0.7217,  0.5292,  0.2469,  0.1823, -0.6200,  0.9436, -0.5221, -0.9343],[-2.0195, -2.3613, -0.6441, -1.7863,  1.4207,  0.4124,  0.5508, -0.2569],[ 0.4582, -1.6445, -0.6813, -0.8802,  0.9870, -0.6599, -0.4719,  0.3088],[-1.6415, -0.9834,  0.1687,  0.0159,  0.4456, -0.1823,  0.9652, -0.2785],[ 0.8765,  0.8214,  1.0971, -0.4150, -0.9499, -0.5875, -1.3902, -0.9129]])
>>> x.chunk(chunks=2, dim=0)
(tensor([[ 1.0272,  1.5964,  0.1502,  1.3435, -0.1774,  0.7908,  0.6920,  1.0908],[ 0.8614, -0.3212,  0.4715,  0.1476,  1.7950,  1.8308, -0.1419, -0.1448],[-0.7407,  0.5510,  0.1284,  0.1485,  0.2997, -0.8133,  1.5608,  0.0682],[ 0.7217,  0.5292,  0.2469,  0.1823, -0.6200,  0.9436, -0.5221, -0.9343]]), tensor([[-2.0195, -2.3613, -0.6441, -1.7863,  1.4207,  0.4124,  0.5508, -0.2569],[ 0.4582, -1.6445, -0.6813, -0.8802,  0.9870, -0.6599, -0.4719,  0.3088],[-1.6415, -0.9834,  0.1687,  0.0159,  0.4456, -0.1823,  0.9652, -0.2785],[ 0.8765,  0.8214,  1.0971, -0.4150, -0.9499, -0.5875, -1.3902, -0.9129]]))
>>> x.chunk(chunks=2, dim=1)
(tensor([[ 1.0272,  1.5964,  0.1502,  1.3435],[ 0.8614, -0.3212,  0.4715,  0.1476],[-0.7407,  0.5510,  0.1284,  0.1485],[ 0.7217,  0.5292,  0.2469,  0.1823],[-2.0195, -2.3613, -0.6441, -1.7863],[ 0.4582, -1.6445, -0.6813, -0.8802],[-1.6415, -0.9834,  0.1687,  0.0159],[ 0.8765,  0.8214,  1.0971, -0.4150]]), tensor([[-0.1774,  0.7908,  0.6920,  1.0908],[ 1.7950,  1.8308, -0.1419, -0.1448],[ 0.2997, -0.8133,  1.5608,  0.0682],[-0.6200,  0.9436, -0.5221, -0.9343],[ 1.4207,  0.4124,  0.5508, -0.2569],[ 0.9870, -0.6599, -0.4719,  0.3088],[ 0.4456, -0.1823,  0.9652, -0.2785],[-0.9499, -0.5875, -1.3902, -0.9129]]))