【模型复现】resnet,使用net.add_module()的方法构建模型。小小的改进大大的影响,何大神思路很奇妙,基础很扎实

news/2024/5/6 14:30:14/文章来源:https://blog.csdn.net/weixin_43424450/article/details/130042146
  • 从经验来看,网络的深度对模型的性能至关重要,当增加网络层数后,网络可以进行更加复杂的特征模式的提取,所以当模型更深时理论上可以取得更好的结果。但是更深的网络其性能一定会更好吗?实验发现深度网络出现了退化问题(Degradation problem):网络深度增加时,网络准确度出现饱和,甚至出现下降。在 [1512.03385] Deep Residual Learning for Image Recognition (arxiv.org) 中表明56层的网络比20层网络效果还要差。这不会是过拟合问题,因为56层网络的训练误差同样高。我们知道深层网络存在着梯度消失或者爆炸的问题,这使得深度学习模型很难训练。但是现在已经存在一些技术手段如BatchNorm来缓解这个问题。因此,出现深度网络的退化问题是非常令人诧异的。【AI Talking】CVPR2016 最佳论文, ResNet 现场演讲 - 知乎 (zhihu.com)

  • 深度网络的退化问题【读点论文】Deep Residual Learning for Image Recognition 训练更深的网络_羞儿的博客-CSDN博客至少说明深度网络不容易训练。但是我们考虑这样一个事实:现在你有一个浅层网络,你想通过向上堆积新层来建立深层网络,一个极端情况是这些增加的层什么也不学习,仅仅复制浅层网络的特征,即这样新层是恒等映射(Identity mapping)。在这种情况下,深层网络应该至少和浅层网络性能一样,也不应该出现退化现象。好吧,你不得不承认肯定是目前的训练方法有问题,才使得深层网络很难去找到一个好的参数。后面何凯明提出了残差的结构融合到模型设计中。

  • 残差即观测值与估计值之间的差。假设我们要建立深层网络,当我们不断堆积新的层,但增加的层什么也不学习(极端情况),那么我们就仅仅复制浅层网络的特征,即新层是浅层的恒等映射(Identity mapping),这样深层网络的性能应该至少和浅层网络一样,那么退化问题就得到了解决。

  • 假设要求解的映射为 H(x),也就是观测值,假设上一层 resnet/上一个残差块输出的特征映射为 x(identity function跳跃连接),也就是估计值。那么我们就可以把问题转化为求解网络的残差映射函数 F(x) = H(x) - x。如果网络很深,出现了退化问题,那么我们就只需要让我们的残差映射F(x)等于 0,即要求解的映射 H(x)就等于上一层输出的特征映射 x,因为x是当前输出的最佳解,这样我们这一层/残差块的网络状态仍是最佳的一个状态。但是上面提到的是理想假设,实际情况下残差F(x)不会为0,x肯定是很难达到最优的,但是总会有那么一个时刻它能够无限接近最优解。采用ResNet的话,就只用小小的更新F(x)部分的权重值就行了,可以更好地学到新的特征。

  • ResNet网络是参考了VGG19网络,在其基础上进行了修改,并通过短路机制加入了残差单元,如下图所示。变化主要体现在ResNet直接使用stride=2的卷积做下采样,并且用global average pool层替换了全连接层。ResNet的一个重要设计原则是:当feature map大小降低一半时,feature map的数量增加一倍,这保持了网络层的复杂度。ResNet相比普通网络每两层间增加了短路机制,这就形成了残差学习,其中虚线表示feature map数量发生了改变。下图展示的34-layer的ResNet,还可以构建更深的网络如下表所示。从表中可以看到,对于18-layer和34-layer的ResNet,其进行的两层间的残差学习,当网络更深时,其进行的是三层间的残差学习,三层卷积核分别是1x1,3x3和1x1,一个值得注意的是隐含层的feature map数量是比较小的,并且是输出feature map数量的1/4。

    • 在这里插入图片描述

    • 在这里插入图片描述

  • ResNet使用两种残差单元,如下图所示。左图对应的是浅层网络,而右图对应的是深层网络。对于短路连接,当输入和输出维度一致时,可以直接将输入加到输出上。但是当维度不一致时(对应的是维度增加一倍),这就不能直接相加。有两种策略:(1)采用zero-padding增加维度,此时一般要先做一个downsamp,可以采用strde=2的pooling,这样不会增加参数;(2)采用新的映射(projection shortcut),一般采用1x1的卷积,这样会增加参数,也会增加计算量。短路连接除了直接使用恒等映射,当然都可以采用projection shortcut。你必须要知道CNN模型:ResNet - 知乎 (zhihu.com)

    • 在这里插入图片描述

    • ResNet block有两种,一种两层结构,一种是三层的bottleneck结构,即将两个3x3的卷积层替换为1x1 + 3x3 + 1x1,它通过1x1 conv来巧妙地缩减或扩张feature map维度,从而使得我们的3x3 conv的filters数目不受上一层输入的影响,它的输出也不会影响到下一层。中间3x3的卷积层首先在一个降维1x1卷积层下减少了计算,然后在另一个1x1的卷积层下做了还原。既保持了模型精度又减少了网络参数和计算量,节省了计算时间。

  • 作者对比18-layer和34-layer的网络效果,如下图所示。可以看到普通的网络出现退化现象,但是ResNet很好的解决了退化问题。

    • 在这里插入图片描述
  • 何凯明作者在他的另外一个工作中又对不同的残差单元做了细致的分析与实验,这里直接抛出最优的残差结构,如下图所示。改进前后一个明显的变化是采用pre-activation,BN和ReLU都提前了。而且作者推荐短路连接采用恒等变换,这样保证短路连接不会有阻碍。[1603.05027] Identity Mappings in Deep Residual Networks (arxiv.org)

    • 在这里插入图片描述

    • 改进前后一个明显的变化是采用pre-activation,BN和ReLU都提前了。而且作者推荐短路连接采用恒等变换,这样保证短路连接不会有阻碍。

