GAN生成漫画脸

news/2024/4/26 5:21:11/文章来源:https://blog.csdn.net/suic009/article/details/128120410

最近对对抗生成网络GAN比较感兴趣,相关知识点文章还在编辑中,以下这个是一个练手的小项目~

 (在原模型上做了,为了减少计算量让其好训练一些。)

一、导入工具包

import tensorflow as tf
from tensorflow.keras import layersimport numpy as np
import os
import time
import glob
import matplotlib.pyplot as plt
from IPython.display import clear_output
from IPython import display

1.1 设置GPU

gpus = tf.config.list_physical_devices("GPU")if gpus:gpu0 = gpus[0]                                        #如果有多个GPU,仅使用第0个GPUtf.config.experimental.set_memory_growth(gpu0, True)  #设置GPU显存用量按需使用tf.config.set_visible_devices([gpu0],"GPU")
gpus 
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

二、导入训练数据

链接: 点这里

fileList = glob.glob('./ani_face/*.jpg')
len(fileList)
41621

2.1 数据可视化 

# 随机显示几张图
for index,i in enumerate(fileList[:3]):display.display(display.Image(fileList[index]))

2.2 数据预处理

# 文件名列表
path_ds = tf.data.Dataset.from_tensor_slices(fileList)# 预处理,归一化,缩放
def load_and_preprocess_image(path):image = tf.io.read_file(path)image = tf.image.decode_jpeg(image, channels=3)image = tf.image.resize(image, [64, 64])image /= 255.0  # normalize to [0,1] rangeimage = tf.reshape(image, [1, 64,64,3])return imageimage_ds = path_ds.map(load_and_preprocess_image)
image_ds
<MapDataset shapes: (1, 64, 64, 3), types: tf.float32>
# 查看一张图片
for x in image_ds:plt.axis("off")plt.imshow((x.numpy() * 255).astype("int32")[0])break

三、网络构建

3.1 D网络

discriminator = keras.Sequential([keras.Input(shape=(64, 64, 3)),layers.Conv2D(64, kernel_size=4, strides=2, padding="same"),layers.LeakyReLU(alpha=0.2),layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),layers.LeakyReLU(alpha=0.2),layers.Conv2D(128, kernel_size=4, strides=2, padding="same"),layers.LeakyReLU(alpha=0.2),layers.Flatten(),layers.Dropout(0.2),layers.Dense(1, activation="sigmoid"),],name="discriminator",
)
discriminator.summary()
Model: "discriminator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, 32, 32, 64)        3136      
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 32, 32, 64)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 16, 16, 128)       131200    
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 16, 16, 128)       0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 8, 8, 128)         262272    
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 8, 8, 128)         0         
_________________________________________________________________
flatten (Flatten)            (None, 8192)              0         
_________________________________________________________________
dropout (Dropout)            (None, 8192)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 8193      
=================================================================
Total params: 404,801
Trainable params: 404,801
Non-trainable params: 0

3.2 G网络

latent_dim = 128generator = keras.Sequential([keras.Input(shape=(latent_dim,)),layers.Dense(8 * 8 * 128),layers.Reshape((8, 8, 128)),layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding="same"),layers.LeakyReLU(alpha=0.2),layers.Conv2DTranspose(256, kernel_size=4, strides=2, padding="same"),layers.LeakyReLU(alpha=0.2),layers.Conv2DTranspose(512, kernel_size=4, strides=2, padding="same"),layers.LeakyReLU(alpha=0.2),layers.Conv2D(3, kernel_size=5, padding="same", activation="sigmoid"),],name="generator",
)
generator.summary()

3.3 重写 train_step

