当前位置: 代码迷 >> 综合 >> PyTorch中torch.nn.functional.unfold函数使用详解
  详细解决方案

PyTorch中torch.nn.functional.unfold函数使用详解

热度:74   发布时间:2023-10-20 18:29:01.0

首先跳到函数定义中,看一下有哪些参数。

 

def unfold(input, kernel_size, dilation=1, padding=0, stride=1):"""input: tensor数据,四维, Batchsize, channel, height, widthkernel_size: 核大小,决定输出tensor的数目。稍微详细讲dilation: 输出形式是否有间隔,稍后详细讲。padding:一般是没有用的必要stride: 核的滑动步长。稍后详细讲"""

我觉得没有一张图很难说清楚这个函数想做啥!

假设我们现在有一个张量特征图,其size为[ 1, C, H, W]

PyTorch中torch.nn.functional.unfold函数使用详解

我们想将这个特征图连续的在分辨率维度(H和W)维度取出特征。就像下面这样:

PyTorch中torch.nn.functional.unfold函数使用详解

就是想把输入tensor数据,按照一定的区域(由核的长宽),不断沿着通道维度取出来,由步长指定核滑动的步长,由dilation指定核内区域哪些被跳过。

这里要说明一下,unfold函数的输入数据是四维,但输出是三维的。假设输入数据是[B, C, H, W], 那么输出数据是 [B, C* kH * kW, L], 其中kH是核的高,kW是核宽。 L则是这个高kH宽kW的核能在H*W区域按照指定stride滑动的次数。

PyTorch中torch.nn.functional.unfold函数使用详解

上面公式中第一项是指核高kH的情况下,能在高H的特征图上滑动的次数,后一项则是在宽这个维度上。当然默认stride=1

得到的这三维tensor,还需要reshape一下,才能得到上图右边的形式。

B, C_kh_kw, L = data.size()
data = data.permute(0, 2, 1)
data = data.view(B, L, C, kh, kw)

 

下面就进入代码实践环节。假设B等于1。

import torch
from torch.nn import functional as fx = torch.arange(0, 1*3*15*15).float()
x = x.view(1,3,15,15)
print(x)
x1 = f.unfold(x, kernel_size=3, dilation=1, stride=1)
print(x1.shape)
B, C_kh_kw, L = x1.size()
x1 = x1.permute(0, 2, 1)
x1 = x1.view(B, L, -1, 3, 3)
print(x1)'''
x的打印的一部分
tensor([[[[  0.,   1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,  10.,11.,  12.,  13.,  14.],[ 15.,  16.,  17.,  18.,  19.,  20.,  21.,  22.,  23.,  24.,  25.,26.,  27.,  28.,  29.],...[[225., 226., 227., 228., 229., 230., 231., 232., 233., 234., 235.,236., 237., 238., 239.],[240., 241., 242., 243., 244., 245., 246., 247., 248., 249., 250.,251., 252., 253., 254.],...[[450., 451., 452., 453., 454., 455., 456., 457., 458., 459., 460.,461., 462., 463., 464.],[465., 466., 467., 468., 469., 470., 471., 472., 473., 474., 475.,476., 477., 478., 479.],...]]])X1 的一部分tensor([[[[[  0.,   1.,   2.],[ 15.,  16.,  17.],[ 30.,  31.,  32.]],[[225., 226., 227.],[240., 241., 242.],[255., 256., 257.]],[[450., 451., 452.],[465., 466., 467.],[480., 481., 482.]]],[[[  1.,   2.,   3.],[ 16.,  17.,  18.],[ 31.,  32.,  33.]],[[226., 227., 228.],[241., 242., 243.],[256., 257., 258.]],[[451., 452., 453.],[466., 467., 468.],[481., 482., 483.]]],
'''

 

首先X就是15*15,通道是3的特征图,同时这些值是从底到高按顺序reshape的。相当于0-15*15-1 是最上面一层,中间那层的数值是从15*15 到15*15*2-1. 最后一层的数值是从 15*15*2 到 15*15*3-1

现在对x1观察。

x1 就像是把x沿着分辨率维度切开了,而且是隔着一个元素单位就切(stride=1))。切出来的大小是3*3的(kernel size=3),和核高宽一致。

大家可以自行测试stride为2和dilation为2的情况。相信大家一定可以更深刻的理解这个函数。