【PyTorch][chapter 22][李宏毅深度学习]【无监督学习][ WGAN]【理论一】

news/2024/5/14 0:13:07/文章来源:https://blog.csdn.net/chengxf2/article/details/136423587

简介:

         2014年Ian Goodfellow提出以来,GAN就存在着训练困难、生成器和判别器的loss无法指示训练进程、生成样本缺乏多样性等问题。从那时起,很多论文都在尝试解决,但是效果不尽人意,比如最有名的一个改进DCGAN依靠的是对判别器和生成器的架构进行实验枚举,最终找到一组比较好的网络架构设置,但是实际上是治标不治本,没有彻底解决问题。

         2017年Martin Arjovsky提出了 Wasserstein GAN(下面简称WGAN)成功解决了GAN的系列问题:

         1  彻底解决GAN训练不稳定的问题,不再需要小心平衡生成器和判别器的训练程度
         2  基本解决了collapse mode的问题,确保了生成样本的多样性
        3   训练过程中终于有一个像交叉熵、准确率这样的数值来指示训练的进程,这个数值越小代                表GAN训练得越好,代表生成器产生的图像质量越高

    WGAN 里面比较难以理解的是EMD和WGAN 损失函数的关系,以及理论证明

这个会在理论二单独证明.

 目录

   1 GAN 

   2 WGAN 

   3 WGAN 理论


一    GAN 问题

        GAN  使用了BCE loss, 根据生成器的损失函数差异: 有两种方案

         2.1   log(1-D(x))

              总体损失函数:

              V(G,D)= log_{x\sim p_r}logD(x)+log_{x\sim p_g}log(1-D(x))

             生成器G的优化目标:

              V(G,D^{*})= 2JS(p_r||p_g)-2log2

              问题

              P_r,P_g 不重叠的时候, JS(p_r||p_g)=log2,为常数,梯度为0,无法优化生成器G.                       备注:

            P_rP_g的支撑集(support)是高维空间中的低维流形(manifold)时,P_rP_g重叠部分测度(measure)为0的概率为1。

              P_r是由低维的流形(随机噪声输入)映射到高维空间的

            

  •    支撑集(support)其实就是函数的非零部分子集,比如ReLU函数的支撑集就是(0,+∞),一个概率分布的支撑集就是所有概率密度非零部分的集合。
  • 流形(manifold)是高维空间中曲线、曲面概念的拓广,我们可以在低维上直观理解这个概念,比如我们说三维空间中的一个曲面是一个二维流形,因为它的本质维度(intrinsic dimension)只有2,一个点在这个二维流形上移动只有两个方向的自由度。同理,三维空间或者二维空间中的一条曲线都是一个一维流形。
  • 测度(measure)是高维空间中长度、面积、体积概念的拓广,可以理解为“超体积

         

        2.2    -logD(x)

          V(G,D)= log_{x\sim p_r}logD(x)-log_{x\sim p_g}log(D(x))

             生成器G的优化目标:

             V(G,D^{*})= KL(p_g||p_r)-2JS(p_r||p_g)

             这个等价最小化目标存在两个严重的问题。第一是它同时要最小化生成分布与真实分布的KL散度,却又要最大化两者的JS散度,一个要拉近,一个却要推远!在数值上则会导致梯度不稳定,这是后面那个JS散度项的毛病。 这就是mode collapse 问题

       


二 WGAN

   2.1 模型

    生成器G

              输入

                     随机噪声z

              输出:

                      x_f=G(z)

    鉴别器D:

               输入:

                    一种是采样自训练集中的

2.2 损失函数

     

       ,

       GAN 输出的是一个0,1之间的概率

       WGAN 输出的一个scalar

     对于鉴别器D:

                输入训练图片x_r,希望C(x_r)尽量大

                输入生成图片x_f,希望C(x_f)尽量小,

2.3 跟GAN区别

  • 1  判别器最后一层去掉sigmoid
  • 2  生成器和判别器的loss不取log
  • 训练技巧:
  • 1  每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c
  • 2 不要用基于动量的优化算法(包括momentum和Adam),推荐RMSProp,SGD也行

 2.4   训练方法

\


三  理论分析

  WGAN 的理论基础是Wasserstein  Distance,也叫推土机距离, Kantorovich-Rubinstein 算法.

本篇我们重点讲两个

  1: 什么是 Wasserstein Distance

  2:    Wasserstein  Distance 怎么变换到WGAN 的损失函数

3.1 Wasserstein  Distance

        Wasserstein Distance也称为推土机距离(Earth Mover’s distance, EMD),Wasserstein Distance的定义是评估由P分布转换成Q分布所需要的最小代价(移动的平均距离的最小值).

代价取决于两个因素: 移动的距离以及移动大小

      定义:

       EMD(P,Q)=inf_{r \in \prod}\sum_{x,y}||x-y||r(x,y)

                                =inf E_{(x,y)\sim r}||x-y||

       我们下面举个例子,帮助理解该函数:

      

