遇到的问题:
tensorflow版本:tensorflow1.13.1-gpu
今天在学习tensorflow的模型保存和载入中出现以下BUG
NotFoundError: Restoring from checkpoint failed. This is most likely
due to a Variable name or other graph key that is missing from the
checkpoint. Please ensure that you have not altered the graph expected
based on the checkpoint. Original error:
Key bias_2 not found in checkpoint
解决过程:经过自己的排查后发现在代码中少了一句
tf.reset_default_graph()
这条语句用来重置图,可能是因为图没有重置,导致检查点找不到所以导入不了模型
实际操作:在有关图的变量定义前加入tf.reset_default_graph(),问题解决,可以正常读取保存的文件
添加位置
网络结构之前都可以,例子如下:
1,网络结构中,会清楚结构中已经定义的变量,所以会报错
添加在此位置不可取不可取不可取!!!
2,添加在网络结构前,可取
代码
# -*- coding: utf-8 -*-
""" Created on Sun Mar 1 21:13:54 2020@author: HPN 这是一个y = 2x的线性回归例子的程序 注释掉session部分就可以进行直接载入模型的测试了 """import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt#自定义函数
def moving_average(a,w = 10):if len(a) < w:return a[:]return [val if idx < w else sum(a[(idx - w):idx])/w for idx,val in enumerate(a)]
# return [val if idx < w else sum(a[(idx-w):idx])/w for idx, val in enumerate(a)]
plotdata = {
'batchsize':[],'loss':[]}#生成数据
train_X = np.linspace(-1,1,100)
train_Y = 2 * train_X + np.random.randn(*train_X.shape) * 0.3
plt.plot(train_X,train_Y,'ro',label = 'Original data')
plt.legend()
plt.show()
tf.reset_default_graph()
#创建占位符
X = tf.placeholder('float')
Y = tf.placeholder('float')#模型参数
W = tf.Variable(tf.random_normal([1]),name = 'weight')
b = tf.Variable(tf.zeros([1]),name = 'bias')
#前向结构
z = tf.multiply(X,W) + b
tf.summary.histogram('z',z)#反向优化
cost = tf.reduce_mean(tf.square(Y - z))
tf.summary.scalar('loss_function',cost)
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)#初始化所有变量
init = tf.global_variables_initializer()
#定义参数
training_epochs = 20
display_step = 2saver = tf.train.Saver(max_to_keep=1)
savedir = 'log/'''' 若要载入模型,只需将session1的部分完全注释掉就可以了 '''
#启动session
with tf.Session() as sess:sess.run(init)merged_summary_op = tf.summary.merge_all()summary_writer = tf.summary.FileWriter('log/mnist_with_summaries',sess.graph)plotdata = {
'batchsize':[],'loss':[]}#向模型输入数据for epoch in range(training_epochs):for(x,y) in zip(train_X,train_Y):sess.run(optimizer,feed_dict = {
X:x,Y:y})#生成summarysummary_str = sess.run(merged_summary_op,feed_dict = {
X:x,Y:y})summary_writer.add_summary(summary_str,epoch)#显示训练中的详细信息if epoch % display_step == 0:loss = sess.run(cost,feed_dict = {
X:train_X,Y:train_Y})print('Epoch:',epoch+1,'cost=',loss,'W=',sess.run(W),'b=',sess.run(b))if not (loss == 'NA'):plotdata['batchsize'].append(epoch)plotdata['loss'].append(loss)saver.save(sess,savedir + 'linermodel.cpkt',global_step = epoch)print('finish!')print('cost=',sess.run(cost,feed_dict={
X:train_X,Y:train_Y}),'W=',sess.run(W),'b=',sess.run(b))#图形显示plt.plot(train_X,train_Y,'ro',label = 'Original data')plt.plot(train_X,sess.run(W) * train_X + sess.run(b),'b--',label = 'Fittedline')plt.legend()#加上图例plt.show()plotdata['avgloss'] = moving_average(plotdata['loss'])plt.figure(1)plt.subplot(211)plt.plot(plotdata['batchsize'],plotdata['avgloss'],'b--')plt.xlabel('Minibatch number')plt.ylabel('Loss')plt.title('Minibatch run vs. Training loss')plt.show() print('x=0.2,z=',sess.run(z,feed_dict={
X:0.2}))'''注释掉以上部分'''load_epoch=18
with tf.Session() as sess2:
# sess2.run(tf.global_variables_initializer())
# saver.restore(sess2,savedir+'linermodel.cpkt')
# print('x=0.2,z=',sess2.run(z,feed_dict={X:0.2}))
# sess2.run(tf.global_variables_initializer()) saver.restore(sess2, savedir+"linermodel.cpkt-" + str(load_epoch))print ("x=0.2,z=", sess2.run(z, feed_dict={
X: 0.2}))
网盘下载地址:https://pan.baidu.com/s/1NgMSxMtchDvHvAzCvpZPEg
密码:zafj