当前位置: 代码迷 >> 综合 >> Tensorflow MNIST 手写体识别代码注释(4)
  详细解决方案

Tensorflow MNIST 手写体识别代码注释(4)

热度:82   发布时间:2023-12-12 16:14:17.0

Tensorflow MNIST 手写体识别代码注释(4)

  • tf.argmax
  • tf.equal()

tf.argmax

测试模型定义如下:

correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y,1))       

tf.argmax(input,axis) 根据 axis 取值的不同返回每行或者每列最大值的索引。

axis 的作用从维度的角度来看,并不复杂。例如 test = np.array([[1, 2, 3], [2, 3, 4], [5, 4, 3], [8, 7, 2]]) 的维度是 4 x 3, axis = 0,就是把 维度 0 压缩,axis = 1,就是把维度 1 压缩。

t e s t = [ 1 2 3 2 3 4 5 4 3 8 7 2 ] test = \left[ \begin{matrix} 1&2&3\\ 2&3&4\\ 5&4&3\\ 8&7&2\\ \end{matrix} \right] test=?????1258?2347?3432??????

argmax(test, axis = 0) --> shape(4 x 3) --> shape( 1 x 3 ) --> [3, 3, 1]
argmax(test, axis = 1) --> shape(4 x 3) --> shape( 4 x 1 ) --> [2, 2, 0, 0]

显然,

tf.argmax(pred, 1) --> shape(100 x 10) --> shape(100 x 1)
tf.argmax(y   , 1) --> shape(100 x 10) --> shape(100 x 1)

事实上,他们返回了 100 张图片的分类标识。而

tf.equal()

tf.equal() 比较两个张量的元素是否相等。看个例子:

import tensorflow as tf
a = [[1,2,3],[4,5,6]]
b = [[1,0,3],[1,5,1]]
with tf.Session() as sess:print(sess.run(tf.equal(a,b)))

结果:

[[ True False  True][False  True False]]

tf.cast(tf.equal(a,b), tf.float32) 的结果为:

[[ 1.0  0.0  1.0 ][ 0.0  1.0  0.0 ]]

MNIST 的测试模型定义如下:

correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y,1))       
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print("Accuracy: ", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))

注意,mnist.test.images、mnist.test.labels,表明测试结果是采用测试数据集测试的。

eval() 其实就是 tf.Tensor的Session.run() 的另外一种写法。上面些的那个代码例子,如果稍微修改一下,加上一个Session context manager:

with tf.Session() as sess:print(accuracy.eval({
    x:mnist.test.images,y_: mnist.test.labels}))

其效果和下面的代码是等价的:

with tf.Session() as sess:print(sess.run(accuracy, {
    x:mnist.test.images,y_: mnist.test.labels}))

但是要注意的是,eval() 只能用于 tf.Tensor 类对象,也就是有输出的 Operation。对于没有输出的 Operation, 可以用.run() 或者 Session.run()。Session.run() 没有这个限制。


到此为止,已经完全读懂 MNIST 的神经网络训练模型了。接下来需要做三件事情,第一件事,写一个自己的例子;第二件事,用 C++ 从头实现 MNIST 算法;第三件事,用 C++ 重构 Tensorflow 的训练结果。

  相关解决方案