问题描述
我目前正在训练一个神经网络,并尝试存储训练后的模型以备将来使用。
该模型基于keras
Sequential
(见下文)。
我正在使用joblib.dump(model, output_file_gen)
来存储信息。
但是,我收到错误消息:
TypeError: can't pickle _thread.RLock objects.
我查看了一些关于此错误消息的 StackOverflow 帖子,它似乎与多线程有关。 我不确定模型中会发生什么,但也许有人可以通过采取措施消除此错误或通过建议更好的方法来存储神经网络来给我建议如何存储模型。
NN 设置如下:
model = Sequential()
model.add(Dense(256, input_dim=self.latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(np.prod(self.img_shape), activation='tanh'))
model.add(Reshape(self.img_shape))
1楼
不建议使用 pickle 或 cPickle 来保存 Keras 模型。这是这里错误的原因(松散推理)
您可以使用model.save(filepath)
将模型保存到单个 HDF5 文件中,该文件将包含:
- 模型的架构,允许重新创建模型
- 模型的权重
- 训练配置(损失,优化器)
- 优化器的状态,允许从您停止的地方恢复训练。
然后您可以使用keras.models.load_model(filepath)
重新实例化/重新加载您的模型。
以上将使用大量磁盘空间。 因此您可以选择保存模型权重。 请参阅了解更多详情