论文讲解见上篇博客
这篇论文的标题是《Unpaired Unsupervised CT Metal Artifact Reduction》,作者是Bo-Yuan Chen和Chu-Song Chen。这篇论文主要研究了如何使用深度学习技术来减少医学成像中由于金属植入物引起的CT图像伪影。
项目给出了几个不同的unet网络的实验,以pytorch_Net.py举例
train
1、参数如下
batch_size = 8
num_epoch = 25
lr = 2e-5
channels = 3
img_size = 320
lmda_g = 0.05
lmda_dnn = 0.1
input_shape = (channels, img_size, img_size)
居然是3通道的,大家要用记者修改
2、获得患者信息
train_patient_info_noise, train_patient_info_clear, train_noise_num, train_clear_num = get_patient_info(CT_dir, OMA_dir, patients_id_list_train, semi=True)test_patient_info_noise, test_patient_info_clear, test_noise_num, test_clear_num = get_patient_info(CT_dir, OMA_dir, patients_id_list_test, semi=True)
def get_patient_info(root, patients_id_list):patient_info_clear = list()patient_info_clear = pd.DataFrame(patient_info_clear, columns = ['name', 'path', 'class']) # clear : 0patient_info_noise = list()patient_info_noise = pd.DataFrame(patient_info_noise, columns = ['name', 'path', 'class']) # noise : 1noise_num = 0clear_num = 0for i, patient_id in enumerate(patients_id_list):patient_id_path = os.path.join(root, patient_id)f = open(os.path.join(patient_id_path, 'MA_slice_num.txt'))noisy_patients_No = list()for line in f.read().splitlines():noisy_patients_No.append(line)for item in os.listdir(patient_id_path):if ('.jpg' in item and item.split('_')[0] in noisy_patients_No):patient_info_noise = patient_info_noise.append({'name':item,'path': patient_id_path, 'class': 1}, ignore_index = True)noise_num += 1elif ('.jpg' in item and item.split('_')[0] not in noisy_patients_No):patient_info_clear = patient_info_clear.append({'name':item,'path': patient_id_path, 'class': 0}, ignore_index = True)clear_num += 1return patient_info_noise, patient_info_clear, noise_num, clear_num
包括CT是否是干净的,CT名,CT路径等
3、根据id划分训练、测试集
test_transform = transforms.Compose([ transforms.Resize((img_size, img_size)), transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])train_set_noise1 = CTImg(transform = train_transform, patient_info = train_patient_info_noise,CT_dir=CT_dir,OMA_dir=OMA_dir,Mask_dir=Mask_dir)train_set_noise = ConcatDataset([train_set_noise1, train_set_noise1, train_set_noise1, train_set_noise1])train_set_noise = ConcatDataset([train_set_noise,train_set_noise])train_set_clear = CTImg(transform = train_transform, patient_info = train_patient_info_clear,CT_dir=CT_dir,OMA_dir=OMA_dir,Mask_dir=Mask_dir)test_set_noise = CTImg(transform = test_transform, patient_info = test_patient_info_noise,CT_dir=CT_dir,OMA_dir=OMA_dir,Mask_dir=Mask_dir)test_set_clear = CTImg(transform = test_transform, patient_info = test_patient_info_clear,CT_dir=CT_dir,OMA_dir=OMA_dir,Mask_dir=Mask_dir)train_noise_loader = DataLoader(train_set_noise, batch_size = batch_size, shuffle=True)train_clear_loader = DataLoader(train_set_clear, batch_size = batch_size, shuffle=True)test_noise_loader = DataLoader(test_set_noise, batch_size = batch_size, shuffle=False)test_clear_loader = DataLoader(test_set_clear, batch_size = batch_size, shuffle=False)
有CT也有noise 的数据
4、加载损失函数
g_loss = torch.nn.BCEWithLogitsLoss()g_r_loss = torch.nn.MSELoss()d_loss = torch.nn.BCEWithLogitsLoss()dnn_loss = torch.nn.MSELoss()dnn_r_loss = torch.nn.MSELoss()
5、两个生成器一个鉴别器
Gen = Generator(input_shape)Dis = Discriminator(input_shape)Dnn = Denoiser_UNet(input_shape)
6、放入cuda,初始化权重、优化函数
if cuda:Gen = Gen.cuda()Dis = Dis.cuda()Dnn = Dnn.cuda()g_loss.cuda()d_loss.cuda()dnn_loss.cuda()# Initialize weightsGen.apply(weights_init_normal)Dis.apply(weights_init_normal)Dnn.apply(weights_init_normal)# Optimizersoptimizer_Gen = torch.optim.Adam(Gen.parameters(), lr=lr, betas=(0.5, 0.999))optimizer_Dis = torch.optim.Adam(Dis.parameters(), lr=lr/2, betas=(0.5, 0.999))optimizer_Dnn = torch.optim.Adam(Dnn.parameters(), lr=lr, betas=(0.5, 0.999))# Input tensor typeTensor = torch.cuda.FloatTensor if cuda else torch.Tensorfix_batch_sample_z = Tensor(get_random_sample(([batch_size] + list(input_shape)), method = 'uniform'))
7、开始训练,训练鉴别器,先生成个噪音g_noise,然后再与干净数据结合,提取特征DIS,计算损失real_loss、fake_loss,返回梯度。
""" Train D """optimizer_Dis.zero_grad()batch_sample_z = Tensor(get_random_sample(([len(clear_img)] + list(input_shape)), method = 'uniform'))g_noise = Gen(torch.cat((Variable(batch_sample_z).cuda(),Variable(clear_img).cuda()), 1))g_img = g_noise + Variable(clear_img).cuda()noisy_real = diff(Variable(noise_img).cuda())noisy_fake = diff(g_img)#if i ==0:# print(f"shape of noisy_real: {noisy_real.shape}, shape of noisy_fake: {noisy_fake.shape}")real_logit = Dis(noisy_real.detach())fake_logit = Dis(noisy_fake.detach())real_label = Variable(noise_cls.float().cuda()) #1fake_label = Variable(clear_cls.float().cuda()) #0real_loss = d_loss(real_logit, real_label)fake_loss = d_loss(fake_logit, fake_label)loss_D = (real_loss + fake_loss) / 2loss_D.backward()optimizer_Dis.step()
训练生成器,
optimizer_Gen.zero_grad()optimizer_Dnn.zero_grad()batch_sample_z = Tensor(get_random_sample(([len(clear_img)] + list(input_shape)), method = 'uniform'))g_noise = Gen(torch.cat((Variable(batch_sample_z).cuda(),Variable(clear_img).cuda()), 1))# semi-partloss_g_r, loss_dnn_r = 0, 0spl = 0for li, (ni,s,nl) in enumerate(zip(noise_img, supervised, noise_label)):b_s_z = Tensor(get_random_sample(([1] + list(input_shape)), method = 'uniform'))if s:spl += 1g_n_GT = Gen(torch.cat((Variable(b_s_z).cuda(),Variable(nl[None]).cuda()), 1))loss_g_r += g_r_loss(g_n_GT, Variable(ni[None]).cuda() -Variable(nl)[None].cuda())dnn_p_GT = Dnn(g_n_GT.detach())loss_dnn_r = dnn_r_loss(dnn_p_GT, Variable(ni[None]).cuda() -Variable(nl[None]).cuda())if spl != 0:loss_g_r /= splloss_dnn_r /= splg_img = g_noise + Variable(clear_img).cuda()noisy_fake = diff(g_img)fake_logit = Dis(noisy_fake) loss_G = g_loss(fake_logit, torch.ones((len(clear_img))).cuda()) + lmda_g * loss_g_r loss_G.backward()optimizer_Gen.step()dnn_pred = Dnn(g_noise.detach()) out = g_img.detach() - dnn_pred loss_Dnn = dnn_loss(out,Variable(clear_img).cuda()) + lmda_dnn * loss_dnn_r loss_Dnn.backward()optimizer_Dnn.step()
8、验证+保存
with torch.no_grad():psnr = PSNR()mae = MAE()N_GT_psnr, DN_GT_psnr, N_GT_mae, DN_GT_mae, N_GT_ssim, DN_GT_ssim = 0, 0, 0, 0, 0, 0for i, ((noise_img, _,_,noise_label,_), (clear_img,_,_,clear_label,_)) in enumerate(zip(test_noise_loader, test_clear_loader)):'''Gen'''g_noise = Gen(torch.cat((Variable(fix_batch_sample_z).cuda(),Variable(clear_img).cuda()), 1)) g_img = g_noise + Variable(clear_img).cuda()'''Dnn'''dnn_pred = Dnn(Variable(noise_img).cuda())out = Variable(noise_img).cuda() - dnn_predbatch_len = len(out)for (noise,label) in zip(Variable(noise_img).cuda(),Variable(noise_label).cuda()): N_GT_psnr += psnr(noise, label)/batch_len#N_GT_ssim += compare_ssim(noise,label)/batch_lenN_GT_mae += mae(noise,label)/batch_lenfor (denoise,label) in zip(out,Variable(noise_label).cuda()): DN_GT_psnr += psnr(clp(denoise), label)/batch_len#DN_GT_ssim += compare_ssim(denoise,label)/batch_lenDN_GT_mae += mae(clp(denoise), label)/batch_lenif i == 0: fig = plt.figure(figsize=[8*6,8*4])axes = [fig.add_subplot(6, 1, r+1 ) for r in range(0, 6)]for ax in axes:ax.axis('off')plt.gca().xaxis.set_major_locator(plt.NullLocator())plt.gca().yaxis.set_major_locator(plt.NullLocator())plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)plt.margins(0,0) axes[0].imshow(torchvision.utils.make_grid(clear_img.cpu(), nrow=8).permute(1, 2, 0))#torchvision.utils.save_image(clear_img.cpu(), './samples/origin_clear_ep{:02d}-{:04d}.png'.format(epoch, i)) axes[1].imshow(torchvision.utils.make_grid(g_noise.cpu(), nrow=8).permute(1, 2, 0))#torchvision.utils.save_image(g_noise.cpu(), './samples/gen_noise_ep{:02d}-{:04d}.png'.format(epoch, i)) axes[2].imshow(torchvision.utils.make_grid(g_img.cpu(), nrow=8).permute(1, 2, 0))#torchvision.utils.save_image(g_img.cpu(), './samples/gen_img_ep{:02d}-{:04d}.png'.format(epoch, i)) axes[3].imshow(torchvision.utils.make_grid(noise_img.cpu(), nrow=8).permute(1, 2, 0))#torchvision.utils.save_image(noise_img.cpu(), './samples/origin_noise_ep{:02d}-{:04d}.png'.format(epoch, i))axes[4].imshow(torchvision.utils.make_grid(dnn_pred.cpu(), nrow=8).permute(1, 2, 0))#torchvision.utils.save_image(dnn_pred.cpu(), './samples/dnn_noise_ep{:02d}-{:04d}.png'.format(epoch, i))axes[5].imshow(torchvision.utils.make_grid(out.cpu(), nrow=8).permute(1, 2, 0))#torchvision.utils.save_image(out.cpu(), './samples/denoised_img_ep{:02d}-{:04d}.png'.format(epoch, i))fig.savefig("results/SS_DNN2UNet/cv{:02d}ep{:02d}.png".format(idx+1,epoch),bbox_inches = 'tight',pad_inches = 0)plt.close(fig)print("saving...")
model
class Generator(nn.Module):def __init__(self, input_shape, cat=True):super(Generator, self).__init__()channels, _, _ = input_shapeif cat:channels*=2 self.down1 = G_Down(channels, 32, normalize=False) self.down2 = G_Down(32, 32) self.down3 = G_Down(32, 64, pooling=True, dropout=0.5) self.down4 = G_Down(64, 64) self.down5 = G_Down(64, 128, pooling=True, dropout=0.5) self.down6 = G_Down(128, 128, normalize=False) self.up1 = G_Up(256, 64, uppooling=True, dropout=0.5)self.up2 = G_Up(64, 64)self.up3 = G_Up(128, 32, uppooling=True, dropout=0.5)self.up4 = G_Up(32, 32)self.up5 = G_Up(32, 3)self.final = nn.Sequential(nn.Conv2d(3, 3, kernel_size = 3,stride=1, padding=1),nn.Tanh())def forward(self, x): #[batchsize, 6, 64, 64]# U-Net generator with skip connections from encoder to decoderd1 = self.down1(x) #[batchsize, 32, 64, 64]d2 = self.down2(d1) #[batchsize, 32, 64, 64]d3 = self.down3(d2) #[batchsize, 64, 32, 32]d4 = self.down4(d3) #[batchsize, 64, 32, 32]d5 = self.down5(d4) #[batchsize, 128, 16, 16]d6 = self.down6(d5) #[batchsize, 128, 16, 16]cat1 = torch.cat((d6, d5), 1) #[batchsize, 256, 16, 16]u1 = self.up1(cat1) #[batchsize, 64, 32, 32]u2 = self.up2(u1) #[batchsize, 64, 32, 32]cat2 = torch.cat((u2, d4), 1) #[batchsize, 128, 32, 32]u3 = self.up3(cat2) #[batchsize, 32, 64, 64] u4 = self.up4(u3) #[batchsize, 32, 64, 64]u5 = self.up5(u4) #[batchsize, 3, 64, 64]return self.final(u5) #[batchsize, 3, 64, 64]
class Discriminator(nn.Module):def __init__(self, input_shape):super(Discriminator, self).__init__()channels, height, width = input_shapeself.input_shape = (channels*2, height, width) #[batchsize, 3, 64, 64]# Calculate output of image discriminator (PatchGAN)self.output_shape = (1, height // 2 ** 3, width // 2 ** 3)def discriminator_block(in_filters, out_filters, normalization=True):"""Returns downsampling layers of each discriminator block"""layers = [nn.Conv2d(in_filters, out_filters, 4, 2, 1)]if normalization:layers.append(nn.BatchNorm2d(out_filters))layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(*discriminator_block(channels*2, 16, normalization=False), #[batchsize, 64, 32, 32]*discriminator_block(16, 32), #[batchsize, 128, 16, 16]*discriminator_block(32, 128), #[batchsize, 256, 8, 8]*discriminator_block(128, 128), #[batchsize, 512, 4, 4])self.final = nn.Sequential(nn.Linear(128 * 20 * 20, 1),nn.Sigmoid(),)def forward(self, img):# Concatenate image and condition image by channels to produce inputconv = self.model(img)conv = conv.view(conv.shape[0], -1)return self.final(conv).view(-1)
综上,与论文框架描述一致,没有弯弯绕绕