烹饪第一个U-Net进行图像分割

今天我们将学习如何准备计算机视觉中最重要的网络之一:U-Net。如果你没有代码和数据集也没关系,可以分别通过下面两个链接进行访问:

代码:

https://www.kaggle.com/datasets/mateuszbuda/lgg-mri-segmentation?source=post_page-----e812e37e9cd0--------------------------------

Kaggle的MRI分割数据集:

https://www.kaggle.com/datasets/mateuszbuda/lgg-mri-segmentation?source=post_page-----e812e37e9cd0--------------------------------

主要步骤:

1. 数据集的探索

2. 数据集和Dataloader类的创建

3. 架构的创建

4. 检查损失(DICE和二元交叉熵)

5. 结果

数据集的探索

我们得到了一组(255 x 255)的MRI扫描的2D图像,以及它们相应的必须将每个像素分类为0(健康)或1(肿瘤)。

这里有一些例子:

8f4f6fa5dd3eee62cf8537f20a64c9ca.jpeg

第一行:肿瘤,第二行:健康主题

数据集和Dataloader类

这是涉及神经网络的每个项目中都会找到的一步。

数据集类

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoaderclass BrainMriDataset(Dataset):def __init__(self, df, transforms):# df contains the paths to all filesself.df = df# transforms is the set of data augmentation operations we useself.transforms = transformsdef __len__(self):return len(self.df)def __getitem__(self, idx):image = cv2.imread(self.df.iloc[idx, 1])mask = cv2.imread(self.df.iloc[idx, 2], 0)augmented = self.transforms(image=image, mask=mask)image = augmented['image'] # Dimension (3, 255, 255)mask = augmented['mask']   # Dimension (255, 255)# We notice that the image has one more dimension (3 color channels), so we have to one one "artificial" dimension to the mask to match itmask = np.expand_dims(mask, axis=0) # Dimension (1, 255, 255)return image, mask

数据加载器

既然我们已经创建了Dataset类来重新整形张量,我们首先需要定义训练集(用于训练模型),验证集(用于监控训练并避免过拟合),以及测试集,最终评估模型在未见数据上的性能。

# Split df into train_df and val_df
train_df, val_df = train_test_split(df, stratify=df.diagnosis, test_size=0.1)
train_df = train_df.reset_index(drop=True)
val_df = val_df.reset_index(drop=True)# Split train_df into train_df and test_df
train_df, test_df = train_test_split(train_df, stratify=train_df.diagnosis, test_size=0.15)
train_df = train_df.reset_index(drop=True)train_dataset = BrainMriDataset(train_df, transforms=transforms)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)val_dataset = BrainMriDataset(val_df, transforms=transforms)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False)test_dataset = BrainMriDataset(test_df, transforms=transforms)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

U-Net架构

e528761129f5514c2493d1871166da60.jpeg

U-Net架构是用于图像分割任务的强大模型,是卷积神经网络(CNN)的一种类型,其名称来自其U形状的结构。U-Net最初由Olaf Ronneberger等人在2015年的论文中首次开发,标题为“U-Net:用于生物医学图像分割的卷积网络”。

其结构涉及编码(降采样)路径和解码(上采样)路径。U-Net至今仍然是一个非常成功的模型,其成功来自两个主要因素:

1. 对称结构(U形状)

2. 前向连接(图片上的灰色箭头)

前向连接的主要思想是,随着我们在层中越来越深入,我们会失去有关原始图像的一些信息。然而,我们的任务是对图像进行分割,我们需要精确的图像来对每个像素进行分类。这就是为什么我们在对称解码器层的每一层中重新注入图像的原因。以下是通过Pytorch实现的代码:

