当前位置: 代码迷 >> 综合 >> 【YOLOV5-5.x 源码解读】general.py
  详细解决方案

【YOLOV5-5.x 源码解读】general.py

热度:32   发布时间:2023-12-14 09:29:49.0

目录

  • 前言
  • 0、导入需要的包和基本配置
  • 1、timeout(没用到)
  • 2、set_logging、init_seeds
    • 2.1、set_logging
    • 2.2、init_seeds
  • 3、get_latest_run
  • 4、is_docker、is_colab、is_pip
    • 4.1、is_docker
    • 4.2、is_colab
    • 4.3、is_pip(没用到)
  • 5、file_size(没用到)
  • 6、colorstr
  • 7、check_online
  • 8、emojis
  • 9、check_git_status
  • 10、check_python、check_requirements
    • 10.1、check_python
    • 10.2、check_requirements
  • 11、make_divisible、check_img_size
    • 11.1、make_divisible
    • 11.2、check_img_size
  • 12、check_imshow
  • 13、check_file
  • 14、check_dataset
  • 15、download
  • 16、clean_str
  • 17、one_cycle
  • 18、labels_to_class_weights、labels_to_image_weights
    • 18.1、labels_to_class_weights
    • 18.2、labels_to_image_weights
  • 19、coco80_to_coco91_class
  • 20、clip_coords
  • 21、scale_coords
  • 22、xyxy2xywh、xywh2xyxy
    • 22.1、xyxy2xywh
    • 22.2、xywh2xyxy
  • 23、xywhn2xyxy、xyxy2xywhn、xyn2xy
    • 23.1、xywhn2xyxy
    • 23.2、xyxy2xywhn
    • 23.3、xyn2xy
  • 24、non_max_suppression
  • 25、strip_optimizer
  • 26、print_mutation
  • 27、apply_classifier
  • 28、increment_path
  • 29、save_one_box
  • 30、resample_segments
  • 31、segment2box
  • 32、segments2boxes
  • 总结

前言

源码: YOLOv5源码.
导航: 【YOLOV5-5.x 源码讲解】整体项目文件导航.
注释版全部项目文件已上传至GitHub: yolov5-5.x-annotations.

这个文件是yolov5的通用工具类,写了一些通用的工具函数,用的很广,整个项目哪里都可能用到。这个文件的函数非常多,代码量也很大(上千行了),也都比较重要,希望大家看的时候多点耐心,都能掌握!

0、导入需要的包和基本配置

import contextlib   # python上下文管理器 执行with…as…的时候调用contextlib
import glob         # 仅支持部分通配符的文件搜索模块
import logging      # 日志模块
import math         # 数学公式模块
import os           # 与操作系统进行交互的模块
import platform     # 提供获取操作系统相关信息的模块
import random       # 生成随机数的模块
import re           # 用来匹配字符串(动态、模糊)的模块
import signal       # 信号处理模块
import time         # 时间模块 更底层
import urllib       # 用于操作网页URL, 并对网页的内容进行抓取处理 如urllib.parse: 解析url
from itertools import repeat  # 循环器模块 创建一个迭代器,重复生成object
from multiprocessing.pool import ThreadPool  # 多线程模块 线程池
from pathlib import Path  # Path将str转换为Path对象 使字符串路径易于操作的模块
from subprocess import check_output  # 创建一个子进程再命令行执行..., 最后返回执行结果(文件)import cv2  # opencv库
import numpy as np   # numpy矩阵处理函数库
import pandas as pd  # pandas矩阵操作模块
import pkg_resources as pkg  # 用于查找, 自省, 激活和使用已安装的Python发行版
import torch    # pytorch框架
import torchvision  # 为pytorch 提供一些辅助工具
import yaml  # yaml配置文件读写模块from utils.google_utils import gsutil_getsize
from utils.metrics import box_iou, fitness
from utils.torch_utils import init_torch_seeds# 设置运行相关的一些基本的配置 Settings
# 控制print打印torch.tensor格式设置 tensor精度为5(小数点后5位) 每行字符数为320个 显示方法为long
torch.set_printoptions(linewidth=320, precision=5, profile='long')
# 控制print打印np.array格式设置 精度为5 每行字符数为320个 format short g, %precision=5
np.set_printoptions(linewidth=320, formatter={
    'float_kind': '{:11.5g}'.format})
# pandas的最大显示行数是10
pd.options.display.max_columns = 10
# 阻止opencv参与多线程(与 Pytorch的 Dataloader不兼容)
cv2.setNumThreads(0)
# 确定最大的线程数 这里被限制在了8
os.environ['NUMEXPR_MAX_THREADS'] = str(min(os.cpu_count(), 8))  # NumExpr max threads

1、timeout(没用到)

这个函数是自定义的timeout超时函数,如果某个程序执行超时,就会触发超时处理函数_timeout_handler 返回超时异常信息。但是这个函数没用到,代码中都是使用库函数自己定义的timeout,没用用这个自定义的timeout函数。所以这个函数可以了解下就行,不过这种超时提示的代码还是有必要学习的。

timeout函数代码:

class timeout(contextlib.ContextDecorator):"""没用到 代码中都是使用库函数自己定义的timeout 没用用这个自定义的timeout函数设置一个超时函数 如果某个程序执行超时 就会触发超时处理函数_timeout_handler 返回超时异常信息并没有用到 这里面的timeout都是用python库函数实现的 并不需要自己另外写一个使用: with timeout(seconds): sleep(10) 或者 @timeout(seconds) decoratordealing with wandb login-options timeout issues as well as check_github() timeout issues"""def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):self.seconds = int(seconds)   # 限制时间self.timeout_message = timeout_msg  # 报错信息self.suppress = bool(suppress_timeout_errors)def _timeout_handler(self, signum, frame):# 超时处理函数 一旦超时 就在seconds后发送超时信息raise TimeoutError(self.timeout_message)def __enter__(self):# signal.signal: 设置信号处理的函数_timeout_handler# 执行流进入with中会执行__enter__方法 如果发生超时, 就会触发超时处理函数_timeout_handler 返回超时异常信息signal.signal(signal.SIGALRM, self._timeout_handler)  # Set handler for SIGALRM# signal.alarm: 设置发送SIGALRM信号的定时器signal.alarm(self.seconds)  # start countdown for SIGALRM to be raiseddef __exit__(self, exc_type, exc_val, exc_tb):# 执行流离开 with 块时(没有发生超时), 则调用这个上下文管理器的__exit__方法来清理所使用的资源signal.alarm(0)  # Cancel SIGALRM if it's scheduledif self.suppress and exc_type is TimeoutError:  # Suppress TimeoutErrorreturn True

2、set_logging、init_seeds

这两个函数是一些初始化操作。set_logging是对日志的设置(format、level)等进行初始化,init_seeds是进行一系列的随机数种子。

2.1、set_logging

这个函数是对日志的格式、等级等进行一个初始化。

def set_logging(rank=-1, verbose=True):"""广泛使用在train.py、test.py、detect.py等文件的main函数的第一步对日志的设置(format、level)等进行初始化"""logging.basicConfig(# 设置日志输出的格式和内容 只打印日志信息format="%(message)s",# 设置日志级别 rank不为-1或0时设置输出级别level为WARN 为-1或0时设置级别为INFOlevel=logging.INFO if (verbose and rank in [-1, 0]) else logging.WARN)

广泛使用在train.py、test.py、detect.py等文件的main函数的第一步:
在这里插入图片描述

2.2、init_seeds

这个函数是使用random.random()、np.random.rand()、init_torch_seeds(调用torch_utils.py中的函数)等生成一系列的随机数种子,以保证结果的可复现性。

init_seeds函数代码:

def init_seeds(seed=0):"""在train函数的一开始调用用于设置一系列的随机数种子"""# 设置随机数 针对使用random.random()生成随机数的时候相同random.seed(seed)# 设置随机数 针对使用np.random.rand()生成随机数的时候相同np.random.seed(seed)# 为CPU设置种子用于生成随机数的时候相同 并确定训练模式init_torch_seeds(seed)

这个函数在train函数的一开始调用:
在这里插入图片描述

3、get_latest_run

这个函数的作用是查找最近保存的权重文件 last*.pt,用以进行断点续训。

get_latest_run函数代码:

def get_latest_run(search_dir='.'):"""用在train.py查找最近的pt文件进行断点续训用于返回该项目中最近的模型 'last.pt'对应的路径:params search_dir: 要搜索的文件的根目录 默认是 '.' 表示搜索该项目中的文件"""# 从Python版本3.5开始, glob模块支持该"**"指令(仅当传递recursive标志时才会解析该指令)# glob.glob函数匹配所有的符合条件的文件, 并将其以list的形式返回last_list = glob.glob(f'{
      search_dir}/**/last*.pt', recursive=True)# os.path.getctime 返回路径对应文件的创建时间# 所以这里是返回路径列表中创建时间最晚(最近的last文件)的路径return max(last_list, key=os.path.getctime) if last_list else ''

