【AI炼丹术】写深度学习代码的一些心得体会

news/2024/4/26 22:13:17/文章来源:https://blog.csdn.net/ARPOSPF/article/details/130346726

写深度学习代码的一些心得体会

  • 体会1
  • 体会2
  • 体会3
  • 总结
  • 内容来源

一般情况下,拿到一批数据之后,首先会根据任务先用领域内经典的Model作为baseline跑通,然后再在这个框架内加入自己设计的Model,微调代码以及修改一些超参数即可。总体流程参考如下:

  1. 先写dataset部分,包括数据的读取、预处理、增广等操作,将数据集准备好。
  2. 然后model部分baseline无需修改,proposed是自行设计,定义模型的结构和参数,建立模型架构。
  3. 最后是train部分,这里调用所有的类实现训练:包括定义模型,模型包裹;获取dataloader;定义loss,优化器,学习率,定义early stoping策略;保存模型权重,保存日志。

当然,文无定法。这个顺序并不是固定不变的,也可以根据具体情况作出相应的调整。例如,当你的数据集已经准备好了,可以直接开始定义模型,然后再定义训练过程;或者在进行模型训练之前,先进行数据集的分析和可视化等操作。

体会1

源自:作者三四但不犹豫
对于图像任务:

  1. 顺序上,先写dataset部分,检查基本的transform,再搭model,构建head和loss,就可以把一个基础的、可以跑的网络就能跑起来了(这点很重要);
  2. 可视化很重要,如果是本地开发机,善用cv.imshow直观、便捷地可视化处理的结果;
  3. 一个基础的train/inference流程跑通后,分别构建1 张、10 张的数据用于debug,确保任意改动后,可以overfit;
  4. 调试代码阶段避免随机性、避免数据增强,一定用tensorboard之类的工具观察 loss 下降是否合理;
  5. 一般数据集最好处理成coco的格式,我的任务跟传统任务不太一样,但也尽量仿照coco来设计,写dataset的时候可以参考开源实现;
  6. 善用开源框架,比如Open-MMLab,Detectron2之类的,好处是方便实验,在框架里写不容易出现难以察觉的bug,坏处是开源框架为了适配各种网络,代码复杂程度会高一点,建议从第一版入手了解框架,然后基于最新的一边阅读一边开发。

体会2

源自:捡到一束光
先给结论:以写了两三年pytorch代码的经验而言,比较好的顺序是先写model,再写dataset,最后写train。在讨论码组件的具体顺序前,先分析每一个组件背后的目的和逻辑。

  • model构成了整个深度学习训练与推断系统骨架,也确定了整个AI模型的输入和输出格式
    • 对于视觉任务,模型架构多为卷积神经网络或是最新的ViT模型
    • 对于NLP任务,模型架构多为Transformer以及Bert
    • 对于时间序列预测,模型架构多为RNNLSTM

不同的model对应了不同的数据输入格式,如ResNet一般是输入多通道二维矩阵,而ViT则需要输入带有位置信息的图像patchs。确定了用什么样的model后,数据的输入格式也就确定下来。根据确定的输入格式,我们才能构建对应的dataset。

  • dataset构建了整个AI模型的输入与输出格式。

    • 在写作dataset组件时,我们需要考虑数据的存储位置与存储方式,如数据是否是分布式存储的,模型是否要在多机多卡的情况下运行,读写速度是否存在瓶颈,如果机械硬盘带来了读写瓶颈则需要将数据预加载进内存等。
    • 在写dataset组件时,我们也要反向微调model组件。例如,确定了分布式训练的数据读写后,需要用nn.DataParallel或者nn.DistributedDataParallel等模块包裹model,使模型能够在多机多卡上运行。
    • 此外,dataset组件的写作也会影响训练策略,这也为构建train组件做了铺垫。比如根据显存大小,我们需要确定相应的BatchSize,而BatchSize则直接影响学习率的大小。再比如根据数据的分布情况,我们需要选择不同的采样策略进行Feature Balance,而这也会体现在训练策略中。
  • train构建了模型的训练策略以及评估方法,它是最重要也是最复杂的组件。先构建model与dataset可以添加限制,减少train组件的复杂度。

    • 在train组件中,我们需要根据训练环境(单机多卡,多机多卡或是联邦学习)确定模型更新的策略,以及确定训练总时长epochs,优化器的类型,学习率的大小与衰减策略,参数的初始化方法,模型损失函数
    • 此外,为了对抗过拟合,提升泛化性,还需要引入合适的正则化方法,如Dropout,BatchNorm,L2-Regularization,Data Augmentation等。
    • 有些提升泛化性能的方法可以直接在train组件中实现(如添加L2-Reg,Mixup),有些则需要添加进model中(如Dropout与BatchNorm),还有些需要添加进dataset中(如Data Augmentation)。。

