当前位置: 代码迷 >> 综合 >> pytorch dataloader
  详细解决方案

pytorch dataloader

热度:9   发布时间:2024-01-26 23:22:28.0

预训练模型使用

    model = ResNet50()model = nn.DataParallel(model).cuda()


加载预训练模型是使用

model = ResNet50().cuda()checkpoint = torch.load(args.net_cache)model.load_state_dict(checkpoint['state_dict'])//

报错:RuntimeError: Error(s) in loading state_dict for Resnet50   Unexpected key(s) in state_dict:““

说明加载模型和预训练模型环境不一致,修改如下(对我的代码可行):

model = ResNet50().cuda()checkpoint = torch.load(args.net_cache)model.load_state_dict(checkpoint['state_dict'],False)//

model.load_state_dict(state_dict, strict=True)

或者是由于用DataParallel训练的模型数据并行方式的,key中会多个module,加载时直接用

model = ResNet50().cuda()
model = nn.DataParallel(model)checkpoint = torch.load(args.net_cache)model.load_state_dict(checkpoint['state_dict'])//