FSP:Flow of Solution Procedure (CVPR 2017) 原理与代码解析

news/2024/4/19 16:45:39/文章来源:https://blog.csdn.net/ooooocj/article/details/129250325

paper:A Gift From Knowledge Distillation: Fast Optimization, Network Minimization and Transfer Learning

code:https://github.com/HobbitLong/RepDistiller/blob/master/distiller_zoo/FSP.py

背景

深度神经网络DNN逐层生成特征。更高层的特征更接近于任务的有用特征。如果我们把DNN的输入看作问题,把输出看作答案,我们就可以把DNN中间生成的特征看作是求解过程中的中间结果。根据这一想法,FitNets可以让学生网络简单地模拟教师网络的中间结果。然而在DNN中,有许多方法或途径来解决从输入生成输出的问题。因此,模拟教师网络生成的特征对学生网络来说是一个硬约束hard constraint。就人而言,老师解释问题的解决过程,学生学习解决问题的流程。当输入特定的问题时,学生网络不一定需要学习中间输出,但当遇到特定类型的问题时,学生网络可以学习这一类问题的通用解决方法。因此作者认为,对于知识蒸馏中的教师网络,演示问题的解决过程比演示中间结果具有更好的泛化性

本文的创新点

本文将神经网络中层与层之间的信息流动定义为需要蒸馏的知识,并通过计算两个特征层之间的内积来得到这种知识。当将这种层之间的流动作为知识传递给学生网络时,作者通过实验得到了三个结论:

  1. 从教师网络学习这种蒸馏知识的学生网络比原始网络的优化(收敛)速度快得多。

  1. 学习这种蒸馏知识的学生网络比原始网络的性能更好。

  1. 即使教师网络是在一个不同的任务或数据集上训练得到的,学生网络也可以从教师网络中学习到这种知识,并且比从头训练的效果更好。

下图是本文提出的知识蒸馏方法的概念图

本文的贡献如下:

  1. 提出了一种知识蒸馏的新方法。

  1. 这种知识对于快速优化非常有用。

  1. 利用所提出的蒸馏知识定义网络的初始权重可以提高小模型的性能。

  1. 即使学生网络接受了与教师网络不同的训练任务,所提出的蒸馏知识也能提高学生网络的表现。

方法介绍

作者设计了网络中两个相邻层之间的FSP(flow of solution procedure)矩阵来表示问题的求解过程,对于挑选的层1输出的feature map表示为 \(F^{1}\in \mathbb{R}^{h\times w\times m}\),其中 \(h,w,m\) 分别表示特征图的高、宽、通道数。层2表示为 \(F^{2}\in \mathbb{R}^{h\times w\times n}\),则FSP矩阵 \(G\in \mathbb{R}^{m\times n}\) 可通过下式求得

其中 \(x\) 表示输入图片,\(W\) 表示网络权重参数。

对于残差网络,网络在一些位置的spatial size发生变化,我们选择教师网络和学生网络对应位置具有相同spatial size的特征图来生成FSP matrix,下图是一个示例

计算教师网络和学生网络对应FSP矩阵的L2损失,完整是损失函数如下

其中 \(\lambda_{i}\) 表示每一对FSP矩阵损失的权重,文中设定所有层计算的FSP之间的损失权重相等。\(N\) 表示所有的采样点。

代码解析

forward函数的输入g_sg_t分别表示学生网络和教师网络中所有用来计算FSP矩阵的层,在compute_fsp中每一层都与相邻层计算fsp矩阵,注意这里的相邻并不是说在原始网络中这两层的相邻的。这里相邻层之间计算fsp矩阵需要保证spatial size相等,如果不相等通过自适应平均池化使之相等。

