当前位置: 代码迷 >> 综合 >> 目标检测中的数据增强方法之——mixup:针对小样本
  详细解决方案

目标检测中的数据增强方法之——mixup:针对小样本

热度:66   发布时间:2023-10-28 11:14:09.0

在深度学习中,一般要求样本的数量要充足,样本数量越多,训练出来的模型效果越好,模型的泛化能力越强。但是实际中,样本数量不足或者样本质量不够好,这就要对样本做数据增强,来提高样本质量。
在样本量不足的情况下,采用mixup或者填鸭式的方法来进行数据增强,是行之有效的增强方法。其中mixup是将正负样本融合成新的一组样本,使得样本量翻倍。填鸭式是将原本样本里的目标抠出来,随机复制粘贴到其他地方。(个人理解)
以下是mixup方法的代码示例:

# coding:utf-8import cv2
import os
import random
import numpy as np
import xml.etree.ElementTree as ET
import xml.dom.minidomimg_path = 'E:/jpg/'           # 原始图片文件夹路径
save_path = 'E:/mixJPG/'       # mixup的图片文件夹路径
xml_path = 'E:/xml/'           # 原始图片对应的标注文件xml文件夹的路径
save_xml = 'E:/mixXML/'        # mixup的图片对应的标注文件xml的文件夹路径
img_names = os.listdir(img_path)
img_num = len(img_names)
print('img_num:', img_num)for imgname in img_names:imgpath = img_path + imgnameimg = cv2.imread(imgpath)img_h, img_w = img.shape[0], img.shape[1]print(img_h,img_w)i = random.randint(0, img_num - 1)print('i:', i)add_path = img_path + img_names[i]addimg = cv2.imread(add_path)add_h, add_w = addimg.shape[0], addimg.shape[1]if add_h != img_h or add_w != img_w:print('resize!')addimg = cv2.resize(addimg, (img_w, img_h), interpolation=cv2.INTER_LINEAR)scale_h, scale_w = img_h / add_h, img_w / add_wlam = np.random.beta(1.5, 1.5)print(lam)mixed_img = lam * img + (1 - lam) * addimgsave_img = save_path + imgname[:-4] + '_3.jpg'cv2.imwrite(save_img, mixed_img)print(save_img)print(imgname, img_names[i])if imgname != img_names[i]:xmlfile1 = xml_path + imgname[:-4] + '.xml'xmlfile2 = xml_path + img_names[i][:-4] + '.xml'print(xmlfile1,xmlfile2)tree1 = ET.parse(xmlfile1)tree2 = ET.parse(xmlfile2)doc = xml.dom.minidom.Document()root = doc.createElement("annotation")doc.appendChild(root)for folds in tree1.findall("folder"):folder = doc.createElement("folder")folder.appendChild(doc.createTextNode(str(folds.text)))root.appendChild(folder)for filenames in tree1.findall("filename"):filename = doc.createElement("filename")filename.appendChild(doc.createTextNode(str(filenames.text)))root.appendChild(filename)for paths in tree1.findall("path"):path = doc.createElement("path")path.appendChild(doc.createTextNode(str(paths.text)))root.appendChild(path)for sources in tree1.findall("source"):source = doc.createElement("source")database = doc.createElement("database")database.appendChild(doc.createTextNode(str("Unknow")))source.appendChild(database)root.appendChild(source)for sizes in tree1.findall("size"):size = doc.createElement("size")width = doc.createElement("width")height = doc.createElement("height")depth = doc.createElement("depth")width.appendChild(doc.createTextNode(str(img_w)))height.appendChild(doc.createTextNode(str(img_h)))depth.appendChild(doc.createTextNode(str(3)))size.appendChild(width)size.appendChild(height)size.appendChild(depth)root.appendChild(size)nodeframe = doc.createElement("frame")nodeframe.appendChild(doc.createTextNode(imgname[:-4] + '_3'))objects = []for obj in tree1.findall("object"):obj_struct = {
    }obj_struct["name"] = obj.find("name").textobj_struct["pose"] = obj.find("pose").textobj_struct["truncated"] = obj.find("truncated").textobj_struct["Difficult"] = obj.find("Difficult").textbbox = obj.find("bndbox")obj_struct["bbox"] = [int(bbox.find("xmin").text),int(bbox.find("ymin").text),int(bbox.find("xmax").text),int(bbox.find("ymax").text)]objects.append(obj_struct)for obj in tree2.findall("object"):obj_struct = {
    }obj_struct["name"] = obj.find("name").textobj_struct["pose"] = obj.find("pose").textobj_struct["truncated"] = obj.find("truncated").textobj_struct["Difficult"] = obj.find("Difficult").text          # 有的版本的labelImg改参数为小写difficultbbox = obj.find("bndbox")obj_struct["bbox"] = [int(int(bbox.find("xmin").text) * scale_w),int(int(bbox.find("ymin").text) * scale_h),int(int(bbox.find("xmax").text) * scale_w),int(int(bbox.find("ymax").text) * scale_h)]objects.append(obj_struct)for obj in objects:nodeobject = doc.createElement("object")nodename = doc.createElement("name")nodepose = doc.createElement("pose")nodetruncated = doc.createElement("truncated")nodeDifficult = doc.createElement("Difficult")nodebndbox = doc.createElement("bndbox")nodexmin = doc.createElement("xmin")nodeymin = doc.createElement("ymin")nodexmax = doc.createElement("xmax")nodeymax = doc.createElement("ymax")nodename.appendChild(doc.createTextNode(obj["name"]))nodepose.appendChild(doc.createTextNode(obj["pose"]))nodepose.appendChild(doc.createTextNode(obj["truncated"]))nodeDifficult.appendChild(doc.createTextNode(obj["Difficult"]))nodexmin.appendChild(doc.createTextNode(str(obj["bbox"][0])))nodeymin.appendChild(doc.createTextNode(str(obj["bbox"][1])))nodexmax.appendChild(doc.createTextNode(str(obj["bbox"][2])))nodeymax.appendChild(doc.createTextNode(str(obj["bbox"][3])))nodebndbox.appendChild(nodexmin)nodebndbox.appendChild(nodeymin)nodebndbox.appendChild(nodexmax)nodebndbox.appendChild(nodeymax)nodeobject.appendChild(nodename)nodeobject.appendChild(nodepose)nodeobject.appendChild(nodetruncated)nodeobject.appendChild(nodeDifficult)nodeobject.appendChild(nodebndbox)root.appendChild(nodeobject)fp = open(save_xml + imgname[:-4] + "_3.xml", "w")doc.writexml(fp, indent='\t', addindent='\t', newl='\n', encoding="utf-8")fp.close()else:xmlfile1 = xml_path + imgname[:-4] + '.xml'print(xmlfile1)tree1 = ET.parse(xmlfile1)doc = xml.dom.minidom.Document()root = doc.createElement("annotation")doc.appendChild(root)for folds in tree1.findall("folder"):folder=doc.createElement("folder")folder.appendChild(doc.createTextNode(str(folds.text)))root.appendChild(folder)for filenames in tree1.findall("filename"):filename=doc.createElement("filename")filename.appendChild(doc.createTextNode(str(filenames.text)))root.appendChild(filename)for paths in tree1.findall("path"):path = doc.createElement("path")path.appendChild(doc.createTextNode(str(paths.text)))root.appendChild(path)for sources in tree1.findall("source"):source = doc.createElement("source")database = doc.createElement("database")database.appendChild(doc.createTextNode(str("Unknow")))source.appendChild(database)root.appendChild(source)for sizes in tree1.findall("size"):size = doc.createElement("size")width = doc.createElement("width")height = doc.createElement("height")depth = doc.createElement("depth")width.appendChild(doc.createTextNode(str(img_w)))height.appendChild(doc.createTextNode(str(img_h)))depth.appendChild(doc.createTextNode(str(3)))size.appendChild(width)size.appendChild(height)size.appendChild(depth)root.appendChild(size)nodeframe = doc.createElement("frame")nodeframe.appendChild(doc.createTextNode(imgname[:-4] + '_3'))objects = []for obj in tree1.findall("object"):obj_struct = {
    }obj_struct["name"] = obj.find("name").textobj_struct["pose"] = obj.find("pose").textobj_struct["truncated"] = obj.find("truncated").textobj_struct["Difficult"] = obj.find("Difficult").textbbox = obj.find("bndbox")obj_struct["bbox"] = [int(bbox.find("xmin").text),int(bbox.find("ymin").text),int(bbox.find("xmax").text),int(bbox.find("ymax").text)]objects.append(obj_struct)for obj in objects:nodeobject = doc.createElement("object")nodename = doc.createElement("name")nodepose = doc.createElement("pose")nodetruncated = doc.createElement("truncated")nodeDifficult = doc.createElement("Difficult")nodebndbox = doc.createElement("bndbox")nodexmin = doc.createElement("xmin")nodeymin = doc.createElement("ymin")nodexmax = doc.createElement("xmax")nodeymax = doc.createElement("ymax")nodename.appendChild(doc.createTextNode(obj["name"]))nodepose.appendChild(doc.createTextNode(obj["pose"]))nodetruncated.appendChild(doc.createTextNode(obj["truncated"]))nodeDifficult.appendChild(doc.createTextNode(obj["Difficult"]))nodexmin.appendChild(doc.createTextNode(str(obj["bbox"][0])))nodeymin.appendChild(doc.createTextNode(str(obj["bbox"][1])))nodexmax.appendChild(doc.createTextNode(str(obj["bbox"][2])))nodeymax.appendChild(doc.createTextNode(str(obj["bbox"][3])))nodebndbox.appendChild(nodexmin)nodebndbox.appendChild(nodeymin)nodebndbox.appendChild(nodexmax)nodebndbox.appendChild(nodeymax)nodeobject.appendChild(nodename)nodeobject.appendChild(nodepose)nodeobject.appendChild(nodetruncated)nodeobject.appendChild(nodeDifficult)nodeobject.appendChild(nodebndbox)root.appendChild(nodeobject)fp = open(save_xml + imgname[:-4] + "_3.xml", "w")doc.writexml(fp, indent='\t', addindent='\t', newl='\n', encoding="utf-8")fp.close()

以上代码是针对目标检测原始图片和采用标注工具labelImg标注得到的xml文件进行mixup的数据增强。
原始数据的标注样例和mixup之后的数据示例如下:
原始图片1
原始图片2
以上是两张原始图片的labelImg标注样例。
下面是mixup数据增强后的图片示例:
mixup增强的样例1
mixup增强的样例2
由上述示例图片可以看出mixup是将原始数据中的随机两个数据进行正负样本融合成新的一组样本,由此使得样本量翻倍,同时mixup后的每个数据中目标物都会比原始数据中的目标物多。