Paddle 实现DCGAN

传统GAN

传统的GAN可以看我的这篇文章:Paddle 基于ANN(全连接神经网络)的GAN(生成对抗网络)实现-CSDN博客

DCGAN

DCGAN是适用于图像生成的GAN,它的特点是:

  • 只采用卷积层和转置卷积层,而不采用全连接层
  • 在每个卷积层或转置卷积层之间,插入一个批归一化层和ReLU激活函数

转置卷积层

转置卷积层执行的是转置卷积或反卷积的操作,即它是常规卷积层的反向操作。它接收一个低分辨率的输入,然后将其通过转置滤波器升采样到更高的分辨率。

对于一个卷积层,它的输出大小公式是:

o = \frac{i + 2p - k}{s} + 1

其中,o表示输出大小,i表示输入大小,p表示填充(padding),k表示卷积核大小(kernel_size),s表示步长(stride)。也就是说:输出大小 = (输入大小 - 卷积核大小 + 2 × 填充数) ÷ 步长 + 1

而对于一个转置卷积层,它的输出大小公式是:

o = s(i-1)-2p+k+u

 其中,o表示输出大小,i表示输入大小,p表示填充(padding),k表示反卷积核大小(kernel_size),s表示步长(stride),u表示输出填充(output padding)。也就是说:输出大小 = (输入大小 - 1) * 步长 - 2*填充 + 反卷积大小 + 输出填充

在paddle中,转置卷积层可以这么定义:

paddle.nn.Conv2DTranspose(in_channels, out_channels, kernel_size, stride, padding)

像卷积层一样,反卷积层的in_channels表示输入通道数(如形如(3, 32, 32)的图片张量的通道数就是3),out_channels表示输出通道数(如把(64, 32, 32)变成3通道的彩色图像(3, 32, 32))。 

代码实现

这里我们采用NWPU-RESISC45数据集,从中选择“freeway”(高速公路)作为训练数据,让机器生成高速公路的图片。这个训练数据内有700张256x256的图片,但由于我的电脑显存不足,因此将图片大小设置为64x64.

先写dataset.py:

import paddle
import numpy as np
from PIL import Image
import osdef getAllPath(path):return [os.path.join(path, f) for f in os.listdir(path)]class FreewayDataset(paddle.io.Dataset):def __init__(self, transform=None):super().__init__()self.data = []for path in getAllPath('./freeway'):img = Image.open(path)img = img.resize((64, 64))img = np.array(img, dtype=np.float32).transpose((2, 1, 0))if transform is not None:img = transform(img)self.data.append(img)self.data = np.array(self.data, dtype=np.float32)def __getitem__(self, idx):return self.data[idx]def __len__(self):return len(self.data)

然后写训练脚本:

from dataset import FreewayDataset
import paddle
from models import Generator, Discriminator
import numpy as npdataset = FreewayDataset()
dataloader = paddle.io.DataLoader(dataset, batch_size=32, shuffle=True)netG = Generator()
netD = Discriminator()if 1:try:mydict = paddle.load('generator.params')netG.set_dict(mydict)mydict = paddle.load('discriminator.params')netD.set_dict(mydict)except:print('fail to load model')loss = paddle.nn.BCELoss()optimizerD = paddle.optimizer.Adam(parameters=netD.parameters(), learning_rate=0.0002, beta1=0.5, beta2=0.999)
optimizerG = paddle.optimizer.Adam(parameters=netG.parameters(), learning_rate=0.0002, beta1=0.5, beta2=0.999)# 最大迭代epoch
max_epoch = 1000for epoch in range(max_epoch):now_step = 0for step, data in enumerate(dataloader):############################# (1) 更新鉴别器############################ 清除D的梯度optimizerD.clear_grad()# 传入正样本,并更新梯度pos_img = datalabel = paddle.full([pos_img.shape[0], 1, 1, 1], 1, dtype='float32')pre = netD(pos_img)loss_D_1 = loss(pre, label)loss_D_1.backward()# 通过randn构造随机数,制造负样本,并传入D,更新梯度noise = paddle.randn([pos_img.shape[0], 100, 1, 1], 'float32')neg_img = netG(noise)label = paddle.full([pos_img.shape[0], 1, 1, 1], 0, dtype='float32')pre = netD(neg_img.detach())  # 通过detach阻断网络梯度传播,不影响G的梯度计算loss_D_2 = loss(pre, label)loss_D_2.backward()# 更新D网络参数optimizerD.step()optimizerD.clear_grad()loss_D = loss_D_1 + loss_D_2############################# (2) 更新生成器############################ 清除D的梯度optimizerG.clear_grad()noise = paddle.randn([pos_img.shape[0], 100, 1, 1], 'float32')fake = netG(noise)label = paddle.full((pos_img.shape[0], 1, 1, 1), 1, dtype=np.float32, )output = netD(fake)# 这个写法没有问题,因为这个loss既会影响到netG(output=netD(netG(noise)))的梯度,也会影响到netD的梯度,但是之后的代码并没有更新netD的参数,而循环开头就清除了netD的梯度loss_G = loss(output, label)loss_G.backward()# 更新G网络参数optimizerG.step()optimizerG.clear_grad()now_step += 1############################ 输出日志###########################if now_step % 10 == 0:print(f'Epoch ID={epoch} Batch ID={now_step} \n\n D-Loss={float(loss_D)} G-Loss={float(loss_G)}')paddle.save(netG.state_dict(), "generator.params")
paddle.save(netD.state_dict(), "discriminator.params")

 最后编写图片生成脚本:

