1.model.train()与model.eval()的用法
看别人的面经时,浏览到一题,问的就是这个。自己刚接触pytorch时套用别人的框架,会在训练开始之前写上model.trian(),在测试时写上model.eval()。然后自己写的时候也就保留了这个习惯,没有去想其中原因。
在经过一番查阅之后,总结如下:
如果模型中有BN层(Batch Normalization)和Dropout,需要在训练时添加model.train(),在测试时添加model.eval()。其中model.train()是保证BN层用每一批数据的均值和方差,而model.eval()是保证BN用全部训练数据的均值和方差;而对于Dropout,model.train()是随机取一部分网络连接来训练更新参数,而model.eval()是利用到了所有网络连接。
联系Batch Normalization和Dropout的原理之后就不难理解为何要这么做了。
不过我还是有疑惑,到底放在下面训练的里面还是在外面?
for img,label in train_loder:
img,label = img.to(device),label.to(device)
...
参考:https://www.cnblogs.com/luckyplj/p/13424561.html