BigGAN

news/2024/4/23 15:03:04/文章来源:https://blog.csdn.net/qq_41048761/article/details/129242472

1、BIGGAN 解读

1.1、作者

Andrew Brock、Jeff Donahue、Karen Simonyan

1.2、摘要

尽管最近在生成图像建模方面取得了进展,但从 ImageNet 等复杂数据集中 成功生成高分辨率、多样化的样本仍然是一个难以实现的目标。为此,我们以迄 今为止最大的规模训练生成对抗网络,并研究该规模特有的不稳定性。我们发现, 对生成器应用正交正则化使其易于使用简单的“截断技巧”,通过减少生成器输 入的方差,可以精细控制样本保真度和品种之间的权衡。我们的修改导致模型在 类条件图像合成中设置了新的技术状态。当以 128×128 分辨率在 ImageNet 上训 练时,BigGANs 的 IS 分数为 166.5,FID 分数为 7.4,比之前最好的 IS 为 52.52 和 FID 为 18.65 有所改进。

1.3、模型

GResidualBlock块代码如下:

class GResidualBlock(nn.Module):''' Implements a residual block in BigGAN's generator '''def __init__(self,c_dim: int,in_channels: int,out_channels: int,):super().__init__()self.conv1 = nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))self.conv2 = nn.utils.spectral_norm(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))self.bn1 = ClassConditionalBatchNorm2d(c_dim, in_channels)self.bn2 = ClassConditionalBatchNorm2d(c_dim, out_channels)self.activation = nn.ReLU()self.upsample_fn = nn.Upsample(scale_factor=2)     # upsample occurs in every gblockself.mixin = (in_channels != out_channels)if self.mixin:self.conv_mixin = nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0))def forward(self, x, y):# x,y输入给BatchNormh = self.bn1(x, y) # BatchNormh = self.activation(h)# ReLUh = self.upsample_fn(h) # Upsampleh = self.conv1(h)# 3x3Conv# x卷积后成h,y输入给BatchNormh = self.bn2(h, y) # BatchNormh = self.activation(h)# ReLUh = self.conv2(h)# 3x3Conv# x输入给Upsamplex = self.upsample_fn(x)# Upsampleif self.mixin:x = self.conv_mixin(x)# 1x1Conv# 1x1卷积后的x + 经过两次3x3卷积后的xreturn h + x # add

Non-Local Block的代码如下:

# Self-Attention module == Non-Local block
class AttentionBlock(nn.Module):''' Implements a self-attention block from SA-GAN '''def __init__(self, channels: int):super().__init__()self.channels = channelsself.theta = nn.utils.spectral_norm(nn.Conv2d(channels, channels // 8, kernel_size=1, padding=0, bias=False))self.phi = nn.utils.spectral_norm(nn.Conv2d(channels, channels // 8, kernel_size=1, padding=0, bias=False))self.g = nn.utils.spectral_norm(nn.Conv2d(channels, channels // 2, kernel_size=1, padding=0, bias=False))self.o = nn.utils.spectral_norm(nn.Conv2d(channels // 2, channels, kernel_size=1, padding=0, bias=False))self.gamma = nn.Parameter(torch.tensor(0.), requires_grad=True)def forward(self, x):spatial_size = x.shape[2] * x.shape[3]# apply convolutions to get query (theta), key (phi), and value (g) transformstheta = self.theta(x)phi = F.max_pool2d(self.phi(x), kernel_size=2)g = F.max_pool2d(self.g(x), kernel_size=2)# reshape spatial size for self-attentiontheta = theta.view(-1, self.channels // 8, spatial_size)phi = phi.view(-1, self.channels // 8, spatial_size // 4)g = g.view(-1, self.channels // 2, spatial_size // 4)# compute dot product attention with query (theta) and key (phi) matricesbeta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), dim=-1)# compute scaled dot product attention with value (g) and attention (beta) matriceso = self.o(torch.bmm(g, beta.transpose(1, 2)).view(-1, self.channels // 2, x.shape[2], x.shape[3]))# apply gain and residualreturn self.gamma * o + x

BigGAN的Generation结构如图所示:

根据上图代码如下:

class Generator(nn.Module):''' Implements the BigGAN generator '''def __init__(self,base_channels: int = 96,bottom_width: int = 4,# yml里面是2z_dim: int = 120,shared_dim: int = 128,n_classes: int = 1000,):super().__init__()n_chunks = 6    # 5 (generator blocks) + 1 (generator input)self.z_chunk_size = z_dim // n_chunks # 120//6 == 20self.z_dim = z_dimself.shared_dim = shared_dimself.bottom_width = bottom_widthself.n_classes = n_classes# no spectral normalization on embeddings, which authors observe to cripple the generatorself.shared_emb = nn.Embedding(n_classes, shared_dim)# Linear层 Linear(20,16*96*2**2)self.proj_z = nn.Linear(self.z_chunk_size, 16 * base_channels * bottom_width ** 2)# 不能用一个大nn。连续的,因为我们在每个块上添加class+noiseself.g_blocks = nn.ModuleList([# ResBlock up 16ch → 16chGResidualBlock(shared_dim + self.z_chunk_size, 16 * base_channels, 16 * base_channels),# ResBlock up 16ch → 8chGResidualBlock(shared_dim + self.z_chunk_size, 16 * base_channels, 8 * base_channels),# ResBlock up 8ch → 4chGResidualBlock(shared_dim + self.z_chunk_size, 8 * base_channels, 4 * base_channels),# ResBlock up 4ch → 2chGResidualBlock(shared_dim + self.z_chunk_size, 4 * base_channels, 2 * base_channels),# Non-Local Block (64 × 64)AttentionBlock(2 * base_channels),# ResBlock up 2ch → chGResidualBlock(shared_dim + self.z_chunk_size, 2 * base_channels, base_channels),])self.proj_o = nn.Sequential(# BN, ReLU, 3 × 3 Conv ch → 3, Tanhnn.BatchNorm2d(base_channels),nn.ReLU(inplace=True),nn.utils.spectral_norm(nn.Conv2d(base_channels, 3, kernel_size=1, padding=0)),nn.Tanh(),)def forward(self, z, y):'''z: random noise with size self.z_dimy: one-hot class embeddings with size self.shared_dim'''y = self.shared_emb(y)# class# 块z并连接到共享类嵌入zs = torch.split(z, self.z_chunk_size, dim=1)z = zs[0]ys = [torch.cat([y, z], dim=1) for z in zs[1:]] # Split的结果+Class# project noise and reshape to feed through generator blocksh = self.proj_z(z)# Linear层h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width)# feed through generator blocksidx = 0for g_block in self.g_blocks:if isinstance(g_block, AttentionBlock):h = g_block(h)else:h = g_block(h, ys[idx])idx += 1# project to 3 RGB channels with tanh to map values to [-1, 1]h = self.proj_o(h)return h

1.4、试验

1.4.1、不同 Batch size 对性能的影响

作者发现简单地将 Batch size 增大就可以实现性能上较好的提升,文章做 了实验验证。在 Batch size 增大到原来 8 倍的时候,生成性能上的 IS 提高 了 46%。文章推测这可能是每批次覆盖更多模式的结果,为生成和判别两个网 络提供更好的梯度。增大 Batch size 还会带来在更少的时间训练出更好性能的 模型,但增大 Batch size 也会使得模型在训练上稳定性下降,后续再分析如何 提高稳定性。

在实验上,单单提高 Batch size 还受到限制,文章在每层的通道数也做了 相应的增加,当通道增加 50%,大约两倍于两个模型中的参数数量。这会导致 IS 进一步提高 21%。文章认为这是由于模型的容量相对于数据集的复杂性而增加。

1.4.2、选择先验分布

z 通过实验对比了 N(0,1)、Bernoulli{0,1}、Censored Normal max(N(0,1), 0),根据参考训练速度、模型性能,文章最终选择了 z∼ N(0,I)。

1.4.3、选择阈值

所谓的“截断技巧”就是通过对从先验分布 z 采样,通过设置阈值的方式 来截断 z 的采样,其中超出范围的值被重新采样以落入该范围内。这个阈值可 以根据生成质量指标 IS 和 FID 决定。 通过实验可以知道通过对阈值的设定,随着阈值的下降生成的质量会越来越 好,但是由于阈值的下降、采样的范围变窄,就会造成生成上取向单一化,造成 生成的多样性不足的问题。往往 IS 可以反应图像的生成质量,FID 则会更假注 重生成的多样性。

1.4.4、尝试控制 G

在探索模型的稳定性上,文章在训练期间监测一系列权重、梯度和损失统计 数据,以寻找可能预示训练崩溃开始的指标。实验发现每个权重矩阵的前三个奇 异值 σ0,σ1,σ2 是最有用的,它们可以使用 Alrnoldi 迭代方法进行有效计 算。

对于奇异值 σ0,大多数 G 层具有良好的光谱规范,但有些层(通常是 G 中 的第一层而非卷积)则表现不佳,光谱规范在整个训练过程中增长,在崩溃时爆 炸。

一顿操作后,文章得出了调节 G 可以改善模型的稳定性,但是无法确保一 直稳定,从而文章转向对 D 的控制。

1.4.5、尝试控制 D

考虑 D 网络的光谱,试图寻找额外的约束来寻求稳定的训练。使用正交正 则化,DropOut 和 L2 的各种正则思想重复该实验,揭示了这些正则化策略的都 有类似行为:对 D 的惩罚足够高,可以实现训练稳定性但是性能成本很高,但 是在图像生成性能上也是下降的,而且降的有点多。

实验还发现 D 在训练期间的损失接近于零,但在崩溃时经历了急剧的向上 跳跃,这种行为的一种可能解释是 D 过度拟合训练集,记忆训练样本而不是学 习真实图像和生成图像之间的一些有意义的边界。

为了评估这一猜测,文章在 ImageNet 训练和验证集上评估判别器,并测量 样本分类为真实或生成的百分比。虽然在训练集下精度始终高于 98%,但验证 准确度在 50-55% 的范围内,这并不比随机猜测更好(无论正则化策略如何)。

这证实了 D 确实记住了训练集,也符合 D 的角色:不断提炼训练数据并为 G 提 供有用的学习信号。 可以通过约束 D 来强制执行稳定性,但这样做会导致性能上的巨大成本。 使用现有技术,通过放松这种调节并允许在训练的后期阶段发生崩溃(人为把握 训练实际),可以实现更好的最终性能,此时模型被充分训练以获得良好的结果。

1.4.6、用分辨率评估模型

在 ImageNet 数据集下做评估,实验在 ImageNet ILSVRC 2012(大家都在 用的 ImageNet 的数据集)上 128×128,256×256 和 512×512 分辨率评估模 型。

1.4.7、验证 G 网络并非是记住训练集

为了进一步说明 G 网络并非是记住训练集,在固定 z 下通过调节条件标签 c 做插值生成,通过下图的实验结果可以发现,整个插值过程是流畅的,也能说 明 G 并非是记住训练集,而是真正做到了图像生成。

1.5、与 GAN 的对比

BigGAN 的主要改进有一下三部分:

(1)通过大规模 GAN 的应用,BigGAN 实现了生成上的巨大突破,参数量 扩大两到四倍,batchsize 扩大八倍;

(2)采用先验分布 z 的“截断技巧”,允许对样本多样性和保真度进行精 细控制;

(3)在大规模 GAN 的实现上不断克服模型训练问题,采用技巧减小训练的 不稳定,但完全的稳定性只能以极高的性能成本实现。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_74987.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

fastadmin:在新增页面,打开弹窗单选,参数回传

样式:核心代码:一、弹窗的控制器中:// 定义一个公共函数select(),如果这个请求是Ajax,则返回index()函数,否则返回view对象的fetch()函数。 public function select() {if ($this->request->isAjax(…

【软件测试】测试老鸟的迷途,进军高级自动化测试测试......

目录:导读前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结(尾部小惊喜)前言 很多从业几年的选手…

【阿旭机器学习实战】【37】电影推荐系统---基于矩阵分解

【阿旭机器学习实战】系列文章主要介绍机器学习的各种算法模型及其实战案例,欢迎点赞,关注共同学习交流。 电影推荐系统 目录电影推荐系统1. 问题介绍1.1推荐系统矩阵分解方法介绍1.2 数据集:ml-100k2. 推荐系统实现2.1 定义矩阵分解函数2.2 …

消息中间件的概念

中间件(middleware)是基础软件的一大类,属于可复用的软件范畴。中间件在操作系统软件,网络和数据库之上,应用软件之下,总的作用是为处于自己上层的应用软件提供运行于开发的环境,帮助用户灵活、高效的开发和集成复杂的…

ICA简介:独立成分分析

1. 简介 您是否曾经遇到过这样一种情况:您试图分析一个复杂且高度相关的数据集,却对信息量感到不知所措?这就是独立成分分析 (ICA) 的用武之地。ICA 是数据分析领域的一项强大技术,可让您分离和识别多元数据集中的底层独立来源。 …

PPP简介,PPP分层体系架构,PPP链路建立过程及PPP的帧格式

PPP(Point-to-Point Protocol)是一种用于在两个网络节点之间传输数据的通信协议。它最初是为在拨号网络上进行拨号连接而开发的,现在已经被广泛应用于各种网络环境中,例如在宽带接入、虚拟专用网(VPN)等场景…

【JAVA】一个项目如何预先加载数据?

这里写目录标题需求实现AutowiredPostConstruct实例CommandLineRunner实例ApplicationListener实例参考需求 一般我们可能会有一些在应用启动时加载资源的需求,局部或者全局使用,让我们来看看都有哪些方式实现。 实现 Autowired 如果是某个类里需求某…

[1]MyBatis+Spring+SpringMVC+SSM整合

一、MyBatis 1、MyBatis简介 1.1、MyBatis历史 MyBatis最初是Apache的一个开源项目iBatis, 2010年6月这个项目由Apache Software Foundation迁移到了Google Code。随着开发团队转投Google Code旗下, iBatis3.x正式更名为MyBatis。代码于2013年11月迁移到Github。…

Vue中如何利用websocket实现实时通讯

首先我们可以先做一个简单的例子来学习一下简单的websocket模拟聊天对话的功能 原理很简单,有点像VUE中的EventBus,用emit和on传来传去 首先我们可以先去自己去用node搭建一个本地服务器 步骤如下 1.新建一个app.js,然后创建pagejson.js文…

【Linux】-- POSIX信号量

目录 POSIX信号量 sem_init - 初始化信号量 sem_destroy - 销毁信号量 sem_wait - 等待信号量(P操作) 基于环形队列的生产消费模型 数据结构 - 环形结构 实现原理 POSIX信号量 #问:什么是信号量? 1. 共享资源 -> 任何一…

【笔记】两台1200PLC进行S7 通信(1)

使用两台1200系列PLC进行S7通信(入门) 文章目录 目录 文章目录 前言 一、通信 1.概念 2.PLC通信 1.串口 2.网口 …

时间颗粒度选择(通过选择时间范围和颗粒度展示选项)

<template><div><el-time-selectplaceholder"起始时间"v-model"startTime":picker-options"startPickerOptions"change"changeStartTime"></el-time-select><el-time-selectplaceholder"结束时间&quo…

想招到实干派程序员?你需要这种面试法

技术招聘中最痛的点其实是不精准。技术面试官或CTO们常常会向我们吐槽&#xff1a; “我经常在想&#xff0c;能不能把我们项目中的代码打印出来&#xff0c;作为候选人的面试题的一部分&#xff1f;” “能不能把一个Bug带上环境&#xff0c;让候选人来试试怎么解决&#xf…

mysql中用逗号隔开的字段作查询用(find_in_set的使用)

mysql中用逗号隔开的字段作查询用(find_in_set的使用) 场景说明 在工作中&#xff0c;经常会遇到一对多的关系。想要在mysql中保存这种关系&#xff0c;一般有两种方式&#xff0c;一种是建立一张中间表&#xff0c;这样一条id就会存在多条记录。或者采用第二种方式&#xff…

【数据结构必会基础】关于树,你所必须知道的亿些概念

目录 1.什么是树 1.1浅显的理解树 1.2 数据结构中树的概念 2.树的各种结构概念 2.1 节点的度 2.2 根节点/叶节点/分支节点 2.3 父节点/子节点 2.4祖先节点/子孙节点 2.5兄弟节点 2.6树的度 2.7节点的层次 2.8森林 3. 如何用代码表示一棵树 3.1链式结构 3.1.1 树节…

Gitea Windows环境下服务搭建

前言&#xff1a;这篇文章没有去分析各大平台的优劣势&#xff0c;仅教学大家搭建一个属于自己的git代码管理器&#xff0c;主要作用在局域网内&#xff0c;办公电脑搭建一个简单的Gitea代码管理器。数据库使用SQLite3&#xff0c;环境是windows10。如果不是这个环境的话&#…

@Import注解的原理

此注解是springboot自动注入的关键注解&#xff0c;所以拿出来单独分析一下。 启动类的run方法跟进去最终找到refresh方法&#xff1b; 这里直接看这个org.springframework.context.support.AbstractApplicationContext#refresh方法即可&#xff0c;它下面有一个方法 invoke…

Node下载阿里OSS存储文件【不知目录结构】

前言&#xff1a;前端传模型ID&#xff0c;后台根据ID去阿里OSS存储下载对应文件&#xff08;不知文件内部层级结构&#xff0c;且OSS只能单个文件下载&#xff09;&#xff0c;打包成zip字节流形式返回给前端下载。 需求分析&#xff1a; 生成OSS文件关系树Node做文件下载存…

kafka(一) 的架构,各概念

Kafka架构 Kafak 总体架构图中包含多个概念&#xff1a; &#xff08;1&#xff09;ZooKeeper&#xff1a;Zookeeper负责保存broker集群元数据&#xff0c;并对控制器进行选举等操作。 &#xff08;2&#xff09;Producer&#xff1a; 生产者负责创建消息&#xff0c;将消息发…

【神经网络】LSTM为什么能缓解梯度消失

1.LSTM的结构 我们先来看一下LSTM的计算公式&#xff1a; 1.遗忘门&#xff1a; 2.输入门&#xff1a; 3.细胞状态 4.输出门 2.LSTM的梯度路径 根据LSTM的计算公式&#xff0c;可以得出LSTM的cell state与、、都存在计算关系&#xff0c;而、、的计算公式又全部都与有关&#x…