前言
今天查了一下午加晚上的bug…记录一下…
解决办法
- 减小学习率
- 加大batchsize(相当于减小学习率)
- 关闭fp16
减小学习率或加大batchsize的原因
原因很简单,过大的学习率学飞了…
关闭fp16
今天就是被这个坑的…
FP16的作用
- 减少显存占用,FP16 的内存占用只有 FP32 的一半,自然地就可以帮助训练过程节省一半的显存空间。
- 加快训练和推断的计算,在大部分的测试中,基于 FP16 的加速方法能够给模型训练带来多一倍的加速体验。
但是这样减少内存的前提是损失了浮点数的精度…我实验的代码是T2T-ViT,把backbone搬到业务上用,然后loss一直nan,然后怀疑了很多,从参数加载错了,到模型参数初始化错了,到网络参数错了等等等…都不是。
无奈使用祖传的debug技巧按行debug…
记录下用的小技巧,找到tensor进入网络前用torch.ones(shape)构建一个全1的tensor送进去!这样非常好使,还可以直接发现是不是网络参数nan了…然后就是一行行print,最后发现这个github的代码用1e-8来防止除0…
罪魁祸首在这:
github链接
github链接
self.epsilon = 1e-8 # for stable in division
y = torch.einsum('bti,bni->btn', qp, kptv) / (D.repeat(1, 1, self.emb) + self.epsilon) # (B, T, emb)/Diag
提了个issue给作者…可能后面会改了吧。
但是有个疑问,16位的浮点数的范围是5.96×10^ -8 ~ 6.55×10 ^ 4但是我改了1e-4还在在别的地方nan了…所以最后索性把fp16关掉了…就解决了问题。
所以所以!!!loss nan基本就是出现了除0的情况!建议检查模型参数,或者一步步检查output,肯定能解决问题!