当前位置: 代码迷 >> 综合 >> pytorch 中 torch.sqrt 的坑
  详细解决方案

pytorch 中 torch.sqrt 的坑

热度:89   发布时间:2024-01-31 00:13:04.0

记录一下,避免后面踩坑

单步调了半小时,幸好这个问题出现很频繁
发现是函数与导函数定义域问题。。。

sqrt(x) 函数的定义域为 [0, 无穷大)
sqrt(x) 的导函数的定义域 却是 (0, 无穷大)

这些函数定义域跟导函数的定义域不一样,正向传播可以得到正常结果,但是一旦backward就会得到Nan。。。

问题重现

import torch
a = torch.zeros(1)
a.requireds_grad = True
b = torch.sqrt(a)
b.backward()
print(a.grad)
# 得到nan

如何解决
让输入的值符合sqrt的导函数定义域就可以解决该问题了。举个例子:设 x 的定义域为 [0, 无穷大) ,给 x 加个很小的数,例如1e-8,使其输入值的定义域略微往右偏移,就可以避开 0 这个未定义值了;y = sqrt(x + 1e-8)