TensorBoard是Tensorflow框架提供的可视化工具,通过TensorBoard,我们就可以直观且方便的查看程序运行过程中生成的各种统计信息,包括:
scalar
image
audio
histogram
graph
这些信息可以通过tf.summary调用相关的方法进行写入。
启动Tensorbard服务
在控制台输入:
tensorboard --logdir 路径
这样就可以启动tensorboard服务,启动之后可以从控制台看到登录地址,tensorboard默认的端口为6006。
说明:
- 选项logdir指定的路径中,需要具有tensorboard可以读取的数据(通过tf.summary.FileWriter写入),否则无法展示。
- 为了达到最好的展示效果,最好使用Chrome浏览器进行查看(一家人)。
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets("data/", one_hot=True)def variable_summaries(var):with tf.name_scope("summaries"):mean = tf.reduce_mean(var)# 写入数据,供tensorboard进行展示。tf.summary.scalar("mean", mean)stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))tf.summary.scalar("stddev", stddev)tf.summary.scalar("max", tf.reduce_max(var))tf.summary.scalar("min", tf.reduce_min(var))tf.summary.histogram("histogram", var)with tf.name_scope("input"):X = tf.placeholder(dtype=tf.float32, shape=[None, 784])y = tf.placeholder(dtype=tf.float32, shape=[None, 10])variable_summaries(X)variable_summaries(y)images = tf.reshape(X, [-1, 28, 28, 1])tf.summary.image("input", images, max_outputs=3)with tf.name_scope("output"):W = tf.Variable(tf.zeros(shape=[784, 10]))variable_summaries(W)b = tf.Variable(tf.zeros(shape=[1, 10]))variable_summaries(b)z = tf.matmul(X, W) + ba = tf.nn.softmax(z)tf.summary.histogram("y_hat", a)with tf.name_scope("loss"):loss = -tf.reduce_sum(y * tf.log(a))train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)variable_summaries(loss)with tf.name_scope("accuracy"):correct = tf.equal(tf.argmax(y, axis=1), tf.argmax(a, axis=1))accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))variable_summaries(accuracy)with tf.Session() as sess:# 对之前计算的统计信息进行汇总。merged = tf.summary.merge_all()# tensorflow提供的API,对文件进行操作。# 如果logs/目录存在,则删除该目录以及目录下所有的文件与子目录。# 程序每次运行,都会在logs/下写入数据文件,如果存在多个数据文件,使用tensorboard# 进行展示的时候,就会出现图像互相叠加与干扰。if tf.gfile.Exists("logs/"):tf.gfile.DeleteRecursively("logs/")# FileWriter用来在指定的路径下写入数据文件,供tensorboard进行可视化展示。train_writer = tf.summary.FileWriter("logs/", sess.graph)sess.run(tf.global_variables_initializer())for i in range(1, 101):batch_X, batch_y = mnist.train.next_batch(128)# 使用Session来运行汇总(合并)的数据(merged),则汇总的数据就可以写入到数据文件中(供tensorboard进行读取展示)。train_summary, _ = sess.run([merged, train_step], feed_dict={
X: batch_X, y: batch_y})train_writer.add_summary(train_summary, i)if i % 10 == 0:test_summary, acu = sess.run([merged, accuracy], feed_dict={
X: mnist.test.images, y: mnist.test.labels})train_writer.add_summary(test_summary, i)print(f"第{i}次训练,测试准确率为:{acu}")