当前位置: 代码迷 >> 综合 >> RuntimeError: ONNX export failed: Couldn‘t export Python operator SwishImplementation
  详细解决方案

RuntimeError: ONNX export failed: Couldn‘t export Python operator SwishImplementation

热度:57   发布时间:2024-01-29 19:55:59.0

将pytorch模型转为onnx模型,代码如下

weights_path = '模型权重的路径'
net = EfficientNet.from_name('efficientnet-b2', {'num_classes': 2})
net.load_state_dict(torch.load(weights_path))
net.eval()torch.onnx.export(net, input, 'efficientnet_ss.onnx')

报错:RuntimeError: ONNX export failed: Couldn't export Python operator SwishImplementation

原因:onnx不支持函数Swish

解决办法:

weights_path = '模型权重的路径'
net = EfficientNet.from_name('efficientnet-b2', {'num_classes': 2})
net.set_swish(memory_efficient=False)
net.load_state_dict(torch.load(weights_path))
net.eval()torch.onnx.export(net, input, 'efficientnet_ss.onnx')

net.set_swish()函数会将swish转换为x*sigmoid(x),从而使其满足onnx的支持的操作,那么pytorch模型就可以顺利的转换为onnx模型了

  相关解决方案