欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

mxnet src/imperative/./imperative_utils.h:72: Check failed: inputs[i]->ctx().dev_mask() == ctx.dev_m

程序员文章站 2022-05-27 10:32:11
...

mxnet 1.6 自定义OP (计算metric)训练报错:

src/imperative/./imperative_utils.h:72: Check failed: inputs[i]->ctx().dev_mask() == ctx.dev_mask() (1 vs. 2) : Operator broadcast_add require all inputs live on the same context. But the first argument is on gpu(0) while the 2-th argument is on cpu(0)

原因:类似于这个报错
你自定义操作里有一些运算涉及到两个mxnet的ndarray,但是他们的context不同。

解决:关键是找到这些运算的位置,看他们使用的哪两个ndarry,然后转到同一个context.

比如:mx.nd.pick 运算,+ 运算等。改正后的相关代码节选:

# 都放到CPU算就没啥事了,这一点可以通过设置        os.environ['CUDA_VISIBLE_DEVICES'] = ''来进行验证,看CPU环境下运行程序会不会有问题
label = labels[i].as_in_context(mx.cpu())
pred = preds[i].as_in_context(mx.cpu())
zy = mx.nd.pick(pred, label, axis=1)
nll = pred + body
...

详细:

Traceback (most recent call last):
  File "/home/user1/pjs/frvt/frvt/arcface-mtl/recognition/train_1118.py", line 454, in <module>
    main()
  File "/home/user1/pjs/frvt/frvt/arcface-mtl/recognition/train_1118.py", line 450, in main
    train_net(args)
  File "/home/user1/pjs/frvt/frvt/arcface-mtl/recognition/train_1118.py", line 442, in train_net
    epoch_end_callback=epoch_cb)
  File "/home/user1/miniconda3/lib/python3.7/site-packages/mxnet/module/base_module.py", line 533, in fit
    self.update_metric(eval_metric, data_batch.label)
  File "/home/user1/miniconda3/lib/python3.7/site-packages/mxnet/module/module.py", line 775, in update_metric
    self._exec_group.update_metric(eval_metric, labels, pre_sliced)
  File "/home/user1/miniconda3/lib/python3.7/site-packages/mxnet/module/executor_group.py", line 640, in update_metric
    eval_metric.update_dict(labels_, preds)
  File "/home/user1/miniconda3/lib/python3.7/site-packages/mxnet/metric.py", line 349, in update_dict
    metric.update_dict(labels, preds)
  File "/home/user1/miniconda3/lib/python3.7/site-packages/mxnet/metric.py", line 133, in update_dict
    self.update(label, pred)
  File "/home/user1/pjs/frvt/frvt/arcface-mtl/recognition/metric_agr.py", line 128, in update
    nll = preds[i] + body
  File "/home/user1/miniconda3/lib/python3.7/site-packages/mxnet/ndarray/ndarray.py", line 266, in __add__
    return add(self, other)
  File "/home/user1/miniconda3/lib/python3.7/site-packages/mxnet/ndarray/ndarray.py", line 3548, in add
    None)
  File "/home/user1/miniconda3/lib/python3.7/site-packages/mxnet/ndarray/ndarray.py", line 3484, in _ufunc_helper
    return fn_array(lhs, rhs)
  File "<string>", line 58, in broadcast_add
  File "/home/user1/miniconda3/lib/python3.7/site-packages/mxnet/_ctypes/ndarray.py", line 107, in _imperative_invoke
    ctypes.byref(out_stypes)))
  File "/home/user1/miniconda3/lib/python3.7/site-packages/mxnet/base.py", line 255, in check_call
    raise MXNetError(py_str(_LIB.MXGetLastError()))
mxnet.base.MXNetError: [15:37:01] src/imperative/./imperative_utils.h:72: Check failed: inputs[i]->ctx().dev_mask() == ctx.dev_mask() (1 vs. 2) : Operator broadcast_add require all inputs live on the same context. But the first argument is on gpu(0) while the 2-th argument is on cpu(0)
Stack trace:
  [bt] (0) /home/user1/miniconda3/lib/python3.7/site-packages/mxnet/libmxnet.so(+0x6b41eb) [0x7f9a1e29f1eb]
  [bt] (1) /home/user1/miniconda3/lib/python3.7/site-packages/mxnet/libmxnet.so(mxnet::imperative::GetContext(nnvm::NodeAttrs const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, mxnet::Context const&)+0x4fc) [0x7f9a2147929c]
  [bt] (2) /home/user1/miniconda3/lib/python3.7/site-packages/mxnet/libmxnet.so(mxnet::Imperative::Invoke(mxnet::Context const&, nnvm::NodeAttrs const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&)+0x1c0) [0x7f9a21483720]
  [bt] (3) /home/user1/miniconda3/lib/python3.7/site-packages/mxnet/libmxnet.so(+0x3754a1f) [0x7f9a2133fa1f]
  [bt] (4) /home/user1/miniconda3/lib/python3.7/site-packages/mxnet/libmxnet.so(MXImperativeInvokeEx+0x62) [0x7f9a2133ffe2]
  [bt] (5) /home/user1/miniconda3/lib/python3.7/lib-dynload/../../libffi.so.6(ffi_call_unix64+0x4c) [0x7f9ab9188630]
  [bt] (6) /home/user1/miniconda3/lib/python3.7/lib-dynload/../../libffi.so.6(ffi_call+0x22d) [0x7f9ab9187fed]
  [bt] (7) /home/user1/miniconda3/lib/python3.7/lib-dynload/_ctypes.cpython-37m-x86_64-linux-gnu.so(_ctypes_callproc+0x2ce) [0x7f9ab919dede]
  [bt] (8) /home/user1/miniconda3/lib/python3.7/lib-dynload/_ctypes.cpython-37m-x86_64-linux-gnu.so(+0x12914) [0x7f9ab919e914]
相关标签: MxNet # DL-报错