函数在train.py中被调用:
在这里插入图片描述

4、is_docker、is_colab、is_pip

下面是三个检测函数,is_docker检测当前环境是否是docker环境,is_colab检查当前环境是否是Google Colab环境,is_pip检测当前文件是否在pip package(site-packages)文件里。前面两个函数在后面的函数都会用到,但是is_pip是没用到的。

4.1、is_docker

这个函数是查询当前环境是否是docker环境,会用到后面的check_git_status和check_imshow等函数中。

is_docker函数代码:

def is_docker():"""在后面的check_git_status和check_imshow等函数中被调用查询当前环境是否是docker环境 Is environment a Docker container?"""return Path('/workspace').exists()  # or Path('/.dockerenv').exists()

4.2、is_colab

这个函数是检查当前环境是否是Google Colab环境,会用到后面的check_imshow函数中。

is_colab函数代码:

def is_colab():"""用到后面的check_imshow函数中检查当前环境是否是Google Colab环境 Is environment a Google Colab instance?"""try:import google.colabreturn Trueexcept Exception as e:return False

4.3、is_pip(没用到)

这个函数是检测当前文件是否在pip package(site-packages)文件里,不过这个函数没用到。

is_pip函数代码:

def is_pip():"""没用到当前文件是否在pip package(site-packages)文件里Is file in a pip package?"""return 'site-packages' in Path(__file__).absolute().parts

5、file_size(没用到)

这个函数是返回本地文件的大小,功能和之前google_utils.py中的gsutil_getsize函数(返回网站链接对应文件的大小)很像。不过这个函数并没有用到哦,随便看看就好。

file_size函数代码:

def file_size(file):"""没用到返回本地文件的大小(MB):params file: 要查询的文件地址"""# .stat(): 返回文件相关状态 st_size: 返回文件的大小return Path(file).stat().st_size / 1e6

6、colorstr

这个函数是将输出的开头和结尾加上颜色,使命令行输出显示会更加好看。

colorstr函数代码:

def colorstr(*input):"""用到下面的check_git_status、check_requirements等函数 train.py、test.py、detect.py等文件中把输出的开头和结尾加上颜色 命令行输出显示会更加好看 如: colorstr('blue', 'hello world')Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code"""# 如果输入长度为1, 就是没有选择颜色 则选择默认颜色设置 blue + bold# args: 输入的颜色序列 string: 输入的字符串*args, string = input if len(input) > 1 else ('blue', 'bold', input[0])# 定义一些基础的颜色 和 字体设置colors = {
    'black': '\033[30m',  # basic colors'red': '\033[31m','green': '\033[32m','yellow': '\033[33m','blue': '\033[34m','magenta': '\033[35m','cyan': '\033[36m','white': '\033[37m','bright_black': '\033[90m',  # bright colors'bright_red': '\033[91m','bright_green': '\033[92m','bright_yellow': '\033[93m','bright_blue': '\033[94m','bright_magenta': '\033[95m','bright_cyan': '\033[96m','bright_white': '\033[97m','end': '\033[0m',  # misc'bold': '\033[1m','underline': '\033[4m'}# 把输出的开头和结尾加上颜色 命令行输出显示会更加好看return ''.join(colors[x] for x in args) + f'{
      string}' + colors['end']

这个函数会用到下面的check_git_status、check_requirements等函数中,而且还会广泛用在train.py、test.py、detect.py等其他文件中如:

在这里插入图片描述
在这里插入图片描述

函数效果如下(可以看到输出开头、结尾变量使用其他颜色):
在这里插入图片描述

7、check_online

这个函数是检查当前主机是否联网了。会在下面的check_git_status、check_requirements等函数中使用。

check_online函数代码:

def check_online():"""在下面的check_git_status、check_requirements等函数中使用检查当前主机网络连接是否可用"""import socket  # 导入socket模块 可解决基于tcp和ucp协议的网络传输try:# 连接到一个ip 地址addr("1.1.1.1")的TCP服务上, 端口号port=443 timeout=5 时限5秒 并返回一个新的套接字对象socket.create_connection(("1.1.1.1", 443), 5)  # check host accessibility# 没发现什么异常, 连接成功, 有网, 就返回Truereturn Trueexcept OSError:# 连接异常, 没网, 返回Falsereturn False

8、emojis

这个函数是忽略掉字符串中无法用ascii编码的内容(比如表情、图像),返回Windows系统可以安全、完整显示的字符串。会在下面的check_git_status、check_requirements等函数中使用。

emojis函数代码:

def emojis(str=''):"""在下面的check_git_status、check_requirements等函数中使用返回Windows系统可以安全、完整显示的字符串Return platform-dependent emoji-safe version of string"""# 通过.encode().decode()的组合忽略掉无法用ascii编码的内容(比如表情、图像)return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str

9、check_git_status

这个函数是检查当前的代码版本是否是最新的。如果不是最新的,会提示使用git pull命令进行升级。

函数代码:

def check_git_status(err_msg=', for updates see https://github.com/ultralytics/yolov5'):"""用在train.py的main函数的一开始部分检查当前代码版本是否是最新的 如果不是最新的 会提示使用git pull命令进行升级"""# 彩色显示github单词 github:print(colorstr('github: '), end='')try:# 检查电脑有没有安装git仓库 没有安装直接报异常并输出异常信息assert Path('.git').exists(), 'skipping check (not a git repository)'# 检查电脑系统有没有安装docker环境变量 没有直接报异常并输出异常信息assert not is_docker(), 'skipping check (Docker image)'# 检查主机是否联网assert check_online(), 'skipping check (offline)'# 创建cmd命令cmd = 'git fetch && git config --get remote.origin.url'# 并创建子进程进行执行cmd命令 返回执行结果 时限5秒url = check_output(cmd, shell=True, timeout=5).decode().strip().rstrip('.git')  # git fetchbranch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip()  # checked outn = int(check_output(f'git rev-list {
      branch}..origin/master --count', shell=True))  # commits behind# n>0 说明当前版本之后还有commit 因此当前版本不是最新的 s为输出的相关提示if n > 0:# 如果不是最新 提升字符s: WARNING...s = f"?? WARNING: code is out of date by {
      n} commit{
      's' * (n > 1)}. " \f"Use 'git pull' to update or 'git clone {
      url}' to download latest."else:# 已经是最新s = f'up to date with {
      url} ?'# 输出显示信息(最新/不是最新) emojis: 忽略掉Windows电脑无法用ascii编码的字符print(emojis(s))  # emoji-safeexcept Exception as e:# 只要报任何异常 直接输出异常信息print(f'{
      e}{
      err_msg}')

这个函数只用在train.py的main函数的一开始部分:
在这里插入图片描述

10、check_python、check_requirements

check_python是检查当前的版本号是否满足最小版本号minimum,check_requirements是检查已经安装的包是否满足requirements对应txt文件的要求。check_requirements会调用check_python。

10.1、check_python

这个函数是检查当前的版本号是否满足最小版本号minimum。会在下面的check_requirements函数被调用。

check_python函数代码:

def check_python(minimum='3.6.2', required=True):"""用在下面的函数check_requirements中检查当前的版本号是否满足最小版本号minimumCheck current python version vs. required python version"""# cuurent: 当前使用的python版本号 如3.8.10current = platform.python_version()# 对比当前版本号和输出的至少的版本号(python版本一般是向下兼容的)# 如果满足返回result=True 反正返回result=False# pkg.parse_version(版本号)用于对比两个版本号的大小result = pkg.parse_version(current) >= pkg.parse_version(minimum)if required:# 检查版本号满不满足最小版本号minimumassert result, f'Python {
      minimum} required by YOLOv5, but Python {
      current} is currently installed'return result

10.2、check_requirements

这个函数用于检查已经安装的包是否满足requirements对应txt文件的要求。会调用colorstr、check_python、check_online等函数。

check_requirements函数代码:

def check_requirements(requirements='requirements.txt', exclude=()):"""用在train.py、test.py、detect.py等文件用于检查已经安装的包是否满足requirements对应txt文件的要求Check installed dependencies meet requirements (pass *.txt file or list of packages)"""# 红色显示requirements单词 requirements:prefix = colorstr('red', 'bold', 'requirements:')# 检查当前的python版本符不符合最低版本要求 check python versioncheck_python()# 解析requirements.txt中的所有包 解析成list 里面存放着一个个的pkg_resources.Requirement类# 如: ['matplotlib>=3.2.2', 'numpy>=1.18.5', ……]if isinstance(requirements, (str, Path)):  # requirements.txt file# 将str字符串requirements转换成路径requirementsfile = Path(requirements)if not file.exists():  # requirements.txt文件不存在print(f"{
      prefix} {
      file.resolve()} not found, check failed.")return# pkg_resources.parse_requirements:可以解析file中的每一条要求# 每一行转换为pkg_resources.Requirement类并进行进一步处理# 处理形式为调用每一行对应的name和specifier属性。前者代表需要包的名称,后者代表版本# 返回list 每个元素是requirements.txt的一行 如: ['matplotlib>=3.2.2', 'numpy>=1.18.5', ……]requirements = [f'{
      x.name}{
      x.specifier}' for x in pkg.parse_requirements(file.open()) if x.name not in exclude]else:  # list or tuple of packagesrequirements = [x for x in requirements if x not in exclude]n = 0  # 统计下面程序更新包的个数 number of packages updates# 依次检查环境中安装的包(及每个包对应的依赖包)是否满足requirements中的每一个最低要求安装包for r in requirements:try:# pkg_resources.require(file) 返回对应包所需的所有依赖包 当这些包有哪个未安装或者版本不对的时候就会报错pkg.require(r)except Exception as e:# 没有找到当前包r 或者 当前包r的版本低于最低要求# 首先打印信息print(f"{
      prefix} {
      r} not found and is required by YOLOv5, attempting auto-update...")try:# 再检查当前主机是否联网assert check_online(), f"'pip install {
      r}' skipped (offline)"# 最后创建一个子进程再执行pip指令并返回执行结果print(check_output(f"pip install '{
      r}'", shell=True).decode())n += 1  # 更新包的数量加1except Exception as e:print(f'{
      prefix} {
      e}')if n:# if packages updated 打印一写更新信息source = file.resolve() if 'file' in locals() else requirementss = f"{
      prefix} {
      n} package{
      's' * (n > 1)} updated per {
      source}\n" \f"{
      prefix} ?? {
      colorstr('bold', 'Restart runtime or rerun command for updates to take effect')}\n"print(emojis(s))  # emoji-safe

用在train.py中:
在这里插入图片描述
test.py:
在这里插入图片描述

11、make_divisible、check_img_size

这两个函数主要是用来约束图像的长款或者feature map的长款,必须是divisor(等于算法的最大下采样率一般是32)的最小倍数。

11.1、make_divisible

这个函数用来取大于等于x且是divisor的最小倍数,保证输入的x(一般是长宽)是算法的最大下采样率的倍数。

def make_divisible(x, divisor):"""用在下面的make_divisible函数中 yolo.py的parse_model函数和commom.py的AutoShape函数中取大于等于x且是divisor的最小倍数Returns x evenly divisible by divisor"""# math.ceil 向上取整return math.ceil(x / divisor) * divisor

这个函数用在下面的make_divisible函数中及 yolo.py的parse_model函数和commom.py的AutoShape函数中:
在这里插入图片描述

11.2、check_img_size

这个函数是为了保证img_size是能被s(32)整除,如果不能就返回大于等于img_size且是s的最小倍数。这个函数本质是通过调用make_divisible函数实现的。

check_img_size函数代码:

def check_img_size(img_size, s=32):"""这个函数主要用于train.py中和detect.py中 用来检查图片的长宽是否符合规定检查img_size是否能被s整除,这里默认s为32 返回大于等于img_size且是s的最小倍数Verify img_size is a multiple of stride s"""# 取大于等于x的最小值且该值能被divisor整除new_size = make_divisible(img_size, int(s))  # ceil gs-multipleif new_size != img_size:print('WARNING: --img-size %g must be multiple of max stride %g, updating to %g' % (img_size, s, new_size))return new_size

用来保证img的长宽符合规定,一般用在train.py中:
在这里插入图片描述
或者detect.py中:
在这里插入图片描述

12、check_imshow

这个函数是检查一下前环境是否可以使用opencv.imshow显示图片。

def check_imshow():"""用在detect.py中 使用webcam的时候调用检查当前环境是否可以使用opencv.imshow显示图片主要有两点限制: Docker环境 + Google Colab环境"""# Check if environment supports image displaystry:# 检查当前环境是否是一个Docker环境 cv2.imshow()不能再docker环境中使用assert not is_docker(), 'cv2.imshow() is disabled in Docker environments'# 检查当前环境是否是一个Google Colab环境 cv2.imshow()不能在Google Colab环境中使用assert not is_colab(), 'cv2.imshow() is disabled in Google Colab environments'# 初始化一张图片检查下opencv是否可用cv2.imshow('test', np.zeros((1, 1, 3)))cv2.waitKey(1)cv2.destroyAllWindows()cv2.waitKey(1)return Trueexcept Exception as e:print(f'WARNING: Environment does not support cv2.imshow() or PIL Image.show() image displays\n{
      e}')return False

会在detect.py中使用webcam的时候调用:
在这里插入图片描述

13、check_file

这个函数是检查本都相关文件路径能否找到这个文件,没找到就说明文件丢失了,返回空;如果传入的是一个网络地址就直接下载这个文件;否则找到就返回本地匹配到的第一个文件名。这个函数很有用,用的很广。

check_file函数代码:

def check_file(file):"""用在train.py和test.py文件中 检查本地有没有这个文件检查相关文件路径能否找到文件 并返回文件名Search/download file (if necessary) and return path"""file = str(file)  # convert to str()# 如果传进来的是文件或者是’‘, 直接返回文件名strif Path(file).is_file() or file == '':  # existsreturn file# 如果传进来的以 'http:/' 或者 'https:/' 开头的url地址, 就下载elif file.startswith(('http:/', 'https:/')):  # downloadurl = str(Path(file)).replace(':/', '://')  # Pathlib turns :// -> :/# urllib.parse: 解析url .unquote: 对url进行解码 file: 要下载的文件名# '%2F' to '/', split https://url.com/file.txt?authfile = Path(urllib.parse.unquote(file)).name.split('?')[0]print(f'Downloading {
      url} to {
      file}...')# 使用torch.hub.download_url_to_file从url地址上中下载文件名为file的文件torch.hub.download_url_to_file(url, file)# 检查是否下载成功assert Path(file).exists() and Path(file).stat().st_size > 0, f'File download failed: {
      url}'  # check# 返回下载的文件名return fileelse:# 否则, 传进来的就是当前项目下的一个全局路径 查找匹配的文件名 返回第一个# glob.glob: 匹配当前项目下的所有项目 返回所有符合条件的文件filesfiles = glob.glob('./**/' + file, recursive=True)  # find fileassert len(files), f'File not found: {
      file}'  # assert file was foundassert len(files) == 1, f"Multiple files match '{
      file}', specify exact path: {
      files}"  # assert unique# 返回第一个匹配到的文件名return files[0]  # return file

在train.py中使用(检查本地data、cfg、hyp等文件是否存在):
在这里插入图片描述
在test.py中使用(检查本地data文件是否存在):
在这里插入图片描述

14、check_dataset

这个函数是检查本地是否有指定的数据集,没用就从torch库中下载并解压数据集。

check_dataset函数代码:

