当前位置: 代码迷 >> 综合 >> PyTorch中in-place
  详细解决方案

PyTorch中in-place

热度:91   发布时间:2023-12-12 09:06:00.0

in-place operation在pytorch中是指改变一个tensor的值的时候,不经过复制操作,而是直接在原来的内存上改变它的值。可以把它成为原地操作符。

在pytorch中经常加后缀“_”来代表原地 in-place operation,比如说.add_() 或者.scatter_()。python里面的+=,*=也是in-place operation。

下面是正常的加操作,执行结束加操作之后x的值没有发生变化:

import torch
x=torch.rand(2) #tensor([0.8284, 0.5539])
print(x)
y=torch.rand(2)
print(x+y)      #tensor([1.0250, 0.7891])
print(x)        #tensor([0.8284, 0.5539])

下面是原地操作,执行之后改变了原来变量的值:

import torch
x=torch.rand(2) #tensor([0.8284, 0.5539])
print(x)
y=torch.rand(2)
x.add_(y)
print(x)        #tensor([1.1610, 1.3789])

在官方文档中有这一段话:

如果你使用了in-place operation而没有报错的话,那么你可以确定你的梯度计算是正确的。

Variable上的In-place操作

在自动求导中支持in-place操作是一件很困难的事情,我们在大多数情况下都不鼓励使用它们。Autograd的缓冲区释放和重用非常高效,并且很少场合下in-place操作能实际上明显降低内存的使用量。除非您在内存压力很大的情况下,否则您可能永远不需要使用它们。

限制in-place操作适用性主要有两个原因:

1.覆盖梯度计算所需的值。这就是为什么变量不支持log_。它的梯度公式需要原始输入,而虽然通过计算反向操作可以重新创建它,但在数值上是不稳定的,并且需要额外的工作,这往往会与使用这些功能的目的相悖。

2.每个in-place操作实际上需要实现重写计算图。不合适的版本只需分配新对象并保留对旧图的引用,而in-place操作则需要将所有输入的creator更改为表示此操作的Function。这就比较棘手,特别是如果有许多变量引用相同的存储(例如通过索引或转置创建的),并且如果被修改输入的存储被任何其他Variable引用,则in-place函数实际上会抛出错误。

In-place正确性检查

每个变量保留有version counter,它每次都会递增,当在任何操作中被使用时。当Function保存任何用于后向的tensor时,还会保存其包含变量的version counter。一旦访问self.saved_tensors,它将被检查,如果它大于保存的值,则会引起错误。