YOLOV8逐步分解(2)_DetectionTrainer类初始化过程

news/2024/4/27 1:35:18/文章来源:https://blog.csdn.net/yueguang8/article/details/137125375

 接上篇文章yolov8逐步分解(1)--默认参数&超参配置文件加载继续讲解。

 1. 默认配置文件加载完成后,创建对象trainer时,需要从默认配置中获取类DetectionTrainer初始化所需的参数args,如下所示

def train(cfg=DEFAULT_CFG, use_python=False):"""Train and optimize YOLO model given training data and device."""model = cfg.model or 'yolov8n.pt'data = cfg.data or 'coco128.yaml'  # or yolo.ClassificationDataset("mnist")device = cfg.device if cfg.device is not None else ''args = dict(model=model, data=data, device=device)if use_python:from ultralytics import YOLOYOLO(model).train(**args)else:trainer = DetectionTrainer(overrides=args)  #初始化训练器trainer.train()

        通过debug可以看到,如下所示,args值为指定模型和数据集

 2. 使用上一步中获取的参数args,创建并初始化一个目标检测训练器trainer

trainer = DetectionTrainer(overrides=args)

3. DetectionTrainer类的初始化代码如下,下面我们将逐步讲解。

def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):"""Initializes the BaseTrainer class.Args:cfg (str, optional): Path to a configuration file. Defaults to DEFAULT_CFG.overrides (dict, optional): Configuration overrides. Defaults to None.对配置文件/训练数据文件参数进行加载,关键信息判断处理解析,保证文件存在,不存在则下载等合法性检测,及值的初始化化操作"""self.args = get_cfg(cfg, overrides)  #将overrides中的配置与cfg中的配置融合,返回SimpleNameSpace类型self.device = select_device(self.args.device, self.args.batch) #选择运行在CPU/GPU还是苹果推出的MPS库上self.check_resume() #判断是否基于之前的断点继续训练,如果是,则加载之前保存的数据参数self.validator = Noneself.model = Noneself.metrics = Noneself.plots = {}init_seeds(self.args.seed + 1 + RANK, deterministic=self.args.deterministic) #初始化随机数# Dirs 创建运行结果保存额目录及文件:创建本次训练的目录/ weights保存目录 /保存运行参数project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task #project: runs/detectname = self.args.name or f'{self.args.mode}'  #name: 'train'if hasattr(self.args, 'save_dir'):  #判断是否设置保存路径 ,如果没有则根据项目和任务名穿件保存目录self.save_dir = Path(self.args.save_dir)else:self.save_dir = Path(increment_path(Path(project) / name, exist_ok=self.args.exist_ok if RANK in (-1, 0) else True))self.wdir = self.save_dir / 'weights'  # weights dir #runs/detect/train72/weighhtsif RANK in (-1, 0):self.wdir.mkdir(parents=True, exist_ok=True)  # make dirself.args.save_dir = str(self.save_dir)yaml_save(self.save_dir / 'args.yaml', vars(self.args))  # save run args  #保存运行参数self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt'  # checkpoint pathsself.save_period = self.args.save_period   #保存周期#设置 epoch次数 和 batch的大小self.batch_size = self.args.batchself.epochs = self.args.epochsself.start_epoch = 0if RANK == -1:print_args(vars(self.args))# Deviceif self.device.type == 'cpu':self.args.workers = 0  # faster CPU training as time dominated by inference, not dataloading# Model and Dataset 初始化模型文件 和数据集self.model = self.args.model  #yolov8n.pttry:if self.args.task == 'classify':   #分类任务self.data = check_cls_dataset(self.args.data)elif self.args.data.endswith('.yaml') or self.args.task in ('detect', 'segment'):  #检测和分割任务self.data = check_det_dataset(self.args.data) #加载数据yaml文件,进行关键属性值检测,并进行路径转换,确保数据集文件存在,不存在则下载if 'yaml_file' in self.data:self.args.data = self.data['yaml_file']  # for validating 'yolo train data=url.zip' usageexcept Exception as e:raise RuntimeError(emojis(f"Dataset '{clean_url(self.args.data)}' error ❌ {e}")) from eself.trainset, self.testset = self.get_dataset(self.data) #初始化训练集测试集参数 获取路径self.ema = None# Optimization utils initself.lf = None   #损失函数self.scheduler = None  #学习率调整策略# Epoch level metrics 指标self.best_fitness = Noneself.fitness = Noneself.loss = None   #当前损失值self.tloss = None  #总损失值self.loss_names = ['Loss']self.csv = self.save_dir / 'results.csv'self.plot_idx = [0, 1, 2]# Callbacksself.callbacks = _callbacks or callbacks.get_default_callbacks()if RANK in (-1, 0):callbacks.add_integration_callbacks(self)

