前言
有时候会面对需要把数据进行维度转换的情况,
比如本来512*N*W*H(BNWH)的维度需要转换为512*(N*W*H)的一个output和(N*W*H)*512的一个output,然后将两者进行矩阵乘法。
即(NHW)*512 X 512*(NWH) = (NHW)*(NHW)
,
然后再和初始的512*N*W*H进行矩阵乘法,结果仍旧是512*N*W*H,常用在一些non-local conv block中。
代码
import torch
import numpy as npinput = torch.randn(2,3,4,4)
# 将从第二个维度开始进行压缩
# 可以根据自己需要选择从哪里开始压缩
out = input.flatten(start_dim=1,end_dim=3)
out.shape
得到:
torch.Size([2, 48]) # 3*4*4=48