使用pytorch构建一个初级的无监督的GAN网络模型

在这个系列中将系统的构建GAN及其相关的一些变种模型,来了解GAN的基本原理。本片为此系列的第一篇,实现起来很简单,所以不要期待有很好的效果出来。

第一篇我们搭建一个无监督的可以生成数字 (0-9) 手写图像的 GAN,使用MINIST数据集,包含0-9的60000张手写数字图像,如图:
在这里插入图片描述

原理

首先简单讲一下GAN的工作原理,如下为前向传播的过程:
在这里插入图片描述
GAN网络有两个模型,分别是生成器generator和判别器discriminator。generator的作用是生成图片的,也就是我们想要的结果,通过输入随机噪声来生成图片;而discriminator是判断输入的图片是真实数据还是生成的假数据,输入生成的假数据或真实数据,输出真与假的概率值。

而反向传播过程其实是分开的,即generator和discriminator是分别进行梯度更新的。且交替进行训练的,一个模型训练,另一个模型就要保持不变,保持两个模型的能力要相当才能一起进步,否则如果判别器的性能要比生成器要好的话就很容易陷入模式崩溃mdoel collapse或梯度消失等。
下图为discriminator的反向传播的过程:
在这里插入图片描述
discriminator的工作是为了将生成的假数据判别为0,将真实的数据判别为1,即公正判别非黑即白,所以loss的计算为:
在这里插入图片描述

下图为generator的反向传播的过程:
在这里插入图片描述
而generator的工作是为了将生成的假数据让discriminator判别为1,即骗过discriminator颠倒黑白,所以loss的计算为:
在这里插入图片描述

代码

下面开始直接上代码,我在网上学习别人代码的习惯是先把所有代码跑起来再来仔细看每个代码模块,我在这也就先放上所有代码再分析各个模块。
model.py:

from torch import nn
import torchdef get_generator_block(input_dim, output_dim):return nn.Sequential(nn.Linear(input_dim, output_dim),nn.BatchNorm1d(output_dim),nn.ReLU(inplace=True),)class Generator(nn.Module):def __init__(self, z_dim=10, im_dim=784, hidden_dim=128):super(Generator, self).__init__()self.gen = nn.Sequential(get_generator_block(z_dim, hidden_dim),get_generator_block(hidden_dim, hidden_dim * 2),get_generator_block(hidden_dim * 2, hidden_dim * 4),get_generator_block(hidden_dim * 4, hidden_dim * 8),nn.Linear(hidden_dim * 8, im_dim),nn.Sigmoid())def forward(self, noise):return self.gen(noise)def get_gen(self):return self.gendef get_discriminator_block(input_dim, output_dim):return nn.Sequential(nn.Linear(input_dim, output_dim), #Layer 1nn.LeakyReLU(0.2, inplace=True))class Discriminator(nn.Module):def __init__(self, im_dim=784, hidden_dim=128):super(Discriminator, self).__init__()self.disc = nn.Sequential(get_discriminator_block(im_dim, hidden_dim * 4),get_discriminator_block(hidden_dim * 4, hidden_dim * 2),get_discriminator_block(hidden_dim * 2, hidden_dim),nn.Linear(hidden_dim, 1))def forward(self, image):return self.disc(image)def get_disc(self):return self.disc

train.py:

import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST # Training dataset
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from model import Discriminator, Generator
torch.manual_seed(0) # Set for testing purposes, please do not change!def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):image_unflat = image_tensor.detach().cpu().view(-1, *size)image_grid = make_grid(image_unflat[:num_images], nrow=5)plt.imshow(image_grid.permute(1, 2, 0).squeeze())plt.show()def get_noise(n_samples, z_dim, device='cpu'):return torch.randn(n_samples,z_dim,device=device)criterion = nn.BCEWithLogitsLoss()
n_epochs = 200
z_dim = 64
display_step = 500
batch_size = 128
lr = 0.00001
device = 'cuda'dataloader = DataLoader(MNIST('./', download=True, transform=transforms.ToTensor()),  # 已经下载过的可以改为False跳过下载batch_size=batch_size,shuffle=True)gen = Generator(z_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
disc = Discriminator().to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)def get_disc_loss(gen, disc, criterion, real, num_images, z_dim, device):fake_noise = get_noise(num_images, z_dim, device=device)fake = gen(fake_noise)disc_fake_pred = disc(fake.detach())disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))disc_real_pred = disc(real)disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))disc_loss = (disc_fake_loss + disc_real_loss) / 2return disc_lossdef get_gen_loss(gen, disc, criterion, num_images, z_dim, device):fake_noise = get_noise(num_images, z_dim, device=device)fake = gen(fake_noise)disc_fake_pred = disc(fake)gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))return gen_losscur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0
gen_loss = False
error = False
for epoch in range(n_epochs):# Dataloader returns the batchesfor real, _ in tqdm(dataloader):cur_batch_size = len(real)# Flatten the batch of real images from the datasetreal = real.view(cur_batch_size, -1).to(device)### Update discriminator #### Zero out the gradients before backpropagationdisc_opt.zero_grad()# Calculate discriminator lossdisc_loss = get_disc_loss(gen, disc, criterion, real, cur_batch_size, z_dim, device)# Update gradientsdisc_loss.backward(retain_graph=True)# Update optimizerdisc_opt.step()### Update generator ###gen_opt.zero_grad()gen_loss = get_gen_loss(gen, disc, criterion, cur_batch_size, z_dim, device)gen_loss.backward()gen_opt.step()# Keep track of the average discriminator lossmean_discriminator_loss += disc_loss.item() / display_step# Keep track of the average generator lossmean_generator_loss += gen_loss.item() / display_step### Visualization code ###if cur_step % display_step == 0 and cur_step > 0:print(f"Step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}")fake_noise = get_noise(cur_batch_size, z_dim, device=device)fake = gen(fake_noise)show_tensor_images(fake)show_tensor_images(real)mean_generator_loss = 0mean_discriminator_loss = 0cur_step += 1

