将pytorch模型转为onnx模型,代码如下
weights_path = '模型权重的路径'
net = EfficientNet.from_name('efficientnet-b2', {'num_classes': 2})
net.load_state_dict(torch.load(weights_path))
net.eval()torch.onnx.export(net, input, 'efficientnet_ss.onnx')
报错:RuntimeError: ONNX export failed: Couldn't export Python operator SwishImplementation
原因:onnx不支持函数Swish
解决办法:
weights_path = '模型权重的路径'
net = EfficientNet.from_name('efficientnet-b2', {'num_classes': 2})
net.set_swish(memory_efficient=False)
net.load_state_dict(torch.load(weights_path))
net.eval()torch.onnx.export(net, input, 'efficientnet_ss.onnx')
net.set_swish()函数会将swish转换为x*sigmoid(x),从而使其满足onnx的支持的操作,那么pytorch模型就可以顺利的转换为onnx模型了