问题描述
在win下保存模型报错,Can‘t pickle local object‘get_linear_schedule_with_warmup’<locals> lr_lamda’
问题分析
保存模型源码如下:
def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
使用到了lambda表达式,但是pickle在序列化的时候是不支持的!
解决方案
1.仅保存模型参数
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
但我原本就仅保存了模型参数,行不通~
2.使用第三方包dill
dill作用和pickle相同~支持lambda
使用方法:
2.1 pip下载该包
pip install dill
2.2 修改保存命令
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"),pickle_module=dill)
正常保存,不再报错啦~~