图像生成:Pytorch实现一个简单的对抗生成网络模型

news/2024/5/5 10:34:42/文章来源:https://blog.csdn.net/FriendshipTang/article/details/137604437

图像生成:Pytorch实现一个简单的对抗生成网络模型

  • 前言
  • 相关介绍
  • 具体步骤
    • 准备并读取数据集
    • 定义生成器
    • 定义判别器
    • 定义损失函数
    • 定义优化器
    • 开始训练
    • 完整代码
  • 训练生成的图片

前言

  • 由于本人水平有限,难免出现错漏,敬请批评改正。
  • 更多精彩内容,可点击进入人工智能知识点专栏、Python日常小操作专栏、OpenCV-Python小应用专栏、YOLO系列专栏、自然语言处理专栏或我的个人主页查看
  • 基于DETR的人脸伪装检测
  • YOLOv7训练自己的数据集(口罩检测)
  • YOLOv8训练自己的数据集(足球检测)
  • YOLOv5:TensorRT加速YOLOv5模型推理
  • YOLOv5:IoU、GIoU、DIoU、CIoU、EIoU
  • 玩转Jetson Nano(五):TensorRT加速YOLOv5目标检测
  • YOLOv5:添加SE、CBAM、CoordAtt、ECA注意力机制
  • YOLOv5:yolov5s.yaml配置文件解读、增加小目标检测层
  • Python将COCO格式实例分割数据集转换为YOLO格式实例分割数据集
  • YOLOv5:使用7.0版本训练自己的实例分割模型(车辆、行人、路标、车道线等实例分割)
  • 使用Kaggle GPU资源免费体验Stable Diffusion开源项目

相关介绍

对抗生成网络(Adversarial Generative Networks,简称GANs)是由伊恩·古德费洛(Ian Goodfellow)等人在2014年提出的深度学习框架,主要用于无监督学习中的数据生成任务。GAN的设计灵感来源于博弈论中的极小极大游戏,在机器学习领域开辟了一种全新的生成模型方法。

GAN包含两个主要组成部分:

  1. 生成器(Generator, G)

    • 生成器G是一个神经网络,它的功能是从随机噪声向量(通常来自高斯分布或其他先验分布)出发,通过一系列变换来生成新的数据样本,这些样本应该尽可能接近于训练数据的真实分布。
    • G的目标是尽可能地“欺骗”判别器,使其无法区分生成的数据和真实数据。
  2. 判别器(Discriminator, D)

    • 判别器D也是一个神经网络,它的作用是接收输入的样本,并将其分类为“真”(来自真实数据分布)或“假”(来自生成器产生的分布)。
    • D的目标是准确地区分真实数据和伪造数据,从而提高自身鉴别能力。

训练过程中,GAN采用了极小极大博弈策略,具体步骤如下:

  • 先固定判别器参数,更新生成器参数以使得生成的数据更有可能被判别器误认为真实数据;
  • 再固定生成器参数,更新判别器参数以更好地区分真实数据和生成器生成的伪数据。

这种交替优化的方式促使两个网络性能不断提升,直到达到纳什均衡,此时生成器能够生成非常逼真的新样本,而判别器再也无法有效区分真实数据和生成数据。

GAN在图像生成、图像到图像转换、视频生成、音频合成等多个领域都取得了显著成果,并且随着研究的深入,衍生出了许多改进版本和变体,比如条件GAN(Conditional GANs)、 Wasserstein GAN(WGAN)、CycleGAN等,进一步提高了生成效果并拓展了应用范围。
在这里插入图片描述

具体步骤

准备并读取数据集

以mnist数据集为例。

# Configure data loader
os.makedirs("./data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(datasets.MNIST("./data/mnist",train=True,download=True,transform=transforms.Compose([transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),),batch_size=opt.batch_size,shuffle=True,
)

定义生成器

class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()def block(in_feat, out_feat, normalize=True):layers = [nn.Linear(in_feat, out_feat)]if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8))layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(*block(opt.latent_dim, 128, normalize=False),*block(128, 256),*block(256, 512),*block(512, 1024),nn.Linear(1024, int(np.prod(img_shape))),nn.Tanh())def forward(self, z):img = self.model(z)img = img.view(img.size(0), *img_shape)return img

定义判别器

class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(int(np.prod(img_shape)), 512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 1),nn.Sigmoid(),)def forward(self, img):img_flat = img.view(img.size(0), -1)validity = self.model(img_flat)return validity

