记录一下,避免后面踩坑
单步调了半小时,幸好这个问题出现很频繁
发现是函数与导函数定义域问题。。。
sqrt(x) 函数的定义域为 [0, 无穷大)
sqrt(x) 的导函数的定义域 却是 (0, 无穷大)
这些函数定义域跟导函数的定义域不一样,正向传播可以得到正常结果,但是一旦backward就会得到Nan。。。
问题重现
import torch
a = torch.zeros(1)
a.requireds_grad = True
b = torch.sqrt(a)
b.backward()
print(a.grad)
# 得到nan
如何解决
让输入的值符合sqrt的导函数定义域就可以解决该问题了。举个例子:设 x 的定义域为 [0, 无穷大) ,给 x 加个很小的数,例如1e-8,使其输入值的定义域略微往右偏移,就可以避开 0 这个未定义值了;y = sqrt(x + 1e-8)