文章目录
- 网络介绍
- 训练
- 推理结果
网络介绍
Pix2Pix是基于条件生成对抗网络(cGAN, Condition Generative Adversarial Networks )实现的一种深度学习图像转换模型,可以实现语义/标签到真实图片、灰度图到彩色图、航空图到地图、白天到黑夜、线稿图到实物图的转换。Pix2Pix是将cGAN应用于有监督的图像到图像翻译的经典之作,其包括两个模型:生成器和判别器。
下面介绍本网络用到的生成器、判别器:
生成器G用到的是U-Net结构,输入的轮廓图 𝑥编码再解码成真是图片,
判别器D用到的是作者自己提出来的条件判别器PatchGAN,判别器D的作用是在轮廓图𝑥的条件下,对于生成的图片 𝐺(𝑥)判断为假,对于真实判断为真。
训练
训练分为两个主要部分:训练判别器和训练生成器。
训练判别器的目的是最大程度地提高判别图像真伪的概率。
训练生成器是希望能产生更好的虚假图像。
在这两个部分中,分别获取训练过程中的损失,并在每个周期结束时进行统计。部分实现如下:
import numpy as np
import os
import datetime
from mindspore import value_and_grad, Tensorepoch_num = 3
ckpt_dir = "results/ckpt"
dataset_size = 400
val_pic_size = 256
lr = 0.0002
n_epochs = 100
n_epochs_decay = 100def get_lr():lrs = [lr] * dataset_size * n_epochslr_epoch = 0for epoch in range(n_epochs_decay):lr_epoch = lr * (n_epochs_decay - epoch) / n_epochs_decaylrs += [lr_epoch] * dataset_sizelrs += [lr_epoch] * dataset_size * (epoch_num - n_epochs_decay - n_epochs)return Tensor(np.array(lrs).astype(np.float32))dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True, num_parallel_workers=1)
steps_per_epoch = dataset.get_dataset_size()
loss_f = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()def forword_dis(reala, realb):lambda_dis = 0.5fakeb = net_generator(reala)pred0 = net_discriminator(reala, fakeb)pred1 = net_discriminator(reala, realb)loss_d = loss_f(pred1, ops.ones_like(pred1)) + loss_f(pred0, ops.zeros_like(pred0))loss_dis = loss_d * lambda_disreturn loss_disdef forword_gan(reala, realb):lambda_gan = 0.5lambda_l1 = 100fakeb = net_generator(reala)pred0 = net_discriminator(reala, fakeb)loss_1 = loss_f(pred0, ops.ones_like(pred0))loss_2 = l1_loss(fakeb, realb)loss_gan = loss_1 * lambda_gan + loss_2 * lambda_l1return loss_gand_opt = nn.Adam(net_discriminator.trainable_params(), learning_rate=get_lr(),beta1=0.5, beta2=0.999, loss_scale=1)
g_opt = nn.Adam(net_generator.trainable_params(), learning_rate=get_lr(),beta1=0.5, beta2=0.999, loss_scale=1)grad_d = value_and_grad(forword_dis, None, net_discriminator.trainable_params())
grad_g = value_and_grad(forword_gan, None, net_generator.trainable_params())def train_step(reala, realb):loss_dis, d_grads = grad_d(reala, realb)loss_gan, g_grads = grad_g(reala, realb)d_opt(d_grads)g_opt(g_grads)return loss_dis, loss_ganif not os.path.isdir(ckpt_dir):os.makedirs(ckpt_dir)g_losses = []
d_losses = []
data_loader = dataset.create_dict_iterator(output_numpy=True, num_epochs=epoch_num)for epoch in range(epoch_num):for i, data in enumerate(data_loader):start_time = datetime.datetime.now()input_image = Tensor(data["input_images"])target_image = Tensor(data["target_images"])dis_loss, gen_loss = train_step(input_image, target_image)end_time = datetime.datetime.now()delta = (end_time - start_time).microsecondsif i % 2 == 0:print("ms per step:{:.2f} epoch:{}/{} step:{}/{} Dloss:{:.4f} Gloss:{:.4f} ".format((delta / 1000), (epoch + 1), (epoch_num), i, steps_per_epoch, float(dis_loss), float(gen_loss)))d_losses.append(dis_loss.asnumpy())g_losses.append(gen_loss.asnumpy())if (epoch + 1) == epoch_num:mindspore.save_checkpoint(net_generator, ckpt_dir + "Generator.ckpt")
推理结果
此章节学习到此结束,感谢昇思平台。