当前位置: 代码迷 >> 综合 >> tensorflow tf.train.Saver max_to_keep 参数无效
  详细解决方案

tensorflow tf.train.Saver max_to_keep 参数无效

热度:48   发布时间:2023-11-18 03:58:29.0

背景

在使用 tf.train.Saver 时遇到这样一种情况,明明设置了 max_to_keep=5,但是保存的模型仍然超过了5个,那这是为什么呢?

举例

原因1:

假设跑10个 step,每个 step 保存一次模型。

with tf.Session() as sess:saver = tf.train.Saver(max_to_keep=5)for step in range(10):init_checkpoint = 'model.ckpt-%d' % step  # 不同部分saver.save(sess, init_checkpoint)

以上代码是会使 max_to_keep=5 失效的,原因就在于 init_checkpoint = 'model.ckpt-%d' % step 其 step 放在了 .ckpt 之后,改为 init_checkpoint = 'model-%d.ckpt' % step 的话即可。

正确代码:

with tf.Session() as sess:saver = tf.train.Saver(max_to_keep=5)for step in range(10):init_checkpoint = 'model-%d.ckpt' % step  # 不同部分saver.save(sess, init_checkpoint)

原因2:

假设我们定义了两个 saver,一个用来加载预训练模型的参数,一个用来保存模型。

with tf.Session() as sess:# 加载预训练模型参数checkpoint = 'bert_model.ckpt'assignment_map = ...loader = tf.train.Saver(assignment_map)loader.restore(sess, checkpoint)# 保存模型saver = tf.train.Saver(max_to_keep=5)for step in range(10):# 模型保存的路径及名称, 如下为在当前目录保存模型 model-0.ckpt、model-1.ckpt、...init_checkpoint = 'model-%d.ckpt' % stepsaver.save(sess, init_checkpoint)

以上代码即便 init_checkpoint 正确了,但仍会使 max_to_keep = 5 失效,这是为什么呢?

因为我们创建了两个 tf.train.Saver 对象,只有在只创建一个 tf.train.Saver 对象时,max_to_keep = 5 才会起作用。

正确代码:

with tf.Session() as sess:# 加载预训练模型参数checkpoint = 'bert_model.ckpt'assignment_map = ...saver = tf.train.Saver(assignment_map, max_to_keep=5)saver.restore(sess, checkpoint)# 保存模型for step in range(10):# 模型保存的路径及名称, 如下为在当前目录保存模型 model-0.ckpt、model-1.ckpt、...init_checkpoint = 'model-%d.ckpt' % stepsaver.save(sess, init_checkpoint)

总结:

  • init_checkpoint 参数,不要在.ckpt后拼接其他字符串了;
  • tf.train.Saver 对象只能创建一个;
  相关解决方案