当前位置: 代码迷 >> 综合 >> loss nan解决办法
  详细解决方案

loss nan解决办法

热度:84   发布时间:2023-11-21 13:13:08.0

前言

今天查了一下午加晚上的bug…记录一下…

解决办法

  • 减小学习率
  • 加大batchsize(相当于减小学习率)
  • 关闭fp16

减小学习率或加大batchsize的原因

原因很简单,过大的学习率学飞了…

关闭fp16

今天就是被这个坑的…
FP16的作用

  • 减少显存占用,FP16 的内存占用只有 FP32 的一半,自然地就可以帮助训练过程节省一半的显存空间。
  • 加快训练和推断的计算,在大部分的测试中,基于 FP16 的加速方法能够给模型训练带来多一倍的加速体验。

但是这样减少内存的前提是损失了浮点数的精度…我实验的代码是T2T-ViT,把backbone搬到业务上用,然后loss一直nan,然后怀疑了很多,从参数加载错了,到模型参数初始化错了,到网络参数错了等等等…都不是。

无奈使用祖传的debug技巧按行debug…

记录下用的小技巧,找到tensor进入网络前用torch.ones(shape)构建一个全1的tensor送进去!这样非常好使,还可以直接发现是不是网络参数nan了…然后就是一行行print,最后发现这个github的代码用1e-8来防止除0…

罪魁祸首在这:
github链接
github链接

self.epsilon = 1e-8  # for stable in division
y = torch.einsum('bti,bni->btn', qp, kptv) / (D.repeat(1, 1, self.emb) + self.epsilon)  # (B, T, emb)/Diag

提了个issue给作者…可能后面会改了吧。

但是有个疑问,16位的浮点数的范围是5.96×10^ -8 ~ 6.55×10 ^ 4但是我改了1e-4还在在别的地方nan了…所以最后索性把fp16关掉了…就解决了问题。

所以所以!!!loss nan基本就是出现了除0的情况!建议检查模型参数,或者一步步检查output,肯定能解决问题!

  相关解决方案