《Unpaired Unsupervised CT Metal ArtifactReduction》代码讲解

论文讲解见上篇博客
        这篇论文的标题是《Unpaired Unsupervised CT Metal Artifact Reduction》,作者是Bo-Yuan Chen和Chu-Song Chen。这篇论文主要研究了如何使用深度学习技术来减少医学成像中由于金属植入物引起的CT图像伪影。

项目给出了几个不同的unet网络的实验,以pytorch_Net.py举例

train

1、参数如下

batch_size = 8 
num_epoch = 25
lr = 2e-5
channels = 3
img_size = 320
lmda_g = 0.05
lmda_dnn = 0.1
input_shape = (channels, img_size, img_size)

居然是3通道的,大家要用记者修改

2、获得患者信息

    train_patient_info_noise, train_patient_info_clear, train_noise_num, train_clear_num = get_patient_info(CT_dir, OMA_dir, patients_id_list_train, semi=True)test_patient_info_noise, test_patient_info_clear, test_noise_num, test_clear_num = get_patient_info(CT_dir, OMA_dir, patients_id_list_test, semi=True)
def get_patient_info(root, patients_id_list):patient_info_clear = list()patient_info_clear = pd.DataFrame(patient_info_clear, columns = ['name', 'path', 'class']) # clear : 0patient_info_noise = list()patient_info_noise = pd.DataFrame(patient_info_noise, columns = ['name', 'path', 'class']) # noise : 1noise_num = 0clear_num = 0for i, patient_id in enumerate(patients_id_list):patient_id_path = os.path.join(root, patient_id)f = open(os.path.join(patient_id_path, 'MA_slice_num.txt'))noisy_patients_No = list()for line in f.read().splitlines():noisy_patients_No.append(line)for item in os.listdir(patient_id_path):if ('.jpg' in item and item.split('_')[0] in noisy_patients_No):patient_info_noise = patient_info_noise.append({'name':item,'path': patient_id_path, 'class': 1}, ignore_index = True)noise_num += 1elif ('.jpg' in item and item.split('_')[0] not in noisy_patients_No):patient_info_clear = patient_info_clear.append({'name':item,'path': patient_id_path, 'class': 0}, ignore_index = True)clear_num += 1return patient_info_noise, patient_info_clear, noise_num, clear_num

包括CT是否是干净的,CT名,CT路径等

3、根据id划分训练、测试集

test_transform = transforms.Compose([  transforms.Resize((img_size, img_size)),                                 transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])train_set_noise1 = CTImg(transform = train_transform, patient_info = train_patient_info_noise,CT_dir=CT_dir,OMA_dir=OMA_dir,Mask_dir=Mask_dir)train_set_noise = ConcatDataset([train_set_noise1, train_set_noise1, train_set_noise1, train_set_noise1])train_set_noise = ConcatDataset([train_set_noise,train_set_noise])train_set_clear = CTImg(transform = train_transform, patient_info = train_patient_info_clear,CT_dir=CT_dir,OMA_dir=OMA_dir,Mask_dir=Mask_dir)test_set_noise =  CTImg(transform = test_transform, patient_info = test_patient_info_noise,CT_dir=CT_dir,OMA_dir=OMA_dir,Mask_dir=Mask_dir)test_set_clear =  CTImg(transform = test_transform, patient_info = test_patient_info_clear,CT_dir=CT_dir,OMA_dir=OMA_dir,Mask_dir=Mask_dir)train_noise_loader = DataLoader(train_set_noise, batch_size = batch_size, shuffle=True)train_clear_loader = DataLoader(train_set_clear, batch_size = batch_size, shuffle=True)test_noise_loader = DataLoader(test_set_noise, batch_size = batch_size, shuffle=False)test_clear_loader = DataLoader(test_set_clear, batch_size = batch_size, shuffle=False)

有CT也有noise 的数据

4、加载损失函数

g_loss = torch.nn.BCEWithLogitsLoss()g_r_loss = torch.nn.MSELoss()d_loss = torch.nn.BCEWithLogitsLoss()dnn_loss = torch.nn.MSELoss()dnn_r_loss = torch.nn.MSELoss()

5、两个生成器一个鉴别器

    Gen = Generator(input_shape)Dis = Discriminator(input_shape)Dnn = Denoiser_UNet(input_shape)

6、放入cuda,初始化权重、优化函数

