【Pytorch】DCGAN实战(三):二次元动漫头像生成

文章目录

  • 1.实现效果
  • 2.环境配置
    • 2.1Python
    • 2.2Pytorch、CUDA
    • 2.3Python IDE
  • 3.具体实现
    • 3.1数据预处理(data.py)
      • (1)导入包
      • (2)定义数据类
    • 3.2模型Generator,Discriminator,权重初始化(model.py)
      • (1)导入包
      • (2)Generator
      • (3)Discriminator
      • (4)权重初始化
    • 3.3网络训练(net.py)
      • (1)导入包
      • (2)创建类
    • 3.4 主函数(main.py)
      • (1)导入文件
      • (2)定义超参数
      • (3)实例化
      • (4)进行训练
  • 4.训练过程
    • 4.1 Generator和Discriminator的Loss损失曲线图
    • 4.2 D(x)和D(G(z))曲线图
    • 4.3最终生成结果图
  • 5.完整代码
  • 6.引用参考
  • 7.问题反馈


1.实现效果

使用DCGAN训练faces数据集,最终实现生成二次元动漫头像。
最后虽然生成了动漫头像,但是一些细节还是和真实的图像差别较大,比如说眼睛大小,眼睛颜色等。
之后我会将MINIST数据集、Oxford17数据集、以及faces数据集在训练过程中不同轮次的输出结果做一个总结。
生成二次元动漫头像的程序依然是沿用data.py、model.py、net.py、main.py但具体的编程的细节呢有所改变。
之前MINIST以及Oxford17数据集的程序
这里:
【Pytorch】DCGAN实战(一):基于MINIST数据集的手写数字生成
【Pytorch】DCGAN实战(二):基于Oxord17的鲜花图像生成

2.环境配置

2.1Python

Python版本为3.7

2.2Pytorch、CUDA

在这里不详细介绍了,网上有很多的安装教程,小伙伴们自行查找吧!

2.3Python IDE

Pycharm

3.具体实现

整体分为4个文件:data.py、model.py、net.py、main.py

3.1数据预处理(data.py)

(1)导入包

from torch.utils.data import DataLoader
from torchvision import utils, datasets, transforms

(2)定义数据类

class ReadData():def __init__(self,data_path,image_size=64):self.root=data_pathself.image_size=image_sizeself.dataset=self.getdataset()def getdataset(self):#3.datasetdataset = datasets.ImageFolder(root=self.root,transform=transforms.Compose([transforms.Resize(self.image_size),transforms.CenterCrop(self.image_size),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),]))print(f'Total Size of Dataset: {len(dataset)}')return datasetdef getdataloader(self,batch_size=128):dataloader = DataLoader(self.dataset,batch_size=batch_size,shuffle=True,num_workers=0)return dataloader

3.2模型Generator,Discriminator,权重初始化(model.py)

(1)导入包

import torch.nn as nn

(2)Generator