使用pytorch复现resnet模型

  • 导包,查看设备信息

  • import time
    import torch
    from torch import nn, optim
    import torch.nn.functional as F
    import sys
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(torch.__version__)
    print(device)
    
  • 1.13.1
    cpu
    
  • 模块构建

  • class Residual(nn.Module):  def __init__(self, in_channels, out_channels, use_1x1conv=False, stride=1):super(Residual, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)if use_1x1conv:self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)else:self.conv3 = None  # 不做任何操作self.bn1 = nn.BatchNorm2d(out_channels)self.bn2 = nn.BatchNorm2d(out_channels)def forward(self, X):Y = F.relu(self.bn1(self.conv1(X)))Y = self.bn2(self.conv2(Y))if self.conv3:  # 残差部分,有的需要经过1*1的卷积,有的直接连接过来不做任何操作X = self.conv3(X)return F.relu(Y + X)
    blk = Residual(3, 3)
    X = torch.rand((4, 3, 6, 6))
    print(blk(X).shape)
    blk = Residual(3, 6, use_1x1conv=True, stride=2)
    print(blk(X).shape)
    
  • torch.Size([4, 3, 6, 6])
    torch.Size([4, 6, 3, 3])
    
  • 模型搭建

  • net = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3), # stemnn.BatchNorm2d(64), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
    class GlobalAvgPool2d(nn.Module):# 全局平均池化层可通过将池化窗口形状设置成输入的高和宽实现def __init__(self):super(GlobalAvgPool2d, self).__init__()def forward(self, x):return F.avg_pool2d(x, kernel_size=x.size()[2:])
    class FlattenLayer(torch.nn.Module):def __init__(self):super(FlattenLayer, self).__init__()def forward(self, x): # x shape: (batch, *, *, ...)return x.view(x.shape[0], -1)
    def resnet_block(in_channels, out_channels, num_residuals, first_block=False):if first_block:assert in_channels == out_channels # 第一个模块的通道数同输入通道数一致blk = []for i in range(num_residuals):if i == 0 and not first_block:blk.append(Residual(in_channels, out_channels, use_1x1conv=True, stride=2))else:blk.append(Residual(out_channels, out_channels))return nn.Sequential(*blk)
    net.add_module("resnet_block1", resnet_block(64, 64, 2, first_block=True)) # 构建网络的另一种方式
    net.add_module("resnet_block2", resnet_block(64, 128, 2))
    net.add_module("resnet_block3", resnet_block(128, 256, 2))
    net.add_module("resnet_block4", resnet_block(256, 512, 2))
    net.add_module("global_avg_pool", GlobalAvgPool2d()) # GlobalAvgPool2d的输出: (Batch, 512, 1, 1)
    net.add_module("fc", nn.Sequential(FlattenLayer(), nn.Linear(512, 10))) 
    X = torch.rand((1, 1, 224, 224))
    for name, layer in net.named_children():X = layer(X)print(name, ' output shape:\t', X.shape)
    
  • 0  output shape:     torch.Size([1, 64, 112, 112])
    1  output shape:     torch.Size([1, 64, 112, 112])
    2  output shape:     torch.Size([1, 64, 112, 112])
    3  output shape:     torch.Size([1, 64, 56, 56])
    resnet_block1  output shape:     torch.Size([1, 64, 56, 56])
    resnet_block2  output shape:     torch.Size([1, 128, 28, 28])
    resnet_block3  output shape:     torch.Size([1, 256, 14, 14])
    resnet_block4  output shape:     torch.Size([1, 512, 7, 7])
    global_avg_pool  output shape:     torch.Size([1, 512, 1, 1])
    fc  output shape:     torch.Size([1, 10])
    
  • 数据加载及模型训练

  • import torchvision
    def evaluate_accuracy(data_iter, net, device=None):if device is None and isinstance(net, torch.nn.Module):# 如果没指定device就使用net的devicedevice = list(net.parameters())[0].device acc_sum, n = 0.0, 0with torch.no_grad():for X, y in data_iter:if isinstance(net, torch.nn.Module):net.eval() # 评估模式, 这会关闭dropoutacc_sum += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()net.train() # 改回训练模式else: if('is_training' in net.__code__.co_varnames): # 如果有is_training这个参数# 将is_training设置成Falseacc_sum += (net(X, is_training=False).argmax(dim=1) == y).float().sum().item() else:acc_sum += (net(X).argmax(dim=1) == y).float().sum().item() n += y.shape[0]return acc_sum / n
    def train_mnist(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs):net = net.to(device)print("training on ", device)loss = torch.nn.CrossEntropyLoss()for epoch in range(num_epochs):train_l_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time()for X, y in train_iter:X = X.to(device)y = y.to(device)y_hat = net(X)l = loss(y_hat, y)optimizer.zero_grad()l.backward()optimizer.step()train_l_sum += l.cpu().item()train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()n += y.shape[0]batch_count += 1test_acc = evaluate_accuracy(test_iter, net)print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'% (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))
    def load_data_fashion_mnist(batch_size, resize=None, root='~/Datasets/FashionMNIST'):"""Download the fashion mnist dataset and then load into memory."""trans = []if resize:trans.append(torchvision.transforms.Resize(size=resize))trans.append(torchvision.transforms.ToTensor())transform = torchvision.transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root=root, train=True, download=True, transform=transform)mnist_test = torchvision.datasets.FashionMNIST(root=root, train=False, download=True, transform=transform)if sys.platform.startswith('win'):num_workers = 0  # 0表示不用额外的进程来加速读取数据else:num_workers = 4train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)return train_iter, test_iter
    batch_size = 256
    # 如出现“out of memory”的报错信息,可减小batch_size或resize
    train_iter, test_iter = load_data_fashion_mnist(batch_size, resize=96)
    lr, num_epochs = 0.001, 5
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    train_mnist(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)
    
  • training on  cpu
    epoch 1, loss 0.4161, train acc 0.847, test acc 0.845, time 1232.7 sec
    epoch 2, loss 0.2490, train acc 0.908, test acc 0.904, time 1202.3 sec
    epoch 3, loss 0.2074, train acc 0.924, test acc 0.905, time 1212.4 sec
    epoch 4, loss 0.1831, train acc 0.932, test acc 0.915, time 1142.8 sec
    epoch 5, loss 0.1570, train acc 0.942, test acc 0.876, time 1127.6 sec
    

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_283888.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python玄阶斗技--tkinter事件

