【代码笔记】Pytorch学习 DataLoader模块详解

news/2024/5/7 2:03:44/文章来源:https://blog.csdn.net/Small___ming/article/details/125536365

Pytorch DataLoader模块详解

  • dataloader整体结构
  • DataLoader
    • init 初始化
      • 参数解释
      • 代码解析
        • IterableDataset 判断
        • 构建Sampler,单样本
        • 构建BatchSampler,组建batch
        • 构建collate_fn 对获取的batch进行处理
        • 其他的一些逻辑判断
    • _get_iterator
      • 代码解析
    • multiprocessing_context
    • multiprocessing_context
    • __setattr__
    • __iter__
      • 代码解释
    • _auto_collation
      • 代码解析
    • _index_sampler
    • __len__
    • check_worker_number_rationality
  • _SingleProcessDataLoaderIter
    • 代码解析
  • _BaseDataLoaderIter

dataloader整体结构

dataloader主要有6个class构成(可见下图)

  • _DatasetKind:
  • _InfiniteConstantSampler:
  • DataLoader:
  • _BaseDataLoaderIter:
  • _SingleProcessDataLoaderIter:
  • _MultiProcessingDataLoaderIter:
    在这里插入图片描述

DataLoader

我们首先看一下DataLoader的整体结构:

  • init:
  • _get_iterator:
  • multiprocessing_context:
  • multiprocessing_context:
  • setattr:
  • iter:
  • _auto_collation:
  • _index_sampler:
  • len:
  • check_worker_number_rationality:

init 初始化

参数解释

这里会把参数全部列出,这里列出的目的是让大家知道各个参数的意义。实际上很多是用不到的,我用加粗字体表示一些常用的参数。

  • self:代之Dataset这个类本身
  • dataset: Dataset[T_co]是默认值,是你要处理的数据集
  • batch_size: Optional[int] = 1, 可选,默认是1。每个batch可以加载batct_size个数据。
  • shuffle: bool = False, 每轮训练后是否将数据集打乱
  • sampler: Optional[Sampler] = None, 默认是None 自定义方法(某种顺序)从Dataset中取样本,指定这个参数就不能设置shuffle。因为shuffle是打乱数据集的顺序,而sample是以某种顺序取数据,所以二者互斥!sampler可能是获取一整个数据集的数据,是对一整个数据集进行操作,而不是一个batch_size。
  • batch_sampler: Optional[Sampler[Sequence]] = None, 返回一个batch的索引,与batch_size, shuffle, sampler, drop_last互斥
    传入了batch_sampler,相当于已经告诉了PyTorch如何从Dataset取多少数据,怎么取数据去组成一个mini batch,所以不需要以上参数。可以理解为batch_sampler是batch_size和sampler的结合,所以不需要batch_size, sampler, shuffle, drop_last(因为drop_last也是怎么取数据)。
  • num_workers: int = 0, 多进程加载数据,默认为0,即采用主进程加载数据
  • collate_fn: Optional[_collate_fn_t] = None, 聚集函数,用来对一个batch进行后处理,拿到一个batch的数据后进行什么处理,返回处理后的batch数据。默认源码中进行了若干逻辑判断,仅将数据组合起来返回,没有实质性工作。默认collate_fn的声明是:def default_collate(batch): 所以自定义collate_fn需要以batch为输入,以处理后的batch为输出。类似于transform,transform是对单个数据处理,而collate_fn是对单个batch做处理。
  • pin_memory: bool = False, 用于将tensor加载到GPU中进行运算
  • drop_last: bool = False, 是否保存最后一个mini batch,样本数量可能不支持被batch size整除,所以drop_last参数决定是否保留最后一个可能批量较小的batch
  • timeout: float = 0, 控制从进程中获取一个batch数据的时延
  • worker_init_fn: Optional[_worker_init_fn_t] = None, 初始化子进程
  • multiprocessing_context=None,
  • generator=None,
  • prefetch_factor: int = 2, 控制样本在每个进程里的预加载,默认为2
  • persistent_workers: bool = False 控制加载完一次Dataset是否保留进程,默认为False
def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,shuffle: bool = False, sampler: Optional[Sampler] = None,batch_sampler: Optional[Sampler[Sequence]] = None,num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None,pin_memory: bool = False, drop_last: bool = False,timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None,multiprocessing_context=None, generator=None,*, prefetch_factor: int = 2,persistent_workers: bool = False):

代码解析

