当前位置: 代码迷 >> 综合 >> TensorFlow笔记(二)——用一个demo来介绍tf.train.Saver()模型参数的保存和加载
  详细解决方案

TensorFlow笔记(二)——用一个demo来介绍tf.train.Saver()模型参数的保存和加载

热度:48   发布时间:2023-11-23 22:35:15.0

1. Saver的背景介绍

        我们经常在训练完一个模型之后希望保存训练的结果,这些结果指的是模型的参数,以便下次迭代的训练或者用作测试。Tensorflow针对这一需求提供了Saver类。
      Saver类提供了向checkpoints文件保存和从checkpoints文件中恢复变量的相关方法。Checkpoints文件是一个二进制文件,它把变量名映射到对应的tensor值 。
      只要提供一个计数器,当计数器触发时,Saver类可以自动的生成checkpoint文件。这让我们可以在训练过程中保存多个中间结果。例如,我们可以保存每一步训练的结果。
      为了避免填满整个磁盘,Saver可以自动的管理Checkpoints文件。例如,我们可以指定保存最近的N个Checkpoints文件。

2. Saver的实例

下面以一个例子来讲述如何使用Saver类,该例子列举了机器学习的一个通用过程:

                                        1.准备数据 -> 2.构造模型(设置求解目标函数) -> 3.求解模型

其中构造模型、初始化所用的tf.Variables()函数的说明见这里

import tensorflow as tf
import numpy as np# 1.准备数据: 
x = tf.placeholder(tf.float32, shape=[None, 1])
y = 4 * x + 4# 2.构造一个线性模型 
w = tf.Variable(tf.random_normal([1], -1, 1)) #创建新对象,当检测到命名冲突时,系统会自己处理
b = tf.Variable(tf.zeros([1]))
y_predict = w * x + b# 3.求解模型
# 设置损失函数:误差的均方差 
loss = tf.reduce_mean(tf.square(y - y_predict))
# 选择梯度下降的方法
optimizer = tf.train.GradientDescentOptimizer(0.5)
# 迭代的目标:最小化损失函数
train = optimizer.minimize(loss)#参数定义声明 
isTrain = False
train_steps = 100
checkpoint_steps = 50
checkpoint_dir = ''
saver = tf.train.Saver()  # defaults to saving all variables - in this case w and b
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))############################################################
# 以下是用 tf 来解决上面的任务
# 1.初始化变量:tf 的必备步骤,主要声明了变量,就必须初始化才能用
# init = tf.global_variables_initializer() # 设置tensorflow对GPU的使用按需分配
#config  = tf.ConfigProto()
#config.gpu_options.allow_growth = True# 2.启动图 (graph)
with tf.Session() as sess:sess.run(tf.initialize_all_variables()) #判断当前工作状态if isTrain: #isTrain:True表示训练;False:表示测试# 3.迭代,反复执行上面的最小化损失函数这一操作(train op),拟合平面for i in xrange(train_steps): #train_steps表示训练的次数,例子中使用100sess.run(train, feed_dict={x: x_data})if (i + 1) % checkpoint_steps == 0: #表示训练多少次保存一下checkpoints,例子中使用50print ('step: {}  train_acc: {}  loss: {}'.format(step, sess.run(W), sess.run(b)))saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1) #表示checkpoints文件的保存路径,例子中使用当前路径else: #如果isTrain=False,则进行测试ckpt = tf.train.get_checkpoint_state(checkpoint_dir)if ckpt and ckpt.model_checkpoint_path:saver.restore(sess, ckpt.model_checkpoint_path) #恢复变量else:passprint(sess.run(w),sess.run(b))

isTrain:用来区分训练阶段和测试阶段,True表示训练,False表示测试
train_steps:表示训练的次数,例子中使用100
checkpoint_steps:表示训练多少次保存一下checkpoints,例子中使用50
checkpoint_dir:表示checkpoints文件的保存路径,例子中使用当前路径

训练结果;

                                       

三、参数保存和加载

    最简单的保存和加载模型的方法是使用tf.train.Saver 对象。它的构造器将在计算图上添加save和restore节点,针对图上所有或者指定的变量。saver对象提供了运行这些节点的方法,只要指定用于读写的checkpoint的文件。结合上面demo说:

3.1 训练阶段

使用Saver.save()方法保存模型:

  1. sess:表示当前会话,当前会话记录了当前的变量值
  2. checkpoint_dir + 'model.ckpt':表示存储的文件名
  3. global_step:表示当前是第几步

3.1.1 checkpoint文件

变量以二进制文件的形式保存在checkpoint文件中,粗略地来说就是变量名与tensor数值的一个映射 
当你创建一个Saver对象是,你可以选择变量在checkpoint文件中名字。默认情况下,它会使用Variable.name作为变量名。 
为了理解什么变量在checkpoint文件中,你可以使用inspect_checkpoint库,更加详细地,使用print_tensors_in_checkpoint_file函数。

3.1.2 保存变量

使用tf.train.Saver()创建一个Saver对象,然后用它来管理模型中的所有变量。

# 创建一些变量
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# 添加用于初始化变量的节点
init_op = tf.global_variables_initializer()# 添加用于保存和加载所有变量的节点
saver = tf.train.Saver()# 然后,加载模型,初始化所有变量,完成一些操作后,把变量保存到磁盘上
with tf.Session() as sess:sess.run(init_op)# 进行一些操作..# 将变量保存到磁盘上save_path = saver.save(sess, "/tmp/model.ckpt")print("Model saved in file: %s" % save_path)

训练完成后,当前目录底下会多出5个文件。

          

    打开名为“checkpoint”的文件,可以看到保存记录,和最新的模型存储位置。

                                        

3.2测试阶段

    测试阶段使用saver.restore()方法恢复变量:

  1. sess:表示当前会话,之前保存的结果将被加载入这个会话
  2. ckpt.model_checkpoint_path:表示模型存储的位置,不需要提供模型的名字,它会去查看checkpoint文件,看看最新的是谁,叫做什么。

加载变量

Saver对象还可以用于加载变量。注意当你从文件中加载变量是,你不用实现初始化它们。

# 创建两个变量
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# 添加用于保存和加载所有变量的节点
saver = tf.train.Saver()# 然后,加载模型,使用saver对象从磁盘上加载变量,之后再使用模型进行一些操作
with tf.Session() as sess:# 从磁盘上加载对象saver.restore(sess, "/tmp/model.ckpt")print("Model restored.")# 使用模型进行一些操作...

 运行结果如下图所示,加载了之前训练的参数w和b的结果           
                               

四、做迁移学习时

选择变量进行保存和加载

      如果你不传递任何参数给tf.train.Saver(),Saver对象将处理图中的所有变量。每一个变量使用创建时传递给它的名字保存在磁盘上。 
      有时候,我们需要显示地指定变量保存在checkpoint文件中的名字。例如,你可能使用名为“weights”的变量训练模型;在保存的时候,你希望用“params”为名字保存。 
      有时候,我们只保存和加载模型的部分参数。例如,你已经训练了一个5层的神经网络;现在你想训练一个新的神经网络,它有6层。加载旧模型的参数作为新神经网络前5层的参数。 
      通过传递给tf.train.Saver()一个Python字典,你可以简单地指定名字和想要保存的变量。字典的keys是保存在磁盘上的名字,values是变量的值。 
注意: 
      如果你需要保存和加载不同子集的变量,你可以随心所欲地创建任意多的saver对象。同一个变量可以被多个saver对象保存。它的值仅仅在restore()方法运行之后发生改变。 
      如果在会话开始之初,你仅加载了部分变量,你还需要为其他变量运行初始化操作。参见tf.initialize_variables() 查询更多的信息。

# 创建一些对象
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# 添加一个节点用于保存和加载变量v2,使用名字“my_v2”
saver = tf.train.Saver({"my_v2": v2})
# Use the saver object normally after that.
...

 

转自:

Variables: 创建、初始化、保存和加载:https://blog.csdn.net/u011500062/article/details/53414195

Tensorflow系列——Saver的用法:https://blog.csdn.net/u011500062/article/details/51728830

TensorFlow入门(一)基本用法:https://blog.csdn.net/Jerr__y/article/details/57084008

  相关解决方案