def check_dataset(data, autodownload=True):"""用在train.py和detect.py中 检查本地有没有数据集检查数据集 如果本地没有则从torch库中下载并解压数据集:params data: 是一个解析过的data_dict len=7例如: ['path'='../datasets/coco128', 'train','val', 'test', 'nc', 'names', 'download']:params autodownload: 如果本地没有数据集是否需要直接从torch库中下载数据集 默认True"""# path: WindowPath '..\datasets\coco128'path = Path(data.get('path', ''))  # optional 'path' field# 如果path不为空 就更新(扩展)train、val和test的路径# train: data['train'] -> path / data['train']# 'images/train2017' -> '..\\datasets\\coco128\\images\\train2017'# val: data['val'] -> path / data['val']# 'images/train2017' -> '..\\datasets\\coco128\\images\\train2017'if path:for k in 'train', 'val', 'test':  #if data.get(k):  # prepend pathdata[k] = str(path / data[k]) if isinstance(data[k], str) else [str(path / x) for x in data[k]]# train: 训练路径 '..\\datasets\\coco128\\images\\train2017'# val: 验证路径 '..\\datasets\\coco128\\images\\train2017'# test: 测试路径 None# s: 下载地址 'https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip'train, val, test, s = [data.get(x) for x in ('train', 'val', 'test', 'download')]if val:# path.resolve() 该方法将一些的 路径/路径段 解析为绝对路径# val: [WindowsPath('E:/yolo_v5/datasets/coco128/images/train2017')]val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])]  # val path# 如果val不存在 说明本地不存在数据集if not all(x.exists() for x in val):print('\nWARNING: Dataset not found, nonexistent paths: %s' % [str(x) for x in val if not x.exists()])# 如果下载地址s和下载标记(flag)autodownload不为空, 就直接下载if s and autodownload:  # download script# 如果下载地址s是http开头就从url中下载数据集if s.startswith('http') and s.endswith('.zip'):# f: 得到下载文件的文件名 filenamef = Path(s).nameprint(f'Downloading {
      s} ...')# 开始下载 利用torch.hub.download_url_to_file函数从s路径中下载文件名为f的文件torch.hub.download_url_to_file(s, f)root = path.parent if 'path' in data else '..'  # unzip directory i.e. '../'Path(root).mkdir(parents=True, exist_ok=True)  # create root# 执行解压命名 将文件f解压到root地址 解压后文件名为fr = os.system(f'unzip -q {
      f} -d {
      root} && rm {
      f}')  # unzip# 如果下载地址s是bash开头就使用bash指令下载数据集elif s.startswith('bash '):  # bash scriptprint(f'Running {
      s} ...')# 使用bash命令下载r = os.system(s)# 否则下载地址就是一个python脚本 执行python脚本下载数据集else:  # python scriptr = exec(s, {
    'yaml': data})  # return Noneprint('Dataset autodownload %s\n' % ('success' if r in (0, None) else 'failure'))  # print resultelse:# 下载地址为空 或者不需要下载 标记(flag)autodownloadraise Exception('Dataset not found.')

使用在train.py中:
在这里插入图片描述

15、download

这个函数是将url中的文件下载下来,再解压。但是这个文件并没有在程序中被调用,一般要下载东西都是调用torch.hub.download_url_to_file系统函数和google_utils.py中的attempt_download函数进行下载文件。所以,这个函数随便看看就好。

def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1):"""在voc.yaml中下载数据集Multi-threaded file download and unzip function:params url: 下载文件的url地址:params dir: 下载下来文件保存的目录:params unzip: 下载后文件是否需要解压:params delete: 解压后原文件(未解压)是否需要删除:params curl: 是否使用cmd curl语句下载文件 False就使用torch.hub下载:params threads: 下载一个文件需要的线程数"""def download_one(url, dir):"""Download 1 file:params url: 文件下载地址 Path(url).name=文件名:params dir: 文件保存的目录"""f = dir / Path(url).name  # filename# 这个目录下不存在这个文件 就直接下载if not f.exists():print(f'Downloading {
      url} to {
      f}...')if curl:  # 使用cmd命令curl下载os.system(f"curl -L '{
      url}' -o '{
      f}' --retry 9 -C -")  # curl download, retry and resume on failelse:     # 使用torch.hub下载torch.hub.download_url_to_file(url, f, progress=True)  # torch download# 如果需要解压 且下载的文件后缀是 '.zip' 或 '.gz'if unzip and f.suffix in ('.zip', '.gz'):print(f'Unzipping {
      f}...')if f.suffix == '.zip':s = f'unzip -qo {
      f} -d {
      dir}'  # unzip -quiet -overwriteelif f.suffix == '.gz':s = f'tar xfz {
      f} --directory {
      f.parent}'  # unzip# 解压后是否需要删除未解压的文件if delete:  # delete zip file after unzips += f' && rm {
      f}'os.system(s)  # 调用cmd执行s命令dir = Path(dir)dir.mkdir(parents=True, exist_ok=True)  # make directoryif threads > 1:  # 使用线程池# 定义了一个线程池, 最多创建threads个线程pool = ThreadPool(threads)# 进程池中的该方法会将 iterable 参数传入的可迭代对象分成 chunksize 份传递给不同的进程来处理。pool.imap(lambda x: download_one(*x), zip(url, repeat(dir)))  # multi-threadedpool.close()pool.join()else:for u in tuple(url) if isinstance(url, str) else url:download_one(u, dir)

16、clean_str

这个函数是将字符串中一些奇怪的符号 “|@#!?·$?%&()=??^*;:,¨?><+” 换成下划线 ‘_’。

clean_str函数代码:

def clean_str(s):"""在datasets.py中的LoadStreams类中被调用字符串s里在pattern中字符替换为下划线_ 注意pattern中[]不能省Cleans a string by replacing special characters with underscore _"""# re: 用来匹配字符串(动态、模糊)的模块 正则表达式模块# pattern: 表示正则中的模式字符串 repl: 就是replacement的字符串 string: 要被处理, 要被替换的那个string字符串# 所以这句话执行的是将字符串s里在pattern中的字符串替换为 "_"return re.sub(pattern="[|@#!?·$?%&()=??^*;:,¨?><+]", repl="_", string=s)

只用在datasets.py中的LoadStreams类中:

在这里插入图片描述

17、one_cycle

这个函数是一种特殊的学习率衰减策略。来自这篇论文: one_cycle. 感兴趣的朋友可以读一读。

def one_cycle(y1=0.0, y2=1.0, steps=100):"""用在train.py的学习率衰减策略模块one_cycle lr lr先增加, 再减少, 再以更小的斜率减少论文: https://arxiv.org/pdf/1803.09820.pdf"""# lambda function for sinusoidal ramp from y1 to y2return lambda x: ((1 - math.cos(x * math.pi / steps)) / 2) * (y2 - y1) + y1

在train.py的学习率衰减策略模块中使用:
在这里插入图片描述
如果是linear_lr的lr变化效果:
在这里插入图片描述

如果是one_cycle的lr变化效果:
在这里插入图片描述
一般使用one_cycle的效果会比较好。

18、labels_to_class_weights、labels_to_image_weights

这两个函数是联合使用的。最终的目的是为了在数据集中采样的时候,不使用随机采样,而是使用更加科学的按图片权重进行采样。第一个函数labels_to_class_weights是为了得到数据集中所有类别的权重(频率大的权重小)。第二个函数labels_to_image_weights是利用labels_to_class_weights函数得到的类别权重得到每张图片对应的一个权重。然后利用每张图片的权重在当前batch进行采样,这样的采样方式会更加科学点。

两个函数都只在train.py中使用,且是同时使用的如图:

在这里插入图片描述
在这里插入图片描述

18.1、labels_to_class_weights

这个函数是从训练(gt)标签获得每个类的权重 ,标签频率高的类权重低。

labels_to_class_weights函数代码:

def labels_to_class_weights(labels, nc=80):"""用在train.py中 得到每个类别的权重 标签频率高的类权重低从训练(gt)标签获得每个类的权重 标签频率高的类权重低Get class weights (inverse frequency) from training labels:params labels: gt框的所有真实标签labels:params nc: 数据集的类别数:return torch.from_numpy(weights): 每一个类别根据labels得到的占比(次数越多权重越小) tensor"""if labels[0] is None:  # no labels loadedreturn torch.Tensor()labels = np.concatenate(labels, 0)  # labels.shape = (866643, 5) for COCO# classes: 所有标签对应的类别labels labels[:, 0]: 类别 .astype(np.int): 取整classes = labels[:, 0].astype(np.int)  # labels = [labels_num, class+xywh]# weight: 返回每个类别出现的次数 [1, nc]weights = np.bincount(classes, minlength=nc)  # occurrences per class# Prepend gridpoint count (for uCE training)# gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum() # gridpoints per image# weights = np.hstack([gpi * len(labels) - weights.sum() * 9, weights * 9]) ** 0.5 # prepend gridpoints to start# 将出现次数为0的类别权重全部取1 replace empty bins with 1weights[weights == 0] = 1# 其他所有的类别的权重全部取次数的倒数 number of targets per classweights = 1 / weights# normalize 求出每一类别的占比weights /= weights.sum()return torch.from_numpy(weights)  # numpy -> tensor

18.2、labels_to_image_weights