import paddle
from models import Generator
import matplotlib.pyplot as plt# 加载模型
netG = Generator()
mydict = paddle.load('generator.params')
netG.set_dict(mydict)# 设置matplotlib的显示环境
fig, axs = plt.subplots(nrows=2, ncols=5, figsize=(15, 6))  # 创建一个2x5的子图网格# 生成10个噪声向量
for i, ax in enumerate(axs.flatten()):noise = paddle.randn([1, 100, 1, 1], 'float32')img = netG(noise)img = img.numpy()[0].transpose((2, 1, 0))  # img.numpy():张量转np数组img[img < 0] = 0  # 将img中所有小于0的元素赋值为0# 显示图片ax.imshow(img)ax.axis('off')  # 不显示坐标轴# 显示图像
plt.show()

经过数次训练,最终的效果如下:

这样看来,至少有点高速公路的感觉了。 

参考

通过DCGAN实现人脸图像生成-使用文档-PaddlePaddle深度学习平台

卷积层和反卷积层输出特征图大小计算_输出特征图大小的计算方法-CSDN博客 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3029899.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

Oracle体系结构初探:闪回技术

在Oracle体系结构初探这个专栏中&#xff0c;已经写过了REDO、UNDO等内容。觉得可以开始写下有关备份恢复的内容。闪回技术 — Oracle数据库备份恢复机制的一种。它可以在一定条件下&#xff0c;高效快速的恢复因为逻辑错误&#xff08;误删误更新等&#xff09;导致的数据丢失…

数据库表自增主键超过代码Integer长度问题

数据库自增主键是 int(10) unsigned类型的字段&#xff0c;int(M) 中 M指示最大显示宽度&#xff0c;不代表存储长度&#xff0c;实际int(1)也是可以存储21.47亿长度的数字&#xff0c;如果是无符号类型的&#xff0c;那么可以从0~42.94亿。 我们的表主键自增到21.47亿后&#…

应用层协议之 DNS 协议

DNS 就是一个域名解析系统。域名就是网址&#xff0c;类似于 www.baidu.com。网络上的服务器想要访问它&#xff0c;就得需要它对应的 IP 地址&#xff0c;同时&#xff0c;每个域名对对应着一个 / N个 IP 地址&#xff08;即对应多台服务器&#xff09;。 因此&#xff0c;为了…

HarmonyOS开发案例:【生活健康app之实现打卡功能】(2)

实现打卡功能 首页会展示当前用户已经开启的任务列表&#xff0c;每条任务会显示对应的任务名称以及任务目标、当前任务完成情况。用户只可对当天任务进行打卡操作&#xff0c;用户可以根据需要对任务列表中相应的任务进行点击打卡。如果任务列表中的每个任务都在当天完成则为…

安装vmware station记录

想学一下linux,花了3个多小时&#xff0c;才配置好了&#xff0c;记录一下 安装vm12,已配置linux系统 报错&#xff0c;VMware Workstation 与 Device/Credential Guard 不兼容解决方案&#xff0c;网上说有不成功的&#xff0c;电脑蓝屏&#xff0c;选择装vm16试试 vm16 在…

【JVM】JVM规范作用及其核心

目录 认识JVM规范的作用 JVM规范定义的主要内容 认识JVM规范的作用 Java 虚拟机规范为不同的硬件平台提供了一种编译Java技术代码的规范。 Java虚拟机认得不是源文件&#xff0c;认得是编译过后的class文件&#xff0c;它是对这个class文件做要求、起作用的&#xff0c;而并…

算法设计与分析 动态规划/回溯