if cuda:Gen = Gen.cuda()Dis = Dis.cuda()Dnn = Dnn.cuda()g_loss.cuda()d_loss.cuda()dnn_loss.cuda()# Initialize weightsGen.apply(weights_init_normal)Dis.apply(weights_init_normal)Dnn.apply(weights_init_normal)# Optimizersoptimizer_Gen = torch.optim.Adam(Gen.parameters(), lr=lr, betas=(0.5, 0.999))optimizer_Dis = torch.optim.Adam(Dis.parameters(), lr=lr/2, betas=(0.5, 0.999))optimizer_Dnn = torch.optim.Adam(Dnn.parameters(), lr=lr, betas=(0.5, 0.999))# Input tensor typeTensor = torch.cuda.FloatTensor if cuda else torch.Tensorfix_batch_sample_z = Tensor(get_random_sample(([batch_size] + list(input_shape)), method = 'uniform'))

7、开始训练,训练鉴别器,先生成个噪音g_noise,然后再与干净数据结合,提取特征DIS,计算损失real_loss、fake_loss,返回梯度。

            """ Train D """optimizer_Dis.zero_grad()batch_sample_z = Tensor(get_random_sample(([len(clear_img)] + list(input_shape)), method = 'uniform'))g_noise = Gen(torch.cat((Variable(batch_sample_z).cuda(),Variable(clear_img).cuda()), 1))g_img = g_noise + Variable(clear_img).cuda()noisy_real = diff(Variable(noise_img).cuda())noisy_fake = diff(g_img)#if i ==0:#    print(f"shape of noisy_real: {noisy_real.shape}, shape of noisy_fake: {noisy_fake.shape}")real_logit = Dis(noisy_real.detach())fake_logit = Dis(noisy_fake.detach())real_label = Variable(noise_cls.float().cuda()) #1fake_label = Variable(clear_cls.float().cuda()) #0real_loss = d_loss(real_logit, real_label)fake_loss = d_loss(fake_logit, fake_label)loss_D = (real_loss + fake_loss) / 2loss_D.backward()optimizer_Dis.step()

训练生成器,

            optimizer_Gen.zero_grad()optimizer_Dnn.zero_grad()batch_sample_z = Tensor(get_random_sample(([len(clear_img)] + list(input_shape)), method = 'uniform'))g_noise = Gen(torch.cat((Variable(batch_sample_z).cuda(),Variable(clear_img).cuda()), 1))# semi-partloss_g_r, loss_dnn_r = 0, 0spl = 0for li, (ni,s,nl) in enumerate(zip(noise_img, supervised, noise_label)):b_s_z = Tensor(get_random_sample(([1] + list(input_shape)), method = 'uniform'))if s:spl += 1g_n_GT = Gen(torch.cat((Variable(b_s_z).cuda(),Variable(nl[None]).cuda()), 1))loss_g_r += g_r_loss(g_n_GT, Variable(ni[None]).cuda() -Variable(nl)[None].cuda())dnn_p_GT = Dnn(g_n_GT.detach())loss_dnn_r = dnn_r_loss(dnn_p_GT, Variable(ni[None]).cuda() -Variable(nl[None]).cuda())if spl != 0:loss_g_r /= splloss_dnn_r /= splg_img = g_noise + Variable(clear_img).cuda()noisy_fake = diff(g_img)fake_logit = Dis(noisy_fake)         loss_G = g_loss(fake_logit, torch.ones((len(clear_img))).cuda()) + lmda_g * loss_g_r loss_G.backward()optimizer_Gen.step()dnn_pred = Dnn(g_noise.detach())                out = g_img.detach() - dnn_pred               loss_Dnn = dnn_loss(out,Variable(clear_img).cuda()) + lmda_dnn * loss_dnn_r loss_Dnn.backward()optimizer_Dnn.step()      

