当前位置: 代码迷 >> 综合 >> pytorch RuntimeError: Expected to have finished reduction in the prior iteration before starting a n
  详细解决方案

pytorch RuntimeError: Expected to have finished reduction in the prior iteration before starting a n

热度:54   发布时间:2023-12-15 16:06:38.0

报错:

Traceback (most recent call last):File "train.py", line 166, in <module>main(args_)File "train.py", line 118, in mainf_clean_masked, f_occ_masked, fc, fc_occ = backbone(img1, img2)File "/home/user1/miniconda3/envs/py377/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_implresult = self.forward(*input, **kwargs)File "/home/user1/miniconda3/envs/py377/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 692, in forward
Traceback (most recent call last):File "train.py", line 166, in <module>if self.reducer._rebuild_buckets():main(args_)File "train.py", line 118, in main
RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by (1) passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`; (2) making sure all `forward` function outputs participate in calculating loss. If you already have done the above two steps, then the distributed data parallel module wasn't able to locate the output tensors in the return value of your module's `forward` function. Please include the loss function and the structure of the return value of `forward` of your module when reporting this issue (e.g. list, dict, iterable).

解决一:网上很多都是在 DistributedDataParallel 的调用中增加参数 find_unused_parameters=true。这里没有尝试这种方法

解决二:网上还有一种说法是,发生这种情况的原因可能是,你在用 DistributedDataParallel 去包裹 计算图的时候,有一部分计算图露在外面。比如说,你先定义了 DistributedDataParallel ,然后又写了几步 计算。等等。这种情况下解决的方法就是把那几步写在前面,然后再 定义 DistributedDataParallel 这个类。我这里也不是这种情况

解决三:一个网络多路输入的情况。我的网络有两路输入,我先依次做了两路前向,然后再依次做两路后向、参数更新(opt.step)、梯度置0 (zero_grad) 的时候报的这个错。最后的解决是: 做完一次前向,再做一次后向、参数更新、梯度置0 ;然后再做另一次前向,另一次后向、参数更新、梯度置0 。
然后没有再报这个错。


类似的报错还出现在 网络级联的时候

  相关解决方案