当前位置: 代码迷 >> 综合 >> Pytorch:flatten()函数,压缩tensor的维度
  详细解决方案

Pytorch:flatten()函数,压缩tensor的维度

热度:80   发布时间:2023-12-17 04:48:31.0

前言

有时候会面对需要把数据进行维度转换的情况,
比如本来512*N*W*H(BNWH)的维度需要转换为512*(N*W*H)的一个output和(N*W*H)*512的一个output,然后将两者进行矩阵乘法。
(NHW)*512 X 512*(NWH) = (NHW)*(NHW)
然后再和初始的512*N*W*H进行矩阵乘法,结果仍旧是512*N*W*H,常用在一些non-local conv block中。

代码

import torch
import numpy as npinput = torch.randn(2,3,4,4)
# 将从第二个维度开始进行压缩
# 可以根据自己需要选择从哪里开始压缩
out = input.flatten(start_dim=1,end_dim=3) 
out.shape

得到:

torch.Size([2, 48])  # 3*4*4=48