背景
在使用 tf.train.Saver
时遇到这样一种情况,明明设置了 max_to_keep=5,但是保存的模型仍然超过了5个,那这是为什么呢?
举例
原因1:
假设跑10个 step,每个 step 保存一次模型。
with tf.Session() as sess:saver = tf.train.Saver(max_to_keep=5)for step in range(10):init_checkpoint = 'model.ckpt-%d' % step # 不同部分saver.save(sess, init_checkpoint)
以上代码是会使 max_to_keep=5 失效的,原因就在于 init_checkpoint = 'model.ckpt-%d' % step
其 step 放在了 .ckpt 之后,改为 init_checkpoint = 'model-%d.ckpt' % step
的话即可。
正确代码:
with tf.Session() as sess:saver = tf.train.Saver(max_to_keep=5)for step in range(10):init_checkpoint = 'model-%d.ckpt' % step # 不同部分saver.save(sess, init_checkpoint)
原因2:
假设我们定义了两个 saver,一个用来加载预训练模型的参数,一个用来保存模型。
with tf.Session() as sess:# 加载预训练模型参数checkpoint = 'bert_model.ckpt'assignment_map = ...loader = tf.train.Saver(assignment_map)loader.restore(sess, checkpoint)# 保存模型saver = tf.train.Saver(max_to_keep=5)for step in range(10):# 模型保存的路径及名称, 如下为在当前目录保存模型 model-0.ckpt、model-1.ckpt、...init_checkpoint = 'model-%d.ckpt' % stepsaver.save(sess, init_checkpoint)
以上代码即便 init_checkpoint
正确了,但仍会使 max_to_keep = 5
失效,这是为什么呢?
因为我们创建了两个 tf.train.Saver
对象,只有在只创建一个 tf.train.Saver
对象时,max_to_keep = 5
才会起作用。
正确代码:
with tf.Session() as sess:# 加载预训练模型参数checkpoint = 'bert_model.ckpt'assignment_map = ...saver = tf.train.Saver(assignment_map, max_to_keep=5)saver.restore(sess, checkpoint)# 保存模型for step in range(10):# 模型保存的路径及名称, 如下为在当前目录保存模型 model-0.ckpt、model-1.ckpt、...init_checkpoint = 'model-%d.ckpt' % stepsaver.save(sess, init_checkpoint)
总结:
- init_checkpoint 参数,不要在.ckpt后拼接其他字符串了;
- tf.train.Saver 对象只能创建一个;