- 模型保存
官网说明:https://www.tensorflow.org/guide/saved_model?hl=zh-cN
>>> import tensorflow as tf
>>> v1 = tf.get_variable("v1", shape=[3], initializer=tf.zeros_initializer)
>>> v2 = tf.get_variable("v2", shape=[5], initializer=tf.zeros_initializer)
>>> inc_v1 = v1.asign(v1+1)
Traceback (most recent call last):File "<stdin>", line 1, in <module>
AttributeError: 'RefVariable' object has no attribute 'asign'
>>> inc_v1 = v1.assign(v1+1)
>>> inc_v2 = v2.assign(v2-1)
>>> init_op = tf.global_variables_initializer()
>>> saver = tf.train.Saver()
>>> with tf.Session() as sess:
... sess.run(init_op)
... inc_v1.op.run()
... inc_v2.op.run()
... save_path = saver.save(sess,'/Users/jiweiwang/temp/model.ckpt')
... print("Model saved in path: %s" % save_path)
- 查看检查点 checkpoint 文件
print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors, all_tensor_names=False, count_exclude_pattern="")可以查看 TensorFlow 保存模型的参数名称、模型中的参数值、参数总量
官方说明:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/inspect_checkpoint.py
参数:
file_name:检查点文件的名称.
tensor_name:指定要打印检查点文件中特定张量的名称,不打印指定名称,则设为 None,需要注意如果只想查看指定张量,那么位置参数 all_tensors 与关键词参数 all_tensor_name 都需要设置 False,否则还会把所有的张量都打印出来.
all_tensors:布尔型参数,True 表示打印检查点文件中保存的所有张量.
all_tensor_names: 布尔型关键字参数,默认为 False,表示是否打印所有变量的名称,如果只想把所有的张量名称打印出来,而不打印具体的张量值,需要设置位置参数 all_tensors 为 False.
count_exclude_pattern: 正则化字符串,用于计数是排除匹配指定的张量.
(1)打印指定的张量名称及其数值
>>> from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
>>> print_tensors_in_checkpoint_file('/Users/jiweiwang/temp/model.ckpt', None, True)
tensor_name: v1
[1. 1. 1.]
tensor_name: v2
[-1. -1. -1. -1. -1.]