当前位置: 代码迷 >> 综合 >> inception-v3模型神经网络图片识别系统搭建详细流程(2)
  详细解决方案

inception-v3模型神经网络图片识别系统搭建详细流程(2)

热度:70   发布时间:2024-01-24 10:25:00.0

阅读前提示:代码复制过来时带有行号,运行本文程序需要自行删除行号并检查是否存在缩进错误。本文整理了该模型的运行经验,经过验证可行。
本文详细介绍了基于inception-v3模型的神经网络图片识别系统搭建过程。

接上文

2.将ckpt转pb文件

需要将TensorFlow的模型导出为单个文件(同时包含模型结构的定义与权重),方便在其他地方使用(如在Android中部署网络)。convert_pb.py代码如下:

1.	# -*-coding: utf-8 -*-
2.	"""
3.	    @Project: tensorflow_models_nets
4.	    @File   : convert_pb.py
5.	    @Author : panjq
6.	    @E-mail : pan_jinquan@163.com
7.	    @Date   : 2018-08-29 17:46:50
8.	    @info   :
9.	    -通过传入 CKPT 模型的路径得到模型的图和变量数据
10.	    -通过 import_meta_graph 导入模型中的图
11.	    -通过 saver.restore 从模型中恢复图中各个变量的数据
12.	    -通过 graph_util.convert_variables_to_constants 将模型持久化
13.	"""
14.	
15.	import tensorflow as tf
16.	from create_tf_record import *
17.	from tensorflow.python.framework import graph_util
18.	
19.	resize_height = 224  # 指定图片高度
20.	resize_width = 224  # 指定图片宽度
21.	depths = 3
22.	
23.	def freeze_graph_test(pb_path, image_path):
24.	    '''
25.	    :param pb_path:pb文件的路径
26.	    :param image_path:测试图片的路径
27.	    :return:
28.	    '''
29.	    with tf.Graph().as_default():
30.	        output_graph_def = tf.GraphDef()
31.	        with open(pb_path, "rb") as f:
32.	            output_graph_def.ParseFromString(f.read())
33.	            tf.import_graph_def(output_graph_def, name="")
34.	        with tf.Session() as sess:
35.	            sess.run(tf.global_variables_initializer())
36.	
37.	            # 定义输入的张量名称,对应网络结构的输入张量
38.	            # input:0作为输入图像,keep_prob:0作为dropout的参数,测试时值为1,is_training:0训练参数
39.	            input_image_tensor = sess.graph.get_tensor_by_name("input:0")
40.	            input_keep_prob_tensor = sess.graph.get_tensor_by_name("keep_prob:0")
41.	            input_is_training_tensor = sess.graph.get_tensor_by_name("is_training:0")
42.	
43.	            # 定义输出的张量名称
44.	            output_tensor_name = sess.graph.get_tensor_by_name("InceptionV3/Logits/SpatialSqueeze:0")
45.	
46.	            # 读取测试图片
47.	            im=read_image(image_path,resize_height,resize_width,normalization=True)
48.	            im=im[np.newaxis,:]
49.	            # 测试读出来的模型是否正确,注意这里传入的是输出和输入节点的tensor的名字,不是操作节点的名字
50.	            # out=sess.run("InceptionV3/Logits/SpatialSqueeze:0", feed_dict={'input:0': im,'keep_prob:0':1.0,'is_training:0':False})
51.	            out=sess.run(output_tensor_name, feed_dict={input_image_tensor: im,
52.	                                                        input_keep_prob_tensor:1.0,
53.	                                                        input_is_training_tensor:False})
54.	            print("out:{}".format(out))
55.	            score = tf.nn.softmax(out, name='pre')
56.	            class_id = tf.argmax(score, 1)
57.	            print("pre class_id:{}".format(sess.run(class_id)))
58.	
59.	
60.	def freeze_graph(input_checkpoint,output_graph):
61.	    '''
62.	
63.	    :param input_checkpoint:
64.	    :param output_graph: PB模型保存路径
65.	    :return:
66.	    '''
67.	    # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
68.	    # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
69.	
70.	    # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
71.	    output_node_names = "InceptionV3/Logits/SpatialSqueeze"
72.	    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
73.	
74.	    with tf.Session() as sess:
75.	        saver.restore(sess, input_checkpoint) #恢复图并得到数据
76.	        output_graph_def = graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定
77.	            sess=sess,
78.	            input_graph_def=sess.graph_def,# 等于:sess.graph_def
79.	            output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
80.	
81.	        with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
82.	            f.write(output_graph_def.SerializeToString()) #序列化输出
83.	        print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
84.	
85.	        # for op in sess.graph.get_operations():
86.	        #     print(op.name, op.values())
87.	
88.	def freeze_graph2(input_checkpoint,output_graph):
89.	    '''
90.	
91.	    :param input_checkpoint:
92.	    :param output_graph: PB模型保存路径
93.	    :return:
94.	    '''
95.	    # checkpoint = tf.train.get_checkpoint_state(model_folder) #检查目录下ckpt文件状态是否可用
96.	    # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt文件路径
97.	
98.	    # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
99.	    output_node_names = "InceptionV3/Logits/SpatialSqueeze"
100.	    saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
101.	    graph = tf.get_default_graph() # 获得默认的图
102.	    input_graph_def = graph.as_graph_def()  # 返回一个序列化的图代表当前的图
103.	
104.	    with tf.Session() as sess:
105.	        saver.restore(sess, input_checkpoint) #恢复图并得到数据
106.	        output_graph_def = graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定
107.	            sess=sess,
108.	            input_graph_def=input_graph_def,# 等于:sess.graph_def
109.	            output_node_names=output_node_names.split(","))# 如果有多个输出节点,以逗号隔开
110.	
111.	        with tf.gfile.GFile(output_graph, "wb") as f: #保存模型
112.	            f.write(output_graph_def.SerializeToString()) #序列化输出
113.	        print("%d ops in the final graph." % len(output_graph_def.node)) #得到当前图有几个操作节点
114.	
115.	        # for op in graph.get_operations():
116.	        #     print(op.name, op.values())
117.	
118.	
119.	if __name__ == '__main__':
120.	    # 输入ckpt模型路径
121.	    input_checkpoint='models/model.ckpt-200'
122.	    # 输出pb模型的路径
123.	    out_pb_path="models/pb/frozen_model.pb"
124.	    # 调用freeze_graph将ckpt转为pb
125.	    freeze_graph(input_checkpoint,out_pb_path)
126.	
127.	    # 测试pb模型
128.	    image_path = 'test_image/guitar.jpg'
129.	    freeze_graph_test(pb_path=out_pb_path, image_path=image_path)