class GAN(keras.Model):def __init__(self, discriminator, generator, latent_dim):super(GAN, self).__init__()self.discriminator = discriminatorself.generator = generatorself.latent_dim = latent_dimdef compile(self, d_optimizer, g_optimizer, loss_fn):super(GAN, self).compile()self.d_optimizer = d_optimizerself.g_optimizer = g_optimizerself.loss_fn = loss_fnself.d_loss_metric = keras.metrics.Mean(name="d_loss")self.g_loss_metric = keras.metrics.Mean(name="g_loss")@propertydef metrics(self):return [self.d_loss_metric, self.g_loss_metric]def train_step(self, real_images):# 生成噪音batch_size = tf.shape(real_images)[0]random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))# 生成的图片generated_images = self.generator(random_latent_vectors)# Combine them with real imagescombined_images = tf.concat([generated_images, real_images], axis=0)# Assemble labels discriminating real from fake imageslabels = tf.concat([tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0)# Add random noise to the labels - important trick!labels += 0.05 * tf.random.uniform(tf.shape(labels))# 训练判别器,生成的当成0,真实的当成1 with tf.GradientTape() as tape:predictions = self.discriminator(combined_images)d_loss = self.loss_fn(labels, predictions)grads = tape.gradient(d_loss, self.discriminator.trainable_weights)self.d_optimizer.apply_gradients(zip(grads, self.discriminator.trainable_weights))# Sample random points in the latent spacerandom_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))# Assemble labels that say "all real images"misleading_labels = tf.zeros((batch_size, 1))# Train the generator (note that we should *not* update the weights# of the discriminator)!with tf.GradientTape() as tape:predictions = self.discriminator(self.generator(random_latent_vectors))g_loss = self.loss_fn(misleading_labels, predictions)grads = tape.gradient(g_loss, self.generator.trainable_weights)self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))# Update metricsself.d_loss_metric.update_state(d_loss)self.g_loss_metric.update_state(g_loss)return {"d_loss": self.d_loss_metric.result(),"g_loss": self.g_loss_metric.result(),}

3.4 设置回调函数

class GANMonitor(keras.callbacks.Callback):def __init__(self, num_img=3, latent_dim=128):self.num_img = num_imgself.latent_dim = latent_dimdef on_epoch_end(self, epoch, logs=None):random_latent_vectors = tf.random.normal(shape=(self.num_img, self.latent_dim))generated_images = self.model.generator(random_latent_vectors)generated_images *= 255generated_images.numpy()for i in range(self.num_img):img = keras.preprocessing.image.array_to_img(generated_images[i])display.display(img)img.save("gen_ani/generated_img_%03d_%d.png" % (epoch, i))

四、训练模型

epochs = 100  # In practice, use ~100 epochsgan = GAN(discriminator=discriminator, generator=generator, latent_dim=latent_dim)
gan.compile(d_optimizer=keras.optimizers.Adam(learning_rate=0.0001),g_optimizer=keras.optimizers.Adam(learning_rate=0.0001),loss_fn=keras.losses.BinaryCrossentropy(),
)gan.fit(image_ds, epochs=epochs, callbacks=[GANMonitor(num_img=10, latent_dim=latent_dim)]
)

五、保存模型

#保存模型
gan.generator.save('./data/ani_G_model')

生成模型文件:点这里

六、生成漫画脸

G_model =  tf.keras.models.load_model('./data/ani_G_model/',compile=False)def randomGenerate():noise_seed = tf.random.normal([16, 128])predictions = G_model(noise_seed, training=False)fig = plt.figure(figsize=(8, 8))for i in range(predictions.shape[0]):plt.subplot(4, 4, i+1)img = (predictions[i].numpy() * 255 ).astype('int')plt.imshow(img )plt.axis('off')plt.show()
count = 0
while True:randomGenerate()clear_output(wait=True)time.sleep(0.1)if count > 100:breakcount+=1

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.luyixian.cn/news_show_39264.aspx

如若内容造成侵权/违法违规/事实不符,请联系dt猫网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

业务:财务会计业务知识