此外,train还需要记录训练过程的一些重要信息,并将这些信息可视化出来,比如在每个epoch上记录训练集的平均损失以及测试集精度,并将这些信息写入tensorboard,然后在网页端实时监控。在构建train组件中,我们需要随时根据模型表现进行参数微调,并根据结果改进model和dataset两个组件。
tensorboard

体会3

源自:芙兰朵露
作为data driven的学科,不同的AI model适合不同的数据类型,选择用哪个模型是基于你的数据长什么样来决定的。初学者知道用CNN处理图片,用RNN处理时间序列/语言,但这些都是最基础的工作,真正体现水平的是根据数据的性质来选择合适的细分模型。比如稀疏图像需要用Sparse CNN,语言Transformer效果比较好,但对某些特殊的时间序列RNN也有奇效。

接下来还有很多技术细节,比如需不需要数据增强?需不需要标签平滑?需不需要残差链接?需不需要多loss,如果需要如何平衡?需不需要解释模型?我甚至没有提到超参数,因为超参数是锦上添花而不是雪中送炭。只要没有明确的信息瓶颈,超参数对模型的影响是很小的。

上面提到的这些问题不需要全想明白,但心里要大致有个谱,至少也要知道这些问题是可能影响你的训练结果的,这其实需要相当的阅读和积累。这样之后出了问题才知道去哪里debug。

然后就可以开始写了。这些问题想明白之后,其实先写哪个part已经不重要了,因为你的心中已经有了一个picture,先把这个picture给sketch下来,然后开始跑,第一遍效果肯定不好,但你要根据输出的结果大致判断哪个部分出了问题,然后针对性地去改进。这一步真的没什么好办法,很多时候其实是直觉,做多了自然就知道了。训练模型-发现问题-修改模型-再训练,就像炼丹一样,经过无数遍的抟炼,才能得到最后的金丹。

其实洋洋洒洒说了这么多,本质不过是几个字:解决问题的能力making things to work几乎是机器学习中最重要的能力了,而这种能力就是在日常的积累和训练中反复磨练出来的,成功的路上没有捷径

总结

单纯就个人习惯而言,先写model,确保model的结果没有错误,调试正确。然后写dataset,并调试输出正确。之后写损失函数,并调试正确。最后写train训练代码,推理代码。

内容来源

  1. 写深度学习代码是先写model还是dataset还是train呢,有个一般化的顺序吗?
  2. A Recipe for Training Neural Networks

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_103372.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

汇编语言(第3版) - 学习笔记 - 实验8 分析一个奇怪的程序

实验8 分析一个奇怪的程序 题目解析顺序执行查看反汇编测试一下 题目 分析下面的程序,在运行前思考:这个程序可以正确返回吗? 运行后再思考:为什么是这种结果? 通过这个程序加深对相关内容的理解。 assume cs:codesg codesg segmentmov ax, 4c00h int 21h …

JavaWeb-Tomcat

目录 1.什么是Tomcat 2.Tomcat 概述 3.Tomcat基本使用 1.什么是Tomcat Tomcat官网:Apache Tomcat - Welcome! 【摘自百度百科】 Tomcat是Apache 软件基金会(Apache Software Foundation)的Jakarta 项目中的一个核心项目,由Apac…

MySQL: 数据类型之整数型、浮点数、时间日期

目录 前言: 数据类型: 整数型: 浮点数与定点数: 浮点数: 定点数: 日期与时间: DATATIME: DATE: TIMESTAMP: ​编辑 YEAR: TIME: 前言: 前面的几篇写了如何创…

2023年主流的选择仍是Feign, http客户端Feign还能再战

👳我亲爱的各位大佬们好😘😘😘 ♨️本篇文章记录的为 微服务组件之http客户端Feign 相关内容,适合在学Java的小白,帮助新手快速上手,也适合复习中,面试中的大佬🙉🙉🙉。 …

音视频开发面试题大盘点:掌握这些基础知识,你就能轻松应对面试

前言 音视频开发作为一种高技术含量的领域,随着人们对数字媒体的需求不断增加,其前景非常广阔。预计在2023年,音视频开发领域仍将继续保持快速发展的态势,尤其是在移动互联网、物联网、虚拟现实、增强现实等领域。 根据BOSS招聘…

Jenkins Kubernetes