3.1  self.args = get_cfg(cfg, overrides) 该行主要实现功能为:

        将默认配置参数从Simplenamespace转为字典后与overrides中的参数合并更新,进行一些参数的合法性检测后,再转换为Simplenamespace格式输出。

        overrides该参数主要是用于更新默认加载的配置文件中model和data的值,默认配置中上述值均为None,如下图所示:

更新后的配置如下图所示:

3.2 self.device = select_device(self.args.device, self.args.batch) 功能为:

        选择算法运行在CPU还是GPU上,参数batch用于检测设置的batch数值是否是GPU个数的整数倍,若不是整数倍则报错。

3.3  self.check_resume() :判断是否基于之前的断点继续训练,如果是,则加载之前保存的数据参数,本次默认配置参数该值为False.

3.4 接下来创建运行时的文件保存目录,包括本次训练的权重文件保存目录,并保存训练使用的参数以及checkPoint路径等。

# Dirs 创建运行结果保存目录及文件:创建本次训练的目录/ weights保存目录 /保存运行参数project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task #project: runs/detectname = self.args.name or f'{self.args.mode}'  #name: 'train'if hasattr(self.args, 'save_dir'):  #判断是否设置保存路径 ,如果没有则根据项目和任务名创建保存目录self.save_dir = Path(self.args.save_dir)else:self.save_dir = Path(increment_path(Path(project) / name, exist_ok=self.args.exist_ok if RANK in (-1, 0) else True))self.wdir = self.save_dir / 'weights'  # weights dir #runs/detect/train72/weighhtsif RANK in (-1, 0):self.wdir.mkdir(parents=True, exist_ok=True)  # make dirself.args.save_dir = str(self.save_dir)yaml_save(self.save_dir / 'args.yaml', vars(self.args))  # save run args  #保存运行参数self.last, self.best = self.wdir / 'last.pt', self.wdir / 'best.pt'  # checkpoint pathsself.save_period = self.args.save_period   #保存周期

3.5 初始化batch/epoch等参数,这个一目了然,不在解释

3.6  初始化数据集(coco128.yaml),步骤如下: 

        3.6.1 检测传入的数据集参数’dataset’是否是yaml结尾文件

        3.6.2 若是路径并且是压缩格式,则下载数据集配置文件

        3.6.3  加载coco128.yaml,通过函数yaml_load()加载

