当前位置: 代码迷 >> 综合 >> pytorch 广播语义(Broadcasting semantics)
  详细解决方案

pytorch 广播语义(Broadcasting semantics)

热度:130   发布时间:2023-10-25 19:48:58.0

刚开始入坑pytorch,还不知道这个Broadcasting semantics是什么意思,现在仔细一看,发现背后的东西还不少。

简而言之,pytorch的广播语义有如下的几条规则:

一般语义

如果遵守以下规则,则两个张量是“可播放的”:

  • 每个张量至少有一个维度。
  • 迭代尺寸大小时,从尾随尺寸开始,尺寸大小必须相等,其中一个为1,或者其中一个不存在。

那么怎么来理解这两条规则呢???

个人认为就是两个张量(每个张量维度不能为0,这是符合第一个规则),我从末尾的尺寸的来看,它要么相等,要么有一个为1,那么我就说他们是可以“广播”的(有种数学不讲道理下定义的感觉...)。那么会有人问?为什么要使他们“广播兼容”呢?刚开始我也没“追究”,后来发现, 两个广播兼容的张量可以在不同尺寸的情况下进行运算,这真是非常方便啊,不用人为的进行扩充了。它们可以自动扩展为相同的类型大小。

接下来看例子:

    a.shape       +     b.shape      c.shape
    (4, 1)       +         (1)      -->       (4, 1)
      (4, 1)       +         (3,)      -->       (4, 3)
    (2, 3, 4)       +        (1, 4)      -->     (2, 3, 4)
    (2, 3, 4)       +        (3, 1)      -->     (2, 3, 4)
    (2, 3, 4)       +     (2, 1, 1)      -->     (2, 3, 4)
    (2, 3, 4)       +         (3, )       X  
      (4, 3)       +         (4,)       X  
      (4, 3)       +         (3,)      -->       (4, 3)
      (4, 3)       +         (3)      -->       (4, 3)

具体做法就是:

  • 如果尺寸的数量xy不相等,则在尺寸较小的张量的前面加1,使它们的长度相等。
  • 然后,对于每个维度大小,生成的维度大小是该维度的大小xy沿该维度的最大值 。

再看下官方文档的例子:

>>> x=torch.empty(5,7,3)
>>> y=torch.empty(5,7,3)
# 相同形状的张量可以被广播(上述规则总是成立的)>>> x=torch.empty((0,))
>>> y=torch.empty(2,2)
# x和y不能被广播,因为x没有维度# can line up trailing dimensions
>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty(  3,1,1)
# x和y能够广播.
# 1st trailing dimension: both have size 1
# 2nd trailing dimension: y has size 1
# 3rd trailing dimension: x size == y size
# 4th trailing dimension: y dimension doesn't exist# 但是:
>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty(  3,1,1)
# x和y不能被广播  ( 因为 2 != 3  ) 

就地语义

一个复杂因素是就地操作不允许就地张量由于广播而改变形状。

Example:

>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty(3,1,1)
>>> (x.add_(y)).size()
torch.Size([5, 3, 4, 1])# but:
>>> x=torch.empty(1,3,1)
>>> y=torch.empty(3,1,7)
>>> (x.add_(y)).size()
RuntimeError: The expanded size of the tensor (1) must match the existing size (7) at non-singleton dimension 2.

向后兼容性

PyTorch的早期版本允许某些逐点函数在具有不同形状的张量上执行,只要每个张量中的元素数量相等即可。然后通过将每个张量视为1维来执行逐点运算。PyTorch现在支持广播,并且“1维”逐点行为被认为已弃用,并且在张量不可播放但具有相同数量的元素的情况下将生成Python警告。

注意,在两个张量不具有相同形状但是可广播并且具有相同数量的元素的情况下,广播的引入可能导致向后不兼容的改变。例如:

>>> torch.add(torch.ones(4,1), torch.randn(4))

之前会产生一个尺寸为Tensor的尺寸:torch.Size([4,1]),但现在产生尺寸为Tensor:torch.Size([4,4])。为了帮助识别代码中可能存在广播引起的向后不兼容性的情况,您可以将torch.utils.backcompat.broadcast_warning.enabled设置为True,这将在这种情况下生成python警告。

例如:

>>> torch.utils.backcompat.broadcast_warning.enabled=True
>>> torch.add(torch.ones(4,1), torch.ones(4))
__main__:1: UserWarning: self and other do not have the same shape, but are broadcastable, and have the same number of elements.
Changing behavior in a backwards incompatible manner to broadcasting rather than viewing as 1-dimensional.