定义损失函数

# Loss function
adversarial_loss = torch.nn.BCELoss()

定义优化器

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

开始训练

# ----------
#  Training
# ----------for epoch in range(opt.n_epochs):for i, (imgs, _) in enumerate(dataloader):# Adversarial ground truthsvalid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)# Configure inputreal_imgs = Variable(imgs.type(Tensor))# -----------------#  Train Generator# -----------------optimizer_G.zero_grad()# Sample noise as generator inputz = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))# Generate a batch of imagesgen_imgs = generator(z)# Loss measures generator's ability to fool the discriminator# 生成器损失,目的是生成的图像,逼近真是图像,是判别器认为是真g_loss = adversarial_loss(discriminator(gen_imgs), valid)g_loss.backward()optimizer_G.step()# ---------------------#  Train Discriminator# ---------------------optimizer_D.zero_grad()# Measure discriminator's ability to classify real from generated samples# 真实图像在判别器的目标函数损失,即判别器应该判为真real_loss = adversarial_loss(discriminator(real_imgs), valid)# 生成器生成出来的图像在判别器的目标函数损失,即判别器应该判为假fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)# 判别器损失,目的是判别生成的图像是假d_loss = (real_loss + fake_loss) / 2d_loss.backward()optimizer_D.step()print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item()))

完整代码

import argparse
import os
import numpy as np
import mathimport torchvision.transforms as transforms
from torchvision.utils import save_imagefrom torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variableimport torch.nn as nn
import torch.nn.functional as F
import torchos.makedirs("images", exist_ok=True)parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=100, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=128, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")
opt = parser.parse_args()
print(opt)img_shape = (opt.channels, opt.img_size, opt.img_size)cuda = True if torch.cuda.is_available() else Falseclass Generator(nn.Module):def __init__(self):super(Generator, self).__init__()def block(in_feat, out_feat, normalize=True):layers = [nn.Linear(in_feat, out_feat)]if normalize:layers.append(nn.BatchNorm1d(out_feat, 0.8))layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(*block(opt.latent_dim, 128, normalize=False),*block(128, 256),*block(256, 512),*block(512, 1024),nn.Linear(1024, int(np.prod(img_shape))),nn.Tanh())def forward(self, z):img = self.model(z)img = img.view(img.size(0), *img_shape)return imgclass Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(int(np.prod(img_shape)), 512),nn.LeakyReLU(0.2, inplace=True),nn.Linear(512, 256),nn.LeakyReLU(0.2, inplace=True),nn.Linear(256, 1),nn.Sigmoid(),)def forward(self, img):img_flat = img.view(img.size(0), -1)validity = self.model(img_flat)return validity# Loss function
adversarial_loss = torch.nn.BCELoss()# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()if cuda:generator.cuda()discriminator.cuda()adversarial_loss.cuda()# Configure data loader
os.makedirs("./data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(datasets.MNIST("./data/mnist",train=True,download=True,transform=transforms.Compose([transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),),batch_size=opt.batch_size,shuffle=True,
)# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor# ----------
#  Training
# ----------for epoch in range(opt.n_epochs):for i, (imgs, _) in enumerate(dataloader):# Adversarial ground truthsvalid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)# Configure inputreal_imgs = Variable(imgs.type(Tensor))# -----------------#  Train Generator# -----------------optimizer_G.zero_grad()# Sample noise as generator inputz = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))# Generate a batch of imagesgen_imgs = generator(z)# Loss measures generator's ability to fool the discriminatorg_loss = adversarial_loss(discriminator(gen_imgs), valid)g_loss.backward()optimizer_G.step()# ---------------------#  Train Discriminator# ---------------------optimizer_D.zero_grad()# Measure discriminator's ability to classify real from generated samplesreal_loss = adversarial_loss(discriminator(real_imgs), valid)fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)d_loss = (real_loss + fake_loss) / 2d_loss.backward()optimizer_D.step()print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item()))batches_done = epoch * len(dataloader) + iif batches_done % opt.sample_interval == 0:save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)

在这里插入图片描述

训练生成的图片

