当前位置: 代码迷 >> 综合 >> tensorflow(3)Logistic regression
  详细解决方案

tensorflow(3)Logistic regression

热度:100   发布时间:2023-12-03 20:59:21.0

MNIST手写数字识别
tensorflow基本概念

# coding=utf-8
import tensorflow as tf
import input_data
mnist = input_data.read_data_sets('data/',one_hot=True)
train_img   = mnist.train.images
train_label = mnist.train.labels
test_img    = mnist.train.images
test_label  = mnist.train.labelsprint ("mnist loaded")''' saver = tf.train.Saver() with tf.Session() as sess:save_path = saver.save(sess,"/home/aitab/PycharmProjects/mnist2")print("model saved in file:",save_path) '''
print (train_img.shape)
print (train_label.shape)
print (test_img.shape)
print (test_label.shape)
#print (trainimg)
print (train_label[0])x = tf.placeholder("float",[None,784])
y = tf.placeholder("float",[None,10])W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))# model
actv = tf.nn.softmax(tf.matmul(x,W) + b)
# loss function
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(actv),reduction_indices=1))
# optimezer
learning_rate = 0.01
optm = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
# prediction
pred = tf.equal(tf.arg_max(actv,1),tf.arg_max(y,1))
# accuracy
accr = tf.reduce_mean(tf.cast(pred,"float"))
# initalizer
init = tf.global_variables_initializer()train_epochs = 50
batch_size   = 100
step         = 5
sess = tf.Session()
sess.run(init)for epoch in range(train_epochs):avg_cost = 0.batch_num = int(mnist.train.num_examples/batch_size)for i in range(batch_num):# 每次丢入一个batchbatch_xs,batch_ys = mnist.train.next_batch(batch_size)sess.run(optm,feed_dict={
    x:batch_xs,y:batch_ys})feeds = {
    x:batch_xs,y:batch_ys}avg_cost += sess.run(cost,feed_dict=feeds)/batch_num# displayif epoch % step == 0:feeds_train = {
    x:batch_xs,y:batch_ys}feeds_test  = {
    x:mnist.test.images,y:mnist.test.labels}train_acc   = sess.run(accr,feed_dict=feeds_train)test_acc    = sess.run(accr,feed_dict=feeds_test)print("Epoch: %03d/%03d cost:%.9f train_acc:%.3f test_acc:%.3f"% (epoch,train_epochs,test_acc,test_acc))print ("DONE")

OUT:
Epoch: #03d/000 cost:50.000000000 train_acc:0.852 test_acc:0.852
Epoch: #03d/005 cost:50.000000000 train_acc:0.895 test_acc:0.895
Epoch: #03d/010 cost:50.000000000 train_acc:0.905 test_acc:0.905
Epoch: #03d/015 cost:50.000000000 train_acc:0.909 test_acc:0.909
Epoch: #03d/020 cost:50.000000000 train_acc:0.913 test_acc:0.913
Epoch: #03d/025 cost:50.000000000 train_acc:0.914 test_acc:0.914
Epoch: #03d/030 cost:50.000000000 train_acc:0.915 test_acc:0.915
Epoch: #03d/035 cost:50.000000000 train_acc:0.917 test_acc:0.917
Epoch: #03d/040 cost:50.000000000 train_acc:0.918 test_acc:0.918
Epoch: #03d/045 cost:50.000000000 train_acc:0.918 test_acc:0.918
DONE

  相关解决方案