def check_det_dataset(dataset, autodownload=True):"""Download, check and/or unzip dataset if not found locally."""data = check_file(dataset)  #dataset: coco128.yaml #判断文件是否合法,如果不存在在下载,或者从本地搜索# Download (optional)extract_dir = ''if isinstance(data, (str, Path)) and (zipfile.is_zipfile(data) or is_tarfile(data)): #判断数据集是否时zip or tar压缩格式 #new_dir = safe_download(data, dir=DATASETS_DIR, unzip=True, delete=False, curl=False)data = next((DATASETS_DIR / new_dir).rglob('*.yaml'))extract_dir, autodownload = data.parent, False# Read yaml (optional)if isinstance(data, (str, Path)):data = yaml_load(data, append_filename=True)  # dictionary #读取数据集yam文件 simplenamespace格式# Checks 必要参数检测for k in 'train', 'val':if k not in data: #如果数据中既不包含 train也不包含 val,则报错raise SyntaxError(emojis(f"{dataset} '{k}:' key missing ❌.\n'train' and 'val' are required in all data YAMLs."))if 'names' not in data and 'nc' not in data:raise SyntaxError(emojis(f"{dataset} key missing ❌.\n either 'names' or 'nc' are required in all data YAMLs."))if 'names' in data and 'nc' in data and len(data['names']) != data['nc']:raise SyntaxError(emojis(f"{dataset} 'names' length {len(data['names'])} and 'nc: {data['nc']}' must match."))if 'names' not in data: #如果没有names则,用数字代替data['names'] = [f'class_{i}' for i in range(data['nc'])]else:data['nc'] = len(data['names'])data['names'] = check_class_names(data['names']) #检测data['names']是否是dict,以及将key转换为数字# Resolve pathspath = Path(extract_dir or data.get('path') or Path(data.get('yaml_file', '')).parent)  # dataset rootif not path.is_absolute():path = (DATASETS_DIR / path).resolve() #转化为绝对路径data['path'] = path  # download scriptsfor k in 'train', 'val', 'test':  #全部转换为绝对路径if data.get(k):  # prepend pathif isinstance(data[k], str):x = (path / data[k]).resolve()if not x.exists() and data[k].startswith('../'):x = (path / data[k][3:]).resolve()data[k] = str(x)else:data[k] = [str((path / x).resolve()) for x in data[k]]# Parse yamltrain, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))if val:val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])]  # val pathif not all(x.exists() for x in val):  #不存在则下载name = clean_url(dataset)  # dataset name with URL auth strippedm = f"\nDataset '{name}' images not found ⚠️, missing paths %s" % [str(x) for x in val if not x.exists()]if s and autodownload:LOGGER.warning(m)else:m += f"\nNote dataset download directory is '{DATASETS_DIR}'. You can update this in '{SETTINGS_YAML}'"raise FileNotFoundError(m)t = time.time()if s.startswith('http') and s.endswith('.zip'):  # URLsafe_download(url=s, dir=DATASETS_DIR, delete=True)r = None  # successelif s.startswith('bash '):  # bash scriptLOGGER.info(f'Running {s} ...')r = os.system(s)else:  # python scriptr = exec(s, {'yaml': data})  # return Nonedt = f'({round(time.time() - t, 1)}s)'s = f"success ✅ {dt}, saved to {colorstr('bold', DATASETS_DIR)}" if r in (0, None) else f'failure {dt} ❌'LOGGER.info(f'Dataset download {s}\n')check_font('Arial.ttf' if is_ascii(data['names']) else 'Arial.Unicode.ttf')  # download fontsreturn data  # dictionary

        其中,data = yaml_load(data, append_filename=True)加载完成后,data内容如下:

注意:’nc’:80 是通过 data['nc'] = len(data['names']) 后添加的。

   3.6.4 将data中的路径全部转换为绝对路径

 for k in 'train', 'val', 'test':  #全部转换为绝对路径if data.get(k):  # prepend pathif isinstance(data[k], str):x = (path / data[k]).resolve()if not x.exists() and data[k].startswith('../'):x = (path / data[k][3:]).resolve()data[k] = str(x)else:data[k] = [str((path / x).resolve()) for x in data[k]]

        转换完成并更新data后,data的内容如下,其中train,val,test等键的值变为了绝对路径:

        3.6.5 获取训练集、测试集、验证集、以及下载路径

train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download'))

        3.6.6 最終返回data,数据类型为字典,完成对coco128.yaml文件的加载解析及校验工作。

3.7  获取训练集和验证集的路径

self.trainset, self.testset = self.get_dataset(self.data) #初始化训练集测试集参数 获取路径

        其中,获取路径方法函数实现过程如下:

def get_dataset(data):"""Get train, val path from data dict if it exists. Returns None if data format is not recognized."""return data['train'], data.get('val') or data.get('test')

3.8 其他学习率、损失函数等都设置为None

        self.ema = None# Optimization utils initself.lf = None   #损失函数self.scheduler = None  #学习率调整策略# Epoch level metrics 指标self.best_fitness = Noneself.fitness = Noneself.loss = None   #当前损失值self.tloss = None  #总损失值self.loss_names = ['Loss']self.csv = self.save_dir / 'results.csv'self.plot_idx = [0, 1, 2]

