文章目录
-
- 一、保存网络权重
-
-
- .save_weights()
- .load_weights()
- del net_name : 删除网络 net_name
-
- 二、保存网络模型
-
-
- .save()
- .load_model()
-
- 三、工业部署保存模型
-
-
- tf.saved_model.save()
- tf.saved_model.load()
-
一、保存网络权重
.save_weights()
保存网络中所有的 w1,b1,w2,b2…,一些其他细节并不保存
.load_weights()
加载权重文件
del net_name : 删除网络 net_name
删除网络后,要想恢复原始模型,必须建立与之前相同的Sequential
network.save_weights('weights.ckpt')
print('saved weights.')
del networknetwork = Sequential([layers.Dense(256, activation='relu'),layers.Dense(128, activation='relu'),layers.Dense(64, activation='relu'),layers.Dense(32, activation='relu'),layers.Dense(10)])
network.compile(optimizer=optimizers.Adam(lr=0.01),loss=tf.losses.CategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
network.load_weights('weights.ckpt')
print('loaded weights!')
二、保存网络模型
.save()
保存模型的所有细节,删掉网络后,只需要恢复保存的模型就可使用,不用新建squential。
.load_model()
network.save('model.h5')
print('saved total model.')
del networkprint('loaded model from file.')
network = tf.keras.models.load_model('model.h5', compile=False)
network.compile(optimizer=optimizers.Adam(lr=0.01),loss=tf.losses.CategoricalCrossentropy(from_logits=True),metrics=['accuracy'])