前言
本文将简要说明下 Pytorch 中 model.eval()
(模型评估模式),model.train()
(模型训练模式) 和 torch.no_grad()
(取消梯度计算上下文管理器) 的作用与用法。
model.train() 和 model.eval()
这里的 model
指的是在 Pytorch 中定义的模型,需要继承自 torch.nn.Module
,下同。
model.train()
会将 model
设置成训练模式,这只会影响 model
中特定的一些模块,比如:Dropout、BatchNorm 等,因为这些模块在训练阶段和验证阶段(或测试阶段)有着不同的行为。而其他大部分模块(如,nn.Linear、nn.Embedding、nn.Conv1d 等)在训练阶段和验证以及测试阶段都具有同样的行为,所以不会受此模式的影响。
关于 Dropout 在不同阶段有着不同行为的解释可参考我的另一博文:神经网络正则化方法总结——Dropout
model.eval()
会将 model
设置成评估模式,同上,这只会影响 model
中特定的一些模块,比如:Dropout、BatchNorm 等。
若是在模型的非训练阶段(如 evaluation 阶段)未使用 model.eval()
将 model
设置成评估模式,有可能会造成同一样本的多次推断结果不一致的情况(这可是一个很大的问题…)
torch.no_grad()
torch.no_grad
是一个上下文管理器,在其管理范围内 Pytorch 不再计算模型各参数的梯度,即使参数的 requires_grad
属性为 True
,这能有效减少模型计算时所需的内存/显存(因为保存参数梯度需要大量的内存/显存)。这在模型的 evaluation 和 test 以及 predict 这些非训练阶段中很有用,常见用法有两种,一种是使用 with
语句,另一种是使用修饰器:
x = torch.tensor([1], requires_grad=True)# 使用 with 语句
with torch.no_grad():y = x * 2
print(y.requires_grad) # 输出为 False# 使用 torch.no_grad() 修饰器
@torch.no_grad()
def doubler(x):return x * 2
z = doubler(x)
print(z.requires_grad) # 输出为 False
在 Pytorch 模型的非训练阶段,往往需要同时使用 torch.no_grad()
和 model.eval()
.
在从非训练阶段跳转到训练阶段(即 train 阶段)时,别忘了使用 mode.train()
命令。
参考源
- Pytorch 官方文档 torch.nn.Module
- 神经网络正则化方法总结——Dropout