这个函数是利用每张图片真实gt框的真实标签labels和上一步labels_to_class_weights得到的每个类别的权重得到数据集中每张图片对应的权重。

labels_to_image_weights函数代码:

def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):"""用在train.py中 利用上面得到的每个类别的权重得到每一张图片的权重 再对图片进行按权重进行采样通过每张图片真实gt框的真实标签labels和上一步labels_to_class_weights得到的每个类别的权重进行采样Produces image weights based on class_weights and image contents:params labels: 每张图片真实gt框的真实标签:params nc: 数据集的类别数 默认80:params class_weights: [80] 上一步labels_to_class_weights得到的每个类别的权重"""# class_counts: 每个类别出现的次数 [num_labels, nc] 每一行是当前这张图片每个类别出现的次数 num_labels=图片数量=label数量class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])# [80] -> [1, 80]# 整个数据集的每个类别权重[1, 80] * 每张图片的每个类别出现的次数[num_labels, 80] = 得到每一张图片每个类对应的权重[128, 80]# 另外注意: 这里不是矩阵相乘, 是元素相乘 [1, 80] 和每一行图片的每个类别出现的次数 [1, 80] 分别按元素相乘# 再sum(1): 按行相加 得到最终image_weights: 得到每一张图片对应的采样权重[128]image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)# index = random.choices(range(n), weights=image_weights, k=1) # weight image samplereturn image_weights

19、coco80_to_coco91_class

这个函数是将80个类的coco索引换成91类的coco索引。

coco80_to_coco91_class函数代码:

def coco80_to_coco91_class():"""用在test.py中 从80类映射到91类的coco索引 取得对应的class id将80个类的coco索引换成91类的coco索引:return x: 为80类的每一类在91类中的位置"""# converts 80-index (val2014) to 91-index (paper)# https://tech.amikelive.com/node-718/what-object-categories-labels-are-in-coco-dataset/# a = np.loadtxt('data/coco.names', dtype='str', delimiter='\n')# b = np.loadtxt('data/coco_paper.names', dtype='str', delimiter='\n')# x1 = [list(a[i] == b).index(True) + 1 for i in range(80)] # darknet to coco# x2 = [list(b[i] == a).index(True) if any(b[i] == a) else None for i in range(91)] # coco to darknetx = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34,35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]return x

在test.py中定义:
在这里插入图片描述
在test.py中调用(从80类映射到91类的coco索引 取得对应的class id):
在这里插入图片描述

注意: 从这里开始的几个函数都会涉及一些坐标轴xy的操作,记住 x的正坐标是向右,y的正坐标是向下。

20、clip_coords

这个函数的作用是:将boxes的坐标(x1y1x2y2 左上角右下角)限定在图像的尺寸(img_shape hw)内,防止出界。这个函数会用在下面的xyxy2xywhn、save_one_boxd等函数中,很重要,必须掌握。

clip_coords函数代码:

def clip_coords(boxes, img_shape):"""用在下面的xyxy2xywhn、save_one_boxd等函数中将boxes的坐标(x1y1x2y2 左上角右下角)限定在图像的尺寸(img_shape hw)内Clip bounding x1y1x2y2 bounding boxes to image shape (height, width)"""if isinstance(boxes, torch.Tensor):# .clamp_(min, max): 将取整限定在(min, max)之间, 超出这个范围自动划到边界上boxes[:, 0].clamp_(0, img_shape[1])  # x1boxes[:, 1].clamp_(0, img_shape[0])  # y1boxes[:, 2].clamp_(0, img_shape[1])  # x2boxes[:, 3].clamp_(0, img_shape[0])  # y2else:  # np.arrayboxes[:, 0].clip(0, img_shape[1], out=boxes[:, 0])  # x1boxes[:, 1].clip(0, img_shape[0], out=boxes[:, 1])  # y1boxes[:, 2].clip(0, img_shape[1], out=boxes[:, 2])  # x2boxes[:, 3].clip(0, img_shape[0], out=boxes[:, 3])  # y2

21、scale_coords

这个函数是将坐标coords(x1y1x2y2)从img1_shape尺寸缩放到img0_shape尺寸。x的正坐标是向右,y的正坐标是向下。这个函数也是很重要的。

scale_coords函数代码:

def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):"""用在detect.py和test.py中 将预测坐标从feature map映射回原图将坐标coords(x1y1x2y2)从img1_shape缩放到img0_shape尺寸Rescale coords (xyxy) from img1_shape to img0_shape:params img1_shape: coords相对于的shape大小:params coords: 要进行缩放的box坐标信息 x1y1x2y2 左上角 + 右下角:params img0_shape: 要将coords缩放到相对的目标shape大小:params ratio_pad: 缩放比例gain和pad值 None就先计算gain和pad值再pad+scale 不为空就直接pad+scale"""# ratio_pad为空就先算放缩比例gain和pad值 calculate from img0_shapeif ratio_pad is None:# gain = old / new 取高宽缩放比例中较小的,之后还可以再pad 如果直接取大的, 裁剪就可能减去目标gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1])# wh padding wh中有一个为0 主要是pad另一个pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2else:gain = ratio_pad[0][0]  # 指定比例pad = ratio_pad[1]  # 指定pad值# 因为pad = img1_shape - img0_shape 所以要把尺寸从img1 -> img0 就同样也需要减去pad# 如果img1_shape>img0_shape pad>0 coords从大尺寸缩放到小尺寸 减去pad 符合# 如果img1_shape<img0_shape pad<0 coords从小尺寸缩放到大尺寸 减去pad 符合coords[:, [0, 2]] -= pad[0]  # x paddingcoords[:, [1, 3]] -= pad[1]  # y padding# 缩放scalecoords[:, :4] /= gain# 防止放缩后的坐标过界 边界处直接剪切clip_coords(coords, img0_shape)return coords

用在detect.py中将预测坐标映射回原图:
在这里插入图片描述
用在test.py中也是将预测坐标映射回原图:
在这里插入图片描述

22、xyxy2xywh、xywh2xyxy

这两个函数是两个相反的过程。xyxy2xywh是将预测信息xyxy格式转化为xywh的格式,而xywh2xyxy是将预测信息xywh格式转化为xyxy的格式。这两个函数的代码很重要,一定要掌握。代码还是那句话:x的正坐标是向右,y的正坐标是向下。

22.1、xyxy2xywh

这个函数是将预测信息xyxy格式转化为xywh的格式。

xyxy2xywh函数代码:

def xyxy2xywh(x):""""用在detect.py和test.py中 操作最后, 将预测信息从xyxy格式转为xywh格式 再保存Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where x1y1=top-left, x2y2=bottom-right:params x: [n, x1y1x2y2] (x1, y1): 左上角 (x2, y2): 右下角:return y: [n, xywh] (x, y): 中心点 wh: 宽高"""y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)y[:, 0] = (x[:, 0] + x[:, 2]) / 2  # x centery[:, 1] = (x[:, 1] + x[:, 3]) / 2  # y centery[:, 2] = x[:, 2] - x[:, 0]  # widthy[:, 3] = x[:, 3] - x[:, 1]  # heightreturn y

在detect.py中使用:
在这里插入图片描述

22.2、xywh2xyxy

这个函数是将预测信息xywh格式转化为xyxy的格式。

xywh2xyxy函数代码:

def xywh2xyxy(x):"""用在test.py中 操作之前 转为xyxy才可以进行操作注意: x的正方向为右面 y的正方向为下面Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where x1y1=top-left, x2y2=bottom-right:params x: [n, xywh] (x, y)::return y: [n, x1y1x2y2] (x1, y1): 左上角 (x2, y2): 右下角"""y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)y[:, 0] = x[:, 0] - x[:, 2] / 2  # top left xy[:, 1] = x[:, 1] - x[:, 3] / 2  # top left yy[:, 2] = x[:, 0] + x[:, 2] / 2  # bottom right xy[:, 3] = x[:, 1] + x[:, 3] / 2  # bottom right yreturn y

在test.py中(转换后才能操作):
在这里插入图片描述

23、xywhn2xyxy、xyxy2xywhn、xyn2xy

这三个函数主要用于datasets.py文件中。主要是对图像进行一些变换操作。xywhn2xyxy是将xywh(normalized) -> x1y1x2y2。xyxy2xywhn是将x1y1x2y2 -> xywh(normalized)。xyn2xy是将xy(normalized) -> xy。这三个函数也是比较重要的,大家必须掌握。

23.1、xywhn2xyxy

这个函数是xywh(normalized) -> x1y1x2y2。

