当前位置: 代码迷 >> 综合 >> Pytorch 中的 eval 模式,train 模式 和 梯度上下文管理器 torch.no_grad
  详细解决方案

Pytorch 中的 eval 模式,train 模式 和 梯度上下文管理器 torch.no_grad

热度:36   发布时间:2023-11-14 12:34:37.0

前言

本文将简要说明下 Pytorch 中 model.eval() (模型评估模式),model.train() (模型训练模式) 和 torch.no_grad() (取消梯度计算上下文管理器) 的作用与用法。

model.train() 和 model.eval()

这里的 model 指的是在 Pytorch 中定义的模型,需要继承自 torch.nn.Module,下同。

model.train() 会将 model 设置成训练模式,这只会影响 model 中特定的一些模块,比如:Dropout、BatchNorm 等,因为这些模块在训练阶段和验证阶段(或测试阶段)有着不同的行为。而其他大部分模块(如,nn.Linear、nn.Embedding、nn.Conv1d 等)在训练阶段和验证以及测试阶段都具有同样的行为,所以不会受此模式的影响。

关于 Dropout 在不同阶段有着不同行为的解释可参考我的另一博文:神经网络正则化方法总结——Dropout

model.eval() 会将 model 设置成评估模式,同上,这只会影响 model 中特定的一些模块,比如:Dropout、BatchNorm 等。

若是在模型的非训练阶段(如 evaluation 阶段)未使用 model.eval()model 设置成评估模式,有可能会造成同一样本的多次推断结果不一致的情况(这可是一个很大的问题…)

torch.no_grad()

torch.no_grad 是一个上下文管理器,在其管理范围内 Pytorch 不再计算模型各参数的梯度,即使参数的 requires_grad 属性为 True,这能有效减少模型计算时所需的内存/显存(因为保存参数梯度需要大量的内存/显存)。这在模型的 evaluation 和 test 以及 predict 这些非训练阶段中很有用,常见用法有两种,一种是使用 with 语句,另一种是使用修饰器:

x = torch.tensor([1], requires_grad=True)# 使用 with 语句
with torch.no_grad():y = x * 2
print(y.requires_grad) # 输出为 False# 使用 torch.no_grad() 修饰器
@torch.no_grad()
def doubler(x):return x * 2
z = doubler(x)
print(z.requires_grad) # 输出为 False

在 Pytorch 模型的非训练阶段,往往需要同时使用 torch.no_grad()model.eval().

在从非训练阶段跳转到训练阶段(即 train 阶段)时,别忘了使用 mode.train() 命令。

参考源

  • Pytorch 官方文档 torch.nn.Module
  • 神经网络正则化方法总结——Dropout
  相关解决方案