在这里插入图片描述

  • 由于本人水平有限,难免出现错漏,敬请批评改正。
  • 更多精彩内容,可点击进入人工智能知识点专栏、Python日常小操作专栏、OpenCV-Python小应用专栏、YOLO系列专栏、自然语言处理专栏或我的个人主页查看
  • 基于DETR的人脸伪装检测
  • YOLOv7训练自己的数据集(口罩检测)
  • YOLOv8训练自己的数据集(足球检测)
  • YOLOv5:TensorRT加速YOLOv5模型推理
  • YOLOv5:IoU、GIoU、DIoU、CIoU、EIoU
  • 玩转Jetson Nano(五):TensorRT加速YOLOv5目标检测
  • YOLOv5:添加SE、CBAM、CoordAtt、ECA注意力机制
  • YOLOv5:yolov5s.yaml配置文件解读、增加小目标检测层
  • Python将COCO格式实例分割数据集转换为YOLO格式实例分割数据集
  • YOLOv5:使用7.0版本训练自己的实例分割模型(车辆、行人、路标、车道线等实例分割)
  • 使用Kaggle GPU资源免费体验Stable Diffusion开源项目

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_1045875.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【vim 学习系列文章 20 -- a:mode 的值有哪些?】

请阅读【嵌入式开发学习必备专栏 之 Vim】 文章目录 a:mode 的值有哪些?举例Vim 底部状态栏设置 a:mode 的值有哪些? 在 Vim 脚本语言中,a:mode 常常用于函数内部,以获取该函数被调用时 Vim 正处于的模式。它主常用于那些可以从不同模式下被调用的函数…

websocket分析和前后端如何接入websocket

1、前言 websocket一般用途为消息提醒,股票行情数据推送等等,有很多用途。 我们这里简单举例理解websocket和如何前后端接入websocket; 使用网络抓包分析软件。主要是截取网络封包,并尽可能显示出最为详细的网络封包资料。 2、…

怎么快速围绕“人、货、场”做零售数据分析?

做零售数据分析多了,不难发现零售数据分析的关键就是“人、货、场”,那么怎么又快又灵活地分析这三个关键点?不妨参考下奥威BI零售数据分析方案。 奥威BI零售数据分析方案是一套吸取大量项目经验,结合零售企业数据分析共性需求打…

MyBatis 等类似的 XML 映射文件中,当传入的参数为空字符串时,<if> 标签可能会导致 SQL 语句中的条件判断出现意外结果。

问题 传入的参数为空字符串,但还是根据参数查询了。 原因 在 XML 中使用 标签进行条件判断时,需要明确理解其行为。在 MyBatis 等类似的 XML 映射文件中, 标签通常用于动态拼接 SQL 语句的条件部分。当传入的参数 riskLevel 为空字符串时…

acwing总结-线性质数筛

质数筛 题目链接:质数筛线性筛法 ac代码&#xff1a; #include<iostream> #include<algorithm> //https://www.bilibili.com/video/BV1LR4y1Z7pm/?spm_id_from333.337.search-card.all.click&vd_source436ccbb3a8f50110aa75654f38e35672 //链接到b站视频 us…

如何在 YouTube、Medium、Twitter 和 Linkedin 上使用 ChatGPT 赚钱

人工智能SEO&#xff1a;未来内容优化的革命 介绍 在当今的数字时代&#xff0c;利用 ChatGPT 等人工智能工具已经成为在线内容创建和货币化领域的游戏规则改变者。 本指南探讨了如何在 YouTube、Medium、Twitter 和 Linkedin 等各种平台上有效使用 ChatGPT&#xff0c;不仅可以…

el-table动态合并列

需要合并列&#xff0c;并且不确定需要合并多少列的&#xff0c;可以参照如下代码 首先需要再el-table上传入span-method方法 arraySpan({ row, column, rowIndex, columnIndex }){if (row.groupName 汇总 && columnIndex 0) {return [0,0]} else if (row.groupName…

机器学习-08-关联规则和协同过滤

总结 本系列是机器学习课程的系列课程&#xff0c;主要介绍机器学习中关联规则和协同过滤。 参考 机器学习&#xff08;三&#xff09;&#xff1a;Apriori算法&#xff08;算法精讲&#xff09; Apriori 算法 理论 重点 MovieLens:一个常用的电影推荐系统领域的数据集 2…

01、ArcGIS For JavaScript 4.29对3DTiles数据的支持

综述 Cesium从1.99版本开始支持I3S服务的加载&#xff0c;到目前位置&#xff0c;已经支持I3S的倾斜模型、3D Object模型以及属性查询的支持。Cesium1.115又对I3S标准的Building数据实现了加载支持。而ArcGIS之前一直没有跨越对3DTiles数据的支持&#xff0c;所以在一些开发过…

