原始GAN-pytorch-生成MNIST数据集(代码)

news/2024/4/19 7:55:31/文章来源:https://blog.csdn.net/jerry_liufeng/article/details/129238417

文章目录

  • 原始GAN生成MNIST数据集
    • 1. Data loading and preparing
    • 2. Dataset and Model parameter
    • 3. Result save path
    • 4. Model define
    • 6. Training
    • 7. predict

原始GAN生成MNIST数据集

原理很简单,可以参考原理部分原始GAN-pytorch-生成MNIST数据集(原理)

import os
import time
import torch
from tqdm import tqdm
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
import sys 
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

1. Data loading and preparing

测试使用loadlocal_mnist加载数据

from mlxtend.data import loadlocal_mnist
train_data_path = "../data/MNIST/train-images.idx3-ubyte"
train_label_path = "../data/MNIST/train-labels.idx1-ubyte"
test_data_path = "../data/MNIST/t10k-images.idx3-ubyte"
test_label_path = "../data/MNIST/t10k-labels.idx1-ubyte"train_data,train_label = loadlocal_mnist(images_path = train_data_path,labels_path = train_label_path
)
train_data.shape,train_label.shape
((60000, 784), (60000,))
import matplotlib.pyplot as pltimg,ax = plt.subplots(3,3,figsize=(9,9))
plt.subplots_adjust(hspace=0.4,wspace=0.4)
for i in range(3):for j in range(3):num = np.random.randint(0,train_label.shape[0])ax[i][j].imshow(train_data[num].reshape((28,28)),cmap="gray")ax[i][j].set_title(train_label[num],fontdict={"fontsize":20})
plt.show()

在这里插入图片描述

2. Dataset and Model parameter

构造pytorch数据集datasets和数据加载器dataloader

input_size = [1, 28, 28]
batch_size = 128
Epoch = 1000
GenEpoch = 1
in_channel = 64
from torch.utils.data import Dataset,DataLoader
import numpy as np 
from mlxtend.data import loadlocal_mnist
import torchvision.transforms as transformsclass MNIST_Dataset(Dataset):def __init__(self,train_data_path,train_label_path,transform=None):train_data,train_label = loadlocal_mnist(images_path = train_data_path,labels_path = train_label_path)self.train_data = train_dataself.train_label = train_label.reshape(-1)self.transform=transformdef __len__(self):return self.train_label.shape[0] def __getitem__(self,index):if torch.is_tensor(index):index = index.tolist()images = self.train_data[index,:].reshape((28,28))labels = self.train_label[index]if self.transform:images = self.transform(images)return images,labelstransform_dataset =transforms.Compose([transforms.ToTensor()]
)
MNIST_dataset = MNIST_Dataset(train_data_path=train_data_path,train_label_path=train_label_path,transform=transform_dataset)  
MNIST_dataloader = DataLoader(dataset=MNIST_dataset,batch_size=batch_size,shuffle=True,drop_last=False)
img,ax = plt.subplots(3,3,figsize=(9,9))
plt.subplots_adjust(hspace=0.4,wspace=0.4)
for i in range(3):for j in range(3):num = np.random.randint(0,train_label.shape[0])ax[i][j].imshow(MNIST_dataset[num][0].reshape((28,28)),cmap="gray")ax[i][j].set_title(MNIST_dataset[num][1],fontdict={"fontsize":20})
plt.show()

在这里插入图片描述

3. Result save path

time_now = time.strftime('%Y-%m-%d-%H_%M_%S', time.localtime(time.time()))
log_path = f'./log/{time_now}'
os.makedirs(log_path)
os.makedirs(f'{log_path}/image')
os.makedirs(f'{log_path}/image/image_all')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'using device: {device}')
using device: cuda

4. Model define

