Pytorch 工程重构后模型参数(.pt)读取失败解决方法
问题: 近期重构工程遇到了之前训练的模型参数没法读取的问题,即用重构后的工程在测试阶段去加载原来工程训练好的参数会方向有路径问题。
通过探究发现在.pt文件中model一项的类型包含了原来工程的路径信息,需要放置在同一个路径下才可以读取成功。重构type还是比较困难的,但是在测试代码中,我发现大多数模型在加载参数的过程中会将模型参数加载成dict的形式,而忽略没有必要的路径信息。
为此我们可以建立一个新字典,将state_dict的步骤放置在前面,去除后面不用用到的路径信息,完成模型参数迁移,具体代码如下:
model = torch.load("/home/input/xxx.pt", map_location="cuda:1") # load checkpoint
new_dict = OrderedDict()
new_dict["model"] = model["model"].state_dict()
torch.save(new_dict, "/home/output/xxx.pt")