一、引言 会计是以货币为主要计量单位&#xff0c;对企业、事业、机关、团体及其他经济组织的经济活动进行记录、计算、控制、分析、报告&#xff0c;以提供财务和管理信息的工作。会计的职能主要是反映和控制经济活动过程&#xff0c;保证会计信息的合法、真实、准确和完整&a…

校园论坛网站设计设计与实现

项目描述 临近学期结束&#xff0c;还是毕业设计&#xff0c;你还在做java程序网络编程&#xff0c;期末作业&#xff0c;老师的作业要求觉得大了吗?不知道毕业设计该怎么办?网页功能的数量是否太多?没有合适的类型或系统?等等。这里根据疫情当下&#xff0c;你想解决的问…

网页JS自动化脚本(四)修改元素的尺寸颜色显隐状态

修改元素尺寸 在定位到了元素之后, 我们就可以对元素进行一些修改了,我们先来修改元素泊宽度以及高度 window.onloadfunction(){var theElementdocument.querySelector("img.undertips-link-lefticon");theElement.style.width"100px";theElement.style.…

命令模式

文章目录思考命令模式1.命令模式的本质2.何时选用命令模式3.优缺点4.实现耦合写法命令模式优化耦合写法命令模式实现撤销命令模式实现厨师做菜命令模式实现排队命令模式实现日志持久化思考命令模式 命令模式就是解耦强耦合代码&#xff0c;用户只关心功能的实现&#xff0c;开发…

win11该文件没有与之关联的应用怎么办

win11用户在使用电脑的时候遇到了“该文件没有与之关联的应用”的提示&#xff0c;这是怎么回事呢&#xff1f;应该怎么办呢&#xff1f;出现这个情况应该是注册表被误删了&#xff0c;大家需要新建一个文本文档&#xff0c;然后输入下文提供的指令&#xff0c;之后将其重命名为…

linux不显示当前路径的解决方法

1.输入vim ~/.bashrc进入用户的shell环境变量的配置文件(可以设置环境变量以及通过alias设置别名&#xff09; 2.按下“i”键进入编辑模式(底部显示INSERT&#xff09; 3.修改\w为$PWD&#xff1a; 修改为&#xff1a; 4.按“esc”键后输入":wq"保存并退出&#xff…

Databend 开源周报 #69

Databend 是一款强大的云数仓。专为弹性和高效设计&#xff0c;自由且开源。 即刻体验云服务&#xff1a;https://app.databend.com。 New Features multiple catalog 实现删除用户定义目录 (#8820) meta 新增用于删除 key 和使 key 过期的 cli 命令 (#8858) planner 支…

手把手教你做智能合约开源|多文件合约开源|引用文件开源

本文手把手教你使用 区块链浏览器 验证智能合约的三种方式。 验证单一 Solidity 文件 在开始验证之前&#xff0c;我们需要首先部署智能合约。进入 Remix IDE&#xff0c;创建一个合约新文件。复制粘贴下面的代码&#xff1a; // SPDX-License-Identifier: MITpragma solidit…

JAVA学习-java基础讲义02

java基础讲义02一 进制1.1 进制介绍1.2 二进制1.3 任意进制到十进制转换1.4 十进制到任意进制之间的转换1.5 快速转换法1.6 有符号数据表示法二 Java变量和数据类型1.1 变量概述1.2 数据类型1.3 变量定义三 Java数据类型转换3.1 数据类型转换概述3.2 数据类型转换之自动类型转换…

[附源码]Python计算机毕业设计SSM老年公寓管理系统(程序+LW)

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybatis Maven Vue 等等组成&#xff0c;B/S模式 M…

爬虫工作流程、请求与响应原理、requests库讲解

爬虫工作流程、请求与响应原理、requests库讲解 爬虫分类主要分为两大板块 web爬虫&#xff08;浏览器爬虫&#xff09; APP爬虫&#xff08;手机端爬虫&#xff09; 在这两大板块中又可以把爬虫归类为聚焦爬虫和通用爬虫 聚焦爬虫&#xff1a;针对某一个接口&#xff08;ur…

使用Conda

0. Anaconda Prompt 命令提示符 0.1 验证conda是否被安装 conda --version0.2 conda管理环境 可以用命令复制和删除环境 参考. 1. Conda管理包 1.1 常用包管理功能 查找包查看包安装包 查找分为精确查找和模糊查找&#xff0c;如下图所示 卸载包更新包 1.2 conda管理环…

三肽Gly-Cys-Gly、88440-55-5

三肽Gly-Cys-Gly 编号&#xff1a;111774 CAS号&#xff1a;88440-55-5 三字母&#xff1a;H2N-Gly-Cys-Gly-COOH 描 述&#xff1a;羧肽酶 U 抑制剂&#xff08;凝血酶可激活的纤维蛋白溶解抑制剂&#xff0c;TAFI&#xff09;&#xff0c;Ki 0.14 μM。编号: 111774 中文名称…

安装 laravel 遇到的错误和解决方案

安装 laravel 遇到的错误和解决方案 纯粹是为了运行下 laravel&#xff0c;遇到了错误记录下&#xff0c;分享给需要的人。 下载 PHP Windows 版 &#xff0c;我选择的版本是 PHP 7.4 (7.4.33)。下载文件以后找个文件夹解压就可以了。Composer 安装&#xff0c;官网 。 勾选以…

C/C++家族族谱管理系统

C/C家族族谱管理系统 课题名称: 家族族谱管理 主要目标: 通过训练&#xff0c;强化学生对树结构、二叉树结构的表示及操作算法的掌握和灵活运用 3.具体要求: 要求设计实现具有下列功能的家谱管理系统: (1) 输入文件以存放最初家谱中各成员的信息&#xff0c;成员的信息中…

WSL Ubuntu20.04安装pycairo指南

环境说明 wsl Ubuntu20.04 走过的一些可能有用的弯路 由于pycairo要求python3.7&#xff0c;但是之前Ubuntu上有个3.6的python环境&#xff0c;所以就安装了python3.8&#xff1a; sudo apt install python3.8然后python3命令还是链接到python3.6&#xff0c;结果就yongln …

iOS15适配 UINavigationBar和UITabBar设置无效,变成黑色

今天更新了xcode13&#xff0c;运行项目发现iOS15以上的手机导航栏和状态栏之前设置的颜色等属性都不起作用了&#xff0c;都变成了黑色&#xff0c;滚动的时候才能变成正常的颜色&#xff0c;经确认得用UINavigationBarAppearance和UITabBarAppearance这两个属性对导航栏和状态…

系统封装制作

工具网址&#xff1a; 镜像下载&#xff1a; Windows 10 22H2 - MSDN - 山己几子木 (sjjzm.com)pe工具&#xff1a;【新提醒】优启通 v3.7.2022.0910&#xff08;2022.10.14 发布&#xff09;_IT天空原创软件_IT天空 (itsk.com)万能驱动&#xff1a;万能驱动 v7.22.0912.2&…

IOC 的底层原理和Bean管理XML方式、xml注入集合属性

目录 什么是IOC IOC底层管理 工厂模式 IOC 的过程 IOC 接口 IOC 操作Bean 原理 Bean 管理操作有两种方式 1. 基于xml 配置方式创建对象 2. 基于xml方式注入属性 第二种使用有参数构造注入 p 名称空间注入 ICO操作Bean管理&#xff08;xml 注入其他类型属性&#xff…

翻转单词序列、按之字形顺序打印二叉树、二叉搜索树的第k个节点

1、翻转单词序列 本题考点&#xff1a;子串划分&#xff0c;子串逆置 牛客链接 题目描述&#xff1a; 牛客最近来了一个新员工Fish&#xff0c;每天早晨总是会拿着一本英文杂志&#xff0c;写些句子在本子上。同事Cat对Fish写的内容颇感兴趣&#xff0c;有一天他向Fish借来翻…