文章目录
- 介绍
- 保存变量
- 恢复变量
- 有限制地保留 Checkpoint 文件
- 实例
-
- 1、定义模型及训练过程
- 2、保存模型参数
-
- 2.1 不限制 checkpoint 文件个数
- 2.2 限制 checkpoint 文件个数
- 3、加载模型参数
- 参考资料
介绍
很多时候,我们希望在模型训练完成后能将训练好的参数(变量)保存起来。在需要使用模型的其他地方载入模型和参数,就能直接得到训练好的模型。
TensorFlow 提供了 tf.train.Checkpoint 这一强大的变量保存与恢复类,可以使用其 save() 和 restore() 方法将 TensorFlow 中所有包含 Checkpointable State 的对象进行保存和恢复。具体而言,tf.keras.optimizer 、 tf.Variable 、 tf.keras.Layer 或者 tf.keras.Model 实例都可以被保存。
保存变量
# train.py 模型训练阶段model = MyModel()
checkpoint = tf.train.Checkpoint(myModel=model)
# ...(模型训练代码)
# 模型训练完毕后将参数保存到文件
checkpoint.save('./save/model.ckpt')
这里 tf.train.Checkpoint() 接受的初始化参数比较特殊,是一个 **kwargs 。具体而言,是一系列的键值对,键名可以随意取,值为需要保存的对象。在这里,我们取键名为 myModel,指定保存对象为 model。如果我们希望保存其他对象如 Optimizer 的参数,我们可以这样写:
checkpoint = tf.train.Checkpoint(myModel=model, myOptimizer=optimizer)
训练完后,checkpoint 文件会出现在 ‘./save/’ 文件夹下,‘model.ckpt’ 是这些文件的前缀。如果我们只调用了一次 checkpoint.save 函数,那么在 ‘./save/’ 文件夹下会出现名为 checkpoint 、 model.ckpt-1.index 、 model.ckpt-1.data-00000-of-00001 的三个文件,这些文件就记录了变量信息。checkpoint.save() 方法可以运行多次,每运行一次都会得到一个.index 文件和.data 文件,序号依次累加。
恢复变量
当在其他地方需要为模型重新载入之前保存的参数时,需要再次实例化一个 checkpoint,同时保持键名的一致。再调用 checkpoint 的 restore 方法。
# test.py 模型使用阶段model_to_be_restored = MyModel()
checkpoint = tf.train.Checkpoint(myModel=model_to_be_restored) # 实例化Checkpoint,指定恢复对象为model
checkpoint.restore(tf.train.latest_checkpoint('./save')) # 从文件恢复模型参数
当保存了多个文件时,我们往往想载入最近的一个。可以使用 tf.train.latest_checkpoint(save_path) 这个辅助函数返回目录下最近一次 checkpoint 的文件名。例如如果 save 目录下有 model.ckpt-1.index 到 model.ckpt-10.index 的 10 个保存文件, tf.train.latest_checkpoint(’./save’) 即返回 ./save/model.ckpt-10 。
有限制地保留 Checkpoint 文件
在模型的训练过程中,我们往往每隔一定步数保存一个 Checkpoint 并进行编号。不过很多时候我们会有这样的需求:
-
在长时间的训练后,程序会保存大量的 Checkpoint,但我们只想保留最后的几个 Checkpoint;
-
Checkpoint 默认从 1 开始编号,每次累加 1,但我们可能希望使用别的编号方式(例如使用当前 epoch 的编号作为文件编号)。
这时,我们可以使用 TensorFlow 的 tf.train.CheckpointManager 来实现以上需求。具体而言,在定义 Checkpoint 后接着定义一个 CheckpointManager:
checkpoint = tf.train.Checkpoint(model=model)
manager = tf.train.CheckpointManager(checkpoint, directory='./save', checkpoint_name='model.ckpt', max_to_keep=k)
在需要保存模型的时候,我们直接使用 manager.save() 即可。如果我们希望自行指定保存的 Checkpoint 的编号,则可以在保存时加入 checkpoint_number 参数。例如 manager.save(checkpoint_number=100) 。
实例
我们通过对 MNIST 数据集的训练来举例:
1、定义模型及训练过程
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.layers as layersmnist = keras.datasets.mnist(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0# Add a channels dimension
x_train = x_train[..., tf.newaxis].astype(np.float32)
x_test = x_test[..., tf.newaxis].astype(np.float32)train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(x_test.shape[0])class MyModel(keras.Model):# Set layers.def __init__(self):super(MyModel, self).__init__()# Convolution Layer with 32 filters and a kernel size of 5.self.conv1 = layers.Conv2D(32, kernel_size=5, activation=tf.nn.relu)# Max Pooling (down-sampling) with kernel size of 2 and strides of 2.self.maxpool1 = layers.MaxPool2D(2, strides=2)# Convolution Layer with 64 filters and a kernel size of 3.self.conv2 = layers.Conv2D(64, kernel_size=3, activation=tf.nn.relu)# Max Pooling (down-sampling) with kernel size of 2 and strides of 2.self.maxpool2 = layers.MaxPool2D(2, strides=2)# Flatten the data to a 1-D vector for the fully connected layer.self.flatten = layers.Flatten()# Fully connected layer.self.fc1 = layers.Dense(1024)# Apply Dropout (if is_training is False, dropout is not applied).self.dropout = layers.Dropout(rate=0.5)# Output layer, class prediction.self.out = layers.Dense(10)# Set forward pass.def call(self, x, is_training=False):x = tf.reshape(x, [-1, 28, 28, 1])x = self.conv1(x)x = self.maxpool1(x)x = self.conv2(x)x = self.maxpool2(x)x = self.flatten(x)x = self.fc1(x)x = self.dropout(x, training=is_training)x = self.out(x)if not is_training:# tf cross entropy expect logits without softmax, so only# apply softmax when not training.x = tf.nn.softmax(x)return xmodel = MyModel()loss_object = keras.losses.SparseCategoricalCrossentropy()
optimizer = keras.optimizers.Adam()@tf.function
def train_step(images, labels):with tf.GradientTape() as tape:predictions = model(images)loss = loss_object(labels, predictions)gradients = tape.gradient(loss, model.trainable_variables)optimizer.apply_gradients(zip(gradients, model.trainable_variables))
2、保存模型参数
2.1 不限制 checkpoint 文件个数
EPOCHS = 5checkpoint = tf.train.Checkpoint(myAwesomeModel=model)
for epoch in range(EPOCHS):for images, labels in train_ds:train_step(images, labels)path = checkpoint.save('./save/model.ckpt')print("model saved to %s" % path)
2.2 限制 checkpoint 文件个数
EPOCHS = 5checkpoint = tf.train.Checkpoint(myAwesomeModel=model)
manager = tf.train.CheckpointManager(checkpoint, directory='./save', max_to_keep=3)
for epoch in range(EPOCHS):for images, labels in train_ds:train_step(images, labels)path = manager.save(checkpoint_number=epoch)print("model saved to %s" % path)
3、加载模型参数
model_to_be_restored = MyModel()
checkpoint = tf.train.Checkpoint(myAwesomeModel=model_to_be_restored)
checkpoint.restore(tf.train.latest_checkpoint('./save'))
for test_images, test_labels in test_ds:y_pred = np.argmax(model_to_be_restored.predict(test_images), axis=-1)print("test accuracy: %f" % (sum(tf.cast(y_pred == test_labels, tf.float32)) / x_test.shape[0]))
test accuracy: 0.989600
参考资料
简单粗暴 TensorFlow 2