8、验证+保存

        with torch.no_grad():psnr = PSNR()mae = MAE()N_GT_psnr, DN_GT_psnr, N_GT_mae, DN_GT_mae, N_GT_ssim, DN_GT_ssim = 0, 0, 0, 0, 0, 0for i, ((noise_img, _,_,noise_label,_), (clear_img,_,_,clear_label,_)) in enumerate(zip(test_noise_loader, test_clear_loader)):'''Gen'''g_noise = Gen(torch.cat((Variable(fix_batch_sample_z).cuda(),Variable(clear_img).cuda()), 1))            g_img = g_noise + Variable(clear_img).cuda()'''Dnn'''dnn_pred = Dnn(Variable(noise_img).cuda())out = Variable(noise_img).cuda() - dnn_predbatch_len = len(out)for (noise,label) in zip(Variable(noise_img).cuda(),Variable(noise_label).cuda()): N_GT_psnr += psnr(noise, label)/batch_len#N_GT_ssim += compare_ssim(noise,label)/batch_lenN_GT_mae += mae(noise,label)/batch_lenfor (denoise,label) in zip(out,Variable(noise_label).cuda()): DN_GT_psnr += psnr(clp(denoise), label)/batch_len#DN_GT_ssim += compare_ssim(denoise,label)/batch_lenDN_GT_mae += mae(clp(denoise), label)/batch_lenif  i == 0:                fig = plt.figure(figsize=[8*6,8*4])axes = [fig.add_subplot(6, 1, r+1 ) for r in range(0, 6)]for ax in axes:ax.axis('off')plt.gca().xaxis.set_major_locator(plt.NullLocator())plt.gca().yaxis.set_major_locator(plt.NullLocator())plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)plt.margins(0,0) axes[0].imshow(torchvision.utils.make_grid(clear_img.cpu(), nrow=8).permute(1, 2, 0))#torchvision.utils.save_image(clear_img.cpu(), './samples/origin_clear_ep{:02d}-{:04d}.png'.format(epoch, i))               axes[1].imshow(torchvision.utils.make_grid(g_noise.cpu(), nrow=8).permute(1, 2, 0))#torchvision.utils.save_image(g_noise.cpu(), './samples/gen_noise_ep{:02d}-{:04d}.png'.format(epoch, i))                axes[2].imshow(torchvision.utils.make_grid(g_img.cpu(), nrow=8).permute(1, 2, 0))#torchvision.utils.save_image(g_img.cpu(), './samples/gen_img_ep{:02d}-{:04d}.png'.format(epoch, i))                                                         axes[3].imshow(torchvision.utils.make_grid(noise_img.cpu(), nrow=8).permute(1, 2, 0))#torchvision.utils.save_image(noise_img.cpu(), './samples/origin_noise_ep{:02d}-{:04d}.png'.format(epoch, i))axes[4].imshow(torchvision.utils.make_grid(dnn_pred.cpu(), nrow=8).permute(1, 2, 0))#torchvision.utils.save_image(dnn_pred.cpu(),  './samples/dnn_noise_ep{:02d}-{:04d}.png'.format(epoch, i))axes[5].imshow(torchvision.utils.make_grid(out.cpu(), nrow=8).permute(1, 2, 0))#torchvision.utils.save_image(out.cpu(), './samples/denoised_img_ep{:02d}-{:04d}.png'.format(epoch, i))fig.savefig("results/SS_DNN2UNet/cv{:02d}ep{:02d}.png".format(idx+1,epoch),bbox_inches = 'tight',pad_inches = 0)plt.close(fig)print("saving...")

model

class Generator(nn.Module):def __init__(self, input_shape, cat=True):super(Generator, self).__init__()channels, _, _ = input_shapeif cat:channels*=2 self.down1 = G_Down(channels, 32, normalize=False) self.down2 = G_Down(32, 32) self.down3 = G_Down(32, 64, pooling=True, dropout=0.5) self.down4 = G_Down(64, 64)         self.down5 = G_Down(64, 128, pooling=True, dropout=0.5) self.down6 = G_Down(128, 128, normalize=False) self.up1 = G_Up(256, 64, uppooling=True, dropout=0.5)self.up2 = G_Up(64, 64)self.up3 = G_Up(128, 32, uppooling=True, dropout=0.5)self.up4 = G_Up(32, 32)self.up5 = G_Up(32, 3)self.final = nn.Sequential(nn.Conv2d(3, 3, kernel_size = 3,stride=1, padding=1),nn.Tanh())def forward(self, x):               #[batchsize,   6, 64, 64]# U-Net generator with skip connections from encoder to decoderd1 = self.down1(x)              #[batchsize,  32, 64, 64]d2 = self.down2(d1)             #[batchsize,  32, 64, 64]d3 = self.down3(d2)             #[batchsize,  64, 32, 32]d4 = self.down4(d3)             #[batchsize,  64, 32, 32]d5 = self.down5(d4)             #[batchsize, 128, 16, 16]d6 = self.down6(d5)             #[batchsize, 128, 16, 16]cat1 = torch.cat((d6, d5), 1)   #[batchsize, 256, 16, 16]u1 = self.up1(cat1)             #[batchsize,  64, 32, 32]u2 = self.up2(u1)               #[batchsize,  64, 32, 32]cat2 = torch.cat((u2, d4), 1)   #[batchsize, 128, 32, 32]u3 = self.up3(cat2)             #[batchsize,  32, 64, 64]    u4 = self.up4(u3)               #[batchsize,  32, 64, 64]u5 = self.up5(u4)               #[batchsize,   3, 64, 64]return self.final(u5)           #[batchsize,   3, 64, 64]

 