运行结果

运行后每隔500个epoch画出fake和real,刚开始的fake和real是这样的:
在这里插入图片描述
在这里插入图片描述
到后面的fake逐渐变成这样:
在这里插入图片描述

代码解释

网络模型

model.py里面存放了generator和discriminator的网络模型,神经元使用的是简单的全连接层,后面的文章再使用卷积。
在这里插入图片描述
生成器输出为784 = 28 * 28,因为使用的是MINIST手写字体数据集,每张图的shape是28 * 28 * 1(黑白图单通道),所以输出的假数据要与真实数据的shape一致,这样输入鉴别器才不会出错。
在这里插入图片描述
生成的图片(或真实数据)直接输入鉴别器,所以鉴别器的输入也是28*28,而输出为1,即输出判别结果为真或假。
在这里插入图片描述
每个优化器仅优化一个模型的参数,所以一个模型构建一个优化器。

图像显示

在这里插入图片描述
首先将图像的tensor转到cpu上,因为PyTorch中的大部分图像处理和显示函数都是在CPU上执行的,包括我们使用的imshow。
detach() 方法将张量从计算图中分离出来,但是仍指向原变量的存放位置,不同之处只是requirse_grad为false,得到的这个tensor永远不需要计算器梯度,不具有grad,这样做的目的是避免梯度计算的影响,因为在展示图像时通常不需要计算梯度。
Pytorch的计算图由节点和边组成,节点表示张量或者Function,边表示张量和Function之间的依赖关系,类似这样:
在这里插入图片描述
一个网络模型就是一个计算图,在网络backward时候,需要用链式求导法则求出网络最后输出的梯度,然后再对网络进行优化,求导过程就如上图这样。
make_grid 函数用于将多个图像组成一个网格,方便显示。
在这里插入图片描述

在这里插入图片描述
然后每500个batch显示一次当前模型性能所能生成的图片以及当前batch的真实图片(虽然一个batch设置了128张,但是我们只展示25张),以及print出生成器和鉴别器的loss。

损失函数

在这里插入图片描述
损失函数的原理在上面的“原理”中有讲解,这里不再赘述。
在计算鉴别器的loss里,disc_fake_pred = disc(fake.detach())是对生成图片的判别,这里也使用 .detach() 的目的是将生成器产生的假数据与生成器的参数分离,使得在计算 disc_fake_pred 时不会对生成器的梯度进行传播。这是因为在训练鉴别器的阶段,我们只希望更新鉴别器的参数,而不希望更新生成器的参数(就如上面说的生成器的训练和鉴别器的应该要隔开分别训练、交替训练)。

