用法介绍
?pytorch中张量进行拼接和分割的函数分别是torch.cat()和torch.chunk()。torch.cat()是将多个张量组成的元组按照指定的维度进行拼接。torch.chunk()是对一个张量按照某个维度分割成多个子张量块。它们具体的用法如下所示
torch.cat(tensors, dim=0, *, out=None)?\longrightarrow?Tensor
- tensors (tuple of tensor):张量组成的元组
- dim (int):按照某个维度对多个张量进行拼接
注意: 如果多个张量按照某个维度进行拼接,那么其它的维度要一致。
torch.chunk(input, chunks, dim=0)?\longrightarrow?List of Tensors
- input (Tensor):要被分割的张量
- chunks (int):被分割的张量数
- dim (int):按照某个维度对张量进行分割
代码示例
?torch.cat()的代码示例如下所示
>>> import torch
>>> x = torch.randn(2,3)
>>> x
tensor([[-0.5654, 0.7048, 0.5851],[-1.3871, 0.5481, 0.3028]])
>>> torch.cat((x,x,x),0)
tensor([[-0.5654, 0.7048, 0.5851],[-1.3871, 0.5481, 0.3028],[-0.5654, 0.7048, 0.5851],[-1.3871, 0.5481, 0.3028],[-0.5654, 0.7048, 0.5851],[-1.3871, 0.5481, 0.3028]])
>>> torch.cat((x,x,x),1)
tensor([[-0.5654, 0.7048, 0.5851, -0.5654, 0.7048, 0.5851, -0.5654, 0.7048,0.5851],[-1.3871, 0.5481, 0.3028, -1.3871, 0.5481, 0.3028, -1.3871, 0.5481,0.3028]])
?torch.chunk()的代码示例如下所示
>>> import torch
>>> x = torch.randn(8,8)
>>> x
tensor([[ 1.0272, 1.5964, 0.1502, 1.3435, -0.1774, 0.7908, 0.6920, 1.0908],[ 0.8614, -0.3212, 0.4715, 0.1476, 1.7950, 1.8308, -0.1419, -0.1448],[-0.7407, 0.5510, 0.1284, 0.1485, 0.2997, -0.8133, 1.5608, 0.0682],[ 0.7217, 0.5292, 0.2469, 0.1823, -0.6200, 0.9436, -0.5221, -0.9343],[-2.0195, -2.3613, -0.6441, -1.7863, 1.4207, 0.4124, 0.5508, -0.2569],[ 0.4582, -1.6445, -0.6813, -0.8802, 0.9870, -0.6599, -0.4719, 0.3088],[-1.6415, -0.9834, 0.1687, 0.0159, 0.4456, -0.1823, 0.9652, -0.2785],[ 0.8765, 0.8214, 1.0971, -0.4150, -0.9499, -0.5875, -1.3902, -0.9129]])
>>> x.chunk(chunks=2, dim=0)
(tensor([[ 1.0272, 1.5964, 0.1502, 1.3435, -0.1774, 0.7908, 0.6920, 1.0908],[ 0.8614, -0.3212, 0.4715, 0.1476, 1.7950, 1.8308, -0.1419, -0.1448],[-0.7407, 0.5510, 0.1284, 0.1485, 0.2997, -0.8133, 1.5608, 0.0682],[ 0.7217, 0.5292, 0.2469, 0.1823, -0.6200, 0.9436, -0.5221, -0.9343]]), tensor([[-2.0195, -2.3613, -0.6441, -1.7863, 1.4207, 0.4124, 0.5508, -0.2569],[ 0.4582, -1.6445, -0.6813, -0.8802, 0.9870, -0.6599, -0.4719, 0.3088],[-1.6415, -0.9834, 0.1687, 0.0159, 0.4456, -0.1823, 0.9652, -0.2785],[ 0.8765, 0.8214, 1.0971, -0.4150, -0.9499, -0.5875, -1.3902, -0.9129]]))
>>> x.chunk(chunks=2, dim=1)
(tensor([[ 1.0272, 1.5964, 0.1502, 1.3435],[ 0.8614, -0.3212, 0.4715, 0.1476],[-0.7407, 0.5510, 0.1284, 0.1485],[ 0.7217, 0.5292, 0.2469, 0.1823],[-2.0195, -2.3613, -0.6441, -1.7863],[ 0.4582, -1.6445, -0.6813, -0.8802],[-1.6415, -0.9834, 0.1687, 0.0159],[ 0.8765, 0.8214, 1.0971, -0.4150]]), tensor([[-0.1774, 0.7908, 0.6920, 1.0908],[ 1.7950, 1.8308, -0.1419, -0.1448],[ 0.2997, -0.8133, 1.5608, 0.0682],[-0.6200, 0.9436, -0.5221, -0.9343],[ 1.4207, 0.4124, 0.5508, -0.2569],[ 0.9870, -0.6599, -0.4719, 0.3088],[ 0.4456, -0.1823, 0.9652, -0.2785],[-0.9499, -0.5875, -1.3902, -0.9129]]))