train_dataset = BrainMriDataset(train_df, transforms=transforms)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)val_dataset = BrainMriDataset(val_df, transforms=transforms)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False)test_dataset = BrainMriDataset(test_df, transforms=transforms)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)class UNet(nn.Module):def __init__(self):super().__init__()# Define convolutional layers# These are used in the "down" path of the U-Net,# where the image is successively downsampledself.conv_down1 = double_conv(3, 64)self.conv_down2 = double_conv(64, 128)self.conv_down3 = double_conv(128, 256)self.conv_down4 = double_conv(256, 512)# Define max pooling layer for downsamplingself.maxpool = nn.MaxPool2d(2)# Define upsampling layerself.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)# Define convolutional layers# These are used in the "up" path of the U-Net,# where the image is successively upsampledself.conv_up3 = double_conv(256 + 512, 256)self.conv_up2 = double_conv(128 + 256, 128)self.conv_up1 = double_conv(128 + 64, 64)# Define final convolution to output correct number of classes# 1 because there are only two classes (tumor or not tumor)self.last_conv = nn.Conv2d(64, 1, kernel_size=1)def forward(self, x):# Forward pass through the network# Down pathconv1 = self.conv_down1(x)x = self.maxpool(conv1)conv2 = self.conv_down2(x)x = self.maxpool(conv2)conv3 = self.conv_down3(x)x = self.maxpool(conv3)x = self.conv_down4(x)# Up pathx = self.upsample(x)x = torch.cat([x, conv3], dim=1)x = self.conv_up3(x)x = self.upsample(x)x = torch.cat([x, conv2], dim=1)x = self.conv_up2(x)x = self.upsample(x)x = torch.cat([x, conv1], dim=1)x = self.conv_up1(x)# Final outputout = self.last_conv(x)out = torch.sigmoid(out)return out

损失和评估标准

与每个神经网络一样,都有一个目标函数,一种损失,我们通过梯度下降最小化它。我们还引入了评估标准,它帮助我们训练模型(如果它在连续的3个时期中没有改善,那么我们停止训练,因为模型正在过拟合)。从这一段中有两个主要要点:

1. 损失函数是两个损失函数的组合(DICE损失,二元交叉熵)

2. 评估函数是DICE分数,不要与DICE损失混淆

DICE损失:

c5b7a1b8fafaa01bdbfe5f549725a7ce.jpeg

DICE损失

备注:我们添加了一个平滑参数(epsilon)以避免除以零。

二元交叉熵损失:

9a4538b4bda0db62bfa673b13131def8.jpeg

BCE

于是,我们的总损失是:

9da84b914878f2bd0934f20ef2aae91c.jpeg

让我们一起实现它:

def dice_coef_loss(inputs, target):smooth = 1.0intersection = 2.0 * ((target * inputs).sum()) + smoothunion = target.sum() + inputs.sum() + smoothreturn 1 - (intersection / union)def bce_dice_loss(inputs, target):inputs = inputs.float()target = target.float()dicescore = dice_coef_loss(inputs, target)bcescore = nn.BCELoss()bceloss = bcescore(inputs, target)return bceloss + dicescore

评估标准(Dice系数):

我们使用的评估函数是DICE分数。它在0到1之间,1是最好的。

0be7e283b0e50c64a2fd1a935d03b85a.jpeg

Dice分数的图示

其数学实现如下:

3f035ea9fdf85f0e09abd6293c5bd257.jpeg

def dice_coef_metric(inputs, target):intersection = 2.0 * (target * inputs).sum()union = target.sum() + inputs.sum()if target.sum() == 0 and inputs.sum() == 0:return 1.0return intersection / union

训练循环

def train_model(model_name, model, train_loader, val_loader, train_loss, optimizer, lr_scheduler, num_epochs):  print(model_name)loss_history = []train_history = []val_history = []for epoch in range(num_epochs):model.train()  # Enter train mode# We store the training loss and dice scoreslosses = []train_iou = []if lr_scheduler:warmup_factor = 1.0 / 100warmup_iters = min(100, len(train_loader) - 1)lr_scheduler = warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor)# Add tqdm to the loop (to visualize progress)for i_step, (data, target) in enumerate(tqdm(train_loader, desc=f"Training epoch {epoch+1}/{num_epochs}")):data = data.to(device)target = target.to(device)outputs = model(data)out_cut = np.copy(outputs.data.cpu().numpy())# If the score is less than a threshold (0.5), the prediction is 0, otherwise its 1out_cut[np.nonzero(out_cut < 0.5)] = 0.0out_cut[np.nonzero(out_cut >= 0.5)] = 1.0train_dice = dice_coef_metric(out_cut, target.data.cpu().numpy())loss = train_loss(outputs, target)losses.append(loss.item())train_iou.append(train_dice)# Reset the gradientsoptimizer.zero_grad()# Perform backpropagation to compute gradientsloss.backward()# Update the parameters with the computed gradientsoptimizer.step()if lr_scheduler:lr_scheduler.step()val_mean_iou = compute_iou(model, val_loader)loss_history.append(np.array(losses).mean())train_history.append(np.array(train_iou).mean())val_history.append(val_mean_iou)print("Epoch [%d]" % (epoch))print("Mean loss on train:", np.array(losses).mean(), "\nMean DICE on train:", np.array(train_iou).mean(), "\nMean DICE on validation:", val_mean_iou)return loss_history, train_history, val_history