在DataLoader的__init__函数里,我们可以看到,它实现了:

  1. 判断是否是IterableDataset类型,如果是需要进一步判断参数是否正确
  2. 构建Sampler,单样本
  3. 构建BatchSampler,
  4. 组建batch 构建collate
  5. 其他的一些逻辑判断

IterableDataset 判断

  • IterableDataset应用于数据集非常大,将其完全加载进内存不现实(例如高达几个TB的数据),这时就需要IterableDataset构建可迭代的Dataset类,自定义的Dataset需要继承自torch.util.data.IterableDataset,重写__iter__方法,返回可迭代对象(通常是yield生成器)
  • 对于IterableDataset来说,就没有构建采样器Sampler的需求,因为样本是通过调用__iter__一个个读取出来的。执行封装的DataLoader传进去的batch_size次__iter__方法,就获取到一个mini batch
# 判断dataset是否是IterableDataset类型if isinstance(dataset, IterableDataset):self._dataset_kind = _DatasetKind.Iterable# 按照__iter__获取数据,所以不需要打乱if shuffle is not False:raise ValueError("DataLoader with IterableDataset: expected unspecified ""shuffle option, but got shuffle={}".format(shuffle))elif sampler is not None:# 按照__iter__获取数据,也不再需要sampler获取数据raise ValueError("DataLoader with IterableDataset: expected unspecified ""sampler option, but got sampler={}".format(sampler))elif batch_sampler is not None:# 按照__iter__获取数据,也不再需要batch_sampler获取数据索引raise ValueError("DataLoader with IterableDataset: expected unspecified " "batch_sampler option, but got batch_sampler={}".format(batch_sampler))else:self._dataset_kind = _DatasetKind.Map

构建Sampler,单样本

if sampler is None:  # give default samplersif self._dataset_kind == _DatasetKind.Iterable:# 如果是Iterable的Dataset,就采用迭代的方式获取samplersampler = _InfiniteConstantSampler()else:  # 否则判断是否使用shuffle,使用则随机产生sampler,不使用就按照顺序产生samplerif shuffle:sampler = RandomSampler(dataset, generator=generator)else:sampler = SequentialSampler(dataset)

构建BatchSampler,组建batch

  • 注意,上面说batch_sampler不能和batch_size、sampler、drop_last同时使用是指:如果已经定义了batch_sampler则与batch_size和sampler互斥!!!前提是已经定义了batch_sampler!!!但是如果没有定义batch_sampler,则可以通过batch_size,sampler,dorp_last来组建batch!!!
# 要取batch_size个sampler,但是还没有取,即batch_sampler==None
if batch_size is not None and batch_sampler is None:# 获取batch_size个sampler个索引batch_sampler = BatchSampler(sampler, batch_size, drop_last)

构建collate_fn 对获取的batch进行处理

if collate_fn is None:if self._auto_collation:# 默认的实际上什么也没干collate_fn = _utils.collate.default_collateelse:collate_fn = _utils.collate.default_convert

其他的一些逻辑判断

# sampler 不能和 shuffle 同时出现
# 因为shuffle是将数据打乱,而sampler是按照某一顺序获取数据if sampler is not None and shuffle:raise ValueError('sampler option is mutually exclusive with ''shuffle')if batch_sampler is not None:# batch_sampler不能和batch_size,shuffle,sampler,drop_last同时使用。# batch_sampler可以理解为batch_size和sampler的结合if batch_size != 1 or shuffle or sampler is not None or drop_last:raise ValueError('batch_sampler option is mutually exclusive ''with batch_size, shuffle, sampler, and '                          'drop_last')batch_size = Nonedrop_last = Falseelif batch_size is None:# batch_size为None,默认是1,如果drop_last为True就会舍弃最后一个,这样数据就会减少。(构成了一个batch但是仍然舍弃掉)if drop_last:raise ValueError('batch_size=None option disables auto-batching ''and is mutually exclusive with drop_last')self.collate_fn = collate_fnself.persistent_workers = persistent_workersself.__initialized = Trueself._IterableDataset_len_called = None  # See NOTE [ IterableDataset and __len__ ]self._iterator = Noneself.check_worker_number_rationality()torch.set_vital('Dataloader', 'enabled', 'True')  # type: ignore[attr-defined]

_get_iterator

代码解析

def _get_iterator(self) -> '_BaseDataLoaderIter':if self.num_workers == 0:# 单线程return _SingleProcessDataLoaderIter(self)else:# 多线程self.check_worker_number_rationality()return _MultiProcessingDataLoaderIter(self)

multiprocessing_context

multiprocessing_context

setattr

iter