3.9 设置用于结果展示获取的一些回调函数

        # Callbacksself.callbacks = _callbacks or callbacks.get_default_callbacks()if RANK in (-1, 0):callbacks.add_integration_callbacks(self)

        至此,trainer的初始化过程解析完成。

        总结,本章详细介绍了yolov8训练器trainer的初始化过程,讲解参数的加载替换过程,着重讲解了coco128数据集的加载解析及校验,最后介绍了损失函数学习率的初始化。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_1028208.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Docker-Container

Docker ①什么是容器②为什么需要容器③容器的生命周期容器 OOM容器异常退出容器暂停 ④容器命令清单总览docker createdocker rundocker psdocker logsdocker attachdocker execdocker startdocker stopdocker restartdocker killdocker topdocker statsdocker container insp…

Elasticsearch从入门到精通-07ES底层原理学习

Elasticsearch从入门到精通-07ES底层原理和高级功能 👏作者简介:大家好,我是程序员行走的鱼 📖 本篇主要介绍和大家一块学习一下ES底层原理包括集群原理、路由原理、分配控制、分配原理、文档分析原理、文档并发安全原理以及一些高…

第十四届蓝桥杯JavaA组省赛真题 - 特殊日期

解题思路&#xff1a; 暴力秒了 public class Main {public static void main(String[] args) {int cnt 0;for (int i 1900; i < 9999; i) {for (int j 1; j < 12; j) {for (int k 1; k < days(i, j); k) {if (sum(i) sum(j) sum(k)) cnt;}}}System.out.print…

算法笔记~—位运算

目录 常见位运算&#xff1a; 1、基础位运算 2、对于一个数n。确定、修改这个数n二进制x位。 3、提取&#xff08;确定&#xff09;一个数n最右侧的1&#xff08;bit&#xff09;与干掉最右侧的1&#xff08;bit&#xff09; 4、异或运算律 5、位运算的优先级&#xff1a…

Qt扫盲-QAssisant 集成其他qch帮助文档

QAssisant 集成其他qch帮助文档 一、概述二、Cmake qch例子1. 下载 Cmake.qch2. 添加qch1. 直接放置于Qt 帮助的目录下2. 在 QAssisant中添加 一、概述 QAssisant是一个很好的帮助文档&#xff0c;他提供了供我们在外部添加新的 qch帮助文档的功能接口&#xff0c;一般有两中添…

百度智能云千帆,产业创新新引擎

本文整理自 3 月 21 日百度副总裁谢广军的主题演讲《百度智能云千帆&#xff0c;产业创新新引擎》。 各位领导、来宾、媒体朋友们&#xff0c;大家上午好。很高兴今天在石景山首钢园&#xff0c;和大家一起沟通和探讨大模型的发展趋势&#xff0c;以及百度最近一段时间的思考和…

为什么Python不适合写游戏?

知乎上有热门个问题&#xff1a;Python 能写游戏吗&#xff1f;有没有什么开源项目&#xff1f; Python可以开发游戏&#xff0c;但不是好的选择 Python作为脚本语言&#xff0c;一般很少用来开发游戏&#xff0c;但也有不少大型游戏有Python的身影&#xff0c;比如&#xff1…

sheng的学习笔记-AI-人脸识别

目录:sheng的学习笔记-AI目录-CSDN博客 需要学习卷机神经网络等知识&#xff0c;见ai目录 目录 基础知识&#xff1a; 人脸验证&#xff08;face verification&#xff09; 人脸识别&#xff08;face recognition&#xff09; One-Shot学习&#xff08;One-shot learning&…

PTA-练习8

目录 实验5-3 使用函数求Fibonacci数 实验5-4 输出每个月的天数 实验5-9 使用函数求余弦函数的近似值 实验5-11 空心的数字金字塔 实验6-6 使用函数验证哥德巴赫猜想 实验6-7 使用函数输出一个整数的逆序数 实验6-8 使用函数输出指定范围内的完数 实验8-1-7 数组循环右…

Transformer的前世今生 day11(Transformer的流程)

