程序一:ckpt转pb
import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile# 模型参数固化ckpt转pb
def freeze_graph(input_meta,input_checkpoint, output_graph):''':param input_checkpoint::param output_graph: PB模型保存路径:return:'''# 指定输出的节点名称,该节点名称必须是原模型中存在的节点output_node_names = "XXXXX"saver = tf.train.import_meta_graph(input_meta, clear_devices=True) # + '.meta'graph = tf.get_default_graph() # 获得默认的图input_graph_def = graph.as_graph_def() # 返回一个序列化的图代表当前的图with tf.Session() as sess:saver.restore(sess, input_checkpoint) # 恢复图并得到数据output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定sess=sess,input_graph_def=input_graph_def, # 等于:sess.graph_defoutput_node_names=output_node_names.split(",")) # 如果有多个输出节点,以逗号隔开with gfile.GFile(output_graph, "wb") as f: # 保存模型f.write(output_graph_def.SerializeToString()) # 序列化输出print("%d ops in the final graph." % len(output_graph_def.node)) # 得到当前图有几个操作节点
程序二:测试是否转对了
# 测试
def testPb():''':param pb_path:pb文件的路径:param image_path:测试图片的路径:return:'''pb_path = "XXXXX.pb"with tf.Graph().as_default():output_graph_def = tf.GraphDef()if (os.path.isfile(pb_path)):with open(pb_path, "rb") as f:output_graph_def.ParseFromString(f.read())tf.import_graph_def(output_graph_def, name = "")with tf.Session() as sess:sess.run(tf.global_variables_initializer())# 定义输入的张量名称,对应网络结构的输入张量input= tf.get_default_graph().get_tensor_by_name("input:0")is_train = tf.get_default_graph().get_tensor_by_name("is_train:0")# 定义输出的张量名称output_tensor_name = sess.graph.get_tensor_by_name("XXXXXXX:0")out = sess.run(output_tensor_name, feed_dict={input: XXX,is_train : False})print("output:{}".format(out))
其他:
可能会出现错误:
ValueError: Input 0 of node XXXXXXXXXXX/Switch was passed float from XXXXXXXXXXXXXxBathNormalXXXXXXX:0 incompatible with expected float_ref.
原因,转pb的时候BN层是float_ref,而转pb后为float
程序上可以做如下修改
程序二
# 测试
def testPb():''':param pb_path:pb文件的路径:param image_path:测试图片的路径:return:'''pb_path = "XXXXX.pb"with tf.Graph().as_default():output_graph_def = tf.GraphDef()if (os.path.isfile(pb_path)):with open(pb_path, "rb") as f:output_graph_def.ParseFromString(f.read())for node in output_graph_def.node:if node.op == 'RefSwitch':node.op = 'Switch'for index in range(len(node.input)):if 'moving_' in node.input[index]:node.input[index] = node.input[index] + '/read'elif node.op == 'AssignSub':node.op = 'Sub'if 'use_locking' in node.attr:del node.attr['use_locking']tf.import_graph_def(output_graph_def, name = "")with tf.Session() as sess:sess.run(tf.global_variables_initializer())# 定义输入的张量名称,对应网络结构的输入张量input= tf.get_default_graph().get_tensor_by_name("input:0")is_train = tf.get_default_graph().get_tensor_by_name("is_train:0")# 定义输出的张量名称output_tensor_name = sess.graph.get_tensor_by_name("XXXXXXX:0")out = sess.run(output_tensor_name, feed_dict={input: XXX,is_train : False})print("output:{}".format(out))
测试输出与未转化前完全一致,end!