xywhn2xyxy函数代码:

def xywhn2xyxy(x, w=640, h=640, padw=0, padh=0):"""用在datasets.py的 LoadImagesAndLabels类的__getitem__函数、load_mosaic、load_mosaic9等函数中 将xywh(normalized) -> x1y1x2y2 (x, y): 中间点 wh: 宽高 (x1, y1): 左上点 (x2, y2): 右下点Convert nx4 boxes from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right"""y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + padw  # top left xy[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + padh  # top left yy[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + padw  # bottom right xy[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + padh  # bottom right yreturn y

在__getitem__函数中被调用:
在这里插入图片描述

23.2、xyxy2xywhn

这个函数是将x1y1x2y2 -> xywh(normalized)。

xyxy2xywhn函数代码:

def xyxy2xywhn(x, w=640, h=640, clip=False):"""用在datasets.py的 LoadImagesAndLabels类的__getitem__函数中将 x1y1x2y2 -> xywh(normalized) (x1, y1): 左上点 (x2, y2): 右下点 (x, y): 中间点 wh: 宽高Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right"""if clip:# 是否需要将x的坐标(x1y1x2y2)限定在尺寸(h, w)内clip_coords(x, (h, w))  # warning: inplace clipy = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w  # x centery[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h  # y centery[:, 2] = (x[:, 2] - x[:, 0]) / w  # widthy[:, 3] = (x[:, 3] - x[:, 1]) / h  # heightreturn y

用在datasets.py的 LoadImagesAndLabels类的__getitem__函数中:
在这里插入图片描述

23.3、xyn2xy

这个函数是将xy(normalized) -> xy。

xyn2xy函数代码:

def xyn2xy(x, w=640, h=640, padw=0, padh=0):"""用在datasets.py的load_mosaic和load_mosaic9函数中xy(normalized) -> xyConvert normalized segments into pixel segments, shape (n,2)"""y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)y[:, 0] = w * x[:, 0] + padw  # top left xy[:, 1] = h * x[:, 1] + padh  # top left yreturn y

在datasets.py的load_mosaic和load_mosaic9函数中使用:
在这里插入图片描述

24、non_max_suppression

NMS(非极大值抑制),这个函数相信大家都已经很熟悉了,这是目标检测最基本的操作之一了。可以说这个函数是这篇博客当中最重要的代码也不为过,所以大家一定要掌握这个函数(流程原理+代码)。

还写过一篇nms更详细:nms

non_max_suppression函数代码:

def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None,agnostic=False, multi_label=True, labels=(), max_det=300, merge=False):"""Runs Non-Maximum Suppression (NMS) on inference resultsParams:prediction: [batch, num_anchors(3个yolo预测层), (x+y+w+h+1+num_classes)] = [1, 18900, 25] 3个anchor的预测结果总和conf_thres: 先进行一轮筛选,将分数过低的预测框(<conf_thres)删除(分数置0)iou_thres: iou阈值, 如果其余预测框与target的iou>iou_thres, 就将那个预测框置0classes: 是否nms后只保留特定的类别 默认为Noneagnostic: 进行nms是否也去除不同类别之间的框 默认Falsemulti_label: 是否是多标签 nc>1 一般是Truelabels:max_det: 每张图片的最大目标个数 默认1000merge: use merge-NMS 多个bounding box给它们一个权重进行融合 默认FalseReturns:[num_obj, x1y1x2y2+object_conf+cls] = [5, 6]"""# Checks 检查传入的conf_thres和iou_thres两个阈值是否符合范围assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {
      conf_thres}, valid values are between 0.0 and 1.0'assert 0 <= iou_thres <= 1, f'Invalid IoU {
      iou_thres}, valid values are between 0.0 and 1.0'# Settings 设置一些变量nc = prediction.shape[2] - 5  # number of classesmin_wh, max_wh = 2, 4096  # (pixels) 预测物体宽度和高度的大小范围 [min_wh, max_wh]max_nms = 30000  # 每个图像最多检测物体的个数 maximum number of boxes into torchvision.ops.nms()time_limit = 10.0  # nms执行时间阈值 超过这个时间就退出了 seconds to quit afterredundant = True  # 是否需要冗余的detections require redundant detectionsmulti_label &= nc > 1  # multiple labels per box (adds 0.5ms/img)# batch_size个output 存放最终筛选后的预测框结果output = [torch.zeros((0, 6), device=prediction.device)] * prediction.shape[0]# 定义第二层过滤条件xc = prediction[..., 4] > conf_thres  # candidatest = time.time()  # 记录当前时刻时间for xi, x in enumerate(prediction):  # image index, image inference# Apply constraints# 第一层过滤 虑除超小anchor标和超大anchor x=[18900, 25]x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0  # width-height# 第二层过滤 根据conf_thres虑除背景目标(obj_conf<conf_thres 0.1的目标 置信度极低的目标) x=[59, 25]x = x[xc[xi]]  # confidence# {list: bs} 第一张图片的target[17, 5] 第二张[1, 5] 第三张[7, 5] 第四张[6, 5]# Cat apriori labels if autolabelling 自动标注label时调用 一般不用# 自动标记在非常高的置信阈值(即 0.90 置信度)下效果最佳,而 mAP 计算依赖于非常低的置信阈值(即 0.001)来正确评估 PR 曲线下的区域。# 这个自动标注我觉得应该是一个类似RNN里面的Teacher Forcing的训练机制 就是在训练的时候跟着老师(ground truth)走# 但是这样又会造成一个问题: 一直靠老师带的孩子是走不远的 这样的模型因为依赖标签数据,在训练过程中,模型会有较好的效果# 但是在测试的时候因为不能得到ground truth的支持, 所以如果目前生成的序列在训练过程中有很大不同, 模型就会变得脆弱。# 所以个人认为(个人观点): 应该在下面使用的时候有选择的开启这个trick 比如设置一个概率p随机开启 或者在训练的前n个epoch使用 后面再关闭if labels and len(labels[xi]):l = labels[xi]v = torch.zeros((len(l), nc + 5), device=x.device)v[:, :4] = l[:, 1:5]  # boxv[:, 4] = 1.0  # confv[range(len(l)), l[:, 0].long() + 5] = 1.0  # clsx = torch.cat((x, v), 0)# 经过前两层过滤后如果该feature map没有目标框了,就结束这轮直接进行下一张图if not x.shape[0]:continue# 计算conf_scorex[:, 5:] *= x[:, 4:5]  # conf = obj_conf * cls_conf# Box (center x, center y, width, height) to (x1, y1, x2, y2) 左上角 右下角 [59, 4]box = xywh2xyxy(x[:, :4])# Detections matrix nx6 (xyxy, conf, cls)if multi_label:# 第三轮过滤:针对每个类别score(obj_conf * cls_conf) > conf_thres [59, 6] -> [51, 6]# 这里一个框是有可能有多个物体的,所以要筛选# nonzero: 获得矩阵中的非0(True)数据的下标 a.t(): 将a矩阵拆开# i: 下标 [43] j: 类别index [43] 过滤了两个score太低的i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T# pred = [43, xyxy+score+class] [43, 6]# unsqueeze(1): [43] => [43, 1] add batch dimension# box[i]: [43,4] xyxy# pred[i, j + 5].unsqueeze(1): [43,1] score 对每个i,取第(j+5)个位置的值(第j个class的值cla_conf)# j.float().unsqueeze(1): [43,1] classx = torch.cat((box[i], x[i, j + 5, None], j[:, None].float()), 1)else:  # best class onlyconf, j = x[:, 5:].max(1, keepdim=True)    # 一个类别直接取分数最大类的即可x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > conf_thres]# Filter by class 是否只保留特定的类别 默认None 不执行这里if classes is not None:x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]# 检测数据是否为有限数 Apply finite constraint 这轮可有可无,一般没什么用 所以这里给他注释了# if not torch.isfinite(x).all():# x = x[torch.isfinite(x).all(1)]# Check shapen = x.shape[0]  # number of boxesif not n:  # 如果经过第三轮过滤该feature map没有目标框了,就结束这轮直接进行下一张图continueelif n > max_nms:  # 如果经过第三轮过滤该feature map还要很多框(>max_nms) 就需要排序x = x[x[:, 4].argsort(descending=True)[:max_nms]]  # sort by confidence# 第4轮过滤 Batched NMS [51, 6] -> [5, 6]c = x[:, 5:6] * (0 if agnostic else max_wh)  # classes# 做个切片 得到boxes和scores 不同类别的box位置信息加上一个很大的数但又不同的数c# 这样作非极大抑制的时候不同类别的框就不会掺和到一块了 这是一个作nms挺巧妙的技巧boxes, scores = x[:, :4] + c, x[:, 4]  # boxes (offset by class), scores# 返回nms过滤后的bounding box(boxes)的索引(降序排列)# i=tensor([18, 19, 32, 25, 27]) nms后只剩下5个预测框了i = torchvision.ops.nms(boxes, scores, iou_thres)  # NMSif i.shape[0] > max_det:  # limit detectionsi = i[:max_det]if merge and (1 < n < 3E3):  # Merge NMS (boxes merged using weighted mean)# update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)iou = box_iou(boxes[i], boxes) > iou_thres  # iou matrixweights = iou * scores[None]  # box weights 正比于 iou * scores# bounding box合并 其实就是把权重和框相乘再除以权重之和x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True)  # merged boxesif redundant:i = i[iou.sum(1) > 1]  # require redundancyoutput[xi] = x[i]   # 最终输出 [5, 6]# 看下时间超没超时 超时没做完的就不做了if (time.time() - t) > time_limit:print(f'WARNING: NMS time limit {
      time_limit}s exceeded')break  # time limit exceededreturn output

