当前位置: 代码迷 >> 综合 >> mxnet src/imperative/./imperative_utils.h:72: Check failed: inputs[i]->ctx().dev_mask() == ctx.dev_m
  详细解决方案

mxnet src/imperative/./imperative_utils.h:72: Check failed: inputs[i]->ctx().dev_mask() == ctx.dev_m

热度:4   发布时间:2023-12-15 16:21:43.0

mxnet 1.6 自定义OP (计算metric)训练报错:

src/imperative/./imperative_utils.h:72: Check failed: inputs[i]->ctx().dev_mask() == ctx.dev_mask() (1 vs. 2) : Operator broadcast_add require all inputs live on the same context. But the first argument is on gpu(0) while the 2-th argument is on cpu(0)

原因:类似于这个报错,
你自定义操作里有一些运算涉及到两个mxnet的ndarray,但是他们的context不同。

解决:关键是找到这些运算的位置,看他们使用的哪两个ndarry,然后转到同一个context.

比如:mx.nd.pick 运算,+ 运算等。改正后的相关代码节选:

# 都放到CPU算就没啥事了,这一点可以通过设置 os.environ['CUDA_VISIBLE_DEVICES'] = ''来进行验证,看CPU环境下运行程序会不会有问题
label = labels[i].as_in_context(mx.cpu())
pred = preds[i].as_in_context(mx.cpu())
zy = mx.nd.pick(pred, label, axis=1)
nll = pred + body
...

详细:

Traceback (most recent call last):File "/home/user1/pjs/frvt/frvt/arcface-mtl/recognition/train_1118.py", line 454, in <module>main()File "/home/user1/pjs/frvt/frvt/arcface-mtl/recognition/train_1118.py", line 450, in maintrain_net(args)File "/home/user1/pjs/frvt/frvt/arcface-mtl/recognition/train_1118.py", line 442, in train_netepoch_end_callback=epoch_cb)File "/home/user1/miniconda3/lib/python3.7/site-packages/mxnet/module/base_module.py", line 533, in fitself.update_metric(eval_metric, data_batch.label)File "/home/user1/miniconda3/lib/python3.7/site-packages/mxnet/module/module.py", line 775, in update_metricself._exec_group.update_metric(eval_metric, labels, pre_sliced)File "/home/user1/miniconda3/lib/python3.7/site-packages/mxnet/module/executor_group.py", line 640, in update_metriceval_metric.update_dict(labels_, preds)File "/home/user1/miniconda3/lib/python3.7/site-packages/mxnet/metric.py", line 349, in update_dictmetric.update_dict(labels, preds)File "/home/user1/miniconda3/lib/python3.7/site-packages/mxnet/metric.py", line 133, in update_dictself.update(label, pred)File "/home/user1/pjs/frvt/frvt/arcface-mtl/recognition/metric_agr.py", line 128, in updatenll = preds[i] + bodyFile "/home/user1/miniconda3/lib/python3.7/site-packages/mxnet/ndarray/ndarray.py", line 266, in __add__return add(self, other)File "/home/user1/miniconda3/lib/python3.7/site-packages/mxnet/ndarray/ndarray.py", line 3548, in addNone)File "/home/user1/miniconda3/lib/python3.7/site-packages/mxnet/ndarray/ndarray.py", line 3484, in _ufunc_helperreturn fn_array(lhs, rhs)File "<string>", line 58, in broadcast_addFile "/home/user1/miniconda3/lib/python3.7/site-packages/mxnet/_ctypes/ndarray.py", line 107, in _imperative_invokectypes.byref(out_stypes)))File "/home/user1/miniconda3/lib/python3.7/site-packages/mxnet/base.py", line 255, in check_callraise MXNetError(py_str(_LIB.MXGetLastError()))
mxnet.base.MXNetError: [15:37:01] src/imperative/./imperative_utils.h:72: Check failed: inputs[i]->ctx().dev_mask() == ctx.dev_mask() (1 vs. 2) : Operator broadcast_add require all inputs live on the same context. But the first argument is on gpu(0) while the 2-th argument is on cpu(0)
Stack trace:[bt] (0) /home/user1/miniconda3/lib/python3.7/site-packages/mxnet/libmxnet.so(+0x6b41eb) [0x7f9a1e29f1eb][bt] (1) /home/user1/miniconda3/lib/python3.7/site-packages/mxnet/libmxnet.so(mxnet::imperative::GetContext(nnvm::NodeAttrs const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, mxnet::Context const&)+0x4fc) [0x7f9a2147929c][bt] (2) /home/user1/miniconda3/lib/python3.7/site-packages/mxnet/libmxnet.so(mxnet::Imperative::Invoke(mxnet::Context const&, nnvm::NodeAttrs const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&)+0x1c0) [0x7f9a21483720][bt] (3) /home/user1/miniconda3/lib/python3.7/site-packages/mxnet/libmxnet.so(+0x3754a1f) [0x7f9a2133fa1f][bt] (4) /home/user1/miniconda3/lib/python3.7/site-packages/mxnet/libmxnet.so(MXImperativeInvokeEx+0x62) [0x7f9a2133ffe2][bt] (5) /home/user1/miniconda3/lib/python3.7/lib-dynload/../../libffi.so.6(ffi_call_unix64+0x4c) [0x7f9ab9188630][bt] (6) /home/user1/miniconda3/lib/python3.7/lib-dynload/../../libffi.so.6(ffi_call+0x22d) [0x7f9ab9187fed][bt] (7) /home/user1/miniconda3/lib/python3.7/lib-dynload/_ctypes.cpython-37m-x86_64-linux-gnu.so(_ctypes_callproc+0x2ce) [0x7f9ab919dede][bt] (8) /home/user1/miniconda3/lib/python3.7/lib-dynload/_ctypes.cpython-37m-x86_64-linux-gnu.so(+0x12914) [0x7f9ab919e914]
  相关解决方案