class Generator(nn.Module):def __init__(self, nz,ngf,nc):super(Generator, self).__init__()self.nz = nzself.ngf = ngfself.nc=ncself.main = nn.Sequential(# input is Z, going into a convolutionnn.ConvTranspose2d(self.nz, self.ngf * 8, 4, 1, 0, bias=False),nn.BatchNorm2d(self.ngf * 8),nn.ReLU(True),# state size. (ngf*8) x 4 x 4nn.ConvTranspose2d(self.ngf * 8, self.ngf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(self.ngf * 4),nn.ReLU(True),# state size. (ngf*4) x 8 x 8nn.ConvTranspose2d(self.ngf * 4, self.ngf * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(self.ngf * 2),nn.ReLU(True),# state size. (ngf*2) x 16 x 16nn.ConvTranspose2d(self.ngf * 2, self.ngf, 4, 2, 1, bias=False),nn.BatchNorm2d(self.ngf),nn.ReLU(True),# state size. (ngf) x 32 x 32nn.ConvTranspose2d(self.ngf, self.nc, 4, 2, 1, bias=False),nn.Tanh()# state size. (nc) x 64 x 64)def forward(self, input):return self.main(input)

(3)Discriminator

class Discriminator(nn.Module):def __init__(self, ndf,nc):super(Discriminator, self).__init__()self.ndf=ndfself.nc=ncself.main = nn.Sequential(# input is (nc) x 64 x 64nn.Conv2d(self.nc, self.ndf, 4, 2, 1, bias=False),nn.LeakyReLU(0.2, inplace=True),# state size. (ndf) x 32 x 32nn.Conv2d(self.ndf, self.ndf * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(self.ndf * 2),nn.LeakyReLU(0.2, inplace=True),# state size. (ndf*2) x 16 x 16nn.Conv2d(self.ndf * 2, self.ndf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(self.ndf * 4),nn.LeakyReLU(0.2, inplace=True),# state size. (ndf*4) x 8 x 8nn.Conv2d(self.ndf * 4, self.ndf * 8, 4, 2, 1, bias=False),nn.BatchNorm2d(self.ndf * 8),nn.LeakyReLU(0.2, inplace=True),# state size. (ndf*8) x 4 x 4nn.Conv2d(self.ndf * 8, 1, 4, 1, 0, bias=False),# state size. (1) x 1 x 1nn.Sigmoid())def forward(self, input):return self.main(input)

(4)权重初始化

def weights_init(m):classname = m.__class__.__name__if classname.find('Conv') != -1:nn.init.normal_(m.weight.data, 0.0, 0.02)elif classname.find('BatchNorm') != -1:nn.init.normal_(m.weight.data, 1.0, 0.02)nn.init.constant_(m.bias.data, 0)

3.3网络训练(net.py)

(1)导入包

import torch
import torch.nn as nn
from torchvision import utils, datasets, transforms
import time
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
import os

(2)创建类

class DCGAN():def __init__(self,lr,beta1,nz, batch_size,num_showimage,device, model_save_path,figure_save_path,generator, discriminator, data_loader,):self.real_label=1self.fake_label=0self.nz=nzself.batch_size=batch_sizeself.num_showimage=num_showimageself.device = deviceself.model_save_path=model_save_pathself.figure_save_path=figure_save_pathself.G = generator.to(device)self.D = discriminator.to(device)self.opt_G=torch.optim.Adam(self.G.parameters(), lr=lr, betas=(beta1, 0.999))self.opt_D = torch.optim.Adam(self.D.parameters(), lr=lr, betas=(beta1, 0.999))self.criterion = nn.BCELoss().to(device)self.dataloader=data_loaderself.fixed_noise = torch.randn(self.num_showimage, nz, 1, 1, device=device)self.img_list = []self.G_loss_list = []self.D_loss_list = []self.D_x_list = []self.D_z_list = []def train(self,num_epochs):loss_tep = 10G_loss=0D_loss=0print("Starting Training Loop...")# For each epochfor epoch in range(num_epochs):#**********计时*********************beg_time = time.time()# For each batch in the dataloaderfor i, data in enumerate(self.dataloader):############################# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))###########################x = data[0].to(self.device)b_size = x.size(0)lbx = torch.full((b_size,), self.real_label, dtype=torch.float, device=self.device)D_x = self.D(x).view(-1)LossD_x = self.criterion(D_x, lbx)D_x_item = D_x.mean().item()# print("log(D(x))")z = torch.randn(b_size, self.nz, 1, 1, device=self.device)gz = self.G(z)lbz1 = torch.full((b_size,), self.fake_label, dtype=torch.float, device=self.device)D_gz1 = self.D(gz.detach()).view(-1)LossD_gz1 = self.criterion(D_gz1, lbz1)D_gz1_item = D_gz1.mean().item()# print("log(1 - D(G(z)))")LossD = LossD_x + LossD_gz1# print("log(D(x)) + log(1 - D(G(z)))")self.opt_D.zero_grad()LossD.backward()self.opt_D.step()# print("update LossD")D_loss+=LossD############################# (2) Update G network: maximize log(D(G(z)))###########################lbz2 = torch.full((b_size,), self.real_label, dtype=torch.float, device=self.device) # fake labels are real for generator costD_gz2 = self.D(gz).view(-1)D_gz2_item = D_gz2.mean().item()LossG = self.criterion(D_gz2, lbz2)# print("log(D(G(z)))")self.opt_G.zero_grad()LossG.backward()self.opt_G.step()# print("update LossG")G_loss+=LossGend_time = time.time()# **********计时*********************run_time = round(end_time - beg_time)# print('lalala')print(f'Epoch: [{epoch + 1:0>{len(str(num_epochs))}}/{num_epochs}]',f'Step: [{i + 1:0>{len(str(len(self.dataloader)))}}/{len(self.dataloader)}]',f'Loss-D: {LossD.item():.4f}',f'Loss-G: {LossG.item():.4f}',f'D(x): {D_x_item:.4f}',f'D(G(z)): [{D_gz1_item:.4f}/{D_gz2_item:.4f}]',f'Time: {run_time}s',end='\r\n')# print("lalalal2")# Save Losses for plotting laterself.G_loss_list.append(LossG.item())self.D_loss_list.append(LossD.item())# Save D(X) and D(G(z)) for plotting laterself.D_x_list.append(D_x_item)self.D_z_list.append(D_gz2_item)# # Save the Best Model# if LossG < loss_tep:#     torch.save(self.G.state_dict(), 'model.pt')#     loss_tep = LossGif not os.path.exists(self.model_save_path):os.makedirs(self.model_save_path)torch.save(self.D.state_dict(), self.model_save_path + 'disc_{}.pth'.format(epoch))torch.save(self.G.state_dict(), self.model_save_path + 'gen_{}.pth'.format(epoch))# Check how the generator is doing by saving G's output on fixed_noisewith torch.no_grad():fake = self.G(self.fixed_noise).detach().cpu()self.img_list.append(utils.make_grid(fake * 0.5 + 0.5, nrow=10))print()if not os.path.exists(self.figure_save_path):os.makedirs(self.figure_save_path)plt.figure(1,figsize=(8, 4))plt.title("Generator and Discriminator Loss During Training")plt.plot(self.G_loss_list[::10], label="G")plt.plot(self.D_loss_list[::10], label="D")plt.xlabel("iterations")plt.ylabel("Loss")plt.axhline(y=0, label="0", c="g")  # asymptoteplt.legend()plt.savefig(self.figure_save_path + str(num_epochs) + 'epochs_' + 'loss.jpg', bbox_inches='tight')plt.figure(2,figsize=(8, 4))plt.title("D(x) and D(G(z)) During Training")plt.plot(self.D_x_list[::10], label="D(x)")plt.plot(self.D_z_list[::10], label="D(G(z))")plt.xlabel("iterations")plt.ylabel("Probability")plt.axhline(y=0.5, label="0.5", c="g")  # asymptoteplt.legend()plt.savefig(self.figure_save_path + str(num_epochs) + 'epochs_' + 'D(x)D(G(z)).jpg', bbox_inches='tight')fig = plt.figure(3,figsize=(5, 5))plt.axis("off")ims = [[plt.imshow(item.permute(1, 2, 0), animated=True)] for item in self.img_list]ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)HTML(ani.to_jshtml())# ani.to_html5_video()ani.save(self.figure_save_path + str(num_epochs) + 'epochs_' + 'generation.gif')plt.figure(4,figsize=(8, 4))# Plot the real imagesplt.subplot(1, 2, 1)plt.axis("off")plt.title("Real Images")real = next(iter(self.dataloader))  # real[0]image,real[1]labelplt.imshow(utils.make_grid(real[0][:self.num_showimage] * 0.5 + 0.5, nrow=10).permute(1, 2, 0))# Load the Best Generative Model# self.G.load_state_dict(#     torch.load(self.model_save_path + 'disc_{}.pth'.format(epoch), map_location=torch.device(self.device)))self.G.eval()# Generate the Fake Imageswith torch.no_grad():fake = self.G(self.fixed_noise).cpu()# Plot the fake imagesplt.subplot(1, 2, 2)plt.axis("off")plt.title("Fake Images")fake = utils.make_grid(fake[:self.num_showimage] * 0.5 + 0.5, nrow=10).permute(1, 2, 0)plt.imshow(fake)# Save the comparation resultplt.savefig(self.figure_save_path + str(num_epochs) + 'epochs_' + 'result.jpg', bbox_inches='tight')plt.show()def test(self,epoch):# Size of the Figureplt.figure(figsize=(8, 4))# Plot the real imagesplt.subplot(1, 2, 1)plt.axis("off")plt.title("Real Images")real = next(iter(self.dataloader))#real[0]image,real[1]labelplt.imshow(utils.make_grid(real[0][:self.num_showimage] * 0.5 + 0.5, nrow=10).permute(1, 2, 0))# Load the Best Generative Modelself.G.load_state_dict(torch.load(self.model_save_path + 'disc_{}.pth'.format(epoch), map_location=torch.device(self.device)))self.G.eval()# Generate the Fake Imageswith torch.no_grad():fake = self.G(self.fixed_noise.to(self.device))# Plot the fake imagesplt.subplot(1, 2, 2)plt.axis("off")plt.title("Fake Images")fake = utils.make_grid(fake * 0.5 + 0.5, nrow=10)plt.imshow(fake.permute(1, 2, 0))# Save the comparation resultplt.savefig(self.figure_save_path+'result.jpg', bbox_inches='tight')plt.show()

3.4 主函数(main.py)

(1)导入文件

from data import ReadData
from model import Discriminator, Generator, weights_init
from net import DCGAN
import torch

(2)定义超参数

ngpu=1
ngf=64
ndf=64
nc=3
nz=100
lr=0.003
beta1=0.5
batch_size=100
num_showimage=100data_path="./oxford17_class"
model_save_path="./models/"
figure_save_path="./figures/"device = torch.device('cuda:0' if (torch.cuda.is_available() and ngpu > 0) else 'cpu')

(3)实例化

dataset=ReadData(data_path)
dataloader=dataset.getdataloader(batch_size=batch_size)G = Generator(nz,ngf,nc).apply(weights_init)
print(G)
D = Discriminator(ndf,nc).apply(weights_init)
print(D)dcgan=DCGAN( lr,beta1,nz,batch_size,num_showimage,device, model_save_path,figure_save_path,G, D, dataloader)

(4)进行训练

dcgan.train(num_epochs=20)

4.训练过程

4.1 Generator和Discriminator的Loss损失曲线图

训练过程中Generator和Discriminator的Loss曲线图(以200个epoch为例):
Generator和Discriminator的Loss损失曲线图

4.2 D(x)和D(G(z))曲线图

训练过程中Discriminator输出(以200个epoch为例):
D(x)和D(G(z))曲线图

4.3最终生成结果图

训练结束后生成图片(以5个epoch为例):
最终生成结果图

5.完整代码

链接:https://pan.baidu.com/s/15J6sZL3rCPLm2jZFEuyzNw
提取码:DGAN

6.引用参考

https://blog.csdn.net/qq_42951560/article/details/112199229
https://blog.csdn.net/qq_42951560/article/details/110308336

7.问题反馈

如果运行有问题,欢迎给我私信留言!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/256524.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

【Pytorch学习】复现DCGAN训练生成动漫头像

先看一下结果&#xff1a; 1&#xff0c;环境安装指令 conda create -n pytorch python3.7 activate pytorch conda install pytorch torchvision torchaudio cudatoolkit11.3 -c pytorch pip install matplotlib pip install IPython pip install opencv-python 2&#xff0…

利用python+百度智能云为人物头像动漫化(附API代码及SDK代码)

文章目录 前言1.2、打开第一个搜索结果1.3、点击立即使用1.4、创建应用1.5、获取APPID等参数 二、API与SDK的使用1.API代码2.SDK使用2.1.首先下载python 的SDK&#xff1a;2.2.用编译软件打开aip-python-sdk-4.15.1文件夹并在aip目录下创建py文件2.3.SDK代码 总结 前言 利用py…

java基于ssm的卡通动漫网站

本系统设计为卡通动漫网站系统管理&#xff0c;主要功能是前台展示网站新闻信息&#xff0c;具有分类展示功能及在线留言和对文章的在线评论等功能&#xff0c;网站用户的注册&#xff0c;系统简介等。管理员后台的管理&#xff0c;管理员登录后台后可对现有管理员进行增加删除…

Python使用AI photo2cartoon制作属于你的漫画头像

Python使用AI photo2cartoon制作属于你的漫画头像 1. 效果图2. 原理3. 源码参考 git clone https://github.com/minivision-ai/photo2cartoon.git cd ./photo2cartoon python test.py --photo_path images/photo_test.jpg --save_path images/cartoon_result.png1. 效果图 官方…

php 照片变成卡通照片,怎么把照片做成q版卡通 照片变q版卡通人物 q版卡通头像制作...

想要把自己的头像变成真人q版卡通漫画&#xff0c;偷偷问了一个漫画家怎么制作的&#xff0c;他说用电脑手绘&#xff0c;得有画画基础才行&#xff0c;小编这下子就打了退堂鼓了&#xff0c;照片制作视频容易&#xff0c;但是自己画画太难了。有什么比较容易方法把照片做成q版…

刘诗诗吴奇隆大婚热吻头像

吴奇隆刘诗诗巴厘岛甜蜜完婚,现场布置鲜花簇拥,碧海蓝天,宛若仙境。想知道他们的两人结婚头像吗?小编为你采撷一些婚礼头像&#xff0c;重点新娘子美的不要不要的&#xff01;

Python实用案例,Python脚本实现快速卡通化人物头像,让我想起了QQ秀时光!

往期回顾 Python脚本实现天气查询应用 Python实现自动监测Github项目并打开网页 Python实现文件自动归类 Python实现帮你选择双色球号码 Python实现每日更换“必应图片”为“桌面壁纸” Python实现批量加水印 Python实现破译zip压缩包 Python实现批量下载百度图片 前言…

怎么制作真人qq秀_一分钟简单制作一个专属于自己的卡通头像

点击蓝字关注我们 制作一个专属于自己的卡通头像很简单&#xff0c;我们常用的美图秀秀软件就可以轻松制作。 首选在应用市场搜索美图秀秀下载后选择工具箱打开&#xff0c;找到实用工具中的动漫化身这个选项&#xff1b; 打开后点击绘制动漫形象&#xff0c;可以在相册中选择一…

带你读AI论文丨ACGAN-动漫头像生成

摘要&#xff1a;ACGAN-动漫头像生成是一个十分优秀的开源项目。 本文分享自华为云社区《【云驻共创】AI论文精读会&#xff1a;ACGAN-动漫头像生成》&#xff0c;作者&#xff1a;SpiderMan。 1.论文及算法介绍 1.1基本信息 • 论文题目&#xff1a;《Conditional Image Sy…

Docker镜像更新通知器DIUN

什么是 DIUN ? Docker Image Update Notifier 是一个用 Go 编写的 CLI 应用程序&#xff0c;可作为单个可执行文件和 Docker 映像交付&#xff0c;用于当 Docker 映像在 Docker registry中更新时接收通知。 和老苏之前介绍过的 watchtower 不同&#xff0c;DIUN 只是通知&…

idea连接Linux服务器

一、 介绍 配置idea的ssh会话和sftp可以实现对linux远程服务器的访问和文件上传下载&#xff0c;是替代Xshell的理想方式。这样我们就能在idea里面编写文件并轻松的将文件上传到linux服务器中。而且还能远程编辑linux服务器上的文件。掌握并熟练使用&#xff0c;能够大大提高我…

聊聊企业无线网络安全

新钛云服已累计为您分享749篇技术干货 不知不觉无线网络已经成为了办公网主流。最早接触无线网络的时候是2001年&#xff0c;那时候笔记本电脑还比较少见&#xff0c;标配也不支持无线网络&#xff0c;要使用无线网络需要另外加一块PCMIA接口的无线网卡。第一次体验无线网络的时…

千牛中文件已存在于服务器上,千牛登陆在云服务器上

千牛登陆在云服务器上 内容精选 换一换 如果Windows操作系统云服务器未安装密码重置插件&#xff0c;可以参见本节内容重新设置密码。本节操作介绍的方法仅适用于修改Windows本地账户密码&#xff0c;不能修改域账户密码。Linux操作系统请参见重置Linux云服务器密码(未安装重置…

mac安装旺旺启动台找不到_如何正确安装和卸载Mac软件?

Windows和Mac是两个截然不同的系统&#xff0c;很多操作逻辑都有本质上的区别&#xff0c;管家针对刚接触Mac系统的朋友做了一份简单的“Mac软件的安装和卸载”教程&#xff0c;希望对大家有所帮助。 1 如何安装软件&#xff1f; Mac系统安装软件的方法有两种&#xff0c;一种是…

获取千牛聊天记录(此方法新版千牛已失效,7.1之前的版本应该有效,各位自行测试咯)...

分析UI: 分析千牛UI控件,我们用Visual Studio自带的SPY++查找窗口,得到聊天记录的控件信息发现 窗口类名:Aef_RenderWidgetHostHWND ,上网搜了一下说是Chrominum 的窗口。确定一下我们直接选中千牛的聊天窗口按F12,发现会弹出Chrome的开发者工具。到此我们确定了千牛的聊天窗…

千牛2015卖家版官方电脑版

千牛2015卖家版 v2.08 官方电脑版 软件大小&#xff1a;54.9MB 软件语言&#xff1a;简体中文 软件类别&#xff1a;管理工具 软件授权&#xff1a;免费版 更新时间&#xff1a;2015-01-06 应用平台&#xff1a;/Win8/Win7/WinXP 千牛2015卖家版是阿里巴巴专为淘宝、天猫卖家量…

千牛文件在服务器上,千牛挂在云服务器

千牛挂在云服务器 内容精选 换一换 云耀云服务器(Halo Elastic Cloud Server&#xff0c;HECS)是可以快速搭建简单应用的新一代云服务器&#xff0c;具备独立、完整的操作系统和网络功能。提供快速地应用部署和简易的管理能力&#xff0c;适用于网站搭建、开发环境等低负载应用…

pc端网页唤起本地的咚咚和千牛

前段时间接手了一个需求,需求大概就是pc端的产品需要做一个点击按钮唤起咚咚和千牛,并且需要打开对应的顾客聊天窗口。 当时接到这个需求人都不好了,大牛们都没接触过这个需求,不知道咚咚和千牛的协议,去看淘宝和京东开发平台的文档也没发现什么有用的,然后就一直考古呀…

七牛云工具类

首先我们需要创建一个oss.properties文件存储七牛云的必须属性&#xff0c;可在七牛云官网查看 #qiniu.bucket xxx #qiniu.access_key xxx #qiniu.secret_key xxx #qiniu.base_url xxx七牛工具类QiniuOssUtils import com.google.gson.Gson; import com.qiniu.common.QiniuExce…

1688获得店铺的所有商品教程

onebound.1688.item_search_shop 获取key和secret API文档说明 完整返回数据 { "user": { "id": null, "nick": null, "good_num": "", "level": "", …