反向传播

在这里插入图片描述
retain_graph=True参数是用来指示 PyTorch 在反向传播时保留计算图。这个参数的作用是为了在一次反向传播之后保留计算图的状态,以便后续再次调用 backward() 函数时能够继续使用这个计算图进行梯度计算。
Pytoch构建的计算图是动态图,为了节约内存,所以每次一轮迭代完之后计算图就被在内存释放,所以当你想要多次backward时候就会报如下错:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed.

而在GAN中一次的迭代需要先更新鉴别器的参数,然后再更新生成器的参数;在更新生成器的参数时,我们仍然需要使用鉴别器来鉴别real or fake,只要使用到鉴别器就需要他的计算图。因此,我们需要在调用 disc_loss.backward() 后保留计算图,以便后续调用 gen_loss.backward() 时能够继续使用相同的计算图进行梯度计算。而对于生成器的梯度更新 gen_loss.backward(),不需要显式指定 retain_graph=True。
所以,在同一个计算图上多次调用 backward() 函数时才需要使用它。

下一篇构建DCGAN。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/2904851.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

就业班 第二阶段 2401--3.27 day8 shell之循环控制

七、shell编程-循环结构 shell循环-for语句 for i in {取值范围} # for 关键字 i 变量名 in 关键字 取值范围格式 1 2 3 4 5 do # do 循环体的开始循环体 done # done 循环体的结束 #!/usr/bin/env bash # # Author: # Date: 2019/…

kubernetes K8s的监控系统Prometheus升级Grafana,来一个酷炫的Node监控界面(二)

上一篇文章《kubernetes K8s的监控系统Prometheus安装使用(一)》中使用的监控界面总感觉监控的节点数据太少,不能快算精准的判断出数据节点运行的状况。 今天我找一款非常酷炫的多维度数据监控界面,能够非常有把握的了解到各节点的数据,以及运…

HarmonyOS 应用开发之显式Want与隐式Want匹配规则

在启动目标应用组件时,会通过显式 Want 或者隐式 Want 进行目标应用组件的匹配,这里说的匹配规则就是调用方传入的 want 参数中设置的参数如何与目标应用组件声明的配置文件进行匹配。 显式Want匹配原理 显式 Want 匹配原理如下表所示。 名称类型匹配…

【leetcode】环形链表的约瑟夫问题

大家好,我是苏貝,本篇博客带大家刷题,如果你觉得我写的还不错的话,可以给我一个赞👍吗,感谢❤️ 点击查看题目 首先我们要明确一点,题目要求我们要用环形链表,所以用数组等是不被允…

某某消消乐增加步数漏洞分析

一、漏洞简介 1) 漏洞所属游戏名及基本介绍:某某消消乐,三消游戏,类似爱消除。 2) 漏洞对应游戏版本及平台:某某消消乐Android 1.22.22。 3) 漏洞功能:增加游戏步数。 4&#xf…

Spark-Scala语言实战(6)

在之前的文章中,我们学习了如何在scala中定义与使用类和对象,并做了几道例题。想了解的朋友可以查看这篇文章。同时,希望我的文章能帮助到你,如果觉得我的文章写的不错,请留下你宝贵的点赞,谢谢。 Spark-S…

智能设备配网保姆级教程

设备配网 简单来说,配网就是将物联网(IoT)设备连接并注册到云端,使其拥有与云端远程通信的能力。配网后,智能设备才能被手机应用或者项目管理后台控制,依托于智能场景创造价值。本文介绍了配网的相关知识&…

Linux环境安装Redis

Linux环境安装Redis 一,软件安装准备 服务器连接软件 Redis数据库连接软件 这是Windows软件,用于连接Linux服务器使用。推荐使用。 二,下载Redis 下载地址:Index of /releases/ 截止编稿Redis版本已经到7.2.4了,如果…

如何使用Windows电脑部署Lychee私有图床网站并实现无公网IP远程管理本地图片