Kubernetes集成Harbor Harbor 私服配置 在Kubernetes的master和所有worker节点上加上harbor配置,修改daemon.json,支持Docker仓库,并重启Docker。 sudo vim /etc/docker/daemon.json {"registry-mirrors": ["https://jrabv…

微信小程序 开发中的问题(simba_wx)

目录 一、[将 proto 文件转成 json 文件](https://blog.csdn.net/wzxzRoad/article/details/129300513)二、[使用 test.json 文件](https://blog.csdn.net/wzxzRoad/article/details/129300513)三、[微信小程序插件网址](https://ext.dcloud.net.cn/)四、[vant-weapp网址](http…

从0搭建Vue3组件库(八):使用 release-it 实现自动管理发布组件库

使用 release-it 实现自动管理发布组件库 上一篇文章已经打包好我们的组件库了,而本篇文章将介绍如何发布一个组件库。当然本篇文章介绍的肯定不单单只是发布那么简单。 组件库发布 我们要发布的包名为打包后的 easyest,因此在 easyest 下执行pnpm init生成package.json {&…

本地缓存解决方案Caffeine | Spring Cloud 38

一、Caffeine简介 Caffeine是一款高性能、最优缓存库。Caffeine是受Google guava启发的本地缓存(青出于蓝而胜于蓝),在Cafeine的改进设计中借鉴了 Guava 缓存和 ConcurrentLinkedHashMap,Guava缓存可以参考上篇:本地缓…

【Springcloud Alibaba微服务分布式架构 | Spring Cloud】之学习笔记(九)Nacos+Sentinel+Seata

NacosSentinelSeata 9/9 1、SpringCloud Alibaba简介1.1 主要功能1.2 具体组件 2、SpringCloud Alibaba Nacos服务注册和配置中心2.1 Nacos介绍2.2 Nacos下载安装2.3 使用Nacos作为注册中心2.3.1 在父工程的pom文件中引入springcloudalibaba依赖2.3.2 创建cloudalibaba-provide…

适合学生党的蓝牙耳机品牌有哪些?性价比高的无线耳机推荐

相较于有线耳机,蓝牙耳机的受欢迎程度可谓是越来越高,当然,这也离不开部分手机取消耳机孔的设计。最近看到很多网友问,适合学生党的蓝牙耳机品牌有哪些?针对这个问题,我来给大家推荐几款性价比高的无线耳机…

static_cast、dynamic_cast和reinterpret_cast区别和联系

其实网上相关的资料不少,但是能够说清楚明白这个问题的也不多。 于是,我尝试着问了一下AI,感觉回答还可以,但是需要更多的资料验证。 让我们先看看AI是怎么回答这个问题的。 static_cast、dynamic_cast和reinterpret_cast都是C中…

视频音频提取器推荐:快速提取视频中的音频!

视频中的音频可以用于很多用途,比如制作配乐、音频剪辑等。但是,许多人并不知道如何将视频中的音频提取出来。如果您也是这样的情况,那么本文为您介绍一个简单易用的视频音频提取器:。 它是一个免费的在线工具,可以帮…

如何在Web上实现激光点云数据在线浏览和展示?

无人机激光雷达测量是一项综合性较强的应用系统,具有数据精度高、层次细节丰富、全天候作业等优势,能够精确测量三维现实世界,为各个行业提供了丰富有效的数据信息。但无人机激光雷达测量产生的点云数据需要占用大量的存储空间,甚…

DataGridView 真·列头不高亮 真·列头合并

高亮BUG VB.Net,在 .NET Framework 4.8 的 WinForm 下(即不是 WPF 的绘图模式、也不是 Core 或 Mono 的开发框架),使用 DataGridView 行模式,还是有个列头表现为高亮显示: 查找各种解决方式: 设置 ColumnHeadersDefaultCellSty…

YOLOv1代码复现2:数据加载器构建

YOLOv1代码复现2:数据加载器构建 前言 ​ 在经历了Faster-RCNN代码解读的摧残后,下决心要搞点简单的,于是便有了本系列的博客。如果你苦于没有博客详细告诉你如何自己去实现YOLOv1,那么可以看看本系列的博客,也许可以帮…

【Java实战篇】Day13.在线教育网课平台--生成支付二维码与完成支付

文章目录 一、需求:生成支付二维码1、需求分析2、表设计3、接口定义4、接口实现5、完善controller 二、需求:查询支付结果1、需求分析2、表设计与模型类3、接口定义4、接口实现步骤一:查询支付结果步骤二:保存支付结果&#xff08…

VUE3如何定义less全局变量

默认已经安装好了less,这里不过多讲。 (1)首先我们需要下载一个插件依赖: npm i style-resources-loader --save-dev (2)VUE3里配置vue.config.js文件内容 代码: const path require("p…

HashMap如何解决哈希冲突

HashMap如何解决哈希冲突 Hash算法和Hash表Hash冲突解决哈希冲突的方法开放地址法链式寻址法再hash法建立公共溢出区 Hash算法和Hash表 Hash算法就是把任意长度的输入通过散列算法编程固定长度的输出。这个输出结果就是一个散列值。 Hash表又称为“散列表”,它是通…

SpringBoot中一个注解优雅实现重试Retry框架

目录: 1、简介2、实现步骤 1、简介 重试,在项目需求中是非常常见的,例如遇到网络波动等,要求某个接口或者是方法可以最多/最少调用几次;实现重试机制,非得用Retry这个重试框架吗?那肯定不是,相信…