from __future__ import print_functionimport numpy as np
import torch.nn as nn
import torch.nn.functional as Fclass FSP(nn.Module):"""A Gift from Knowledge Distillation:Fast Optimization, Network Minimization and Transfer Learning"""def __init__(self, s_shapes, t_shapes):super(FSP, self).__init__()assert len(s_shapes) == len(t_shapes), 'unequal length of feat list's_c = [s[1] for s in s_shapes]t_c = [t[1] for t in t_shapes]if np.any(np.asarray(s_c) != np.asarray(t_c)):raise ValueError('num of channels not equal (error in FSP)')def forward(self, g_s, g_t):# [(64,32,32,32),(64,64,32,32),(64,128,16,16),(64,256,8,8)]# [(64,32,32,32),(64,64,32,32),(64,128,16,16),(64,256,8,8)]s_fsp = self.compute_fsp(g_s)t_fsp = self.compute_fsp(g_t)loss_group = [self.compute_loss(s, t) for s, t in zip(s_fsp, t_fsp)]return loss_group@staticmethoddef compute_loss(s, t):return (s - t).pow(2).mean()@staticmethoddef compute_fsp(g):fsp_list = []for i in range(len(g) - 1):bot, top = g[i], g[i + 1]  # (64,32,32,32),(64,64,32,32)b_H, t_H = bot.shape[2], top.shape[2]if b_H > t_H:bot = F.adaptive_avg_pool2d(bot, (t_H, t_H))elif b_H < t_H:top = F.adaptive_avg_pool2d(top, (b_H, b_H))else:passbot = bot.unsqueeze(1)  # (64,1,32,32,32)top = top.unsqueeze(2)  # (64,64,1,32,32)bot = bot.view(bot.shape[0], bot.shape[1], bot.shape[2], -1)  # (64,1,32,1024)top = top.view(top.shape[0], top.shape[1], top.shape[2], -1)  # (64,64,1,1024)fsp = (bot * top).mean(-1)  # (64,64,32,1024)->(64,64,32)fsp_list.append(fsp)return fsp_list

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_75072.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

决策树在sklearn中的实现

目录 一.模块sklearn.tree 二.建模基本流程 三.DecisionTreeClassifier重要参数 1.criterion 2.random_state & splitter 3.剪枝参数max_depth 4.剪枝参数min_samples_leaf & min_samples_split 5.max_features & min_impurity_decrease 6.class_weight …

Python IDE:对于 Python 初学者来说,最好的 IDE 是什么?

Python 是科技界最简单、使用最广泛的编程语言之一。它是一种高级通用编程语言&#xff0c;强调代码可读性并使用面向对象的方法。Python可以用来完成很多任务&#xff0c;包括网站开发、软件开发、 自动化 和数据分析 专业开发人员使用Python开发各种流行的软件程序&#xff0…

深入理解Spring MVC上

Spring MVC 是一种基于 Spring 框架的 Web 框架&#xff0c;它提供了一种基于 Model-View-Controller&#xff08;MVC&#xff09;的设计模式&#xff0c;用于构建 Web 应用程序。在 Spring MVC 中&#xff0c;Controller 接受并处理 HTTP 请求&#xff0c;并将其转发给适当的 …

多表left join 慢sql问题

作为个人记录&#xff0c;后续再填坑a对p是1对多 ,p对llup 1对多SELECTa.id,p.id,t1.id FROMliv_series_product aINNER JOIN liv_product p ON p.id a.product_idLEFT JOIN ( SELECT llup.id, llup.product_id, llup.room_id FROM liv_live_user_product llup WHERE llup.ro…

Tomcat部署及多实例

Tomcat部署及多实例一、Tomcat简介1、Tomcat核心组件2、什么是JSP二、Tomcat数据流向1、Tomcat数据流向2、Tomcat-Nginx数据流向三、Tomcat服务部署安装1、安装jdk包2、解压Tomcat所需的安装包3、在/etc/profile添加环境变量4、启动服务并查看5、在浏览器网页验证6、创建用户&a…

为什么硬件性能监控很重要

当今的混合网络环境平衡了分布式网络和现代技术的实施。但它们并不缺少一个核心组件&#xff1a;服务器。保持网络正常运行时间归结为监控和管理导致网络停机的因素。极有可能导致性能异常的此类因素之一是硬件。使用硬件监控器监控网络硬件已成为一项关键需求。 硬件监视器是…

优化知识管理方法丨整理零碎信息,提高数据价值

信息流时代&#xff0c;知识成集合倍数增长&#xff0c;看似我们学习了很多知识&#xff0c;但知识零碎无系统&#xff0c;知识之间缺乏联系&#xff0c;没有深度&#xff0c;所以虽然你很努力&#xff0c;但你发现自己的能力增长特别缓慢&#xff0c;你需要整理知识将零散的知…

蓝桥杯:染色时间

蓝桥杯&#xff1a;染色时间https://www.lanqiao.cn/problems/2386/learning/?contest_id80 问题描述 输入格式 输出格式 样例输入输出 样例输入 样例输出 评测用例规模与约定 解题思路&#xff1a;优先队列 AC代码(Java)&#xff1a; 问题描述 小蓝有一个 n 行 m 列…

std::chrono笔记

文章目录1. radio原型作用示例2. duration原型&#xff1a;作用示例3. time_point原型作用示例4. clockssystem_clock示例steady_clock示例high_resolution_clock先说感觉&#xff0c;这个库真恶心&#xff0c;刚接触感觉跟shi一样&#xff0c;特别是那个命名空间&#xff0c;太…

