【模型复现】Network in Network,将1*1卷积引入网络设计,运用全局平均池化替代全连接层。模块化设计网络

news/2024/5/3 1:07:09/文章来源:https://blog.csdn.net/weixin_43424450/article/details/130021416
  • 《Network In Network》是一篇比较老的文章了(2014年ICLR的一篇paper),是当时比较厉害的一篇论文,同时在现在看来也是一篇非常经典并且影响深远的论文,后续很多创新都有这篇文章的影子。[1312.4400] Network In Network (arxiv.org)这篇文章采用较少参数就取得了Alexnet的效果,Alexnet参数大小为230M,而Network In Network仅为29M。

  • 卷积网络通常由卷积和池化交替堆叠,最后接全连接完成模型构建,卷积通过线性滤波器对应特征图位置相乘并求和,然后进行非线性激活得到特征图。线性模型足以抽象线性可分的隐含特征,但是实际上这些特征通常是高度非线性的,常规的卷积网络则可以通过采用一组超完备滤波器(尽可能多)提取统一潜在特征各种变体(宁可错杀一千不可放过一个),但是同一潜在特征使用太多的滤波器会给下一层带来额外的负担,需要考虑来自前一层的所有变化的组合,来自更高层的滤波器会映射到原始输入的更大区域,它通过结合下层的较低级概念生成较高级的特征,因此作者认为网络局部模块做出更好的特征抽象会更好,顺势引入Network in Network则能达到这个目标,在每个卷积层内引入一个微型网络,来计算和抽象每个局部块的特征

  • 论文Network in Network的网络结构中有由两处新的结构(当时),MLP Convolution Layers和Global Average Pooling。所谓MLPConv其实就是在常规卷积(感受野大于1的)后接若干1x1卷积,每个特征图视为一个神经元,特征图通过1x1卷积就类似多个神经元线性组合,这样就像是MLP(多层感知机)了,这是文章最大的创新点,也就是Network in Network(网络中内嵌微型网络)。径向基(Radial basis network)和 从多层感知机(multilayer perceptron)是两种通用的函数逼近器,作者选择了多层感知机,因为多层感知器与卷积神经网络的结构一样,都是通过反向传播训练。其次多层感知器本身就是一个深度模型,符合特征再利用的原则。NIN(Network in Network)学习笔记_nin函数_SyGoing的博客-CSDN博客

  • 普通卷积层(感受野大于1)及文中提到的GLM(generalized linear model)相当于单层网络,抽象能力有限。为了提高特征的抽象表达能力,作者用MLPConv代替了GLM。 n为网络层数,第一层为线性卷积层(卷积核尺寸大于1),后面的为1x1卷积。

    • 在这里插入图片描述

    • (a)fi,j,k=max(wkTxi,j,0)(b)fi,j,k11=max(wk11Txi,j+bk1,0)...fi,j,knn=max(wknnTfi,jn−1+bkn,0)(a)f_{i,j,k}=max(w^T_kx_{i,j},0)\\ (b)f_{i,j,k_1}^1=max({w_{k_1}^1}^Tx_{i,j}+b_{k_1},0)\\ ...\\ f_{i,j,k_n}^n=max({w_{k_n}^n}^Tf_{i,j}^{n-1}+b_{k_n},0)\\ (a)fi,j,k=max(wkTxi,j,0)(b)fi,j,k11=max(wk11Txi,j+bk1,0)...fi,j,knn=max(wknnTfi,jn1+bkn,0)

    • 1x1卷积作为NIN函数逼近器基本单元,除了增强了网络局部模块的抽象表达能力外,在现在看来还可以实现跨通道特征融合和通道升维降维。

  • 当时作者应该是第一个使用1x1卷积的,具有划时代的意义,之后的Googlenet借鉴了1*1卷积,还专门致谢过这篇论文,现在很多优秀的网络结构都离不开1x1卷积,ResNet、ResNext、SqueezeNet、MobileNetv1-3、ShuffleNetv1-2等等。

  • 传统卷积神经网络在网络的浅层进行卷积运算。对于分类任务,最后一个卷积层得到的特征图被向量化(flatten)然后送入全连接层,接一个softmax逻辑回归层。这种结构将卷积结构与传统神经网络分类器连接起来,卷积层作为特征提取器,得到的特征用传统神经网络进行分类。全连接层参数量是非常庞大的,模型通常会容易过拟合,针对这个问题,Hinton提出Dropout方法来提高泛化能力,但是全连接的计算量依旧很大。

  • 基于此,论文提出用全局平均池化代替全连接层,具体做法是对最后一层的特征图进行平均池化,得到的结果向量直接输入softmax层。这样做好处之一是使得特征图与分类任务直接关联,另一个优点是全局平均池化不需要优化额外的模型参数,因此模型大小和计算量较全连接大大减少,并且可以避免过拟合。

  • 知道NIN的基本单元,整体网络结构为Input+MLPConv+GAP+softmax,网络结构示意图如下:

    • 在这里插入图片描述
  • Network in Network对常规卷积网络的特征提取抽象表示进行改进,提出MLPconv,其实就是在常规卷积后接1x1卷积(首次使用1x1卷积),首次采用全局平均池化降低网络复杂度,避免过拟合,在之后的很多经典论文中都有用到,具有开创性意义;深度学习发展迅猛,论文很多,但是经典的还是少数,所以很值得学习,以前的ResNet,MobileNetv1-3,ShuffleNetv1-2等等。

