实现手写体 mnist 数据集的识别任务,共分为三个模块文件,分别是:
描述网络结构的前向传播过程文件(mnist_forward.py)
描述网络参数优化方法的反向传播过程文件( mnist_backward.py )
验证模型准确率的测试过程文件(mnist_test.py)。
描述网络结构的前向传播过程文件(mnist_forward.py)
#coding:utf-8import tensorflow as tfINPUT_NODE = 784 OUTPUT_NODE = 10 LAYER1_NODE = 500def get_weight(shape, regularizer):w = tf.Variable(tf.truncated_normal(shape, stddev=0.1))if regularizer != None:tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(regularizer)(w))return wdef get_bias(shape):b = tf.Variable(tf.zeros(shape))return bdef forward(x, regularizer):w1 = get_weight([INPUT_NODE, LAYER1_NODE], regularizer)b1 = get_bias([LAYER1_NODE])y1 = tf.nn.relu(tf.matmul(x, w1) + b1)w2 = get_weight([LAYER1_NODE, OUTPUT_NODE], regularizer)b2 = get_bias([OUTPUT_NODE])y = tf.matmul(y1, w2) + b2return y
描述网络参数优化方法的反向传播过程文件( mnist_backward.py )
#coding:utf-8import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data import os import mnist_forwardBATCH_SIZE = 200 REGULARIZER = 0.0001 LR = 0.1 LR_DECAY_RATE = 0.99 EMA_DECAY = 0.99 STEPS = 50000 MODEL_SAVE_PATH = './model/' MODEL_NAME = 'mnist_model'def backward(mnist):x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])y_ = tf.placeholder(tf.float32, [None, mnist_forward.OUTPUT_NODE])y = mnist_forward.forward(x, REGULARIZER)global_step = tf.Variable(0, trainable=False)lr = tf.train.exponential_decay(learning_rate = LR,global_step = global_step,decay_steps = mnist.train.num_examples / BATCH_SIZE,decay_rate = LR_DECAY_RATE,staircase = True)ce = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(y_, 1), logits=y)cem = tf.reduce_mean(ce)loss = cem + tf.add_n(tf.get_collection('losses'))train_step = tf.train.GradientDescentOptimizer(lr).minimize(loss, global_step=global_step)ema = tf.train.ExponentialMovingAverage(decay = EMA_DECAY,num_updates = global_step)ema_op = ema.apply(tf.trainable_variables())with tf.control_dependencies([train_step, ema_op]):train_op = tf.no_op('train')saver = tf.train.Saver()with tf.Session() as sess:sess.run(tf.global_variables_initializer())for i in range(STEPS):xs, ys = mnist.train.next_batch(BATCH_SIZE)_, loss_v, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})if i % 1000 == 0:print('After %d training steps, loss on training batch is %g.' % (step, loss_v))saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step)def main():mnist = input_data.read_data_sets('./data/', one_hot=True)backward(mnist)if __name__ == '__main__':main()
model 文件夹:
验证模型准确率的测试过程文件(mnist_test.py)。
#coding:utf-8import time import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data import mnist_forward import mnist_backward TEST_INTERVAL_SECS = 5def test(mnist):with tf.Graph().as_default() as g:x = tf.placeholder(tf.float32, [None, mnist_forward.INPUT_NODE])y_ = tf.placeholder(tf.float32, [None, mnist_forward.OUTPUT_NODE])y = mnist_forward.forward(x, None)# 实例化可还原滑动平均的 saver# 这样所有参数在会话中被加载时会被赋值为各自的滑动平均值ema = tf.train.ExponentialMovingAverage(mnist_backward.EMA_DECAY)ema_restore = ema.variables_to_restore()saver = tf.train.Saver(ema_restore)# 计算正确率correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))while True:with tf.Session() as sess:# 加载训练好的模型ckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)# 如果已有 ckpt 模型则恢复if ckpt and ckpt.model_checkpoint_path:# 恢复会话saver.restore(sess, ckpt.model_checkpoint_path)# 恢复轮数global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]# 计算正确率accuracy_score = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})print('After %s training steps, test accuracy = %g' % (global_step, accuracy_score))else:print('No checkpoint file found.')returntime.sleep(TEST_INTERVAL_SECS)def main():mnist = input_data.read_data_sets('./data/', one_hot=True)test(mnist)if __name__ == '__main__':main()