阅读前提示:代码复制过来时带有行号,运行本文程序需要自行删除行号并检查是否存在缩进错误。本文整理了该模型的运行经验,经过验证可行。
本文详细介绍了基于inception-v3模型的神经网络图片识别系统搭建过程。
接上文
2.将ckpt转pb文件
需要将TensorFlow的模型导出为单个文件(同时包含模型结构的定义与权重),方便在其他地方使用(如在Android中部署网络)。convert_pb.py代码如下:
1. # -*-coding: utf-8 -*-
2. """
3. @Project: tensorflow_models_nets
4. @File : convert_pb.py
5. @Author : panjq
6. @E-mail : pan_jinquan@163.com
7. @Date : 2018-08-29 17:46:50
8. @info :
9. -通过传入 CKPT 模型的路径得到模型的图和变量数据
10. -通过 import_meta_graph 导入模型中的图
11. -通过 saver.restore 从模型中恢复图中各个变量的数据
12. -通过 graph_util.convert_variables_to_constants 将模型持久化
13. """
14.
15. import tensorflow as tf
16. from create_tf_record import *
17. from tensorflow.python.framework import graph_util
18.
19. resize_height = 224 # 指定图片高度
20. resize_width = 224 # 指定图片宽度
21. depths = 3
22.
23. def freeze_graph_test(pb_path, image_path):
24. '''
25. :param pb_path:pb文件的路径
26. :param image_path:测试图片的路径
27. :return:
28. '''
29. with tf.Graph().as_default():
30. output_graph_def = tf.GraphDef()
31. with open(pb_path, "rb") as f:
32. output_graph_def.ParseFromString(f.read())
33. tf.import_graph_def(output_graph_def, name="")
34. with tf.Session() as sess:
35. sess.run(tf.global_variables_initializer())
36.
37. # 定义输入的张量名称,对应网络结构的输入张量
38. # input:0作为输入图像,keep_prob:0作为dropout的参数,测试时值为1,is_training:0训练参数
39. input_image_tensor = sess.graph.get_tensor_by_name("input:0")
40. input_keep_prob_tensor = sess.graph.get_tensor_by_name("keep_prob:0")
41. input_is_training_tensor = sess.graph.get_tensor_by_name("is_training:0")
42.
43. # 定义输出的张量名称
44. output_tensor_name = sess.graph.get_tensor_by_name("InceptionV3/Logits/SpatialSqueeze:0")
45.
46. # 读取测试图片
47. im=read_image(image_path,resize_height,resize_width,normalization=True)
48. im=im[np.newaxis,:]
49. # 测试读出来的模型是否正确,注意这里传入的是输出和输入节点的tensor的名字,不是操作节点的名字
50. # out=sess.run("InceptionV3/Logits/SpatialSqueeze:0", feed_dict={'input:0': im,'keep_prob:0':1.0,'is_training:0':False})
51. out=sess.run(output_tensor_name, feed_dict={input_image_tensor: im,
52. input_keep_prob_tensor:1.0,
53. input_is_training_tensor:False})
54. print("out:{}".format(out))
55. score = tf.nn.softmax(out, name='pre')
56. class_id = tf.argmax(score, 1)
57. print("pre class_id:{}".format(sess.run(class_id)))
58.
59.
60. def freeze_graph(input_checkpoint,output_graph):
61. '''
62.
63. :param input_checkpoint:
64. :param output_graph: PB模型保存路径
65. :return:
66. '''
67. # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
68. # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
69.
70. # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
71. output_node_names = "InceptionV3/Logits/SpatialSqueeze"
72. saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
73.
74. with tf.Session() as sess:
75. saver.restore(sess, input_checkpoint) #恢复图并得到数据
76. output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
77. sess=sess,
78. input_graph_def=sess.graph_def,# 等于:sess.graph_def
79. output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
80.
81. with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
82. f.write(output_graph_def.SerializeToString()) #序列化输出
83. print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
84.
85. # for op in sess.graph.get_operations():
86. # print(op.name, op.values())
87.
88. def freeze_graph2(input_checkpoint,output_graph):
89. '''
90.
91. :param input_checkpoint:
92. :param output_graph: PB模型保存路径
93. :return:
94. '''
95. # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
96. # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
97.
98. # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
99. output_node_names = "InceptionV3/Logits/SpatialSqueeze"
100. saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
101. graph = tf.get_default_graph() # 获得默认的图
102. input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图
103.
104. with tf.Session() as sess:
105. saver.restore(sess, input_checkpoint) #恢复图并得到数据
106. output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
107. sess=sess,
108. input_graph_def=input_graph_def,# 等于:sess.graph_def
109. output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
110.
111. with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
112. f.write(output_graph_def.SerializeToString()) #序列化输出
113. print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
114.
115. # for op in graph.get_operations():
116. # print(op.name, op.values())
117.
118.
119. if __name__ == '__main__':
120. # 输入ckpt模型路径
121. input_checkpoint='models/model.ckpt-200'
122. # 输出pb模型的路径
123. out_pb_path="models/pb/frozen_model.pb"
124. # 调用freeze_graph将ckpt转为pb
125. freeze_graph(input_checkpoint,out_pb_path)
126.
127. # 测试pb模型
128. image_path = 'test_image/guitar.jpg'
129. freeze_graph_test(pb_path=out_pb_path, image_path=image_path)