在前一篇文章中,我们已经了解是tkinter的一些标签的使用,但一个GUI程序除了让别人看到,还要有一些交互操作,实现人机交互的方法我们称为事件,通过事件分为:鼠标事件,键盘事件和窗口事件。接下来…

Neo4j初学者使用记录(在更)

打开Neo4j cmdR 输入neo4j console 浏览器中输入框中网址:http://localhost:7474/即可打开 新建库 服务器版需要更改配置文件,若neo4j服务正在运行,则按Ctrlc,停止该服务。 配置完后,再重新开启服务,刷新…

如何利用ventoy制作Linux to go (把deepin放到U盘里)

准备工作 最新版本 – 深度科技社区 (deepin.org) deepin镜像官方下载即可 Releases ventoy/vtoyboot GitHub ventoy启动插件选择1.0.29版本 Downloads – Oracle VM VirtualBox VirtualBox虚拟机官网 ventoy下载 VentoyRelease (lanzoui.com) 选择下载1.0.29版本 vento…

第五十八章 线段树(一)

第五十八章 线段树(一)一、树状数组的缺陷二、线段树的作用三、线段树的基本构成1、节点定义2、线段树的结构四、线段树的重要函数1、构造线段树——bulid函数2、查询区间——query函数3、单点修改——modify函数五、例题一、树状数组的缺陷 在前面两个…

