当前位置: 代码迷 >> 综合 >> TF2.0图像分类实战(一)Lenet
  详细解决方案

TF2.0图像分类实战(一)Lenet

热度:75   发布时间:2024-02-12 05:11:10.0

关于Lenet的理论方面,请参考:https://cuijiahua.com/blog/2018/01/dl_3.html

本专栏所采用的数据集为猴子数据集,数据集获取:https://pan.baidu.com/s/1p1lG_AhsrMVu7N3BEm6YQw
提取码:hqjp

# 导入相应的库
# 导入相应的库
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
import tensorflow as tf
import os# 设置图片的高和宽,一次训练所选取的样本数,迭代次数
im_height = 32
im_width = 32
batch_size = 64
epochs = 30image_path = "C:/Users/HaibinZhao/Desktop/untitled/monkey/"  # monkey数据集路径
train_dir = image_path + "training/training"  # 训练集路径
validation_dir = image_path + "validation/validation"  # 验证集路径# 定义训练集图像生成器,并进行图像增强
train_image_generator = ImageDataGenerator(rescale=1. / 255,  # 归一化rotation_range=40,  # 旋转范围width_shift_range=0.2,  # 水平平移范围height_shift_range=0.2,  # 垂直平移范围shear_range=0.2,  # 剪切变换的程度zoom_range=0.2,  # 缩放范围horizontal_flip=True,  # 水平翻转fill_mode='nearest')# 使用图像生成器从文件夹train_dir中读取样本,对标签进行one-hot编码
train_data_gen = train_image_generator.flow_from_directory(directory=train_dir,batch_size=batch_size,shuffle=True,  # 打乱数据target_size=(im_height, im_width),class_mode='categorical')
print(train_data_gen)
# 训练集样本数
total_train = train_data_gen.n# 定义验证集图像生成器,并对图像进行预处理
validation_image_generator = ImageDataGenerator(rescale=1. / 255)  # 归一化# 使用图像生成器从验证集validation_dir中读取样本
val_data_gen = validation_image_generator.flow_from_directory(directory=validation_dir,batch_size=batch_size,shuffle=False,  # 不打乱数据target_size=(im_height, im_width),class_mode='categorical')
print(val_data_gen)
# 验证集样本数
total_val = val_data_gen.n# 定义LeNet
model = tf.keras.Sequential([tf.keras.layers.Conv2D(filters=6, kernel_size=(5, 5), activation='relu', input_shape=(32, 32, 3)),tf.keras.layers.MaxPool2D(pool_size=(2, 2), strides=2),tf.keras.layers.Conv2D(filters=16, kernel_size=(5, 5), activation='relu'),tf.keras.layers.MaxPool2D(pool_size=(2, 2), strides=2),tf.keras.layers.Flatten(),tf.keras.layers.Dense(120, activation='relu'),tf.keras.layers.Dense(84, activation='relu'),tf.keras.layers.Dense(10, activation='softmax')])# 模型打印
model.summary()# 模型编译
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),loss='categorical_crossentropy',metrics=['accuracy'])# 模型保存
checkpoint_save_path = "./checkpoint/LeNet5.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):print('------------load the model--------------')model.load_weights(checkpoint_save_path)cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,save_weights_only=True,save_best_only=True)# 模型训练
history = model.fit(x=train_data_gen, steps_per_epoch=total_train // batch_size, epochs=epochs,validation_data=val_data_gen,validation_steps=total_val // batch_size,callbacks=[cp_callback])# 记录训练集和验证集的准确率和损失值
history_dict = history.history
train_loss = history_dict["loss"]
train_accuracy = history_dict["accuracy"]
val_loss = history_dict["val_loss"]
val_accuracy = history_dict["val_accuracy"]# 绘制损失值
plt.figure()
plt.plot(range(epochs), train_loss, label='train_loss')
plt.plot(range(epochs), val_loss, label='val_loss')
plt.legend()
plt.xlabel('epochs')
plt.ylabel('loss')# 绘制准确率
plt.figure()
plt.plot(range(epochs), train_accuracy, label='train_accuracy')
plt.plot(range(epochs), val_accuracy, label='val_accuracy')
plt.legend()
plt.xlabel('epochs')
plt.ylabel('accuracy')
plt.show()