3.2 例子

     如上图:

      生成图像的概率密度函数是P(x)

       真实图像的概率密度函数是Q(y)

     传输方案一(Transport plan 1 标记 \pi_1

       我们把P(X=6)=0.5的概率 : 

                  移动0.1 至Q(Y=7)     移动0.4 至 Q(Y=9)

      我们把P(X=10)=0.5的概率 : 

                   移动0.1 至Q(Y=7)  ,   移动0.4 至 Q(Y=9)

    我们可以用矩阵表示

    其EMD 距离为

 EMD(P,Q)=|6-7|*0.1+|6-9|*0.4+|10-7|*0.1+|10-9|*0.4

                          =0.1+1.2+0.3+0.4

                          =2.0

      传输方案二(Transport plan 2 标记 \pi_2

       我们把P(X=6)=0.5的概率 : 

                  移动0.2 至Q(Y=7)     移动0.3 至 Q(Y=9)

      我们把P(X=10)=0.5的概率 : 

                 移动0.5 至 Q(Y=9)

    我们可以用矩阵表示

EMD(P,Q)=|6-7|*0.2+|6-9|*0.3+|10-7|*0.0+|10-9|*0.5

                         =0.2+0.9+0.5

                        =1.6

因为方案2 EMD_{\pi_2}<EMD_{\pi_1}, 所以\pi_2更好,实际上我们真实图像的分布复杂度远远比

该分布复杂,对应的传输方案也有无数种,深度学习的目标就是借助神经网络,

从该方案中找到最优的方案

参考


1  WAGAN 论文: https://arxiv.org/pdf/1701.07875.pdf

2  WGAN基本原理及Pytorch实现WGAN-CSDN博客

3   earth-movers-distance_哔哩哔哩_bilibili

4  https://www.youtube.com/watch?v=3JP-xuBJsyc&t=2644s

5     令人拍案叫绝的Wasserstein GAN - 知乎

6  Introduction to the Wasserstein distance_哔哩哔哩_bilibili

7  https://www.youtube.com/watch?v=xs9uibPODGk

8   Wasserstein GAN and the Kantorovich-Rubinstein Duality - Vincent Herrmann

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_999000.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【嵌入式高级C语言】9:万能型链表懒人手册

文章目录 序言单向不循环链表拼图框架搭建 - Necessary功能拼图块1 创建链表头信息结构体 - Necessary2 链表头部插入 - Optional3 链表的遍历 - Optional4 链表的销毁 - Necessary5 链表头信息结构体销毁 - Necessary6 获取链表中节点的个数 - Optional7 链表尾部插入 - Optio…

RFID-科技的“隐秘耳语者”

RFID-科技的“隐秘耳语者” 想象一下&#xff0c;你身处一个光线昏暗的环境中&#xff0c;周围的一切都被厚厚的阴影笼罩。这时&#xff0c;你需要识别并获取一个物体的信息&#xff0c;你会选择怎么做&#xff1f;是点亮灯光&#xff0c;用肉眼仔细观察&#xff0c;还是打开扫…

Haproxy实验搭建

1.yum本地源安装Haproxy [rootlocalhost ~]# systemctl stop firewalld [rootlocalhost ~]# setenforce 0 [rootlocalhost ~]# yum install haproxy -y2.yum网络源安装Haproxy 关闭防火墙和selinux ###先把安装包拖进来[rootlocalhost ~]# yum install rh-haproxy18-haproxy-1…

Linux操作系统的vim常用命令和vim 键盘图

在vi编辑器的命令模式下&#xff0c;命令的组成格式是&#xff1a;nnc。其中&#xff0c;字符c是命令&#xff0c;nn是整数值&#xff0c;它表示该命令将重复执行nn次&#xff0c;如果不给出重复次数的nn值&#xff0c;则命令将只执行一次。例如&#xff0c;在命令模式下按j键表…

[java基础揉碎]继承

为什么需要继承: > 继承就可以解决代码复用的问题 继承的基本介绍: 继承的使用细节: 1.子类继承了所有的属性和方法&#xff0c;但是私有属性和方法不能在子类直接访问&#xff0c;要通过公共的方法去访问 解决, 提供公共的方法返回: 2.子类必须调用父类的构造器,完成父…

Linux系统——LVS、Nginx、HAproxy区别

目录 一、LVS 1.负载均衡机制 1.1负载均衡——NAT模式 1.2负载均衡——DR模式 1.3负载均衡——隧道模式 1.4负载均衡——总结 2.LVS调度算法 3.LVS优点 4.LVS缺点 二、Nginx 1.传统基于进程或线程的模型 2.Nginx架构设计 3.Nginx负载均衡 4.Nginx调度算法 5.Ngi…

Jmeter 测试使用基本组件结构

JMeter简介 Apache组织开发的开源免费压测工具纯Java程序&#xff0c;跨平台性强源程序可以从网上下载高扩展性可对服务器、网络或对象模拟巨大的负载&#xff0c;进行压力测试可以用于接口测试支持分布式、多节点部署 JMeter安装 下载位置 官网https://jmeter.apache.org/ …

Java中SpringBoot四大核心组件是什么

一、Spring Boot Starter 1.1 Starter的应用示例 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-thymeleaf</artifactId> </dependency> <dependency><groupId>org.mybatis.sprin…

深度学习与人类的智能交互:迈向自然与高效的人机新纪元

引言 随着科技的飞速发展&#xff0c;深度学习作为人工智能领域的一颗璀璨明珠&#xff0c;正日益展现出其在模拟人类认知和感知过程中的强大能力。本文旨在探讨深度学习如何日益逼近人类智能的边界&#xff0c;并通过模拟人类的感知系统&#xff0c;使机器能更深入地理解和解…

高级语言讲义2016计专(仅高级语言部分)

1.斐波那契序列的第n项可以表示成以下形式&#xff0c;编写一个非递归函数&#xff0c;返回该数列的第n项的数值 #include <stdio.h>int func(int n) {if(n1||n2)return 1;int p1,q1,num;for(int i3; i<n; i) {numpq;qp;pnum;}return num; } 2.在MXN的二维数组A中&am…

瑞_23种设计模式_模板方法模式

文章目录 1 模板方法模式&#xff08;Template Pattern&#xff09; ★ 钩子函数1.1 介绍1.2 概述1.3 模板方法模式的结构1.4 模板方法模式的优缺点1.5 模板方法模式的使用场景 2 案例一2.1 需求2.2 代码实现 3 案例二3.1 需求3.2 代码实现 4 JDK源码解析&#xff08;InputStre…

【测试工具系列】压测用Jmeter还是LoadRunner?还是其他?

说起JMeter&#xff0c;估计很多测试人员都耳熟能详。它小巧、开源&#xff0c;还能支持多种协议的接口和性能测试&#xff0c;所以在测试圈儿里很受欢迎&#xff0c;也是测试人员常用的工具&#xff0c;但是在企业级性能场景下可能会有性能瓶颈&#xff0c;更适合测试自己使用…

Grafana二次开发环境搭建

1 Grafana环境搭建 1.1 搭建后端服务 下载windows安装版文件grafana.com 1&#xff09;选择版本号&#xff1a;此处我选的8.3.3版本 2&#xff09;安装完成后&#xff0c;请记住安装目录 &#xff0c;我的是在 D:\software\Gragana833 安装完成后会自动运行, 3&#xff09;此…

2024年软考重大改革

中国计算机技术职业资格网 考试日期 考试级别 考试资格名称 5月25日至28日 高级 系统分析师 系统架构设计师 信息系统项目管理师 中级 软件设计师 网络工程师 软件评测师 电子商务设计师 嵌入式系统设计师 数据库系统工程师 信息系统管理工程师 初级 程序员 …

MySQL之体系结构和基础管理

前言 本文以linux系统的MySQL为例详细介绍MySQL的体系结构&#xff0c;因为在实际生产环境中MySQL的运行环境都是linux系统。同时介绍MySQL的基础管理&#xff0c;包括用户管理和权限管理等。 MySQL体系结构 MySQL客户端/服务器工作模型 MySQL是C/S架构&#xff0c;工作模型…

物联网,智慧城市的数字化转型引擎

随着科技的飞速发展&#xff0c;物联网&#xff08;IoT&#xff09;已成为推动智慧城市建设的关键力量。物联网技术通过连接各种设备和系统&#xff0c;实现数据的实时采集、传输和处理&#xff0c;为城市的智能化管理提供了强大的支持。在数字化转型的浪潮中&#xff0c;物联网…

VUE_nuxt启动只能通过localhost访问,ip访问不到:问题解决

修改项目根目录下的 package.json "config": {"nuxt": {"host": "0.0.0.0","port": "3000"} } 这样项目启动后就可以通过ip进行访问了

AutoDev 自定义 Agent:快速接入内部 AI Agent,构建 IDE 即 AI 辅助研发中心

在开源 AI IDE 插件 AutoDev 的 #51 issue 中&#xff0c;我们设计了 AutoDev 的 AI Agent 能力&#xff0c;半年后我们终于交付了这个功能。 在 AutoDev 1.7.0 中&#xff0c;你将可以接入内部的 AI Agent&#xff0c;并将其无缝与现有的 AI 辅助能力结合在一起。 本文将使用 …

使用awk和正则表达式过滤文本或字符串 - 详细指南和示例

当我们在 Linux 中运行某些命令来读取或编辑字符串或文件中的文本时&#xff0c;我们经常尝试将输出过滤到感兴趣的特定部分。这就是使用正则表达式派上用场的地方。 什么是正则表达式&#xff1f; 正则表达式可以定义为表示多个字符序列的字符串。关于正则表达式最重要的事情之…

【C++ STL详解】——string类

目录 前言 一、string类对象的常见构造 二、string类对象的访问及遍历 1.下标【】&#xff08;底层operator【】函数&#xff09; ​编辑 2.迭代器 3.范围for 4.at 5.back和front 三、string类对象的容量操作 1.size 和 length 2.capacity 3.empty 4.clear 5.res…