对于电商行业来讲,真正决定它的并不是规模,而是载体

纵然是在现在这样的情况之下,我们依然无法用「格局已定」来形容和阐述现在的电商市场格局。这一点,我们可以从以抖音、快手为代表的电商新势力的崛起当中,看出一丝端倪。对于电商行业来讲,真正决定它的并不是规模,而是…

Dart中的异步

一 事件循环 flutter 就是运行在一个root isolate 中 程序只要运行起来,就有一个事件循环一直在运行 ,直至程序退出。 EventLoop 先从mrcro 对列中取任务,取完任务再去 event 队列中取任务。队列任务是FIFO。 二 认识Future abstract clas…

[JavaEE]----Spring03

文章目录Spring_day031,AOP简介1.1 什么是AOP?1.2 AOP作用1.3 AOP核心概念2,AOP入门案例2.1 需求分析2.2 思路分析2.3 环境准备2.4 AOP实现步骤步骤1:添加依赖步骤2:定义接口与实现类步骤3:定义通知类和通知步骤4:定义切入点步骤5:制作切面步骤6:将通知…

C++内存管理(new和delete)

目录 1. new/delete操作内置类型 2. new和delete操作自定义类型 3. operator new与operator delete函数 4 .new和delete的实现原理 1 .内置类型 2 .自定义类型 new的原理 delete的原理 new T[N]的原理 delete[]的原理 5. 定位new表达式(placement-new) 6. malloc/f…

使用Process Explorer和Clumsy定位软件高CPU占用问题

目录 1、问题描述 2、使用Process Explorer初步找到CPU占用高的原因 3、使用Clumsy工具在公司内网环境复现了问题 4、根据Process Explorer中的函数调用堆栈,分析源码,最终找出了问题 5、总结 在排查项目客户的视频图像闪烁问题时,无意中…

Centos7安装部署Jenkins

Jenkins简介: Jenkins只是一个平台,真正运作的都是插件。这就是jenkins流行的原因,因为jenkins什么插件都有 Hudson是Jenkins的前身,是基于Java开发的一种持续集成工具,用于监控程序重复的工作,Hudson后来被…

