当前位置: 代码迷 >> 综合 >> tf.keras 模型的保存与加载(六)
  详细解决方案

tf.keras 模型的保存与加载(六)

热度:51   发布时间:2023-10-28 12:16:27.0

文章目录

    • 一、保存网络权重
        • .save_weights()
        • .load_weights()
        • del net_name : 删除网络 net_name
    • 二、保存网络模型
        • .save()
        • .load_model()
    • 三、工业部署保存模型
        • tf.saved_model.save()
        • tf.saved_model.load()

一、保存网络权重

.save_weights()

保存网络中所有的 w1,b1,w2,b2…,一些其他细节并不保存

.load_weights()

加载权重文件

del net_name : 删除网络 net_name

删除网络后,要想恢复原始模型,必须建立与之前相同的Sequential

network.save_weights('weights.ckpt')
print('saved weights.')
del networknetwork = Sequential([layers.Dense(256, activation='relu'),layers.Dense(128, activation='relu'),layers.Dense(64, activation='relu'),layers.Dense(32, activation='relu'),layers.Dense(10)])
network.compile(optimizer=optimizers.Adam(lr=0.01),loss=tf.losses.CategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
network.load_weights('weights.ckpt')
print('loaded weights!')

二、保存网络模型

.save()

保存模型的所有细节,删掉网络后,只需要恢复保存的模型就可使用,不用新建squential。

.load_model()

network.save('model.h5')
print('saved total model.')
del networkprint('loaded model from file.')
network = tf.keras.models.load_model('model.h5', compile=False)
network.compile(optimizer=optimizers.Adam(lr=0.01),loss=tf.losses.CategoricalCrossentropy(from_logits=True),metrics=['accuracy'])

三、工业部署保存模型

tf.saved_model.save()

tf.saved_model.load()

  相关解决方案