class Discriminator(nn.Module):def __init__(self, input_shape):super(Discriminator, self).__init__()channels, height, width = input_shapeself.input_shape = (channels*2, height, width)                        #[batchsize,   3, 64, 64]# Calculate output of image discriminator (PatchGAN)self.output_shape = (1, height // 2 ** 3, width // 2 ** 3)def discriminator_block(in_filters, out_filters, normalization=True):"""Returns downsampling layers of each discriminator block"""layers = [nn.Conv2d(in_filters, out_filters, 4, 2, 1)]if normalization:layers.append(nn.BatchNorm2d(out_filters))layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(*discriminator_block(channels*2, 16, normalization=False),      #[batchsize,   64, 32, 32]*discriminator_block(16, 32),                                  #[batchsize,  128, 16, 16]*discriminator_block(32, 128),                                 #[batchsize,  256,  8,  8]*discriminator_block(128, 128),                                 #[batchsize,  512,  4,  4])self.final = nn.Sequential(nn.Linear(128 * 20 * 20, 1),nn.Sigmoid(),)def forward(self, img):# Concatenate image and condition image by channels to produce inputconv = self.model(img)conv = conv.view(conv.shape[0], -1)return self.final(conv).view(-1)

 

综上,与论文框架描述一致,没有弯弯绕绕

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/3226614.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

美国商超入驻Homedepot,会成为传统家织厂家跨境赛道吗?

近年来,随着全球化步伐的加快和电子商务的蓬勃发展,越来越多的企业开始寻求跨境拓展的机会。在这样的背景下,美国知名的家居用品零售商超——Homedepot成为了许多国内外家织厂家关注的焦点。那么,美国商超入驻Homedepot究竟如何呢…

ArcGis将同一图层的多个面要素合并为一个面要素

这里写自定义目录标题 1.加载面要素的shp数据 2.点击菜单栏的地理处理–融合,如下所示: 3.将shp面要素输入,并设置输出,点击确定即可合并。合并后的属性表就只有一个数据了。

神经网络构成、优化、常用函数+激活函数

Iris分类 数据集介绍,共有数据150组,每组包括长宽等4个输入特征,同时给出输入特征对应的Iris类别,分别用0,1,2表示。 从sklearn包datasets读入数据集。 from sklearn import darasets from pandas impor…

Python 视频的色彩转换