这个函数一般会用再detect.py或者test.py的模型前向推理结束之后:
在这里插入图片描述
如果还不是很懂这个函数的代码,可以参考下我写的另一篇将yolov3博客,里面也很详细的解释了NMS函数流程和代码:【YOLO-V3-SPP 源码解读】三、预测模块.

25、strip_optimizer

这个函数是在模型训练完后, strip_optimizer函数将optimizer、training_results、updates…从保存的模型文件ckpt中删除。

strip_optimizer函数代码:

def strip_optimizer(f='best.pt', s=''):"""用在train.py模型训练完后 将optimizer、training_results、updates...从保存的模型文件f中删除Strip optimizer from 'f' to finalize training, optionally save as 's':params f: 传入的原始保存的模型文件:params s: 删除optimizer等变量后的模型保存的地址 dir"""# x: 为加载训练的模型x = torch.load(f, map_location=torch.device('cpu'))# 如果模型是ema replace model with emaif x.get('ema'):x['model'] = x['ema']# 以下模型训练涉及到的若干个指定变量置空for k in 'optimizer', 'training_results', 'wandb_id', 'ema', 'updates':  # keysx[k] = Nonex['epoch'] = -1  # 模型epoch恢复初始值-1x['model'].half()  # to FP16for p in x['model'].parameters():p.requires_grad = False# 保存模型 x -> s/ftorch.save(x, s or f)mb = os.path.getsize(s or f) / 1E6  # filesizeprint(f"Optimizer stripped from {
      f},{
      (' saved as %s,' % s) if s else ''} {
      mb:.1f}MB")

用在train.py模型训练完后:
在这里插入图片描述

26、print_mutation

这个函数用来打印进化后的超参结果和results到evolve.txt和hyp_evolved.yaml中。

print_mutation函数代码:

def print_mutation(hyp, results, yaml_file='hyp_evolved.yaml', bucket=''):"""用在train.py的进化超参结束后 打印进化后的超参结果和results到evolve.txt和hyp_evolved.yaml中Print mutation results to evolve.txt (for use with train.py --evolve):params hyp: 进化后的超参 dict {28对 key:value}:params results: tuple(7) (mp, mr, map50, map50:95, box_loss, obj_loss, cls_loss):params yaml_file: 要保存的进化后的超参文件名 runs\train\evolve\hyp_evolved.yaml:params bucket: ''"""# 定义相关变量 并赋值 按指定格式输出a = '%10s' * len(hyp) % tuple(hyp.keys())  # str 得到所有超参的key hyperparam keysb = '%10.3g' * len(hyp) % tuple(hyp.values())  # str 得到所有超参的value hyperparam valuesc = '%10.4g' * len(results) % results  # c = results str (P, R, mAP@0.5, mAP@0.5:0.95, box_loss, obj_loss, cls_loss)print('\n%s\n%s\nEvolved fitness: %s\n' % (a, b, c))if bucket:url = 'gs://%s/evolve.txt' % bucketif gsutil_getsize(url) > (os.path.getsize('evolve.txt') if os.path.exists('evolve.txt') else 0):os.system('gsutil cp %s .' % url)  # download evolve.txt if larger than local# 将结果c(results)和b(得到所有超参的value)写入evolve.txt中with open('evolve.txt', 'a') as f:  # append resultf.write(c + b + '\n')x = np.unique(np.loadtxt('evolve.txt', ndmin=2), axis=0)  # load unique rowsx = x[np.argsort(-fitness(x))]  # sortnp.savetxt('evolve.txt', x, '%10.3g')  # save sort by fitness# Save yaml 保存yaml配置文件 为'hyp_evolved.yaml'for i, k in enumerate(hyp.keys()):  # 将hyp保存到数组hyp[]中hyp[k] = float(x[0, i + 7])with open(yaml_file, 'w') as f:  # 将hyp写入yaml_fileresults = tuple(x[0, :7])c = '%10.4g' * len(results) % results  # results (P, R, mAP@0.5, mAP@0.5:0.95, val_losses x 3)f.write('# Hyperparameter Evolution Results\n# Generations: %g\n# Metrics: ' % len(x) + c + '\n\n')yaml.safe_dump(hyp, f, sort_keys=False)if bucket:  # 如果需要存到谷歌云盘, 就上传 默认是不需要的os.system('gsutil cp evolve.txt %s gs://%s' % (yaml_file, bucket))  # upload

用在train.py的进化超参结束后 :
在这里插入图片描述

27、apply_classifier

这个函数定义了一个二级分类器来处理yolo的输出,可以将它用在detect.py中。这里写的这个函数只是一个普通的实现,你也可以根据自己的任务改写这个函数。不过这个函数我们几乎不会用它,因为它很容易出错。我们这里就不仔细介绍了,真的很难用到这个函数,随便看下就好。

函数代码:

def apply_classifier(x, model, img, im0):"""用在detect.py文件的nms后继续对feature map送入model2 进行二次分类定义了一个二级分类器来处理yolo的输出 当前实现本质上是一个参考起点,您可以使用它自行实现此项比如你有照片与汽车与车牌, 你第一次剪切车牌, 并将其发送到第二阶段分类器, 以检测其中的字符Apply a second stage classifier to yolo outputshttps://github.com/ultralytics/yolov5/issues/2700 这个函数使用起来很容易出错 不是很推荐使用https://github.com/ultralytics/yolov5/issues/1472:params x: yolo层的输出:params model: 分类模型:params img: 进行resize + pad之后的图片:params im0: 原尺寸的图片"""im0 = [im0] if isinstance(im0, np.ndarray) else im0for i, d in enumerate(x):  # per imageif d is not None and len(d):d = d.clone()# Reshape and pad cutoutsb = xyxy2xywh(d[:, :4])  # boxes xyxy -> xywhb[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1)  # rectangle to squareb[:, 2:] = b[:, 2:] * 1.3 + 30  # padd[:, :4] = xywh2xyxy(b).long()  # xywh -> xyxy# Rescale boxes from img_size to im0 sizescale_coords(img.shape[2:], d[:, :4], im0[i].shape)# Classespred_cls1 = d[:, 5].long()  # 在之前的yolo模型预测的类别ims = []for j, a in enumerate(d):  # per itemcutout = im0[i][int(a[1]):int(a[3]), int(a[0]):int(a[2])]im = cv2.resize(cutout, (224, 224))  # BGR# cv2.imwrite('test%i.jpg' % j, cutout)im = im[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416im = np.ascontiguousarray(im, dtype=np.float32)  # uint8 to float32im /= 255.0  # 0 - 255 to 0.0 - 1.0ims.append(im)# 用model模型进行分类预测pred_cls2 = model(torch.Tensor(ims).to(d.device)).argmax(1)  # classifier prediction# 保留预测一致的结果x[i] = x[i][pred_cls1 == pred_cls2]  # retain matching class detectionsreturn x

28、increment_path

用于递增路径。比如我输入路径是run/train/exp,但是发现文件夹里面已经有这个文件了,那么就将文件路径扩展围为:runs/train/exp{sep}0, runs/exp{sep}1 etc。

increment_path函数代码:

def increment_path(path, exist_ok=False, sep='', mkdir=False):"""这是个用处特别广泛的函数 train.py、detect.py、test.py等都会用到递增路径 如 run/train/exp --> runs/train/exp{sep}0, runs/exp{sep}1 etc.:params path: window path run/train/exp:params exist_ok: False:params sep: exp文件名的后缀 默认'':params mkdir: 是否在这里创建dir False"""path = Path(path)  # string/win路径 -> win路径# 如果该文件夹已经存在 则将路径run/train/exp修改为 runs/train/exp1if path.exists() and not exist_ok:# path.suffix 得到路径path的后缀 ''suffix = path.suffix# .with_suffix 将路径添加一个后缀 ''path = path.with_suffix('')# 模糊搜索和path\sep相似的路径, 存在一个list列表中 如['runs\\train\\exp', 'runs\\train\\exp1']# f开头表示在字符串内支持大括号内的python表达式dirs = glob.glob(f"{
      path}{
      sep}*")# r的作用是去除转义字符 path.stem: 没有后缀的文件名 exp# re 模糊查询模块 re.search: 查找dir中有字符串'exp/数字'的d \d匹配数字# matches [None, <re.Match object; span=(11, 15), match='exp1'>] 可以看到返回span(匹配的位置) match(匹配的对象)matches = [re.search(rf"%s{
      sep}(\d+)" % path.stem, d) for d in dirs]# i = [1]i = [int(m.groups()[0]) for m in matches if m]  # indices# 生成需要生成文件的exp后面的数字 n = max(i) + 1 = 2n = max(i) + 1 if i else 1  # increment number# 返回path runs/train/exp2path = Path(f"{
      path}{
      sep}{
      n}{
      suffix}")  # update path# path.suffix文件后缀 path.parent 路径的上级目录 runs/train/exp2dir = path if path.suffix == '' else path.parent  # directoryif not dir.exists() and mkdir:  # mkdir 默认False 先不创建dirdir.mkdir(parents=True, exist_ok=True)  # make directoryreturn path  # 返回runs/train/exp2

train.py中使用:
在这里插入图片描述detect.py中使用:
在这里插入图片描述

29、save_one_box

这个函数是用来将预测到的目标从原图中扣出来 剪切好 并保存 会在runs/detect/expn下生成crops文件,将剪切的图片保存在里面。这个函数回调用xyxy2xywh、xywh2xyxy、clip_coords、increment_path等函数。

save_one_box函数代码:

def save_one_box(xyxy, im, file='image.jpg', gain=1.02, pad=10, square=False, BGR=False, save=True):"""用在detect.py文件中 由opt的save-crop参数控制执不执行将预测到的目标从原图中扣出来 剪切好 并保存 会在runs/detect/expn下生成crops文件,将剪切的图片保存在里面Save image crop as {file} with crop size multiple {gain} and {pad} pixels. Save and/or return crop:params xyxy: 预测到的目标框信息 list 4个tensor x1 y1 x2 y2 左上角 + 右下角:params im: 原图片 需要裁剪的框从这个原图上裁剪 nparray (1080, 810, 3):params file: runs\detect\exp\crops\dog\bus.jpg:params gain: 1.02 xyxy缩放因子:params pad: xyxy pad一点点边界框 裁剪出来会更好看:params square: 是否需要将xyxy放缩成正方形:params BGR: 保存的图片是BGR还是RGB:params save: 是否要保存剪切的目标框"""xyxy = torch.tensor(xyxy).view(-1, 4)  # list -> Tensor [1, 4] = [x1 y1 x2 y2]b = xyxy2xywh(xyxy)  # xyxy to xywh [1, 4] = [x y w h]if square:  # 一般不需要rectangle to squareb[:, 2:] = b[:, 2:].max(1)[0].unsqueeze(1)  # attempt rectangle to square# box wh * gain + pad box*gain再加点pad 裁剪出来框更好看b[:, 2:] = b[:, 2:] * gain + padxyxy = xywh2xyxy(b).long()  # xywh -> xyxy# 将boxes的坐标(x1y1x2y2 左上角右下角)限定在图像的尺寸(img_shape hw)内clip_coords(xyxy, im.shape)# crop: 剪切的目标框hwcrop = im[int(xyxy[0, 1]):int(xyxy[0, 3]), int(xyxy[0, 0]):int(xyxy[0, 2]), ::(1 if BGR else -1)]if save:# 保存剪切的目标框cv2.imwrite(str(increment_path(file, mkdir=True).with_suffix('.jpg')), crop)return crop

用于detect.py文件中:
在这里插入图片描述

30、resample_segments

这个函数是 对segment重新采样,比如说segment坐标只有100个,通过interp函数将其采样为n个(默认1000)。

resample_segments函数代码:

def resample_segments(segments, n=1000):"""用在datasets.py文件中的random_perspective函数中对segment重新采样,比如说segment坐标只有100个,通过interp函数将其采样为n个(默认1000):params segments: [N, x1x2...]:params n: 采样个数:return segments: [N, n/2, 2]"""for i, s in enumerate(segments):# 0~len(s)-1 取n(1000)个点x = np.linspace(0, len(s) - 1, n)# 0, 1, 2, ..., len(s)-1xp = np.arange(len(s))# 对所有的segments都进行重新采样 比如说segment坐标只有100个,通过interp函数将其采样为n个(默认1000)segments[i] = np.concatenate([np.interp(x, xp, s[:, i]) for i in range(2)]).reshape(2, -1).T  # segment xy# [N, n/2, 2]return segments

在datasets.py文件中的random_perspective函数中调用:

在这里插入图片描述

31、segment2box

这个函数是将一个多边形标签(不是矩形标签 到底是几边形未知)转化为一个矩形标签。

segment2box函数代码:

def segment2box(segment, width=640, height=640):"""用在datasets.py文件中的random_perspective函数中将一个多边形标签(不是矩形标签 到底是几边形未知)转化为一个矩形标签方法: 对多边形所有的点x1y1 x2y2... 获取其中的(x_min,y_min)和(x_max,y_max) 作为矩形label的左上角和右下角Convert 1 segment label to 1 box label, applying inside-image constraint:params segment: 一个多边形标签 [n, 2] 传入这个多边形n个顶点的坐标:params width: 这个多边形所在图片的宽度:params height: 这个多边形所在图片的高度:return 矩形标签 [1, x_min+y_min+x_max+y_max]"""# 分别获取当前多边形中所有多边形点的x和y坐标x, y = segment.T  # segment xy# inside: 筛选条件 xy坐标必须大于等于0 x坐标必须小于等于宽度 y坐标必须小于等于高度inside = (x >= 0) & (y >= 0) & (x <= width) & (y <= height)# 获取筛选后的所有多边形点的x和y坐标x, y, = x[inside], y[inside]# 取当前多边形中xy坐标的最大最小值,得到边框的坐标xyxyreturn np.array([x.min(), y.min(), x.max(), y.max()]) if any(x) else np.zeros((1, 4))

在datasets.py文件中的random_perspective函数中调用:
在这里插入图片描述

32、segments2boxes

这个函数是将多个多边形标签(不是矩形标签 到底是几边形未知)转化为多个矩形标签。

segments2boxes模块代码:

def segments2boxes(segments):"""用在datasets.py文件中的verify_image_label函数中将多个多边形标签(不是矩形标签 到底是几边形未知)转化为多个矩形标签Convert segment labels to box labels, i.e. (cls, xy1, xy2, ...) to (cls, xywh):params segments: [N, cls+x1y1+x2y2 ...]:return [N, cls+xywh]"""boxes = []for s in segments:# 分别获取当前多边形中所有多边形点的x和y坐标x, y = s.T# 取当前多边形中x和y坐标的最大最小值,得到边框的坐标xyxyboxes.append([x.min(), y.min(), x.max(), y.max()])# [N, cls+xywh]return xyxy2xywh(np.array(boxes))

在datasets.py文件中的verify_image_label函数中调用:

在这里插入图片描述

总结

这个文件的代码主要是一些通用的工具函数,会广泛的在整个项目的文件中使用,所以比较重要,希望大家都可以掌握。

比较重要的函数有:set_logging、init_seeds、get_latest_run、colorstr、check_git_status、check_requirements、make_divisible、check_file、check_dataset、one_cycle、labels_to_class_weights、labels_to_image_weights、strip_optimizer、print_mutation、save_one_box、increment_path。

非常重要的有:clip_coords、scale_coords、xyxy2xywh、xywh2xyxy、xywhn2xyxy、xyxy2xywhn、xyn2xy、non_max_suppression。

–2021.08.03 19:55

  相关解决方案