数字社交的新典范:解析Facebook的成功密码

在当今数字化时代&#xff0c;社交媒体已经成为人们日常生活的重要组成部分&#xff0c;而Facebook作为最知名的社交媒体平台之一&#xff0c;其成功之处备受瞩目。本文将深入解析Facebook的成功密码&#xff0c;探讨其在数字社交领域的新典范。 1. 用户体验的优化 Facebook注…

Docker容器嵌入式开发:在Ubuntu上配置Postman和flatpak

在 Ubuntu 上配置 Postman 可以通过 Snap 命令完成&#xff0c;以下是所有命令的总结&#xff1a; sudo snap install postmansudo snap install flatpak在 Ubuntu 上配置 Postman 和 Flatpak 非常简单。以下是一些简单的步骤&#xff1a; 配置 Flatpak 安装 Flatpak&#x…

【力扣】104. 二叉树的最大深度、111. 二叉树的最小深度

104. 二叉树的最大深度 题目描述 给定一个二叉树 root &#xff0c;返回其最大深度。 二叉树的 最大深度 是指从根节点到最远叶子节点的最长路径上的节点数。 示例 1&#xff1a; 输入&#xff1a;root [3,9,20,null,null,15,7] 输出&#xff1a;3 示例 2&#xff1a; 输…

【linux】yum 和 vim

yum 和 vim 1. Linux 软件包管理器 yum1.1 什么是软件包1.2 查看软件包1.3 如何安装软件1.4 如何卸载软件1.5 关于 rzsz 2. Linux编辑器-vim使用2.1 vim的基本概念2.2 vim的基本操作2.3 vim命令模式命令集2.4 vim底行模式命令集2.5 vim操作总结补充&#xff1a;vim下批量化注释…

IOS虚拟键盘弹出后,弹窗的按钮点击不起作用,不会触发click事件

背景 讨论区项目的回复框&#xff0c;使用的是Popup和TextArea做&#xff0c;布局如下图&#xff0c;希望键盘弹出时候&#xff0c;回复框可以紧贴键盘&#xff0c;这点实现起来比较简单&#xff0c;监听resize事件&#xff0c;动态修改popup的这题内容的top值即可&#xff0c…

二叉数应用——最优二叉树(Huffman树)、贪心算法—— Huffman编码

1、外部带权外部路径长度、Huffman树 从图中可以看出&#xff0c;深度越浅的叶子结点权重越大&#xff0c;深度越深的叶子结点权重越小的话&#xff0c;得出的带权外部路径长度越小。 Huffman树就是使得外部带权路径最小的二叉树 2、如何构造Huffman树 &#xff08;1&#xf…

SpringBoot中的Redis的简单使用

在Spring Boot项目中使用Redis作为缓存、会话存储或分布式锁等组件&#xff0c;可以简化开发流程并充分利用Redis的高性能特性。以下是使用Spring Boot整合Redis的详细步骤&#xff1a; 1. 环境准备 确保开发环境中已安装&#xff1a; Java&#xff1a;用于编写和运行Spring…

HTML5学习记录

简介 超文本标记语言&#xff08;HyperText Markup Language&#xff0c;简称HTML&#xff09;&#xff0c;是一种用于创建网页的标准标记语言。 编辑器 下载传送门https://code.visualstudio.com/ 下载编辑器插件 标题 标题通过 <h1> - <h6> 标签进行定义。 …

xss.pwnfunction-Ok,Boomer

调用0k会直接调用tostring方法获得href里的内容而setTimeout第一个参数恰巧可以接收字符串 但是href必须是协议&#xff1a;主机名 这里tel是dompurify框架的白名单 <a idok hreftel:alert(1337)>

xss跨站脚本攻击笔记

1 XSS跨站脚本攻击 1.1 xss跨站脚本攻击介绍 跨站脚本攻击英文全称为(Cross site Script)缩写为CSS&#xff0c;但是为了和层叠样式表(CascadingStyle Sheet)CSS区分开来&#xff0c;所以在安全领域跨站脚本攻击叫做XSS 1.2 xss跨战脚本攻击分类 第一种类型:反射型XSS 反射…

Node.js cnpm的安装

百度搜索 cnpm,进入npmmirror 镜像站https://npmmirror.com/ cmd窗口输入 npm install -g cnpm --registryhttps://registry.npmmirror.com