结果

让我们在一个带有肿瘤的主题上评估我们的模型:

326c72e0a3dc7e3cf4b716108a4a58e6.jpeg

结果看起来相当不错!我们可以看到模型明显学到了关于图像结构的一些有用信息。然而,它可能可以更好地细化分割,这可以通过我们将很快讨论的更先进的技术来实现。U-Net至今仍然广泛使用,但有一个著名的模型达到了最先进的性能,称为nn-UNet。

·  END  ·

HAPPY LIFE

0ce980e3d287c8e1d8cf10ae1f388a0f.png

本文仅供学习交流使用,如有侵权请联系作者删除

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://xiahunao.cn/news/2780221.html

如若内容造成侵权/违法违规/事实不符,请联系瞎胡闹网进行投诉反馈,一经查实,立即删除!

相关文章

知识价值2-什么是IDE?新手用哪个IDE比较好?

IDE是集成开发环境&#xff08;Integrated Development Environment&#xff09;的缩写&#xff0c;是一种软件应用程序&#xff0c;旨在提供集成的工具集&#xff0c;以方便开发人员进行软件开发。IDE通常包括代码编辑器、编译器、调试器和其他工具&#xff0c;以支持软件开发…

使用R语言fifer包进行分层采样

使用R语言fifer包中的stratified()函数用来进行分层采样非常方便&#xff0c;但fifer包已经从CRAN存储库中删除&#xff0c;需要从存档中下载可用的历史版本&#xff0c;下载链接&#xff1a;Index of /src/contrib/Archive/fifer (r-project.org)https://cran.r-project.org/s…

浅谈路由器交换结构

一、路由器技术概述 路由器&#xff08;Router&#xff09;是连接两个或多个网络的硬件设备&#xff0c;在网络间起网关的作用&#xff0c;是读取每一个数据包中的地址然后决定如何传送的专用智能性的网络设备。它能够理解不同的协议&#xff0c;例如某个局域网使用的以太网协议…

【算法】排序详解(快速排序,堆排序,归并排序,插入排序,希尔排序,选择排序,冒泡排序)

目录 排序的概念&#xff1a; 排序算法的实现&#xff1a; 插入排序&#xff1a; 希尔排序&#xff1a; 选择排序&#xff1a; 堆排序&#xff1a; 冒泡排序&#xff1a; 快速排序&#xff1a; 快速排序的基本框架&#xff1a; 1.Hoare法 2. 挖坑法 3.前后指针法 快…

口腔助手|口腔挂号预约小程序|基于微信小程序的口腔门诊预约系统的设计与实现(源码+数据库+文档)

口腔小程序目录 目录 基于微信小程序的口腔门诊预约系统的设计与实现 一、前言 二、系统功能设计 三、系统实现 1、小程序前台界面实现 2、后台管理员模块实现 四、数据库设计 1、实体ER图 2、具体的表设计如下所示&#xff1a; 五、核心代码 六、论文参考 七、最新…

ASP.NET Core Web API 流式返回,实现ChatGPT逐字显示

&#x1f3c6;作者&#xff1a;科技、互联网行业优质创作者 &#x1f3c6;专注领域&#xff1a;.Net技术、软件架构、人工智能、数字化转型、DeveloperSharp、微服务、工业互联网、智能制造 &#x1f3c6;欢迎关注我&#xff08;Net数字智慧化基地&#xff09;&#xff0c;里面…

关于创建vue项目报错command failed: npm install --loglevel error

一、首先 在这个目录下有个文件叫.vuerc 二、其次 进去之后把里面的"useTaobaoRegistry": false,修改下&#xff0c;我之前是true&#xff0c;后来改成了false才成功。

Linux下的多用户管理和认证:从入门到精通(附实例)

Linux操作系统以其强大的多用户管理和认证机制而著称。这种机制不仅允许多个用户同时登录并执行各种任务&#xff0c;还能确保每个用户的数据安全和隐私。本文将通过一系列实例&#xff0c;带你逐步掌握Linux下的多用户管理和认证。 一、Linux多用户管理的基础知识 在Linux中&…

EasyCaptcha,开源图形验证码新标杆!

引言&#xff1a; 随着互联网的普及&#xff0c;验证码已成为网站和应用程序中不可或缺的安全组件。它能够有效地防止自动化攻击、垃圾邮件和机器人活动。在众多验证码解决方案中&#xff0c;Easy-captcha以其简单易用和高度可定制的特点受到了开发者的青睐。本文将指导读者如…