Transformer的流程 在机器翻译任务中&#xff0c;翻译第一个词&#xff0c;Transformer的流程为&#xff1a; 先将要翻译的句子&#xff0c;一个词一个词的转换为词向量送入编码器层&#xff0c;得到优化过的词向量以及K、V&#xff0c;将K、V送入解码器层&#xff0c;并跟解码…

halcon例程学习——ball.hdev

dev_update_window (off) dev_close_window () dev_open_window (0, 0, 728, 512, black, WindowID) read_image (Bond, die/die_03) dev_display (Bond) set_display_font (WindowID, 14, mono, true, false) *自带的 提示继续 disp_continue_message (WindowID, black, true)…

android studio忽略文件

右键文件&#xff0c;然后忽略&#xff0c;就不会出现在commit里面了 然后提交忽略文件即可

Vue3 + Vite + TS + Element-Plus + Pinia项目(5)对axios进行封装

1、在src文件夹下新建config文件夹后&#xff0c;新建baseURL.ts文件&#xff0c;用来配置http主链接 2、在src文件夹下新建http文件夹后&#xff0c;新建request.ts文件&#xff0c;内容如下 import axios from "axios" import { ElMessage } from element-plus im…

【C++的奇迹之旅】C++关键字命名空间使用的三种方式C++输入输出命名空间std的使用惯例

文章目录 &#x1f4dd;前言&#x1f320; C关键字(C98)&#x1f309; 命名空间&#x1f320;命名空间定义&#x1f309;命名空间使用 &#x1f320;命名空间的使用有三种方式&#xff1a;&#x1f309;加命名空间名称及作用域限定符&#x1f320;使用using将命名空间中某个成员…

【JVM】Java类加载器 和 双亲委派机制

1、java类加载器的分类 JDK8及之前 启动类加载器&#xff0c;BootStrap Class Loader,加载核心类,加载jre/lib目录下的类&#xff0c;C实现的拓展类加载器&#xff0c; Extension Class Loader&#xff0c;加载java拓展类库&#xff0c;jre/lib/ext目录下&#xff0c;比如javax…

蓝桥杯 java 凑算式 16年省赛Java组真题

题目 思路&#xff1a; 求有多少种解法 比如:68/3952/714就是一种解法&#xff0c;53/1972/486 是另一种解法 8/3952/714是可以除尽的 但是后面一个不行 所以我们也要通分 代码&#xff1a; public class 凑算式 {static int[] a {1, 2, 3, 4, 5, 6, 7, 8, 9};static int c…

SpringBoot Redis的使用

官方文档&#xff1a; 官方文档&#xff1a;Spring Data Redis :: Spring Data Redis 和jedis一样&#xff0c;SpringBoot Redis 也可以让我在Java代码中使用redis&#xff0c;同样也是通过引入maven依赖的形式。 加速访问github: 使用steam可以免费加速访问github Spring…

鸿蒙OS开发实例:【页面传值跳转】

介绍 本篇主要介绍如何在HarmonyOS中&#xff0c;在页面跳转之间如何传值 HarmonyOS 的页面指的是带有Entry装饰器的文件&#xff0c;其不能独自存在&#xff0c;必须依赖UIAbility这样的组件容器 如下是官方关于State模型开发模式下的应用包结构示意图&#xff0c;Page就是…

设计模式之单例模式精讲

UML图&#xff1a; 静态私有变量&#xff08;即常量&#xff09;保存单例对象&#xff0c;防止使用过程中重新赋值&#xff0c;破坏单例。私有化构造方法&#xff0c;防止外部创建新的对象&#xff0c;破坏单例。静态公共getInstance方法&#xff0c;作为唯一获取单例对象的入口…

ClickHouse 面试题及答案整理,最新面试题

ClickHouse的数据分布式存储机制是如何设计的&#xff1f; ClickHouse的数据分布式存储机制设计包括以下几个方面&#xff1a; 1、分片和复制&#xff1a; ClickHouse通过分片将数据水平划分为多个部分&#xff0c;每个部分存储在不同的节点上。每个分片可以有一个或多个副本…