返回主目录
返回神经网络目录
上一章:深度篇——神经网络(五) 细说 优化器
下一章:深度篇——神经网络(七) 细说 DNN神经网络手写数字代码演示
本小节,细说 数据增强与fine-tuning,下一小节细说 神经网络手写数字代码演示
本小节的数据增强与fine-tuning,还是属于对神经网络的调优过程。
5. 调优神经网络
(9). 数据增强
数据增强是深度学习中常用的技巧之一,主要用于增加训练数据集,让数据尽可能的多样化,是的训练的偶像具有更强的泛化能力。现有的各大深度学习框架都已经自带了数据增强,平时在使用的时候直接调用对应的接口函数,但是只要的话,缺少对数据进行详细的分析。在实际应用中,并非所有的数据增强方式都使用于当前的训练数据。这需要用户根据自己的数据集特征来确定应该使用哪几种数据增强方式。数据增强方式常用有以下3种:
原图:
①. 空间几何变换类
a. 翻转
翻转包括水平翻转和垂直翻转,其变换公式如下:
水平翻转:
垂直翻转:
矩阵变换公式如下:
水平翻转:
垂直翻转:
b. 旋转
对图像做一定角度旋转操作,其变换公式如下:
顺时针旋转:
逆时针旋转:
矩阵变换公式如下:
顺时针旋转:
逆时针旋转:
c. 平移
平移是指所有的图像在 x 轴 和 y 轴 方向各平移和。其变换公式如下:
矩阵的变换公式如下:
d. crop 裁剪
裁剪图片的感兴趣区域 (ROI)
e. 图像缩放
图像缩放是指对当前图像进行任意的缩放,其变换公式如下:
: 为倍数,如 0.8 倍 或 1.1 倍
矩阵变换公式如下:
f. 错切
错切变换是将坐标点沿 x 和 y 轴发生不等量的变换,得到点的过程。其数学公式如下:
矩阵变换公式如下:
g. 仿射
(a). 仿射变换
同时对图片做裁剪、旋转、转换、模式调整等多重操作。其变换公式如下:
其矩阵变换公式如下:
(b). 视觉变换
对图像应用一个随机的四点透视变换
(c). 分段仿射 (Piecewise Affine)
分段仿射在图像上放置一个规则的点网格,根据正太分布的样本数量移动这些点及周围的图像区域。
②. 像素颜色变换类
a. 噪声类
随机噪声是在原来的图片的基础上,随机叠加一些噪声。
(a). 高斯噪声
图片上叠加高斯噪声
(b) 椒盐噪声
椒盐噪声(salt-and-pepper noise)是指两种噪声,一种是盐噪声(salt noise),另一种是胡椒噪声(pepper noise)。盐=白色(0),椒=黑色(255)。前者是高灰度噪声,后者属于低灰度噪声。一般两种噪声同时出现,呈现在图像上就是黑白杂点。
(c). Coarse Dropout
在面积大小可选,位置随机的矩形区域上丢失信息实现转换,所有通道的信息丢失产生黑色的矩形块,部分通道的信息丢失产生彩色噪声。
(d). Simplex Noise Alpha
产生连续单一噪声的掩模后,将掩模与原图混合。
(e). Frequency Noise Alpha
在频域中用随机指数对噪声映射进行加权,再转换到空间域。在不同图像中,随着指数值逐渐增大,依次出现平滑的大斑点、多云模式、重复出现的小斑块。
b. 模糊类
减少各像素值的差异,实现图片模糊,实现像素的平滑化。
(a). 高斯模糊
(b). Elastic Transformation
根据扭曲场的平滑度与强度逐一地移动局部像素点实现模糊效果。
c. HSV 对比度变换
通过想 HSV 空间中的每个像素添加或减少 V 值,修改色调和饱和度实现对比度转换
d. RGB 颜色扰动
将图片从 RGB 颜色空间转换到另一颜色空间,增加或减少颜色参数后返回 RGB 颜色空间。
这个和噪声的做法相似,只是需要做颜色空间转换步骤而已
e. 随机擦除法
对图片上随机选取一块区域,随机地擦除图像信息
f. 超像素法 (Super pixels)
在最大分辨率处生成图像的若干个超像素,并将其调整到原始大小,再将原始图像中所有超像素区域按一定比例替换为超像素,其他区域不改变。
h. 转换法 (Invert)
按给定的概率值将部分通道的像素值从 V 设置为 255 - V。
i. 边界检测 (Edge Detect)
检测图像中的所有边缘,将它们标记为黑白图像,再将结果与原始图像叠加
j. Gray Scale
将图像从 RGB 颜色空间转换为灰度空间,通过某一通道与原始图像混合。
k. 锐化 (Sharpen) 与 浮雕 (Emboss)
对图像执行某一程度的锐化或浮雕操作,通过某一通道结果与原始图像融合。
l. 颜色抖动
颜色抖动包括:图像的饱和度,量度,对比度,锐度 等等
③. 多样本合成类
a. SMOTE (Synthetic Minority Over-sampling Technique, SMOTE) 综合少数过采样技术
(a). SMOTE 通过人工合成新样本处理样本不平衡问题,提升分类器性能。
(b). 类不平衡现象是数据集中各类别数量不近似相等。如果样本类别之间相差很大,会影响分类器的分类效果。假设小样本数据量极少,仅占总体的 1%,所能提取的相应特征也极少,即使小样本被错误地全部识别为大样本,在经验风险最小化策略下的分类器识别准确率仍能达到 99%,但在验证环节类分类器效果并不佳。
(c). 基于插值的 SMOTE 方法为小样本类合成新的样本,主要思路:
I. 定义好特征空间,将每个样本对应到特征空间中的某一点,根据样本不平衡比例确定采样倍率 N。
II. 对每一个小样本类样本 ,按欧式距离找 k 个最邻近样本,从中随机选取一个样本点,假设选取的近邻点为 。在特征空间中,样本点与最近邻样本点的连线段上随机选取一点作为新样本点。满足以下公式:
III. 重复选取取样,直到大、小样本数量平衡。
由于,数据增强方法很多,图像和代码就不一一展示了,花点时间,都是可以写出来的。下面,是部分数据增强的代码:
#!/usr/bin/env python
# _*_ coding:utf-8 _*_
# ============================================
# @Time : 2020/02/03 15:15
# @Author : WanDaoYi
# @FileName : data_augment.py
# ============================================from datetime import datetime
import os
import numpy as np
import cv2
import copy
from PIL import Image, ImageEnhance# 数据增强
class DataAugment(object):def __init__(self, show_flag=False):self.image_show_flag = show_flagself.input_image_file_path = "C:/Users/Administrator/Desktop/pic_test/input_data/"pass# cv2 转 pildef cv2_pil(self, cv2_image):pil_image = Image.fromarray(cv2.cvtColor(cv2_image, cv2.COLOR_BGR2RGB))return pil_imagepass# pil 转 cv2def pil_cv2(self, pil_image):cv2_image = cv2.cvtColor(np.asarray(pil_image), cv2.COLOR_RGB2BGR)return cv2_imagepass# 基础变换矩阵def basic_matrix(self, translation):return np.array([[1, 0, translation[0]],[0, 1, translation[1]],[0, 0, 1]])# 生成范围矩阵def random_vector(self, min_translation, max_translation):""":param min_translation: 最小范围值:param max_translation: 最大范围值:return: 返回 最小~最大 范围之间的值。"""min_translation = np.array(min_translation)max_translation = np.array(max_translation)print(min_translation.shape, max_translation.shape)assert min_translation.shape == max_translation.shapeassert len(min_translation.shape) == 1return np.random.uniform(min_translation, max_translation)# 根据图像调整当前变换矩阵,仿射变换def adjust_transform_for_image(self, image_info, trans_matrix, pic_name="image"):""":param image_info: 图片信息:param trans_matrix: 变换值:return:"""transform_matrix = copy.deepcopy(trans_matrix)height, width, channels = image_info.shapetransform_matrix[0:2, 2] *= [width, height]center = np.array((0.5 * width, 0.5 * height))transform_matrix = np.linalg.multi_dot([self.basic_matrix(center),transform_matrix,self.basic_matrix(-center)])# 仿射变换; cv2.BORDER_REPLICATE, cv2.BORDER_TRANSPARENToutput_image = cv2.warpAffine(image_info, transform_matrix[:2, :],dsize=(image_info.shape[1], image_info.shape[0]),flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT,borderValue=0, )# 是否展示图像if self.image_show_flag:self.image_show(output_image, pic_name)return output_image# 图像显示def image_show(self, image_info, pic_name="image"):cv2.imshow(pic_name, image_info)cv2.waitKey()pass# 平移变换def level_transform(self, image_info, min_translation, max_translation):""":param image_info: 图像信息:param min_translation: 最小移动范围值 是一个二元元组 如: (0.1, 0.1):param max_translation: 最大移动范围值 是一个二元元组 如: (0.2, 0.2):return: 返回变换值 和 变换后的图像"""pic_name = "level_transform"factor = self.random_vector(min_translation, max_translation)print("平移变换:{}".format(factor))trans_matrix = np.array([[1, 0, factor[0]], [0, 1, factor[1]], [0, 0, 1]])output_image = self.adjust_transform_for_image(image_info, trans_matrix, pic_name)return trans_matrix, output_imagepass# 水平或垂直翻转def flip_transform(self, image_info, level_flip_flag=True):""":param image_info: 图像信息:param level_flip_flag: 是否水平翻转,True 为水平翻转,False 为垂直翻转:return: 返回变换值 和 变换后的图像"""# 水平翻转if level_flip_flag:factor = (-1.0, 1.0)pic_name = "level_flip_transform"# 垂直翻转else:factor = (1.0, -1.0)pic_name = "vertical_flip_transform"print("水平或垂直翻转:{}".format(factor))trans_matrix = np.array([[factor[0], 0, 0], [0, factor[1], 0], [0, 0, 1]])output_image = self.adjust_transform_for_image(image_info, trans_matrix, pic_name)return trans_matrix, output_image# 旋转def rotate_transform(self, image_info, factor, clockwise_flag=False):""":param image_info: 图像信息:param factor: 最小、最大 弧度 范围值 是一个二元元组 如: (0.5, 0.8):param clockwise_flag: 是否顺时针旋转,True 为顺时针,False 为逆时针旋转:return: 返回变换值 和 变换后的图像"""# 获取角度angle = np.random.uniform(factor[0], factor[1])print("随机旋转:{}".format(angle))# 顺时针旋转if clockwise_flag:rotate_matrix = np.array([[np.cos(angle), -np.sin(angle), 0],[np.sin(angle), np.cos(angle), 0],[0, 0, 1]])pic_name = "clockwise_rotate_transform"# 逆时针旋转else:rotate_matrix = np.array([[np.cos(angle), np.sin(angle), 0],[-np.sin(angle), np.cos(angle), 0],[0, 0, 1]])pic_name = "anti_clockwise_rotate_transform"output_image = self.adjust_transform_for_image(image_info, rotate_matrix, pic_name)return rotate_matrix, output_image# 随机错切,包括横向和众向错切def shear_transform(self, image_info, factor):pic_name = "shear_transform"print("随机错切:{}".format(factor))crop_matrix = np.array([[1, factor[0], 0], [factor[1], 1, 0], [0, 0, 1]])output_image = self.adjust_transform_for_image(image_info, crop_matrix, pic_name)return crop_matrix, output_imagepass# 随机裁剪def crop_transform(self, image_info, min_translation, max_translation):""":param image_info: 图像信息:param min_translation: 最小 范围值 是一个二元元组 如: (0.0, 0.2):param max_translation: 最大 范围值 是一个二元元组 如: (0.8, 0.9):return: 返回裁剪的左上角和右下角坐标 和 裁剪后的图"""min_w_per = np.random.uniform(min_translation[0], min_translation[1])min_h_per = np.random.uniform(min_translation[0], min_translation[1])max_w_per = np.random.uniform(max_translation[0], max_translation[1])max_h_per = np.random.uniform(max_translation[0], max_translation[1])h, w, c = image_info.shapemin_h = int(h * min_h_per)min_w = int(w * min_w_per)max_h = int(h * max_h_per)max_w = int(w * max_w_per)print("裁剪范围:左上角={}, 右下角={}".format((min_w, min_h), (max_w, max_h)))image_crop = image_info[min_h: max_h, min_w: max_w, :]if self.image_show_flag:cv2.imshow("crop", image_crop)cv2.waitKey()return ((min_w, min_h), (max_w, max_h)), image_croppass# 随机缩放def scale_transform(self, image_info, min_translation, max_translation):""":param image_info: 图像信息:param min_translation: 最小缩放范围值 是一个二元元组 如: (0.5, 0.5):param max_translation: 最大缩放范围值 是一个二元元组 如: (1.1, 1.1):return: 返回变换值 和 变换后的图像"""pic_name = "scale_transform"factor = self.random_vector(min_translation, max_translation)print("随机缩放:{}".format(factor))scale_matrix = np.array([[factor[0], 0, 0],[0, factor[1], 0],[0, 0, 1]])output_image = self.adjust_transform_for_image(image_info, scale_matrix, pic_name)return scale_matrix, output_imagepass# 仿射def affine_transform(self, image_info, flip_flag=False, level_flip_flag=True,level_param=(), rotate_param=(),scale_param=(), shear_param=()):""":param image_info: 图片信息:param flip_flag: 是否翻转操作,True 为翻转操作,False 为不翻转操作:param level_flip_flag: 是否水平翻转,True 为水平翻转,False 为垂直翻转:param level_param: 随机平移参数,结构如: ((0.1, 0.1), (0.2, 0.2)):param rotate_param: 随机翻转参数,结构如: (0.2, 0.3):param scale_param: 随机缩放参数,结构如: ((0.8, 0.8), (1.2, 1.2)):param shear_param: 随机错切参数,结构如: ((-0.1, -0.1), (0.5, 0.5)):return:"""matrix_list = []if flip_flag:if level_flip_flag:factor = (-1.0, 1.0)passelse:factor = (1.0, -1.0)passflip_matrix = np.array([[factor[0], 0, 0], [0, factor[1], 0], [0, 0, 1]])matrix_list.append(flip_matrix)if len(level_param) == 2:factor = self.random_vector(level_param[0], level_param[1])level_matrix = np.array([[1, 0, factor[0]],[0, 1, factor[1]],[0, 0, 1]])matrix_list.append(level_matrix)if len(rotate_param) == 2:angle = np.random.uniform(rotate_param[0], rotate_param[1])rotate_matrix = np.array([[np.cos(angle), -np.sin(angle), 0],[np.sin(angle), np.cos(angle), 0],[0, 0, 1]])matrix_list.append(rotate_matrix)if len(scale_param) == 2:factor = self.random_vector(scale_param[0], scale_param[1])scale_matrix = np.array([[factor[0], 0, 0],[0, factor[1], 0],[0, 0, 1]])matrix_list.append(scale_matrix)passif len(shear_param) == 2:factor = self.random_vector(scale_param[0], scale_param[1])crop_matrix = np.array([[1, factor[0], 0],[factor[1], 1, 0],[0, 0, 1]])matrix_list.append(crop_matrix)pass# 仿射affine_matrix = np.linalg.multi_dot(matrix_list)affine_image = self.adjust_transform_for_image(image_info, affine_matrix, "affine_transform")return affine_matrix, affine_imagepass# 透视def perspective_transform(self, image_info):h, w = image_info.shape[: 2]# 图像的四个角点image_corner = np.float32([[0, 0], # 左上角[0, h - 1], # 左下角[w - 1, h - 1], # 右下角[w - 1, 0] # 右上角])# 目标角坐标# obj_corner = np.float32([[0, 0], # 左上角# [100, h - 18], # 左下角# [w - 18, h - 18], # 右下角# [w - 1, 0] # 右上角# ])# 目标角坐标# obj_corner = np.float32([[0, 0], # 左上角# [100, h - 1], # 左下角# [200, h - 1], # 右下角# [w - 1, 0] # 右上角# ])# 目标角坐标# obj_corner = np.float32([[0, 0], # 左上角# [0, h - 1], # 左下角# [300, 300], # 右下角# [w - 1, 0] # 右上角# ])# 目标角坐标obj_corner = np.float32([[100, 100], # 左上角[0, h - 1], # 左下角[w - 100, h - 100], # 右下角[w - 1, 0] # 右上角])# 透视变换的系数corner_trans = cv2.getPerspectiveTransform(image_corner, obj_corner)# 生成透视图像corner_trans_image = cv2.warpPerspective(image_info, corner_trans, (w, h))self.image_show(corner_trans_image, pic_name="perspective_transform")return corner_trans, corner_trans_imagepass# 高斯噪声def gauss_transform(self, image_info, noise_sigma=20):""":param image_info: 图像信息:param noise_sigma: 噪声参数:return: 返回噪声信息,和 噪声图像"""h, w = image_info.shape[: 2]noise_info = np.random.randn(h, w) * noise_sigma# 创建一个 0 矩阵作为 后面相加 得到 的噪声矩阵noisy_image = np.zeros(image_info.shape, np.float64)if len(image_info.shape) == 2:noisy_image = image_info + noise_infoelse:noisy_image[:, :, 0] = image_info[:, :, 0] + noise_infonoisy_image[:, :, 1] = image_info[:, :, 1] + noise_infonoisy_image[:, :, 2] = image_info[:, :, 2] + noise_infoself.image_show(noisy_image / 255, pic_name="gauss_transform")return noise_info, noisy_image# 椒盐噪声def salt_pepper_transform(self, image_info, noise_sigma=20):""":param image_info: 图像信息:param noise_sigma: 椒盐噪声参数:return: 返回噪声图像"""h, w = image_info.shape[: 2]noisy_image = np.zeros(image_info.shape, np.int16)for i in range(w):for j in range(h):noise_num = np.random.randint(0, noise_sigma)if noise_num == 0:noisy_image[j][i] = 0elif noise_num == 1:noisy_image[j][i] = 255else:noisy_image[j][i] = image_info[j][i]self.image_show(noisy_image / 255, pic_name="salt_pepper_transform")return noisy_image# 高斯模糊def gaussian_blur_transform(self, image_info, noise_ksize=(9, 9)):noisy_image = cv2.GaussianBlur(image_info, noise_ksize, sigmaX=0)self.image_show(noisy_image / 255, pic_name="gaussian_blur_transform")return noisy_imagepass# 颜色抖动def random_color_transform(self, image_info):pil_image = self.cv2_pil(image_info)# 随机因子random_factor = np.random.randint(0, 31) / 10.# 调整图像的饱和度color_image = ImageEnhance.Color(pil_image).enhance(random_factor)# 随机因子random_factor = np.random.randint(10, 12) / 10.# 调整图像的亮度brightness_image = ImageEnhance.Brightness(color_image).enhance(random_factor)# 随机因1子random_factor = np.random.randint(10, 12) / 10.# 调整图像对比度contrast_image = ImageEnhance.Contrast(brightness_image).enhance(random_factor)# 随机因子random_factor = np.random.randint(0, 31) / 10.sharpness_image = ImageEnhance.Sharpness(contrast_image).enhance(random_factor)cv_image = self.pil_cv2(sharpness_image)self.image_show(cv_image / 255, pic_name="random_color_transform")return cv_imagedef deal_data(self):image_name_list = os.listdir(self.input_image_file_path)for image_name in image_name_list:image_path = os.path.join(self.input_image_file_path, image_name)image_info = cv2.imread(image_path)pic_name = image_name.split(".")[0]# 展示原图# self.image_show(image_info, pic_name)# # 水平翻转# self.flip_transform(image_info)# # 垂直翻转# self.flip_transform(image_info, level_flip_flag=False)# # 随机平移# self.level_transform(image_info, (0.1, 0.1), (0.2, 0.2))# # 随机旋转# self.rotate_transform(image_info, (0.2, 0.3))# # 随机缩放# self.scale_transform(image_info, (0.8, 0.8), (1.2, 1.2))# 随机错切# self.shear_transform(image_info, (-0.4, 0.4))# 随机裁剪# self.crop_transform(image_info, (0.1, 0.2), (0.8, 1.0))# # 仿射# self.affine_transform(image_info,# level_param=((0.1, 0.1), (0.2, 0.2)),# rotate_param=(0.2, 0.3),# scale_param=((0.8, 0.8), (1.2, 1.2)))# 颜色抖动# self.perspective_transform(image_info)# 高斯噪声# self.gauss_transform(image_info)# 椒盐噪声# self.salt_pepper_transform(image_info)# 高斯模糊# self.gaussian_blur_transform(image_info)# 颜色抖动self.random_color_transform(image_info)# 至于后面,是否保存,看自己的需求了。保存代码也很简单 # image_save_path 为保存路径,obj_image 为目标图像# cv2.imwrite(image_save_path, obj_image)passif __name__ == "__main__":# 代码开始时间start_time = datetime.now()print("开始时间: {}".format(start_time))demo = DataAugment(show_flag=True)demo.deal_data()# 代码结束时间end_time = datetime.now()print("结束时间: {}, 训练模型耗时: {}".format(end_time, end_time - start_time))
b. Sample Pairing 样本配对
(a). Sample Pairing 的处理方法是从训练集中随机抽取两张图片,分别经过基础数据增强操作(如随机翻转等) 处理后对像素取平均值的形式叠加合成一个新的样本,标签为原样本标签中的一种。流程如下图所示:
经过 Sample Pairing 处理后可使训练集的规模从 扩增到 ,在 CPU 上也能完成处理。
(b). 训练过程是交替禁用与使用 Sample Pairing 处理操作的结合:
I. 使用传统的数据增强训练网络,不使用 Sample Pairing 数据增强训练。
II. 在 ILSVRC 数据集上完成一个 epoch 或在其他数据集上完成 100 个 epoch 后,加入 Sample Pairing 数据增强训练。
III. 间歇性禁用 Sample Pairing。对于 ILSVRC 数据集,为其中的 300,000 个图像启用 Sample Pairing,然后在接下来的 100,000 个图像中禁用它。对于其他数据集,在开始的 8 个 epoch 中启用,在接下来的 2 个 epoch 中禁用。
IV. 在训练损失函数和精度文档后进行微调,禁用 Sample Pairing。
(c). 实验结果表明,因 Sample Pairing 数据增强操可能引入不同标签的训练样本,导致在各数据集上使用 Sample Pairing 训练的误差明显增加,而在检测误差方面使用 Sample Pairing 训练的验证误差有较大幅度降低。尽管 Sample Pairing 思路简单,性能上提升效果可观,符合奥卡姆剃刀原理,遗憾的是可解释性不强,目前尚缺理论支撑。目前仅有图片数据的实验,而需要下一步的实验与解读。
(d). Sample Pairing 与 对抗神经网络 思想类似。读者可以到网上找一下 对抗神经网络的 styleGan 代码看看,如果合适自己的需求,一言不合就 gan 起来。
c. Mixup 混淆
(a). Mixup 是基于邻域风险最小化(VRM) 原则的数据增强方法,使用线性插值得到新样本数据。在邻域风险最小化原则下,根据特征向量线性插值将导致相关目标线性插值的先验知识,可得出简单且与数据无关的 Mixup 公式:
其中 是插值生成的新数据, 和 是训练集中随机选取的两个数据, 的取值满足 贝塔分布,取值范围介于 0 到 1,超参数 控制特征目标之间的插值强度。
(b). Mixup 的实验丰富,实验结果表明可以改进深度学习模型在 ImageNet 数据集、CIFAR 数据集、语音数据集和表格数据集中的泛化误差,降低模型对已损坏标签的记忆,增强模型对对抗样本的鲁棒性和训练对抗生成网络的稳定性。
(c). Mixup 处理实现了边界模糊化,提供平滑的预测效果,增强模型在训练数据范围之外的预测能力。随着超参数 增大,实际数据的训练误差就会增加,而泛化误差会减少。说明 Mixup 隐式地控制着模型的复杂性。随着模型容量与超参数的增加,训练误差随之降低。
(d). 尽管 Mixup 有着可观的效果改进,但在偏差——方差平衡方面尚未有较好的解释。在其他类型的有监督、无监督、半监督和强化学习中,mixup 还有很大的发展空间。
d. 总结
SMOTE、Sample Pairing、Mixup 三者思路上有相同之处,都是试图将离散样本点连续化来拟合真实样本分布,但所增加的样本点在特征空间中仍位于已知小样本点所围成的区域内。但在特征空间中,小样本数据的真实分布可能并不限于该区域中,在给定范围之外适当插值,也许能实现更好的数据增强效果。
(10). fine-tuning
①. fine-tuning 的理解
fine-tuning 是微调的意思,是利用 预训练模型 (即 pre-trained model),运用到自己的数据上来训练得到新的模型。
②. fine-tuning 的三种模式
a. 只预测,不训练
针对那些已经训练好,而且验证集准确率很高的模型。使用的时候,只需将测试集的数据集灌入到该模型网络中,便可得到很好的预测结果。该模式快速而简单。
b. 需要训练,但只序列最后的分类层。
这种模式需要修改最后的分类层,并使分类结果符合要求。是在 pre-trained model 的基础上进行分类降维,一般训练后面几层全连接层。
c. 完全训练,分类层 + 之前的卷积层都训练
该模式是对一个较好的 pre-trained model 进行再训练,以提高模型验证集的准确率。如在 early-stopping 中得到的 pre-trained model 后,在该模型的基础上再进行微调。
③. fine-tuning 使用参考
a. 新的数据集较小,并且和 pre-trained model 所用的训练数据集相似度较高。
由于新的数据集较小,并且相似度较高,所以可以选择只预测不训练模式 和 只训练最后分类层模式进行比较,选择验证集准确率更高的模式所对应的模型。如果训练模型,需要做数据增强处理,因为新数据集较小,容易过拟合。
b. 新的数据集较大,并且和 pre-trained model 所使用的训练集相似度较高
由于新的数据集较大,并且相似度较高,所以可以选择只训练最后分类层模式 和 完全训练模式进行比较,因为新数据集够大,不用担心过拟合,选择验证集准确率更高的模式所对应的模型。
c. 新的数据集较小,并且和 pre-trained model 所使用的训练集差异很大
由于新的数据集较小,并且差异很大,不适合 fine-tuning。如果真要用 fine-tuning,则固定前面特征提取层的权值,修改并训练后面的分类层。这种情况下,就算做了数据增强处理,因为新的数据较小,也是很容易过拟合的。在实际中,这种问题下较好的解决方案一般是从网络的某层开始取出特征,然后训练 SVM 分类器。
d. 新的数据集较大,并且和 pre-trained model 所使用的训练数据集差异很大
由于新的数据集较大,完全可以从头开始训练,单在实际中更偏向于训练整个 pre-trained model 的网络。
④. 使用 fine-tuning 的注意事项
a. 不要随意移除原始结构中的层或更改其参数。
因为网络结构传导是一层接着一层的,修改了某层参数后,在往后传导的过程中可能就得不到预想的结果。
b. learning rate 学习率不应设置得太大。因为 fine-tuning 的前提就是这个模型的权重很多是有意义的,但是如果学习率过大的话,就会存在更新权值过快,破坏了原来好的权重信息。在 fine-tuning 中,学习率一般设置在 1e-5,如果是从头开始训练的话,可以设置得大点:1e-1 ~ 1e-3 都可以,具体看效果而定。
c. 交叉验证,训练多个模型。
d. 对数据做数据增强处理。
对训练集、验证集数据做数据增强处理,使模型的泛化能力更好。
e. 平衡采样
很多数据集存在样本不均衡的问题,有些类别特别多,有些类别特别少。训练模型时,从一个数据集中依次读取样本训练。这样的话,小类样本参与训练的机会就比大类样本少。训练出来的模型会偏向于大类样本,即大类样本的性能好,而小类样本的性能差。
平衡采样策略就是把样本按类别分组,每个类别生成一个样本列表。训练过程中先随机选择一个或几个类别,然后从各个类别所对应的样本列表中随机选择样本。这样可以保证每个类别参与训练的机会比较均衡。
⑤. fine-tuning 的流程
a. 准备训练集、验证集和测试集数据
b. 数据预处理
(a). 将准备的数据集处理成 pre-trained model input 数据模型一样。
(b). 计算数据集的均值文件
因为数据集中特定领域的图像均值文件会跟 ImageNet 上比较 General 的数据的均值不太一样
c. 选择模型文件,修改分类层的参数,修改网络最后一层的输出类别
d. 调整 Solver 中的部分配置参数。通常学习率、迭代次数都要适当减少。
e. 启动训练,加载 pre-trained model 参数,对模型进行微调。
⑥. model zoo 中有大量预训练好的模型供使用
返回主目录
返回神经网络目录
上一章:深度篇——神经网络(五) 细说 优化器
下一章:深度篇——神经网络(七) 细说 DNN神经网络手写数字代码演示