当前位置: 代码迷 >> 综合 >> mxnet autograd.record()源码分析
  详细解决方案

mxnet autograd.record()源码分析

热度:71   发布时间:2024-02-20 01:29:06.0

引言

正在做mxnet框架下,使用deeplesion数据集yolov3的病灶检测的项目,训练过程中产生了access violated reading 0xFFFFF…的错误,经过google又说是应为backward()在autograd.record()作用域的原因,如下所示:所以想研究一下mxnet自动求导的过程,autograd.record()是如何在底层控制梯度计算的。

with autograd.record(mode == "train"):losses.backward()

1. autograd.record()

这里使用了python的一个特性上下文管理器,record函数调用了函数:

return _RecordingStateScope(True, train_mode)

2._RecordingStateScope()

class _RecordingStateScope(object):"""Scope for managing training state.Example::with _RecordingStateScope(True, True):y = model(x)backward([y])"""def __init__(self, is_record, train_mode): #pylint: disable=redefined-outer-nameself._enter_is_record = is_recordself._enter_train_mode = train_modeself._prev_is_record = Noneself._prev_train_mode = Nonedef __enter__(self):if self._enter_is_record is not None:self._prev_is_record = set_recording(self._enter_is_record)if self._enter_train_mode is not None:self._prev_train_mode = set_training(self._enter_train_mode)def __exit__(self, ptype, value, trace):if self._enter_is_record is not None and self._prev_is_record != self._enter_is_record:set_recording(self._prev_is_record)if self._enter_train_mode is not None and self._prev_train_mode != self._enter_train_mode:set_training(self._prev_train_mode)

这里enter和exit分别是进入推出上下文管理器的函数,当训练时先进入enter函数,self._enter_is_record默认为True,所以调用set_recording(True)函数

def set_recording(is_recording): #pylint: disable=redefined-outer-name"""Set status to recording/not recording. When recording, graph will be constructedfor gradient computation.Parameters----------is_recording: boolReturns-------previous state before this set."""prev = ctypes.c_int()check_call(_LIB.MXAutogradSetIsRecording(ctypes.c_int(is_recording), ctypes.byref(prev)))return bool(prev.value)

以上可知函数最终调用的是libmxnet.dll(源码编译生成的动态链接库)里的MXAutogradSetIsRecording函数。

int MXAutogradSetIsRecording(int is_recording, int* prev) {API_BEGIN();*prev = Imperative::Get()->set_is_recording(static_cast<bool>(is_recording));API_END();
}
  bool set_is_recording(bool is_recording) {bool old = is_recording_;is_recording_ = is_recording;return old;}

可以看到当使用record()函数后,最终改变的是is_recording_ 的值,is_recording_定义如下:

namespace mxnet {
/*! \brief runtime functions for NDArray */
class Imperative {public:/*! \brief */class AGInfo {public:Context ctx;OpReqType grad_req;OpStatePtr state;std::vector<NDArray> outputs;std::vector<NDArray> out_grads;bool fresh_out_grad;...std::vector<NDArray*> Backward(const std::vector<NDArray*>& outputs,const std::vector<NDArray*>& ograds,const std::vector<NDArray*>& variables,bool is_train, bool retain_graph,bool create_graph);/*! \return AutogradRuntime singleton */static Imperative* Get();private:friend class NDArray;
#if DMLC_CXX11_THREAD_LOCALstatic thread_local bool is_train_;static thread_local bool is_recording_;// TOOD(junwu): Added numpy compatibility switch for backward compatibility.// Delete it in the next major release.static thread_local bool is_np_shape_;
#elsestatic MX_THREAD_LOCAL bool is_train_;static MX_THREAD_LOCAL bool is_recording_;// TOOD(junwu): Added numpy compatibility switch for backward compatibility.// Delete it in the next major release.static MX_THREAD_LOCAL bool is_np_shape_;
#endif
};

static thread_local bool is_recording_ 是由两个关键字static,thread_local修饰的,static修饰的类成员,当改变is_recording_ 的值时,该类的所有对象的is_recording_ 属性均会被改变。thread_local修饰的变量,其生命周期时线程的生命周期;所以static的ThreadLocal变量是一个与线程相关的静态变量,即一个线程内,static变量是被各个实例共同引用的,但是不同线程内,static变量是隔开的。

实验也证明在autograd.record()上下文管理器范围的计算时多线程的,在我们调用record()后修改了is_recording_ 和is_training_变量的值。

那么mxnet时如何通过autograd.record()控制反向传播的导数的计算呢?

训练中我们进行梯度计算的代码如下:


a = nd.random.normal(shape=1)
a.attach_grad()
with autograd.record():c = f(a)
c.backward()

梯度的计算是通过Ndarry类调用backward()实现的,而从Imperative类中我们可以看到,ndarray时Imperative友元类(Imperative同样是ndarray的友元类,互为友元类可以互相调用方法和查看属性),自动求梯度中使用的Backward也是类Imperative中的方法。因此我们在调用autograd.record()是相应的Imperative类中的属性is_record_ 及is_training_ 属性被相应的置位。在with上下文管理器作用域内生成的ndarry类对象,因为is_record_ ,is_training_ 属性是友元类的static属性,如果is_record_ is_ 置1那么在with作用域下的所有ndarry变量都被标记为了已经被记录的状态,那么这些被标记的变量就参与到了反向传播和梯度计算。在作用域之外的变量就不会参与其中,我们亦可以使用with autograd.pause()来让某些计算过程不涉及反传梯度计算。

  相关解决方案