这篇教学会介绍使用OpenCV 的cvtcolor() 方法,将视频的色彩模型从RGB 转换为灰阶、HLS、HSV...等。 因为程式中的OpenCV 会需要使用镜头或GPU,所以请使用本机环境( 参考:使用Python 虚拟环境) 或使用Anaconda Jupyter 进行实作( 参考&#x…

【数据结构】--- 堆

​ 个人主页:星纭-CSDN博客 系列文章专栏 :数据结构 踏上取经路,比抵达灵山更重要!一起努力一起进步! 目录 一.堆的介绍 二.堆的实现 1.向下调整算法 2.堆的创建 3.堆的实现 4.堆的初始化和销毁 5.堆的插入 5.1扩容…

Bad substitution 奇怪的问题

记得之前写过一篇文章是关于shell 脚本的,这里,当时的系统是 CentOS 的,最近公司把所有的服务器系统都更换为 Ubuntu 了, 结果以前写的那个脚本无法执行了,错误就是 Bad substitution,网上搜索基本都是 {}…

[C++初阶]list类的初步理解

一、标准库的list类 list的底层是一个带哨兵位的双向循环链表结构 对比forward_list的单链表结构,list的迭代器是一个双向迭代器 与vector等顺序结构的容器相比,list在任意位置进行插入删除的效率更好,但是不支持任意位置的随机访问 list是一…

【EIScopus稳检索-高录用】第五届大数据与社会科学国际学术会议(ICBDSS 2024)

大会官网:www.icbdss.org 大会时间:2024年8月16-18日 大会地点:中国-上海 接受/拒稿通知:投稿后1-2周内 收录检索:EI,Scopus *所有参会者现场均可获取参会证明,会议通知(邀请函)&…

二维码生成需知:名片二维码尺寸多少合适?电子名片二维码制作方法?

随着数字化时代的到来,二维码在各个领域的应用越来越广泛,名片作为商业交流的重要工具之一,也开始逐渐融入二维码的元素。通过在名片上添加二维码,我们可以轻松实现信息的快速传递和分享。然而,名片二维码的尺寸选择成…

【割点 C++BFS】2556. 二进制矩阵中翻转最多一次使路径不连通

本文涉及知识点 割点 图论知识汇总 CBFS算法 LeetCode2556. 二进制矩阵中翻转最多一次使路径不连通 给你一个下标从 0 开始的 m x n 二进制 矩阵 grid 。你可以从一个格子 (row, col) 移动到格子 (row 1, col) 或者 (row, col 1) ,前提是前往的格子值为 1 。如…

国产口碑最好的骨传导耳机有哪些?优选五大高口碑机型推荐!

作为一名有着多年工作经验的数码测评师,可以说对骨传导耳机或者蓝牙耳机等数码产品有着深入的了解,近期,有很多粉丝,或者身边的朋友经常向我咨询关于骨传导耳机的问题。确实如此,优质的骨传导耳机能在保护听力、保持环…

HKT DICT解决方案,为您量身打造全方位的一站式信息管理服务

随着大数据时代的到来,企业对现代化管理、数据整合与呈现的解决方案需求不断增长。为满足更多企业客户的多元化信息管理发展需求,香港电讯(HKT)强势推出全面、高效、安全、可靠的一站式DICT(Digital Information and C…

【Python系列】深入解析 Python 中的 JSON 处理工具

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

IDEA常用技巧荟萃:精通开发利器的艺术

1 概述 在现代软件开发的快节奏环境中,掌握一款高效且功能全面的集成开发环境(IDE)是提升个人和团队生产力的关键。IntelliJ IDEA,作为Java开发者的首选工具之一,不仅提供了丰富的编码辅助功能,还拥有高度…

【NLP学习笔记】transformers中的tokenizer切词时是否返回token_type_ids

结论 先说结论: 是否返回token_type_ids,可以在切词时通过 return_token_type_idsTrue/False指定,指定了True就肯定会返回,指定False,不一定就不返回。 分析 Doc地址 https://huggingface.co/docs/transformers/main…

【电脑应用技巧】如何寻找电脑应用的安装包华为电脑、平板和手机资源交换共享

电脑的初学者可能会直接用【百度】搜索电脑应用程序的安装包,但是这样找到的电脑应用程序安装包经常会被加入木马或者强制捆绑一些不需要的应用装入电脑。 今天告诉大家一个得到干净电脑应用程序安装包的方法,就是用【联想的应用商店】。联想电脑我是一点…

看到指针就头疼?这篇文章让你对指针有更全面的了解!

文章目录 1.什么是指针2.指针和指针类型2.1 指针-整数2.2 指针的解引用 3.野指针3.1为什么会有野指针3.2 如何规避野指针 4.指针运算4.1 指针-整数4.2 指针减指针4.3 指针的关系运算 5.指针与数组6.二级指针7.指针数组 1.什么是指针 指针的两个要点 1.指针是内存中的一个最小单…

智能雷达AI小程序源码系统 销售名片+企业商城+公司动态 带完整的安装代码包以及搭建教程

系统概述 智能雷达AI小程序源码系统是基于先进的AI技术和小程序框架开发的全能型企业级应用。它不仅整合了个人销售名片的便捷分享,还融入了功能丰富的企业商城和实时更新的公司动态展示,实现了从品牌形象塑造到产品销售,再到客户关系维护的…

TransIT-VirusGEN® Transfection Reagent

Mirus转染试剂TransIT-VirusGEN Transfection Reagent,该产品旨在增强载体转染到 贴壁或悬浮的HEK 293细胞的转染效率,并增加重组腺相关病毒或慢病毒的产量。 使用TransIT-VirusGEN转染试剂转染悬浮或贴壁HEK293细胞可获得最高的转染效率。使用不同的转…

【Flask从入门到精通:第一课:flask的基本介绍、flask快速搭建项目并运行】

从0开始入手到上手一个新的框架,应该怎么展开?flask这种轻量级的框架与django这种的重量级框架的区别?针对web开发过程中,常见的数据库ORM的操作。跟着学习flask的过程中,自己去学习和了解一个新的框架(San…