pytorch复现NIN

  • 导包,查看配置信息

  • import time
    import torch
    import torchvision
    from torch import nn, optim
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(torch.__version__)
    print(device)
    
  • 1.13.1
    cpu
    
  • NIN模块及模型构建

  • def nin_block(in_channels, out_channels, kernel_size, stride, padding):blk = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),nn.ReLU(),nn.Conv2d(out_channels, out_channels, kernel_size=1),nn.ReLU(),nn.Conv2d(out_channels, out_channels, kernel_size=1),nn.ReLU())return blk  # convnext模块和它好像
    class FlattenLayer(torch.nn.Module):def __init__(self):super(FlattenLayer, self).__init__()def forward(self, x): # x shape: (batch, *, *, ...)return x.view(x.shape[0], -1)
    net = nn.Sequential(nin_block(1, 96, kernel_size=11, stride=4, padding=0),nn.MaxPool2d(kernel_size=3, stride=2),nin_block(96, 256, kernel_size=5, stride=1, padding=2),nn.MaxPool2d(kernel_size=3, stride=2),nin_block(256, 384, kernel_size=3, stride=1, padding=1),nn.MaxPool2d(kernel_size=3, stride=2), nn.Dropout(0.5),# 标签类别数是10nin_block(384, 10, kernel_size=3, stride=1, padding=1),# 全局平均池化层可通过将窗口形状设置成输入的高和宽实现nn.AvgPool2d(kernel_size=5),# 将四维的输出转成二维的输出,其形状为(批量大小, 10)FlattenLayer())
    X = torch.rand(1, 1, 224, 224)
    for name, blk in net.named_children(): X = blk(X)print(name, 'output shape: ', X.shape)
    
  • 0 output shape:  torch.Size([1, 96, 54, 54])
    1 output shape:  torch.Size([1, 96, 26, 26])
    2 output shape:  torch.Size([1, 256, 26, 26])
    3 output shape:  torch.Size([1, 256, 12, 12])
    4 output shape:  torch.Size([1, 384, 12, 12])
    5 output shape:  torch.Size([1, 384, 5, 5])
    6 output shape:  torch.Size([1, 384, 5, 5])
    7 output shape:  torch.Size([1, 10, 5, 5])
    8 output shape:  torch.Size([1, 10, 1, 1])
    9 output shape:  torch.Size([1, 10])
    
  • 获取数据和训练模型

  • import sys
    batch_size = 32
    # 如出现“out of memory”的报错信息,可减小batch_size或resize
    def load_data_fashion_mnist(batch_size, resize=None, root='~/Datasets/FashionMNIST'):"""Download the fashion mnist dataset and then load into memory."""trans = []if resize:trans.append(torchvision.transforms.Resize(size=resize))trans.append(torchvision.transforms.ToTensor())transform = torchvision.transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root=root, train=True, download=True, transform=transform)mnist_test = torchvision.datasets.FashionMNIST(root=root, train=False, download=True, transform=transform)if sys.platform.startswith('win'):num_workers = 0  # 0表示不用额外的进程来加速读取数据else:num_workers = 4train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)return train_iter, test_iter
    def evaluate_accuracy(data_iter, net, device=None):if device is None and isinstance(net, torch.nn.Module):# 如果没指定device就使用net的devicedevice = list(net.parameters())[0].device acc_sum, n = 0.0, 0with torch.no_grad():for X, y in data_iter:if isinstance(net, torch.nn.Module):net.eval() # 评估模式, 这会关闭dropoutacc_sum += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()net.train() # 改回训练模式else: # if('is_training' in net.__code__.co_varnames): # 如果有is_training这个参数# 将is_training设置成Falseacc_sum += (net(X, is_training=False).argmax(dim=1) == y).float().sum().item() else:acc_sum += (net(X).argmax(dim=1) == y).float().sum().item() n += y.shape[0]return acc_sum / n
    def mytrain(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs):net = net.to(device)print("training on ", device)loss = torch.nn.CrossEntropyLoss()for epoch in range(num_epochs):train_l_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time()for X, y in train_iter:X = X.to(device)y = y.to(device)y_hat = net(X)l = loss(y_hat, y)optimizer.zero_grad()l.backward()optimizer.step()train_l_sum += l.cpu().item()train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()n += y.shape[0]batch_count += 1test_acc = evaluate_accuracy(test_iter, net)print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'% (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))
    train_iter, test_iter = load_data_fashion_mnist(batch_size, resize=224)
    lr, num_epochs = 0.002, 5
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    mytrain(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)
    
  • training on  cpu
    epoch 1, loss 2.2931, train acc 0.106, test acc 0.100, time 2587.8 sec
    epoch 2, loss 2.3026, train acc 0.100, test acc 0.100, time 2413.8 sec
    epoch 3, loss 2.3026, train acc 0.100, test acc 0.100, time 2336.9 sec
    epoch 4, loss 2.3026, train acc 0.100, test acc 0.100, time 2333.4 sec
    epoch 5, loss 2.3026, train acc 0.100, test acc 0.100, time 2333.9 sec
    
  • 针对分类任务提出了一个新的深度网络NIN。这种新网络包括mlpconv层(使用MLP来进行卷积)以及全局平均池化层(取代FC层)。mlpconv层对局部块特征提取更好,全局平均池化可以作为正则化器来防止全局的过拟合。我们使用这种结构在几种数据集上取得了目前最好的效果。通过特征图的可视化,我们验证了最后一层mlpconv输出的特征图是类别的信度图,同时也提升了使用NIN进行目标检测的可能性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_283279.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

