当前位置: 代码迷 >> 综合 >> torch.nn.Module中的training属性详情,与Module.train()和Module.eval()的关系
  详细解决方案

torch.nn.Module中的training属性详情,与Module.train()和Module.eval()的关系

热度:42   发布时间:2023-12-01 13:39:35.0

Module类的构造函数:

    def __init__(self):"""Initializes internal Module state, shared by both nn.Module and ScriptModule."""torch._C._log_api_usage_once("python.nn_module")self.training = Trueself._parameters = OrderedDict()self._buffers = OrderedDict()self._backward_hooks = OrderedDict()self._forward_hooks = OrderedDict()self._forward_pre_hooks = OrderedDict()self._state_dict_hooks = OrderedDict()self._load_state_dict_pre_hooks = OrderedDict()self._modules = OrderedDict()

其中training属性表示BatchNorm与Dropout层在训练阶段和测试阶段中采取的策略不同,通过判断training值来决定前向传播策略。

对于一些含有BatchNorm,Dropout等层的模型,在训练和验证时使用的forward在计算上不太一样。在前向训练的过程中指定当前模型是在训练还是在验证。使用module.train()和module.eval()进行使用,其中这两个方法的实现均有training属性实现。
关于这两个方法的定义源码如下:
train():

    def train(self, mode=True):r"""Sets the module in training mode.This has any effect only on certain modules. See documentations ofparticular modules for details of their behaviors in training/evaluationmode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,etc.Args:mode (bool): whether to set training mode (``True``) or evaluationmode (``False``). Default: ``True``.Returns:Module: self"""self.training = modefor module in self.children():module.train(mode)return self

eval():

    def eval(self):r"""Sets the module in evaluation mode.This has any effect only on certain modules. See documentations ofparticular modules for details of their behaviors in training/evaluationmode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,etc.This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.Returns:Module: self"""return self.train(False)

从源码中可以看出,train和eval方法将本层及子层的training属性同时设为true或false。
具体如下:
net.train() # 将本层及子层的training设定为True
net.eval() # 将本层及子层的training设定为False
net.training = True # 注意,对module的设置仅仅影响本层,子module不受影响
net.training, net.submodel1.training

关于train(),eval()函数可以查看pytorch中的model.train()和model.eval()

参考链接

『PyTorch』第十四弹_torch.nn.Module类属性

  相关解决方案