当前位置: 代码迷 >> 综合 >> keras mnist图像分类任务(服饰)
  详细解决方案

keras mnist图像分类任务(服饰)

热度:82   发布时间:2023-11-22 07:31:41.0

  刚刚入手tensorflow,有错请指正。tensorflow自2.0以来,上手程度就很欧克了。这个项目主要涉及的是使用keras做分类任务,使用的是mnist数据集,比较适合新手,刚入坑的小伙伴们。

  首先导入要使用的模块,并且加载数据集。

'''
制作人:追天一方
功能:keras图像分类。
错误之处请留言,抱拳。
'''
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt#查看tensorflow版本
print(tf.__version__)#导入mnist服饰数据集
fashion_mnist=tf.keras.datasets.fashion_mnist
(train_images,train_labels),(test_images,test_labels)=fashion_mnist.load_data()#分类标签
class_name=['T-shirt/top','Trouser','Pullover','Dress','Coat','Sandal','Shirt','Sneaker','Bag','Ankle boot']

  随后查看一个图像,读者可以根据自己的喜好查看任何一张图片,只需要修改索引即可。

#查看一个图像
plt.figure()
plt.imshow(train_images[2])
plt.colorbar()
plt.grid(False)
plt.show()

  当然也可以查看多张图片,代码如下。

#展示训练集前25张图像
plt.figure(figsize=(10,10))
for i in range(25):plt.subplot(5,5,i+1)plt.xticks([])plt.yticks([])plt.grid(False)plt.imshow(train_images[i],cmap=plt.cm.binary)plt.xlabel(class_name[train_labels[i]])
plt.show()

 

接下来就是构建模型了,不得不说keras构建模型还是比较简单了,调用

tf.keras.Sequential函数即可。
#构建模型
model=tf.keras.Sequential([tf.keras.layers.Flatten(input_shape=(28,28)),tf.keras.layers.Dense(128,activation='relu'),tf.keras.layers.Dense(10)
])

  设置训练时候的损失函数和优化器,并且训练模型。

  

#设置损失函数和优化器
model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
#训练模型
model.fit(train_images,train_labels,epochs=10)

  最后就是评估模型了

#评估模型准确率
test_loss,test_acc=model.evaluate(test_images,test_labels,verbose=2)
print('\nTest accuracy:',test_acc)

  模型也训练完了,下面就是使用模型进行预测了,当然为了方便笔者用的是训练集进行预测的。

probability_model=tf.keras.Sequential([model,tf.keras.layers.Softmax()])#预测,使用的是测试集
predictions=probability_model.predict(test_images)print(np.argmax(predictions[0]))

  接下来就是一些可视化操作了,废话不多说,直接上代码。

#定义一个可视化图片函数,如果预测错误则是红色标签
def plot_image(i,predictions_array,true_label,img):true_label,img=true_label[i],img[i]plt.grid(False)plt.xticks([])plt.yticks([])plt.imshow(img,cmap=plt.cm.binary)predicted_label=np.argmax(predictions_array)if predicted_label==true_label:color='blue'else:color='red'plt.xlabel("{} {:2.0f}% ({})".format(class_name[predicted_label],100*np.max(predictions_array),class_name[true_label]),color=color)#可视化概率分布
def plot_value_array(i,predictions_array,true_label):true_label=true_label[i]plt.grid(False)plt.xticks(range(10))plt.yticks([])thisplot=plt.bar(range(10),predictions_array,color="#777777")plt.ylim([0,1])predicted_label=np.argmax(predictions_array)thisplot[predicted_label].set_color('red')thisplot[true_label].set_color('blue')#验证多个图像
num_rows=5
num_cols=3
num_images=num_rows*num_cols
plt.figure(figsize=(2*2*num_cols,2*num_rows))
for i in range(num_images):plt.subplot(num_rows,2*num_cols,2*i+1)plot_image(i,predictions[i],test_labels,test_images)plt.subplot(num_rows,2*num_cols,2*i+2)plot_value_array(i,predictions[i],test_labels)
plt.tight_layout()
plt.show()

  一个简单的分类任务就完成了,如果像了解训练好的模型保存方式,请关注我的下一篇博客。

  相关解决方案