vue2 diff算法

diff是什么 diff 算法是一种通过同层的树节点进行比较的高效算法 其有两个特点&#xff1a; ♥比较只会在同层级进行, 不会跨层级比较 ♥在diff比较的过程中&#xff0c;循环从两边向中间比较 diff 算法的在很多场景下都有应用&#xff0c;在 vue 中&#xff0c;作用于虚拟 dom…

预备2-CMD常用命令

CMD常用命令 先学简单常用的, 其余的要用在学 打开Cmd窗口 Win键R> 输入Cmd回车鼠标点击开始 > 附件>Cmd打开一个窗口,在地址栏输入cmd 操作目录 1.dir 查询当前目录有哪些文件 2.cd.. 上一级目录 3.cd e: 切换到E盘 4.d: 直接去d盘 5.cd /d e:abc 直接去E盘的abc目…

2023年房地产行业研究报告

第一章 行业发展概况 房地产业是指以土地和建筑物为经营对象&#xff0c;从事房地产开发、建设、经营、管理以及维修、装饰和服务的集多种经济活动为一体的综合性产业&#xff0c;是具有先导性、基础性、带动性和风险性的产业。主要包括&#xff1a;土地开发&#xff0c;房屋的…

解决AAC音频编码时间戳的计算问题

1.主题音频是流式数据&#xff0c;并不像视频一样有P帧和B帧的概念。就像砌墙一样&#xff0c;咔咔往上摞就行了。一般来说&#xff0c;AAC编码中生成文件这一步&#xff0c;如果使用的是OutputStream流写入文件的话&#xff0c;就完全不需要计算时间。但在音视频同步或者使用A…

debian 部署nginx https

我是flask 处理请求单进程&#xff0c; 差点意思 &#xff0c; 考虑先flask 在往下走 一&#xff1a;安装nginx 因为我是debian 系统&#xff0c;所以我的建议是直接 sudo apt-get install nginx 你也可以选择在官网下载&#xff0c; 但是我搭建ssl 的时候安装openssl非常的麻…

【无标题】(2019)NOC编程猫创新编程复赛小学组真题含参考

&#xff08;2019&#xff09;NOC编程猫创新编程复赛小学组最后6道大题。前10道是选择填空题 略。 这道题是绘图题&#xff0c;没什么难度&#xff0c;大家绘制这2个正十边形要注意&#xff1a;一是不要超出舞台&#xff1b;二是这2个正十边形不要相交。 这里就不给出具体程序了…

数睿通2.0数据服务功能模块发布

文章目录引言API 目录API 权限API 日志结语引言 数睿通 2.0 之前基本完成了数据集成和数据开发两大模块&#xff0c;也因此得到了一些朋友的帮助和支持&#xff0c;在此由衷的表示感谢&#xff0c;你们的支持便是我们更新的最大动力&#xff01; 目前&#xff0c;数据服务模块…

色环电阻的阻值如何识别

这种是色环电阻&#xff0c;其外表有一圈圈不同颜色的色环&#xff0c;现在在一些电器和电源电路中还有使用。下面的两种色环电阻它颜色还不一样&#xff0c;一个蓝色&#xff0c;一个土黄色&#xff0c;其实这个蓝色的属于金属膜色环电阻&#xff0c;外表涂的是一层金属膜&…

狂神说:面向对象(三)——多态

多态// 对象能执行什么方法&#xff0c;主要看对象左边的类型&#xff0c;和右边的没有关系多态&#xff1a;同一方法可以根据发送对象的不同而采用不同的行为方式父类&#xff1a;public class Person {public void run(){System.out.println("Person > run");}}…

【并发编程学习篇】深入理解CountDownLatch

一、CountDownLatch介绍 CountDownLatch&#xff08;闭锁&#xff09;是一个同步协助类&#xff0c;允许一个或多个线程等待&#xff0c;直到其他线程完成操作集。CountDownLatch使用给定的计数值&#xff08;count&#xff09;初始化。await方法会阻塞直到当前的计数值被coun…

只需四步,手把手教你打造专属数字人

伴随ChatGPT的问世&#xff0c;在技术与商业运作上都日渐发展成熟的数字人产业正持续升温。去年9月&#xff0c;北京市发布了国内首个数字人产业专项支持政策&#xff0c;提出将依托国家文化专网将数字人纳入文化数据服务平台。以数字人、ChatGPT为代表的互联网3.0创新应用产业…