import torch
from torch import nn class Discriminator(nn.Module):def __init__(self,input_size,inplace=True):super(Discriminator,self).__init__()c,h,w = input_sizeself.dis = nn.Sequential(nn.Linear(c*h*w,512),  # 输入特征数为784,输出为512nn.BatchNorm1d(512),nn.LeakyReLU(0.2),  # 进行非线性映射nn.Linear(512, 256),  # 进行一个线性映射nn.BatchNorm1d(256),nn.LeakyReLU(0.2),nn.Linear(256, 1),nn.Sigmoid()  # 也是一个激活函数,二分类问题中,# sigmoid可以班实数映射到【0,1】,作为概率值,# 多分类用softmax函数)def forward(self,x):b,c,h,w = x.size()x = x.view(b,-1)x = self.dis(x)x = x.view(-1)return x class Generator(nn.Module):def __init__(self,in_channel):super(Generator,self).__init__() # 调用父类的构造方法self.gen = nn.Sequential(nn.Linear(in_channel, 128),nn.LeakyReLU(0.2),nn.Linear(128, 256),nn.BatchNorm1d(256),nn.LeakyReLU(0.2),nn.Linear(256, 512),nn.BatchNorm1d(512),nn.LeakyReLU(0.2),nn.Linear(512, 1024),nn.BatchNorm1d(1024),nn.LeakyReLU(0.2),nn.Linear(1024, 784),nn.Tanh())def forward(self,x):res = self.gen(x)return res.view(x.size()[0],1,28,28)D = Discriminator(input_size=input_size)
G = Generator(in_channel=in_channel)
D.to(device)
G.to(device)
D,G
(Discriminator((dis): Sequential((0): Linear(in_features=784, out_features=512, bias=True)(1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): LeakyReLU(negative_slope=0.2)(3): Linear(in_features=512, out_features=256, bias=True)(4): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(5): LeakyReLU(negative_slope=0.2)(6): Linear(in_features=256, out_features=1, bias=True)(7): Sigmoid())),Generator((gen): Sequential((0): Linear(in_features=64, out_features=128, bias=True)(1): LeakyReLU(negative_slope=0.2)(2): Linear(in_features=128, out_features=256, bias=True)(3): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(4): LeakyReLU(negative_slope=0.2)(5): Linear(in_features=256, out_features=512, bias=True)(6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(7): LeakyReLU(negative_slope=0.2)(8): Linear(in_features=512, out_features=1024, bias=True)(9): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(10): LeakyReLU(negative_slope=0.2)(11): Linear(in_features=1024, out_features=784, bias=True)(12): Tanh())))

6. Training

criterion = nn.BCELoss()
D_optimizer = torch.optim.Adam(D.parameters(),lr=0.0003)
G_optimizer = torch.optim.Adam(G.parameters(),lr=0.0003)
D.train()
G.train()
gen_loss_list = []
dis_loss_list = []for epoch in range(Epoch):with tqdm(total=MNIST_dataloader.__len__(),desc=f'Epoch {epoch+1}/{Epoch}')as pbar:gen_loss_avg = []dis_loss_avg = []index = 0for batch_idx,(img,_) in enumerate(MNIST_dataloader):img = img.to(device)# the output labelvalid = torch.ones(img.size()[0]).to(device)fake = torch.zeros(img.size()[0]).to(device)# Generator inputG_img = torch.randn([img.size()[0],in_channel],requires_grad=True).to(device)# ------------------Update Discriminator------------------# forwardG_pred_gen = G(G_img)G_pred_dis = D(G_pred_gen.detach())R_pred_dis = D(img)# the misfitG_loss = criterion(G_pred_dis,fake)R_loss = criterion(R_pred_dis,valid)dis_loss = (G_loss+R_loss)/2dis_loss_avg.append(dis_loss.item())# backwardD_optimizer.zero_grad()dis_loss.backward()D_optimizer.step()# ------------------Update Optimizer------------------# forwardG_pred_gen = G(G_img)G_pred_dis = D(G_pred_gen)# the misfitgen_loss = criterion(G_pred_dis,valid)gen_loss_avg.append(gen_loss.item())# backwardG_optimizer.zero_grad()gen_loss.backward()G_optimizer.step()# save figureif index % 200 == 0 or index + 1 == MNIST_dataset.__len__():save_image(G_pred_gen, f'{log_path}/image/image_all/epoch-{epoch}-index-{index}.png')index += 1# ------------------进度条更新------------------pbar.set_postfix(**{'gen-loss': sum(gen_loss_avg) / len(gen_loss_avg),'dis-loss': sum(dis_loss_avg) / len(dis_loss_avg)})pbar.update(1)save_image(G_pred_gen, f'{log_path}/image/epoch-{epoch}.png')filename = 'epoch%d-genLoss%.2f-disLoss%.2f' % (epoch, sum(gen_loss_avg) / len(gen_loss_avg), sum(dis_loss_avg) / len(dis_loss_avg))torch.save(G.state_dict(), f'{log_path}/{filename}-gen.pth')torch.save(D.state_dict(), f'{log_path}/{filename}-dis.pth')# 记录损失gen_loss_list.append(sum(gen_loss_avg) / len(gen_loss_avg))dis_loss_list.append(sum(dis_loss_avg) / len(dis_loss_avg))# 绘制损失图像并保存plt.figure(0)plt.plot(range(epoch + 1), gen_loss_list, 'r--', label='gen loss')plt.plot(range(epoch + 1), dis_loss_list, 'r--', label='dis loss')plt.legend()plt.xlabel('epoch')plt.ylabel('loss')plt.savefig(f'{log_path}/loss.png', dpi=300)plt.close(0)
Epoch 1/1000: 100%|██████████| 469/469 [00:11<00:00, 41.56it/s, dis-loss=0.456, gen-loss=1.17] 
Epoch 2/1000: 100%|██████████| 469/469 [00:11<00:00, 42.34it/s, dis-loss=0.17, gen-loss=2.29] 
Epoch 3/1000: 100%|██████████| 469/469 [00:10<00:00, 43.29it/s, dis-loss=0.0804, gen-loss=3.11]
Epoch 4/1000: 100%|██████████| 469/469 [00:11<00:00, 40.74it/s, dis-loss=0.0751, gen-loss=3.55]
Epoch 5/1000: 100%|██████████| 469/469 [00:12<00:00, 39.01it/s, dis-loss=0.105, gen-loss=3.4]  
Epoch 6/1000: 100%|██████████| 469/469 [00:11<00:00, 39.95it/s, dis-loss=0.112, gen-loss=3.38]
Epoch 7/1000: 100%|██████████| 469/469 [00:11<00:00, 40.16it/s, dis-loss=0.116, gen-loss=3.42]
Epoch 8/1000: 100%|██████████| 469/469 [00:11<00:00, 42.51it/s, dis-loss=0.124, gen-loss=3.41]
Epoch 9/1000: 100%|██████████| 469/469 [00:11<00:00, 40.95it/s, dis-loss=0.136, gen-loss=3.41]
Epoch 10/1000: 100%|██████████| 469/469 [00:11<00:00, 39.59it/s, dis-loss=0.165, gen-loss=3.13]
Epoch 11/1000: 100%|██████████| 469/469 [00:11<00:00, 40.28it/s, dis-loss=0.176, gen-loss=3.01]
Epoch 12/1000: 100%|██████████| 469/469 [00:12<00:00, 37.60it/s, dis-loss=0.19, gen-loss=2.94] 
Epoch 13/1000: 100%|██████████| 469/469 [00:11<00:00, 39.17it/s, dis-loss=0.183, gen-loss=2.95]
Epoch 14/1000: 100%|██████████| 469/469 [00:12<00:00, 38.51it/s, dis-loss=0.182, gen-loss=3.01]
Epoch 15/1000: 100%|██████████| 469/469 [00:10<00:00, 44.58it/s, dis-loss=0.186, gen-loss=2.95]
Epoch 16/1000: 100%|██████████| 469/469 [00:10<00:00, 44.08it/s, dis-loss=0.198, gen-loss=2.89]
Epoch 17/1000: 100%|██████████| 469/469 [00:10<00:00, 45.11it/s, dis-loss=0.187, gen-loss=2.99]
Epoch 18/1000: 100%|██████████| 469/469 [00:10<00:00, 44.98it/s, dis-loss=0.183, gen-loss=3.03]
Epoch 19/1000: 100%|██████████| 469/469 [00:10<00:00, 46.68it/s, dis-loss=0.187, gen-loss=2.98]
Epoch 20/1000: 100%|██████████| 469/469 [00:10<00:00, 46.12it/s, dis-loss=0.192, gen-loss=3]   
Epoch 21/1000: 100%|██████████| 469/469 [00:10<00:00, 46.80it/s, dis-loss=0.193, gen-loss=3.01]
Epoch 22/1000: 100%|██████████| 469/469 [00:10<00:00, 45.86it/s, dis-loss=0.186, gen-loss=3.04]
Epoch 23/1000: 100%|██████████| 469/469 [00:10<00:00, 46.00it/s, dis-loss=0.17, gen-loss=3.2]  
Epoch 24/1000: 100%|██████████| 469/469 [00:10<00:00, 46.41it/s, dis-loss=0.173, gen-loss=3.19]
Epoch 25/1000: 100%|██████████| 469/469 [00:10<00:00, 45.15it/s, dis-loss=0.19, gen-loss=3.1]  
Epoch 26/1000: 100%|██████████| 469/469 [00:10<00:00, 44.26it/s, dis-loss=0.178, gen-loss=3.16]
Epoch 27/1000: 100%|██████████| 469/469 [00:10<00:00, 45.14it/s, dis-loss=0.187, gen-loss=3.17]
Epoch 28/1000:   1%|▏         | 6/469 [00:00<00:12, 38.20it/s, dis-loss=0.184, gen-loss=3.04]---------------------------------------------------------------------------

7. predict

input_size = [3, 32, 32]
in_channel = 64
gen_para_path = './log/2023-02-11-17_52_12/epoch999-genLoss1.21-disLoss0.40-gen.pth'
dis_para_path = './log/2023-02-11-17_52_12/epoch999-genLoss1.21-disLoss0.40-dis.pth'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator_Transpose(in_channel=in_channel).to(device)
dis = DiscriminatorLinear(input_size=input_size).to(device)
gen.load_state_dict(torch.load(gen_para_path, map_location=device))
gen.eval()
# 随机生成一组数据
G_img = torch.randn([1, in_channel, 1, 1], requires_grad=False).to(device)
# 放入网路
G_pred = gen(G_img)
G_dis = dis(G_pred)
print('generator-dis:', G_dis)
# 图像显示
G_pred = G_pred[0, ...]
G_pred = G_pred.detach().cpu().numpy()
G_pred = np.array(G_pred * 255)
G_pred = np.transpose(G_pred, [1, 2, 0])
G_pred = Image.fromarray(np.uint8(G_pred))
G_pred.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_74735.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

记一次线上es慢查询导致的服务不可用

现象 某日线上业务同学反馈订单列表查询页面一直loding&#xff0c;然后提示请求超时&#xff0c;几分钟之后恢复正常 接到报障之后&#xff0c;马上根据接口URL&#xff0c;定位到了请求链路&#xff0c;发现是es查询超时&#xff0c;这里我们的业务订单表数据是由几百万的&a…

如何基于MLServer构建Python机器学习服务

文章目录前言一、数据集二、训练 Scikit-learn 模型三、基于MLSever构建Scikit-learn服务四、测试模型五、训练 XGBoost 模型六、服务多个模型七、测试多个模型的准确性总结参考前言 在过去我们训练模型&#xff0c;往往通过编写flask代码或者容器化我们的模型并在docker中运行…

Python学习笔记202302

1、numpy.empty 作用&#xff1a;根据给定的维度和数值类型返回一个新的数组&#xff0c;其元素不进行初始化。 用法&#xff1a;numpy.empty(shape, dtypefloat, order‘C’) 2、logging.debug 作用&#xff1a;Python 的日志记录工具&#xff0c;这个模块为应用与库实现了灵…

C# Sqlite数据库加密

sqlite官方的数据库加密是收费的&#xff0c;而且比较贵。 幸亏微软提供了一种免费的方法。 1 sqlite加密demo 这里我做了一个小的demo演示如下&#xff1a; 在界面中拖入数据库名、密码、以及保存的路径 比如我选择保存路径桌面的sqlite目录&#xff0c;数据库名guigutool…

Verilog 学习第五节(串口接收部分)

小梅哥串口部分学习part2 串口通信接收原理串口通信接收程序设计与调试巧用位操作优化串口接收逻辑设计串口接收模块的项目应用案例串口通信接收原理 在采样的时候没有必要一直判断一个clk内全部都是高/低电平&#xff0c;如果采用直接对中间点进行判断的话&#xff0c;很有可能…

Linux 红帽9.0 本地源 与 网络源 搭建

本次我们使用的是 redhat 9.0 版本&#xff0c;是redhat 的最新版本&#xff0c;我们一起来对其进行 本地仓库 和 网络仓库的搭建部署~&#xff01;&#xff01;关于 本地仓库&#xff08; 本地源 &#xff09;&#xff0c;和 网络仓库 &#xff08; 网络源 &#xff09;&#…

ESP32蓝牙配网

注意********menuconfig 配置&#xff08;必须打开蓝牙我这是C2所以使用NimBLE &#xff09;可以直接从demo的配置文件拷贝 Component config ---> Bluetooth ---> NimBLE - BLE only Component config ---> Bluetooth ---> NimBLE Options ---> Enable blufi…

计算结构体大小

计算结构体大小 目录计算结构体大小一. 结构体内存对齐1. 简介2. 嵌套结构体二. offsetof三. 内存对齐的意义四. 修改默认对齐数一. 结构体内存对齐 以字节&#xff08;bety&#xff09;为单位 1. 简介 对于结构体成员在内存里的存储&#xff0c;存在结构体的对齐规则&#…

Vue下载安装步骤的详细教程(亲测有效) 1

目录 一、【准备工作】nodejs下载安装(npm环境) 1 下载安装nodejs 2 查看环境变量是否添加成功 3、验证是否安装成功 4、修改模块下载位置 &#xff08;1&#xff09;查看npm默认存放位置 &#xff08;2&#xff09;在 nodejs 安装目录下&#xff0c;创建 “node_global…

Java查漏补缺(14)数据结构剖析、一维数组、链表、栈、队列、树与二叉树、List接口分析、Map接口分析、Set接口分析、HashMap的相关问题

Java查漏补缺&#xff08;14&#xff09;数据结构剖析、一维数组、链表、栈、队列、树与二叉树、List接口分析、Map接口分析、Set接口分析、HashMap的相关问题本章专题与脉络1. 数据结构剖析1.1 研究对象一&#xff1a;数据间逻辑关系1.2 研究对象二&#xff1a;数据的存储结构…

Laravel框架04:视图与CSRF攻击

Laravel框架04&#xff1a;视图与CSRF攻击一、视图概述二、变量分配与展示三、模板中直接使用函数四、循环与分支语法标签五、模板继承、包含1. 继承2. 包含六、外部静态文件引入七、CSRF攻击概述八、从CSRF验证中排除例外路由一、视图概述 视图存放在 resources/views 目录下…

MyBatis学习笔记(七) —— 特殊SQL的执行

7、特殊SQL的执行 7.1、模糊查询 模糊查询的三种方式&#xff1a; 方式1&#xff1a;select * from t_user where username like ‘%${mohu}%’ 方式2&#xff1a;select * from t_user where username like concat(‘%’,#{mohu},‘%’) 方式3&#xff1a;select * from t_u…

收集分享一些AI工具第三期(网站篇)

感谢大家对于内容的喜欢&#xff0c;目前已经来到了AI工具分享的最后一期了&#xff0c;目前为止大部分好用的AI工具都已经介绍给大家了&#xff0c;希望大家可以喜欢。 image-to-sound-fx (https://huggingface.co/spaces/fffiloni/image-to-sound-fx) 图片转换为相对应的声音…

2.27 junit5常用语法

一.了解junitjunit是一个开源的java单元测试框架,java方向使用最广泛的单元测试框架.所需要的依赖<dependencies><!-- https://mvnrepository.com/artifact/org.seleniumhq.selenium/selenium-java --><dependency><groupId>org.seleniumhq.selenium&l…

笔记本触摸板没反应怎么办?处理方法看这些

触摸板在笔记本电脑中是非常重要的一部分&#xff0c;很多用户都会选择使用触摸板代替鼠标。然而&#xff0c;有时你可能会发现&#xff0c;你的笔记本电脑触摸板没反应&#xff0c;无法正常使用。这对于日常使用来说是非常困扰的&#xff0c;但不用担心&#xff0c;我们将在这…

react源码解析10.commit阶段

在render阶段的末尾会调用commitRoot(root);进入commit阶段&#xff0c;这里的root指的就是fiberRoot&#xff0c;然后会遍历render阶段生成的effectList&#xff0c;effectList上的Fiber节点保存着对应的props变化。之后会遍历effectList进行对应的dom操作和生命周期、hooks回…

【数据结构】知识点总结(C语言)

线性表、栈和队列、串、数组和广义表、树和二叉树、图、查找、排序线性表线性表&#xff08;顺序表示&#xff09;线性表是具有相同特性元素的一个有限序列&#xff0c;数据元素之间是线性关系&#xff0c;起始元素称为线性起点&#xff0c;终端元素称为线性终点。线性表的顺序…

sed 功能详解

介绍sedsed是一种流编辑器&#xff0c;它一次处理一行内容&#xff0c;把当前处理的行存储在临时缓冲区中&#xff08;buffer&#xff09;,称为"模式空间"&#xff0c;接着sed命令处理缓冲区中的内容&#xff0c;处理完成后&#xff0c;把缓冲区的内容送往屏幕&#…

RCEE: Event Extraction as Machine Reading Comprehension 论文解读

RCEE: Event Extraction as Machine Reading Comprehension 论文&#xff1a;Event Extraction as Machine Reading Comprehension (aclanthology.org) 代码&#xff1a;jianliu-ml/EEasMRC (github.com) 期刊/会议&#xff1a;EMNLP 2020 摘要 事件提取(Event extraction,…

哪个品牌蓝牙耳机性价比高?性价比高的平价蓝牙耳机推荐

现如今&#xff0c;随着蓝牙技术的进步&#xff0c;蓝牙耳机在人们日常生活中的便捷性更胜从前。越来越多的蓝牙耳机品牌被大众看见、认可。那么&#xff0c;哪个品牌的蓝牙耳机性价比高&#xff1f;接下来&#xff0c;我给大家推荐几款性价比高的平价蓝牙耳机&#xff0c;一起…