🌈个人主页: Aileen_0v0 🔥热门专栏: 华为鸿蒙系统学习|计算机网络|数据结构与算法|MySQL| ​💫个人格言:“没有罗马,那就自己创造罗马~” #mermaid-svg-MSVdVLkQMnY9Y2HW {font-family:"trebuchet ms",verdana,arial,sans-serif;f…

什么是RISC-V?开源 ISA 如何重塑未来的处理器设计

RISC-V代表了处理器架构的范式转变,特点是其开源模型简化了设计理念并促进了全球community-driven的开发。RISC-V导致了处理器技术发展前进方式的重大转变,提供了一个不受传统复杂性阻碍的全新视角。 RISC-V起源于加州大学伯克利分校的学术起点&#xff…

腾讯云服务器多少钱一年?2024年最新价格整理

2024年腾讯云4核8G服务器租用优惠价格:轻量应用服务器4核8G12M带宽646元15个月,CVM云服务器S5实例优惠价格1437.24元买一年送3个月,腾讯云4核8G服务器活动页面 txybk.com/go/txy 活动链接打开如下图: 腾讯云4核8G服务器优惠价格 轻…

设计模式 - 简单工厂模式

文章目录 前言 大家好,今天给大家介绍一下23种常见设计模式中的一种 - 工厂模式 1 . 问题引入 请用C、Java、C#或 VB.NET任意一种面向对象语言实现一个计算器控制台程序,要求输入两个数和运算符 号,得到结果。 下面的代码实现默认认为两个操作数为Inte…

阿里云CentOS7安装Hadoop3伪分布式

ECS准备 开通阿里云ECS 略 控制台设置密码 连接ECS 远程连接工具连接阿里云ECS实例,这里远程连接工具使用xshell 根据提示接受密钥 根据提示写用户名和密码 用户名:root 密码:在控制台设置的密码 修改主机名 将主机名从localhost改为需要…

excel中批量插入分页符

excel中批量插入分页符,实现按班级打印学生名单。 1、把学生按照学号、班级排序好。 2、选择班级一列,点击数据-分类汇总。汇总方式选择计数,最后三个全部勾选。汇总结果一定要显示在数据的下发,如果显示在上方,后期…

操作教程|在MeterSphere中通过SSH登录服务器的两种方法

MeterSphere开源持续测试平台拥有非常强大的插件集成机制,用户可以通过插件实现平台能力的拓展,借助插件或脚本实现多种功能。在测试过程中,测试人员有时需要通过SSH协议登录至服务器,以获取某些配置文件和日志文件,或…

Python爬虫:爬虫常用伪装手段

目录 前言 一、设置User-Agent 二、设置Referer 三、使用代理IP 四、限制请求频率 总结 前言 随着互联网的快速发展,爬虫技术在网络数据采集方面发挥着重要的作用。然而,由于爬虫的使用可能会对被爬取的网站造成一定的压力,因此&#…

HarmonyOS实战开发-实现带有卡片的电影应用

介绍 本篇Codelab基于元服务卡片的能力,实现带有卡片的电影应用,介绍卡片的开发过程和生命周期实现。需要完成以下功能: 元服务卡片,用于在桌面上添加2x2或2x4规格元服务卡片。关系型数据库,用于创建、查询、添加、删…

SQL,group by分组后分别计算组内不同值的数量

SQL,group by分组后分别计算组内不同值的数量 如现有一张购物表shopping 先要求小明和小红分别买了多少笔和多少橡皮,形成以下格式 SELECT name,COUNT(*) FROM shopping GROUP BY name;SELECT name AS 姓名,SUM( CASE WHEN cargo 笔 THEN 1 ELSE 0 END)…

Prometheus +Grafana +node_exporter可视化监控Linux虚机

1、介绍 待补充 2、架构图 待补充 Prometheus :主要是负责存储、抓取、聚合、查询方面。 node_exporter :主要是负责采集物理机、中间件的信息。 3、搭建过程 配置要求:1台主服务器 n台从服务器 (被监控的linux虚机&am…

百度智能云千帆,产业创新新引擎

本文整理自 3 月 21 日百度副总裁谢广军的主题演讲《百度智能云千帆,产业创新新引擎》。 各位领导、来宾、媒体朋友们,大家上午好。很高兴今天在石景山首钢园,和大家一起沟通和探讨大模型的发展趋势,以及百度最近一段时间的思考和…