当前位置: 代码迷 >> 综合 >> PyTorch中torch.nn.functional.pad函数使用详解
  详细解决方案

PyTorch中torch.nn.functional.pad函数使用详解

热度:41   发布时间:2023-10-20 18:34:22.0

顾明思义,这个函数是用来扩充张量数据的边界的。但是PyTorch中,pad的函数和numpy以及tensorflow的pad用法都不一样。今天就带来这个函数简明的用法解释。

首先跳到函数定义中,看一下有哪些参数。

def pad(input, pad, mode=‘constant’, value=0)

  • input : 输入张量
  • pad: 指定padding的维度和数目,形式是元组,稍后讲。
  • mode: 填充模式,不一样的模式,填充的值也不一样,
  • value: 仅当mode为‘constant’时有效,意思是填充的值是常亮,且值为value

重点就是讲一下这个pad参数。

假设现在有一个tensor的shape为[3,3,32,40][3,3,32,40][3,3,32,40],四维张量。
假设pad为:

(2,2
3,4,
1,2,
1,1 )

第一行的(2,2),意义是对最低的维度(dim=-1)前面填充2个单位,后面填充2个单位。
第二行的(3,4),意义是对
倒数第二个维度(dim=-2)
,前面填充3个单位,后面填充4个单位

第三行第四行的意义以此类推。重点就是pad里面每两个元素为1组,指定了由低维到高维,每一维度,前面填充和后面填充的数值单位。

如果对于一个四维张量,pad里面有4个元素,又是啥情况?
当然是只对最后两个维度pading了。
下面就看一个例子。

import torch
from torch.nn import functional as Fa = torch.randn([2,3,4,5])  # torch.Size([2, 3, 4, 5])
padding = (1,2,   # 前面填充1个单位,后面填充两个单位,输入的最后一个维度则增加1+2个单位,成为82,3,3,4
)
print(a.shape)
b = F.pad(a, padding)
print(b.shape)  # torch.Size([2, 10, 9, 8]) 

从上面的例子看出,之后后三个维度发生了扩增,因为我们输入的padding长度为6,只能影响后三个维度。