当前位置: 代码迷 >> 综合 >> Tensorflow模型(.ckpt)转.pb模型(不知道输出节点)
  详细解决方案

Tensorflow模型(.ckpt)转.pb模型(不知道输出节点)

热度:81   发布时间:2024-02-22 03:06:49.0

1、tensorflow模型的文件解读

使用tensorflow训练好的模型会自动保存为四个文件,如下
在这里插入图片描述
checkpoint:记录近几次训练好的模型结果(名称)。

xxx.data-00000-of-00001: 模型的所有变量的值(weights, biases, placeholders,gradients, hyper-parameters etc),也就是模型训练好参数和其他值。

xxx.index :模型的元数据,二进制或者其他格式,不可直接查看 。是一个不可变得字符串表,每一个键都是张量的名称,它的值是一个序列化的BundleEntryProto。 每个BundleEntryProto描述张量的元数据:“数据”文件中的哪个文件包含张量的内容,该文件的偏移量,校验和一些辅助数据等。

xxx.meta:模型的meta数据 ,二进制或者其他格式,不可直接查看,保存了TensorFlow计算图的结构信息,通俗地讲就是神经网络的网络结构。

2、ckpt转pb文件

import tensorflow as tf
import os
from tensorflow.python.tools import freeze_graph
from model import network  # network是你们自己定义的模型结构(代码结构)# egs:
# def network(input):
# return tf.layers.softmax(input)model_path  = "model.ckpt-0000" #设置model的路径,因新版tensorflow会生成三个文件,只需写到数字前def main():tf.reset_default_graph()# 设置输入网络的数据维度,根据训练时的模型输入数据的维度自行修改input_node = tf.placeholder(tf.float32, shape=(None, None, 200)) output_node = network(input_node)   # 神经网络的输出# 设置输出数据类型(特别注意,这里必须要跟输出网络参数的数据格式保持一致,不然会导致模型预测 精度或者预测能力的丢失)以及重新定义输出节点的名字(这样在后面保存pb文件以及之后使用pb文件时直接使用重新定义的节点名字即可)flow = tf.cast(output_node , tf.float16, 'the_outputs') saver = tf.train.Saver()with tf.Session() as sess:saver.restore(sess, model_path)#保存模型图(结构),为一个json文件tf.train.write_graph(sess.graph_def, 'output_model/pb_model', 'model.pb')#将模型参数与模型图结合,并保存为pb文件freeze_graph.freeze_graph('output_model/pb_model/model.pb', '', False, model_path, 'the_outputs','save/restore_all', 'save/Const:0', 'output_model/pb_model/frozen_model.pb', False, "")print("done")if __name__ == '__main__':main()

3、获取.ckpt模型中节点名称

# function: get the node name of ckpt model
from tensorflow.python import pywrap_tensorflow
# checkpoint_path = 'model.ckpt-xxx'
checkpoint_path =  "model.ckpt-0000" 
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:print("tensor_name: ", key)
  相关解决方案