代码解释

 # 其中 -> '_BaseDataLoaderIter' 是函数注释,运行时跟没有加注解之前的效果也没有任何差距。# 主要作用是提醒程序猿这里应该是 '_BaseDataLoaderIter'的数据类型def __iter__(self) -> '_BaseDataLoaderIter':if self.persistent_workers and self.num_workers > 0:if self._iterator is None:self._iterator = self._get_iterator()else:self._iterator._reset(self)return self._iteratorelse:return self._get_iterator()

_auto_collation

代码解析

	@propertydef _auto_collation(self):# 根据batch_sampler判断是否设置_auto_collationreturn self.batch_sampler is not None

_index_sampler

len

check_worker_number_rationality

_SingleProcessDataLoaderIter

代码解析

def __init__(self, loader):super(_SingleProcessDataLoaderIter, self).__init__(loader)assert self._timeout == 0assert self._num_workers == 0self._dataset_fetcher = _DatasetKind.create_fetcher(self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)def _next_data(self):# 获取索引index = self._next_index()  # may raise StopIteration# 获取数据data = self._dataset_fetcher.fetch(index)  # may raise StopIterationif self._pin_memory:data = _utils.pin_memory.pin_memory(data)# 返回数据return data

_BaseDataLoaderIter

__next__方法会调用_next_data,_next_data获取一个batch的数据

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_283039.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Python】轻松掌握基础语法(一)

文章目录常量和表达式变量和类型变量的定义变量的使用变量的类型intfloatstrbool动态类型注释输入和输出输出输入运算符算数运算符关系运算符逻辑运算符赋值运算符其他常量和表达式 print(1 2 * 3)print是Python内置的一个函数,作用为输入打印到控制台形如1 2 * …

人工智能前沿——「全域全知全能」人类新宇宙ChatGPT