蓝桥杯刷题冲刺 | 倒计时1天

作者:指针不指南吗 专栏:蓝桥杯倒计时冲刺 🐾蓝桥杯加油,大家一定可以🐾 文章目录我是菜菜,最近容易我犯的错误总结 一些tips 各位蓝桥杯加油加油 当输入输出数据不超过 1e6 时,scanf printf 和…

elasticsearch基础6——head插件安装和web页面查询操作使用、ik分词器

文章目录一、基本了解1.1 插件分类1.2 插件管理命令二、分析插件2.1 es中的分析插件2.1.1 官方核心分析插件2.1.2 社区提供分析插件2.2 API扩展插件三、Head 插件3.1 安装3.2 web页面使用3.2.1 概览页3.2.1.1 unassigned问题解决3.2.2 索引页3.2.3 数据浏览页3.2.4 基本查询页3…

微服务+springcloud+springcloud alibaba学习笔记(1/9)

1.微服务简介 什么是微服务呢? 就是将一个大的应用,拆分成多个小的模块,每个模块都有自己的功能和职责,每个模块可以 进行交互,这就是微服务 简而言之,微服务架构的风格,就是将单一程序开发成…

项目管理案例分析有哪些?

项目管控中遇到的问题有哪些?这些问题是如何解决的? 在项目管理领域,案例分析是一种常见的方法来学习和理解项目管理实践,下面就来介绍几个成功案例,希望能给大家带来一些参考。 1、第六空间:快速响应个性…

1669_MIT 6.828 xv6代码的获取以及编译启动

全部学习汇总: GreyZhang/g_unix: some basic learning about unix operating system. (github.com) 6.828的学习的资料从开始基本信息的讲解,逐步往unix的一个特殊版本xv6过度了。这样,先得熟悉一下这个OS的基本代码以及环境。 在课程中其实…

最短路径算法及Python实现

最短路径问题 在图论中,最短路径问题是指在一个有向或无向的加权图中找到从一个起点到一个终点的最短路径。这个问题是计算机科学中的一个经典问题,也是许多实际问题的基础,例如路线规划、通信网络设计和交通流量优化等。在这个问题中&#…

Downloader工具配置参数并烧录到flash中

1 Downloader工具介绍 Downloader工具可以用来烧录固件到设备中,固件格式默认为*dcf。该工具还可以用来在线调试EQ或者进行系统设置。 2 配置参数 2.1 作用 当有一个dcf文件时,配合不同的配置文件*.setting,在不进行编译的情况下&#xff…

【毕业设计】ESP32通过MQTT协议连接服务器(二)

