当前位置: 代码迷 >> 综合 >> keras的基本用法(四)——Fine Tuning神经网络
  详细解决方案

keras的基本用法(四)——Fine Tuning神经网络

热度:21   发布时间:2023-12-12 21:40:43.0

文章作者:Tyan
博客:noahsnail.com  |  CSDN  |  简书

本文主要介绍Keras的一些基本用法,主要涉及已有网络的fine tuning,以ResNet50为例。

  • Demo
#!/usr/bin/env python
# _*_ coding: utf-8 _*_from keras.models import Model
from keras.layers import Dense
from keras.applications.resnet50 import ResNet50
from keras.preprocessing.image import ImageDataGenerator# 训练的batch_size
batch_size = 16
# 训练的epoch
epochs = 100# 图像Generator,用来构建输入数据
train_datagen = ImageDataGenerator(width_shift_range=0.1,height_shift_range=0.1,zoom_range=0.2,horizontal_flip=True)# 从文件中读取数据,目录结构应为train下面是各个类别的子目录,每个子目录中为对应类别的图像
train_generator = train_datagen.flow_from_directory('./train', target_size = (224, 224), batch_size = batch_size)# 训练图像的数量
image_numbers = train_generator.samples# 输出类别信息
print train_generator.class_indices# 生成测试数据
test_datagen = ImageDataGenerator()
validation_generator = test_datagen.flow_from_directory('./validation', target_size = (224, 224), batch_size = batch_size)# 使用ResNet的结构,不包括最后一层,且加载ImageNet的预训练参数
base_model = ResNet50(weights = 'imagenet', include_top = False, pooling = 'avg')# 构建网络的最后一层,3是自己的数据的类别
predictions = Dense(3, activation='softmax')(base_model.output)# 定义整个模型
model = Model(inputs=base_model.input, outputs=predictions)# 编译模型,loss为交叉熵损失
model.compile(optimizer='rmsprop', loss='categorical_crossentropy')# 训练模型
model.fit_generator(train_generator,steps_per_epoch = image_numbers // batch_size, epochs = epochs, validation_data = validation_generator, validation_steps = batch_size)# 保存训练得到的模型
model.save_weights('weights.h5')
  • 部分结果
{
   'Type_3': 2, 'Type_2': 1, 'Type_1': 0}
Found 761 images belonging to 3 classes.
Epoch 1/401/16 [>.............................] - ETA: 119s - loss: 1.33922017-06-07 10:18:48.246289: I tensorflow/core/common_runtime/gpu/pool_allocator.cc:247] PoolAllocator: After 2521 get requests, put_count=2161 evicted_count=1000 eviction_rate=0.462749 and unsatisfied allocation rate=0.579135 2017-06-07 10:18:48.246348: I tensorflow/core/common_runtime/gpu/pool_allocator.cc:259] Raising pool_size_limit_ from 100 to 110 16/16 [==============================] - 120s - loss: 2.3753 - val_loss: 10.8293 Epoch 2/401/16 [>.............................] - ETA: 5s - loss: 1.00542017-06-07 10:20:40.464589: I tensorflow/core/common_runtime/gpu/pool_allocator.cc:247] PoolAllocator: After 2270 get requests, put_count=2642 evicted_count=1000 eviction_rate=0.378501 and unsatisfied allocation rate=0.286784 2017-06-07 10:20:40.464643: I tensorflow/core/common_runtime/gpu/pool_allocator.cc:259] Raising pool_size_limit_ from 256 to 281 16/16 [==============================] - 83s - loss: 1.7988 - val_loss: 11.5219 Epoch 3/40 16/16 [==============================] - 81s - loss: 1.6640 - val_loss: 11.0043 Epoch 4/403/16 [====>.........................] - ETA: 4s - loss: 1.87452017-06-07 10:23:26.725923: I tensorflow/core/common_runtime/gpu/pool_allocator.cc:247] PoolAllocator: After 11057 get requests, put_count=11071 evicted_count=1000 eviction_rate=0.0903261 and unsatisfied allocation rate=0.0945103 2017-06-07 10:23:26.725986: I tensorflow/core/common_runtime/gpu/pool_allocator.cc:259] Raising pool_size_limit_ from 655 to 720 16/16 [==============================] - 83s - loss: 1.7237 - val_loss: 11.7738 Epoch 5/40 16/16 [==============================] - 83s - loss: 1.6304 - val_loss: 10.6538 Epoch 6/40 16/16 [==============================] - 80s - loss: 1.2182 - val_loss: 4.5027 Epoch 7/40 16/16 [==============================] - 83s - loss: 1.3179 - val_loss: 11.5891 Epoch 8/40 16/16 [==============================] - 82s - loss: 1.1806 - val_loss: 10.5800 Epoch 9/40 16/16 [==============================] - 81s - loss: 1.1935 - val_loss: 11.1477 Epoch 10/40 16/16 [==============================] - 80s - loss: 1.1727 - val_loss: 7.0913 Epoch 11/40 16/16 [==============================] - 83s - loss: 1.2058 - val_loss: 6.4474 Epoch 12/40 16/16 [==============================] - 82s - loss: 1.2702 - val_loss: 7.7678 Epoch 13/40 16/16 [==============================] - 84s - loss: 1.2060 - val_loss: 7.9961 Epoch 14/40 16/16 [==============================] - 83s - loss: 1.0768 - val_loss: 11.2121 Epoch 15/40 16/16 [==============================] - 80s - loss: 1.1401 - val_loss: 13.2052 Epoch 16/40 16/16 [==============================] - 83s - loss: 1.1961 - val_loss: 13.0330
  相关解决方案