🚀🚀🚀OpenAI聊天机器人ChatGPT——「全域全知全能」人类全宇宙大爆炸!!🔥🔥🔥 一、什么是ChatGPT?🍀🍀 ChatGPT是生成型预训练变换模型(Chat G…

1.半导体基础知识

1.半导体基础知识本征半导体什么是半导体?什么是本征半导体?本征半导体的结构本征半导体中的两种载流子为什么将自然界导电性能中等的半导体材料制成本征半导体杂质半导体N型半导体P型半导体PN结PN结中的扩散运动漂移运动和PN结的形成PN结的单向导电性PN…

uniapp开发小程序-swiper点击预览大图(商品详情页轮播图)

1.实现效果&#xff1a; 2.具体代码&#xff1a; <view class"swiper_box"><!--轮播图--><swiper class"ms_swiper" :autoplay"true" circular"true" change"swiperChange"><swiper-item v-for"…

【动态规划模板】神似的01和完全背包、多重背包和分组背包问题

神似的01背包与完全背包&#x1f349; 【经典题目】01背包采药 题目描述 &#x1f388; 辰辰是个天资聪颖的孩子&#xff0c;他的梦想是成为世界上最伟大的医师。为此&#xff0c;他想拜附近最有威望的医师为师。医师为了判断他的资质&#xff0c;给他出了一个难题。医师把他…

【Redis学习】SpringBoot集成Redis

总体概述 jedis-lettuce-RedisTemplate三者的联系 本地Java连接Redis常见问题 bind配置请注释掉 保护模式设置为no Linux系统的防火墙设置 redis服务器的IP地址和密码是否正确 忘记写访问redis的服务端口号和auth密码 集成jedis 简介 Jedis Client是Redis官网推荐的一个面…

【c语言】文件的读写

文件读写使用二进制读写比较方便&#xff0c;分别使用fread和fwrite函数进行。 一、函数定义 以二进制形式读取文件&#xff0c;从stream流中读取内容&#xff0c;读到ptr指向的空间中&#xff0c;读取size大小的count个内存单元。 返回值为读取到的字符个数。 以二进制形式读…

解决Abp设置DefaultLanguage默认语言不生效的问题

文章目录现象原因分析解决问题现象 默认地&#xff0c;Abp的语言提供程序将返回的CultureInfo为En&#xff0c;在一些默认实现的接口&#xff08;比如/api/TokenAuth/Authenticate&#xff09;返回的错误信息是英文 目标是改成简体中文显示&#xff0c;但是即便我们在AbpSett…

android framework-zygote进程

Zygote进程&#xff1a;可以看到zygote的父进程是init进程 一、Zygote整体时序图 涉及源码路径 android-10.0.0_r41\frameworks\base\cmds\app_process\Android.mk android-10.0.0_r41\frameworks\base\cmds\app_process\app_main.cpp android-10.0.0_r41\frameworks\base\core…

图像阈值化

图像阈值化 图像阈值化简介 ⚫ 图像阈值化是图像处理的重要基础部分, 应用很广泛, 可以根据灰度差异来分割图像不同部分 ⚫ 阈值化处理的图像一般为单通道图像(灰度图) ⚫ 阈值化参数的设置可以使用滑动条来debug ⚫ 阈值化处理易光照影响, 处理时应注意 ⚫ 本节主要介绍…

1000题!!阿里P8架构师手写“Java面试宝典”带你横扫全网

序言 很多同学学习Java并发一头扎进源码&#xff0c;最后头破血流&#xff0c;无功而返。横看成岭侧成峰&#xff0c;远近高低各不同。学习要始终从不同的视角来看待问题。学习并发亦是如此&#xff0c;需要通过理论远看轮廓&#xff0c;然后通过源码近看明细。 今天小编分享…

Java之堆和堆排序

目录 一.什么是堆 1.基本介绍 2.堆的实现方式 二.最大堆的实现 1.最大堆 2.思路分析 0.基础操作 1.添加上浮操作 2.删除下沉操作 3.将数组堆化操作 2.代码实现 三.堆排序 1.什么是堆排序 2.思路分析 3.代码实现 一.什么是堆 1.基本介绍 堆是一种数据结构&#…

新增 ABB COMLI 等 5 个工业协议驱动

3 月&#xff0c;Neuron 团队主要在为 2.4.0 版本的发布做准备&#xff0c;进行了官网文档的重构与完善&#xff0c;为常用驱动增加了相应的连接示例及常见问题。同时新增南向驱动 ABB COMLI&#xff0c;此驱动可通过串口连接 ABB 某些型号的 PLC。 新增驱动插件 南向驱动 IE…

kafka之一----概念结构

https://kafka.apache.org/ https://blog.csdn.net/liuyu973971883/article/details/109036572 https://blog.csdn.net/u013256816/article/details/80300225 1、概念 Kafka是由Apache软件基金会开发的一个开源流处理平台&#xff0c;由Scala和Java编写。Kafka是一种高吞吐量…

【Rust基础】语法知识

系列综述&#xff1a; &#x1f49e;目的&#xff1a;本系列是个人学习Rust语言整理的&#xff0c;整理期间苛求每个知识点&#xff0c;平衡理解简易度与深入程度。 &#x1f970;来源&#xff1a;材料主要源于b站的Rust中文社群线上学习室和菜鸟教程进行的&#xff0c;每个知识…

leaflet实现波动的marker效果(131)

第131个 点击查看专栏目录 本示例的目的是介绍如何在vue+leaflet中显示波动的marker效果。 直接复制下面的 vue+leaflet源代码,操作2分钟即可运行实现效果. 文章目录 示例效果配置方式示例源代码(共76行)安装插件相关API参考:专栏目标示例效果 配置方式 1)查看基础设置…

算法学习day52

算法学习day521.力扣 300.最长递增子序列1.1 题目描述1.2分析1.3 代码2.力扣674. 最长连续递增序列2.1 题目描述2.2 分析2.3 代码3.力扣718. 最长重复子数组3.1 题目描述3.2 分析3.3 代码3.参考资料1.力扣 300.最长递增子序列 1.1 题目描述 题目描述&#xff1a; 给一个整数…

CAM类激活映射 |神经网络可视化 | 热力图

文章目录前言&#xff1a;安装库&#xff1a;分类案例--ResNet50分割案例AttributeError: ‘tuple‘ object has no attribute ‘cpu‘RuntimeError: grad can be implicitly created only for scalar outputsTypeError: cant convert cuda:0 device type tensor to numpy. Use…

蓝桥杯嵌入式第十三届省赛题目解析

马上就要比赛了&#xff0c;我也是把自己写完调试好的题目分享出来给大家&#xff0c;同时也祝大家取得自己理想的成绩。 好了废话不多说&#xff0c;我们先看客观题再看程序设计题。 目录 客观题&#xff1a; 程序设计题&#xff1a; 题目解析&#xff1a; CubeMX配置 …

分期的秘密:名义利率和实际利率

分期付款&#xff0c;是一种常见的消费方式&#xff0c;但是这其中却有不少猫腻。名义上的年化利率和实际上的利率竟有可能相差两倍之多。今天&#xff0c;我就以招行的现金分期举例&#xff0c;简单剖析一下其中的玄机。 以上就是招行现金分期的月利率&#xff0c;我们做一点小…