文章目录0 前期教程1 前言2 配置SSL证书3 配置用户名和密码4 配置客户端id(client_id)5 conf文件理解6 websocket配置7 其他资料0 前期教程 【毕业设计】ESP32通过MQTT协议连接服务器(一) 1 前言 上一篇教程简单讲述了怎么在虚拟…

【调试】ftrace(三)trace-cmd和kernelshark

之前使用ftrace的时候需要一系列的配置,使用起来有点繁琐,这里推荐一个ftrace的一个前端工具,它就是trace-cmd trace-cmd安装教程 安装trace-cmd及其依赖库 git clone https://git.kernel.org/pub/scm/libs/libtrace/libtraceevent.git/ c…

【Ruby学习笔记】19.Ruby 连接 Mysql - MySql2

Ruby 连接 Mysql - MySql2 前面一章节我们介绍了 Ruby DBI 的使用。这章节我们技术 Ruby 连接 Mysql 更高效的驱动 mysql2,目前也推荐使用这种方式连接 MySql。 安装 mysql2 驱动: gem install mysql2你需要使用 –with-mysql-config 配置 mysql_conf…

【DevOps】GitOps 初识(下) - 让DevOps变得更好

实践GitOps的五大难题 上一篇文章中,我们介绍了GitOps能为我们带来许多的好处,然而,任何新的探索都将不会是一帆风顺的。在开始之前,如果能了解实践GitOps通常会遇到的挑战,并对此作出合适的应对,可能会使…

数据结构和算法(一):复杂度、数组、链表、栈、队列

从广义上来讲:数据结构就是一组数据的存储结构 , 算法就是操作数据的方法 数据结构是为算法服务的,算法是要作用在特定的数据结构上的。 10个最常用的数据结构:数组、链表、栈、队列、散列表、二叉树、堆、跳表、图、Trie树 10…

StorageManagerService.java中的mVold.mount

android源码:android-11.0.0_r21(网址:Search (aospxref.com)) 一、问题 2243行mVold.mount执行的是哪个mount函数? 2239 private void mount(VolumeInfo vol) { 2240 try { 2241 // TOD…

【LeetCode】-- 108. 将有序数组转换为二叉搜索树

1. 题目 108. 将有序数组转换为二叉搜索树 - 力扣(LeetCode) 给你一个整数数组 nums ,其中元素已经按升序排列,请你将其转换为一棵高度平衡二叉搜索树。高度平衡二叉树是一棵满足「每个节点的左右两个子树的高度差的绝对值不超过 …

mysql在CentOS7.x环境安装

查看当前环境的yum源 ls -l /etc/yum.repos.d/ 可以看到当前环境是没有下载mysql对应的yum源的, 所以需要去官网下载对应的yum源. 找mysql的yum源并安装 http://repo.mysql.com/ 在选择对应yum源之前, 需要看一下自己系统的版本: 进入官网后, 鼠标右击进入查看页面源代码, 因为…

Leetcode.463 岛屿的周长

题目链接 Leetcode.463 岛屿的周长 easy 题目描述 给定一个 row x col的二维网格地图 grid,其中:grid[i][j] 1表示陆地, grid[i][j] 0表示水域。 网格中的格子 水平和垂直 方向相连(对角线方向不相连)。整个网格被…

如何从功能测试转型到自动化测试:我三年的学习经历

前言 在软件测试的领域里,自动化测试已经成为了不可或缺的一部分。 与传统的手工测试相比,自动化测试具有更高的效率和精确度,能够有效地减少测试时间和成本,同时提高测试质量。作为一个从事软件测试的人员,如果你想…

Oracle JDK 和 OpenJDK 有什么区别?

可能在看这个问题之前很多人和我一样并没有接触和使用过 OpenJDK 。那么 Oracle JDK 和 OpenJDK 之间是否存在重大差异?下面我通过收集到的一些资料,为你解答这个被很多人忽视的问题。 首先,2006 年 SUN 公司将 Java 开源,也就有…

智慧方政务云顶层设计与建设方案(ppt)

本资料来源公开网络,仅供个人学习,请勿商用,如有侵权请联系删除 对一网统管总体架构的理解物联网生态中的业务定位物联网产品与解决方案概览智联物联网管理平台总体方案智联物联网管理平台总体架构智联联连接平台(HLINK)应用架构智慧社区基于…

Linux--进程信号

前言 无人问津也好,技不如人也罢,你都要试着安静下来,去做自己该做的事情,而不是让烦恼和焦虑毁掉你不就不多的热情和定力。心可以碎,手不能停,该干什么干什么,在崩溃中继续努力前行&#xff0c…