当前位置: 代码迷 >> 综合 >> tensorflow ckpt 格式的model转pb 固化模型
  详细解决方案

tensorflow ckpt 格式的model转pb 固化模型

热度:91   发布时间:2024-02-02 06:32:23.0

程序一:ckpt转pb

import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile# 模型参数固化ckpt转pb
def freeze_graph(input_meta,input_checkpoint, output_graph):''':param input_checkpoint::param output_graph: PB模型保存路径:return:'''# 指定输出的节点名称,该节点名称必须是原模型中存在的节点output_node_names = "XXXXX"saver = tf.train.import_meta_graph(input_meta, clear_devices=True) # + '.meta'graph = tf.get_default_graph()  # 获得默认的图input_graph_def = graph.as_graph_def()  # 返回一个序列化的图代表当前的图with tf.Session() as sess:saver.restore(sess, input_checkpoint)  # 恢复图并得到数据output_graph_def = graph_util.convert_variables_to_constants(  # 模型持久化,将变量值固定sess=sess,input_graph_def=input_graph_def,  # 等于:sess.graph_defoutput_node_names=output_node_names.split(","))  # 如果有多个输出节点,以逗号隔开with gfile.GFile(output_graph, "wb") as f:  # 保存模型f.write(output_graph_def.SerializeToString())  # 序列化输出print("%d ops in the final graph." % len(output_graph_def.node))  # 得到当前图有几个操作节点

程序二:测试是否转对了

# 测试
def testPb():''':param pb_path:pb文件的路径:param image_path:测试图片的路径:return:'''pb_path = "XXXXX.pb"with tf.Graph().as_default():output_graph_def =  tf.GraphDef()if (os.path.isfile(pb_path)):with open(pb_path, "rb") as f:output_graph_def.ParseFromString(f.read())tf.import_graph_def(output_graph_def, name = "")with tf.Session() as sess:sess.run(tf.global_variables_initializer())# 定义输入的张量名称,对应网络结构的输入张量input= tf.get_default_graph().get_tensor_by_name("input:0")is_train = tf.get_default_graph().get_tensor_by_name("is_train:0")# 定义输出的张量名称output_tensor_name = sess.graph.get_tensor_by_name("XXXXXXX:0")out = sess.run(output_tensor_name, feed_dict={input: XXX,is_train : False})print("output:{}".format(out))

其他:

可能会出现错误:

ValueError: Input 0 of node XXXXXXXXXXX/Switch was passed float from XXXXXXXXXXXXXxBathNormalXXXXXXX:0 incompatible with expected float_ref.

原因,转pb的时候BN层是float_ref,而转pb后为float

程序上可以做如下修改

程序二


# 测试
def testPb():''':param pb_path:pb文件的路径:param image_path:测试图片的路径:return:'''pb_path = "XXXXX.pb"with tf.Graph().as_default():output_graph_def =  tf.GraphDef()if (os.path.isfile(pb_path)):with open(pb_path, "rb") as f:output_graph_def.ParseFromString(f.read())for node in output_graph_def.node:if node.op == 'RefSwitch':node.op = 'Switch'for index in range(len(node.input)):if 'moving_' in node.input[index]:node.input[index] = node.input[index] + '/read'elif node.op == 'AssignSub':node.op = 'Sub'if 'use_locking' in node.attr:del node.attr['use_locking']tf.import_graph_def(output_graph_def, name = "")with tf.Session() as sess:sess.run(tf.global_variables_initializer())# 定义输入的张量名称,对应网络结构的输入张量input= tf.get_default_graph().get_tensor_by_name("input:0")is_train = tf.get_default_graph().get_tensor_by_name("is_train:0")# 定义输出的张量名称output_tensor_name = sess.graph.get_tensor_by_name("XXXXXXX:0")out = sess.run(output_tensor_name, feed_dict={input: XXX,is_train : False})print("output:{}".format(out))

测试输出与未转化前完全一致,end!

  相关解决方案