推荐系统|召回05_矩阵补充、最近邻查找

文章目录 矩阵补充Matrix Completion模型结构模型训练模型存储 矩阵补充Matrix Completion 模型结构 通过用户ID和物品ID分别找到对应的向量&#xff0c;然后去做内积&#xff0c;内积的数值可以去衡量匹配的程度。 不共享参数的意思是指用户ID和物品ID使用不同的Embedding L…

【计算几何】给定一组点的多边形面积

目录 一、说明二、有序顶点集三、无序顶点集3.1 凸多边形3.2 非凸多边形 四、结论 ​ 一、说明 计算多边形面积的方法有很多种。众所周知的多边形&#xff08;如三角形、矩形、正方形、梯形等&#xff09;的面积可以使用简单的数学公式计算。在这篇文章中&#xff0c;我将讨论…

《UE5_C++多人TPS完整教程》学习笔记2 ——《P3 多人游戏概念(Multiplayer Concept)》

本文为B站系列教学视频 《UE5_C多人TPS完整教程》 —— 《P3 多人游戏概念&#xff08;Multiplayer Concept&#xff09;》 的学习笔记&#xff0c;该系列教学视频为 Udemy 课程 《Unreal Engine 5 C Multiplayer Shooter》 的中文字幕翻译版&#xff0c;UP主&#xff08;也是译…

图灵日记--MapSet字符串常量池反射枚举Lambda表达式泛型

目录 搜索树概念实现性能分析和 java 类集的关系 搜索概念及场景模型 Map的使用Map常用方法 Set的说明常见方法说明 哈希表冲突-避免-负载因子调节冲突-解决-闭散列冲突-解决-开散列/哈希桶冲突严重时的解决办法 实现和 java 类集的关系 字符串常量池String对象创建intern方法 …

SpringCloud-Eureka服务注册中心测试实践

5. Eureka服务注册中心 5.1 什么是Eureka Netflix在涉及Eureka时&#xff0c;遵循的就是API原则.Eureka是Netflix的有个子模块&#xff0c;也是核心模块之一。Eureka是基于REST的服务&#xff0c;用于定位服务&#xff0c;以实现云端中间件层服务发现和故障转移&#xff0c;服…

Junit5基础教程

文章目录 一&#xff0c;导入依赖二&#xff0c;基本功能一、常用断言二、执行顺序和常用注解1、通过BeforeAll类的注解来保证顺序2、通过order注解来保证执行顺序 三、依赖测试四、参数化测试五、测试套件SelectPackages、IncludePackages、SelectClasses、IncludeTags等注解的…

C语言printf函数详解..

1.printf函数解析 前面我们有讲过printf函数的格式为&#xff1a; printf(“占位1 占位2 占位3……”, 替代1, 替代2, 替代3……); 今天我们进一步深入的解析一下这个函数 2.printf函数的特点 1.printf函数是一个变参函数(即参数的数量和类型都不确定) 2.printf函数的第一个…

【MySQL】——数值函数的学习

&#x1f308;个人主页: Aileen_0v0 &#x1f525;热门专栏: 华为鸿蒙系统学习|计算机网络|数据结构与算法 ​&#x1f4ab;个人格言:“没有罗马,那就自己创造罗马~” #mermaid-svg-Z1fAnfrxGD7I5gqp {font-family:"trebuchet ms",verdana,arial,sans-serif;font-siz…

bugku 1

Flask_FileUpload 文件上传 先随便传个一句话木马 看看回显 果然不符合规定 而且发现改成图片什么的都不行 查看页面源代码&#xff0c;发现提示 那应该就要用python命令才行 试试ls 类型要改成图片 cat /flag 好像需要密码 bp爆破 根据提示&#xff0c;我们先抓包 爆破 …

ChatGPT高效提问—prompt常见用法(续篇九)

ChatGPT高效提问—prompt常见用法(续篇九) ​ 如何准确地向大型语言模型提出问题,使其更好地理解我们的意图,从而得到期望的答案呢?编写有效的prompt的技巧,精心设计的prompt,获得期望的的答案。 1.1 增加条件 ​ 在各种prompt技巧中,增加条件是最常用的。在prompt中…

MOMENTUM: 1

攻击机 192.168.223.128 目标机 192.168.223.146 主机发现 nmap -sP 192.168.223.0/24 端口扫描 nmap -sV -p- -A 192.168.223.146 开启了22 80端口 看一下web界面 随便打开看看 发现这里有个参数id&#xff0c;sql尝试无果&#xff0c;发现写入什么&#xff0c;网页显示…