1.最大子段和 int a[N]; int maxn(int n) {int tempa[0];int ans0;ansmax(temp,ans);for(int i1;i<n;i){if(temp>0){tempa[i];}else tempa[i];ansmax(temp,ans);}return ans; } int main() {int n,ans0;cin>>n;for(int i0;i<n;i) cin>>a[i];ansmaxn(n);co…

LeetCode例题讲解:876.链表的中间结点

给你单链表的头结点 head &#xff0c;请你找出并返回链表的中间结点。 如果有两个中间结点&#xff0c;则返回第二个中间结点。 示例 1&#xff1a; 输入&#xff1a;head [1,2,3,4,5] 输出&#xff1a;[3,4,5] 解释&#xff1a;链表只有一个中间结点&#xff0c;值为 3 。…

kubernetes删除命名空间下所有资源

kubernetes强制删除命名空间下所有资源 在 Kubernetes 中&#xff0c;当一个命名空间处于 Terminating 状态但不会完成删除过程时&#xff0c;通常是因为内部资源没有被正确清理。要强制删除这个命名空间及其所有资源&#xff0c;你可以采取以下步骤&#xff1a; 1. 确认命名空…

渲染农场评测:6大热门云渲染平台全面比较

在3D行业中&#xff0c;选择一个合适的云渲染平台可能会令许多专业人士感到难以抉择。为此&#xff0c;我们精心准备了6家流行云渲染平台的详尽评测&#xff0c;旨在为您的决策过程提供实用的参考和支持。 目前&#xff0c;市面上主要的3D网络渲染平台包括六大服务商&#xff0…

张驰咨询六西格玛黑带培训,上海开班,质量精英的摇篮!

一、课程背景与意义 在当今竞争激烈的市场环境中&#xff0c;企业要想立于不败之地&#xff0c;就必须不断提升自身的核心竞争力。而六西格玛作为一种先进的质量管理工具和方法&#xff0c;已经被越来越多的企业所采纳。通过六西格玛黑带培训&#xff0c;学员们可以系统地掌握…

【c++算法篇】双指针(下)

&#x1f525;个人主页&#xff1a;Quitecoder &#x1f525;专栏&#xff1a;算法笔记仓 朋友们大家好啊&#xff0c;本篇文章我们来到算法的双指针的第二部分 目录 1.有效三角形的个数2.查找总价格为目标值的两个商品3.三数之和4.四数之和5.双指针常见场景总结 1.有效三角形…

CAPL如何实现TLS握手认证

CAPL有专门的章节介绍如何实现TLS握手认证的函数: CAPL调用哪些函数实现TLS握手认证,需要了解TLS在整个通信过程的哪个阶段。 首先TCP需要建立连接,这是TLS握手的前提。当TLS握手认证完成后,可以传输数据。 所以TLS握手开始前需要确保TCP建立连接,TCP传输数据前需要确保…

基于SSM的文化遗产的保护与旅游开发系统(有报告)。Javaee项目。ssm项目。

演示视频&#xff1a; 基于SSM的文化遗产的保护与旅游开发系统&#xff08;有报告&#xff09;。Javaee项目。ssm项目。 项目介绍&#xff1a; 采用M&#xff08;model&#xff09;V&#xff08;view&#xff09;C&#xff08;controller&#xff09;三层体系结构&#xff0c;…

适合年轻人的恋爱交友脱单软件有哪些?中国十大社交软件排行榜分享

交友始祖&#xff1a;Tinder 一直很受欢迎&#xff0c;可以向上扫给 super like (每日有一次免费机会)。如果双方互相 like&#xff0c;代表配对成功&#xff0c;就可以开始聊天。另外&#xff0c;每日有 10 个 top picks 供选择&#xff0c;你可以免费选一位 主力编外&#xf…

K8s源码分析(二)-K8s调度队列介绍

本文首发在个人博客上&#xff0c;欢迎来踩&#xff01; 本次分析参考的K8s版本是 文章目录 调度队列简介调度队列源代码分析队列初始化QueuedPodInfo元素介绍ActiveQ源代码介绍UnschedulableQ源代码介绍**BackoffQ**源代码介绍队列弹出待调度的Pod队列增加新的待调度的Podpod调…

数据分析:基于sparcc的co-occurrence网络

介绍 Sparcc是基于16s或metagenomics数据等计算组成数据之间关联关系的算法。通常使用count matrix数据。 安装Sparcc软件 git clone gitgithub.com:JCSzamosi/SparCC3.git export PATH/path/SparCC3:$PATHwhich SparCC.py导入数据 注&#xff1a;使用rarefy抽平的count ma…

牛客小白月赛93

B交换数字 题目&#xff1a; 思路&#xff1a;我们可以知道&#xff0c;a*b% mod (a%mod) * (b%mod) 代码&#xff1a; void solve(){int n;cin >> n;string a, b;cin >> a >> b;for(int i 0;i < n;i )if(a[i] > b[i])swap(a[i], b[i]);int num1…

图片无损压缩工具-VIKY

一、前言 Viky v3.4是一款功能强大的图片压缩工具&#xff0c;它能够提供高效的图片无损压缩服务。通过使用独特的压缩算法&#xff0c;该软件在显著减小图片文件大小的同时&#xff0c;还保持了图像的清晰度和色彩饱和度&#xff0c;确保了图像质量的优异表现。 二、软件特点…

AS-VJ900实时视频拼接系统产品介绍:两画面视频拼接方法和操作

目录 一、实时视频拼接系统介绍 &#xff08;一&#xff09;实时视频拼接的定义 &#xff08;二&#xff09;无缝拼接 &#xff08;三&#xff09;AS-VJ900功能介绍 1、功能 2、拼接界面介绍 二、拼接前的准备 &#xff08;一&#xff09;摄像机选择 &#xff08;二&a…