JavaScript基础-02

常量(字面量):数字和字符串 常量也称之为“字面量”,是固定值,不可改变。看见什么,它就是什么。 常量有下面这几种: 数字常量(数值常量)字符串常量布尔常量自定义常量…

【MATLAB数学建模编程实战】Kmeans算法编程及算法的简单原理

欢迎关注,本专栏主要更新MATLAB仿真、界面、基础编程、画图、算法、矩阵处理等操作,拥有丰富的实例练习代码,欢迎订阅该专栏!(等该专栏建设成熟后将开始收费,快快上车吧~~) 【MATLAB数学建模编…

[LeetCode周赛复盘] 第 340 场周赛20230409

[LeetCode周赛复盘] 第 340 场周赛20230409 一、本周周赛总结二、 6361. 对角线上的质数1. 题目描述2. 思路分析3. 代码实现三、6360. 等值距离和1. 题目描述2. 思路分析3. 代码实现四、6359. 最小化数对的最大差值1. 题目描述2. 思路分析3. 代码实现五、 6353. 网格图中最少访…

ROS实践06 自定义消息类型

文章目录运行环境:思路:1.1 定义.msg文件1)功能包下新建 msg 目录,添加文件 Person.msg2)修改package.xml3)修改CMakeLists.txt2.1 自定义消息调用(C)1)编译后修改includePath2)发布方实现2.1修改CMakeLists.txt2.3运行…

【OpenCV-Python】cvui 之 trackbar

CVUI 之 trackbar cvui::trackbar() 渲染一个 trackbar, 可以左右拖动或点击对数字进行增加或减少的调整。 不使用离散间隔 使用离散间隔 Python import numpy as np import cv2 import cvuidef trackbar_test():WINDOW_NAME Trackbar-Test# 创建画布frame np.z…

【Python童年游戏】满满的回忆杀—那些年玩过的童年游戏你还记得吗?那个才是你的菜?看到第一个我就泪奔了(致我们逝去的青春)

导语 滴一一学生卡🙌 结伴上车的学生仔子们 用笑声打破车厢的沉默 大人眼里的晚高峰 是给放学后快乐😀时光的加时 下车的学生匆匆起身带起 一阵熟悉的栀子香于💓 是关于校园的记忆 开始零零散散地闪现 放学后集合的秘密基地/跟着城…

LVGL v8学习笔记 |12 - 移植LVGL 8.3到ESP32C3开发板(ST7789)

一、移植前的准备 1. 基础工程 ESP32-IDF开发笔记 | 03 - 使用SPI外设驱动ST7789 SPILCD2. lvgl源码 https://github.com/lvgl/lvgl下载最新发布的 8.3.6 版本:https://github.com/lvgl/lvgl/releases/ 二、移植lvgl 1. 复制lvgl源码到工程中 将下载的 lvgl-8.3.6 文件夹直…

JSON数据遍历之for-in

JSON数据遍历之for-in object 本身就是无对象的集合,因此在用 for-in 语句遍历对象的属性时,遍历出的属性顺序与对象定义时不同。 W3C标准 根据 ECMA-262(ECMAScript)第三版中描述,for-in 语句的属性遍历的顺序是由对…

KlayGE-001-简介

KlayGE 引擎学习-001 [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-lWlSlet9-1680688988724)(images/KlayGE_logo.png)] 一、KlayGE引擎介绍 软件简介 KlayGE中文译为:粘土游戏引擎,是一个开源、跨平台,基于插…

6-MATLAB APP Design-表格组件(uitable)

此博文通过MATLAB APP Design实现对学生成绩的处理,具体的功能包括读取表格数据、添加学生数据、计算总成绩、成绩排序+以及表格的保存。 一、APP 界面设计展示 1. 在画布中拖入面板、表格和四个按钮,布局如下。将面板的title写为“学生成绩计算器”并居中,将四个按钮的t…