VM-UNet: 基于纯 Mamba 架构的医学图像分割模型
论文地址:https://arxiv.org/abs/2402.02491
项目地址:https://github.com/JCruan519/VM-UNet
Abstract
在医学图像分割领域,基于CNN和基于Transformer的模型都得到了广泛的探索。然而,CNN在远程建模能力方面表现出局限性,而Transformer则受到二次计算复杂性的阻碍。最近,以Mamba为例的状态空间模型(SSM)作为一种很有前途的方法出现了。它们不仅在远程相互作用建模方面表现优异,而且保持了线性计算复杂度。本文利用状态空间模型,提出了一种用于医学图像分割的U-shaped架构模型,命名为视觉Mamba UNet (VM-UNet)。具体来说,引入了视觉状态空间(VSS)块作为基础块来捕获广泛的上下文信息,并构造了一个非对称的编码器-解码器结构。我们在ISIC17, ISIC18和Synapse数据集上进行了全面的实验,结果表明VM-UNet在医学图像分割任务中具有竞争力。据我们所知,这是第一个基于纯SSM模型构建的医学图像分割模型。我们的目标是建立一个基线,并为未来更高效和有效的基于SSM的细分系统的发展提供有价值的见解。
1 Introduction
自动医学图像分割技术帮助医生更快地进行病理诊断,从而提高患者护理的效率。近年来,基于CNN的模型和基于transformer的模型在各种视觉任务中表现出了显著的性能,特别是在医学图像分割方面。UNet[27]作为基于CNN的模型的代表,以结构简单、可扩展性强而闻名,后续的许多改进都是基于这种U型架构[11,37,28,29,30] 。TransUnet[10]是基于Transformer的模型中的先驱,它在编码阶段首先使用Vision Transformer (ViT)[13]进行特征提取,在解码阶段使用CNN,显示出重要的全局信息获取能力。随后,TransFuse[36]采用了ViT和CNN的并行架构,同时捕获局部和全局特征。此外,Swin-Unet[9]将Swin Transformer[21]与u型架构相结合,首次引入了纯基于Transformer的U型模型。
然而,基于CNN的模型和基于Transformer的模型都有固有的局限性。基于CNN的模型受到其局部接受域的限制,极大地阻碍了它们捕捉远程信息的能力。这通常会导致提取不充分的特征,从而导致次优分割结果。尽管基于Transformer的模型在全局建模方面表现出优异的性能,但自注意机制在图像大小方面要求二次复杂度,导致计算负担很高[31,13],特别是对于需要密集预测的任务,如医学图像分割。目前这些模型的缺点迫使我们开发一种新的医学图像分割架构,能够捕获强远程信息并保持线性计算复杂度。
近年来,状态空间模型(SSM)引起了研究人员的极大兴趣。在经典SSM[18]研究的基础上,现代SSM(如Mamba[16])不仅建立了长距离依赖关系,而且在输入大小方面表现出线性复杂性。此外,基于SSM的模型在许多领域都得到了大量的研究,包括语言理解[17,16]、通用视觉[38,20]等。特别是,U-Mamba[24]最近引入了一种新的SSM-CNN混合模型,这标志着它在医学图像分割任务中的首次应用。SegMamba[35]在编码器部分采用SSM,而在解码器部分仍然使用CNN,提出了一种SSM-CNN混合模型用于三维脑肿瘤分割任务。虽然上述工作已经将SSM用于医学图像分割任务,但纯粹基于SSM的模型的性能还有待探索。
受VMamba[20]在图像分类任务中取得成功的影响,本文首次引入了Vision Mamba UNet (VM-UNet),这是一种纯粹基于SSM的模型,旨在展示其在医学图像分割任务中的潜力。具体来说,VM-UNet由三个主要部分组成:编码器、解码器和跳跃连接。编码器由VMamba的VSS块组成,用于特征提取,以及用于下采样的patch merging 操作。相反,解码器包括VSS块和patch expanding操作,以恢复分割结果的大小。对于跳跃连接模块,为了突出最原始的纯SSM模型的分割性能,我们采用了最简单的加法运算形式。
在器官分割和皮肤病变分割任务上进行了全面的实验,以证明纯SSM模型在医学图像分割中的潜力。具体来说,我们在Synapse[19]、ISIC17[8]和ISIC18[12]数据集上进行了大量的实验,结果表明VM-UNet可以达到有竞争力的性能。此外,重要的是要注意VM-UNet代表了纯基于SSM的分段模型的最基本形式,因为它不包括任何专门设计的模块。
本文的主要贡献如下:
- 提出了VM-UNet,首次探索了纯粹基于SSM的模型在医学图像分割中的潜在应用。
- 在三个数据集上进行了综合实验,结果表明VM-UNet具有相当的竞争力。
- 我们为纯SSM模型在医学图像分割任务中建立了基线,为开发更高效、更有效的基于SSM的分割方法提供了有价值的见解。
2 Preliminaries
在现代基于SSM的模型中,即结构化状态空间序列模型(S4)和Mamba都依赖于一个经典的连续系统,该系统通过中间隐式状态 h ( t ) ∈ R N h(t)∈R^N h(t)∈RN将一维输入函数或序列映射为 x ( t ) ∈ R x(t)∈R x(t)∈R到输出 y ( t ) ∈ R y(t)∈R y(t)∈R。上述过程可以表示为线性常微分方程(ODE):
h ′ ( t ) = A h ( t ) + B x ( t ) y ( t ) = C h ( t ) (1) \begin{aligned}h'(t)&=\mathbf{A}h(t)+\mathbf{B}x(t)\\y(t)&=\mathbf{C}h(t)\end{aligned} \tag{1} h′(t)y(t)=Ah(t)+Bx(t)=Ch(t)(1)
其中, A ∈ R N × N A∈R^{N×N} A∈RN×N表示状态矩阵, B ∈ R N × 1 B∈R^{N×1} B∈RN×1, C ∈ R N × 1 C∈R^{N×1} C∈RN×1表示投影参数。
S4和Mamba将这个连续系统离散化,使其更适合深度学习场景。具体来说,他们引入一个时间尺度参数∆,并使用固定的离散化规则将 A \mathbf{A} A和 B \mathbf{B} B转换为离散参数 A ˉ \mathbf{\bar{A}} Aˉ和 B ˉ \mathbf{\bar{B}} Bˉ。通常采用零阶保持器(ZOH)作为离散化规则,其定义如下:
A ‾ = exp ( Δ A ) B ‾ = ( Δ A ) − 1 ( exp ( Δ A ) − I ) ⋅ Δ B (2) \begin{aligned}\overline{\mathbf{A}}&=\exp(\boldsymbol\Delta\mathbf{A})\\\overline{\mathbf{B}}&=(\boldsymbol\Delta\mathbf{A})^{-1}(\exp(\boldsymbol\Delta\mathbf{A})-\mathbf{I})\cdot\boldsymbol\Delta\mathbf{B}\end{aligned} \tag{2} AB=exp(ΔA)=(ΔA)−1(exp(ΔA)−I)⋅ΔB(2)
离散化后,基于SSM的模型可以通过线性递归或全局卷积两种方式进行计算,分别定义为公式3和公式4。
h ′ ( t ) = A ‾ h ( t ) + B ‾ x ( t ) y ( t ) = C h ( t ) K ‾ = ( C B ‾ , C A B ‾ , … , C A ‾ L − 1 B ‾ ) y = x ∗ K ‾ \begin{align} &h^{\prime}(t)={\overline{{\mathbf{A}}}}h(t)+{\overline{{\mathbf{B}}}}x(t) \\ &y(t)=\mathbf{C}h(t) \tag{3}\\ &\overline{K}=(\mathbf{C}\overline{\mathbf{B}},\mathbf{C}\overline{\mathbf{AB}},\ldots,\mathbf{C}\overline{\mathbf{A}}^{L-1}\overline{\mathbf{B}}) \\ &y=x*\overline{\mathbf{K}} \tag{4} \end{align} h′(t)=Ah(t)+Bx(t)y(t)=Ch(t)K=(CB,CAB,…,CAL−1B)y=x∗K(3)(4)
式中, K ‾ ∈ R L \overline{K}∈R^L K∈RL表示一个结构化卷积核,L表示输入序列x的长度。
3 Methods
在本节中,我们首先介绍VM-UNet的总体结构。随后,我们详细阐述了核心组件VSS模块。最后,我们描述了在训练过程中使用的损失函数。
3.1 Vision Mamba UNet (VM-UNet)
如图1 (a)所示,展示了VM-UNet的总体架构。具体来说,VM-UNet包括patch Embedding层、编码器、解码器、最终投影层和跳跃连接。与以往的方法[9]不同,我们没有采用对称结构,而是采用了不对称设计。
Patch Embedding层将输入图像 x ∈ R H × W × 3 x∈R^{H×W×3} x∈RH×W×3划分为大小为4 × 4的不重叠的Patch,随后将图像的维数映射为C, C默认为96。此过程得到嵌入图像 x ′ ∈ R H 4 × W 4 × C x^′∈R^{\frac{H}{4} × \frac{W}{4}×C} x′∈R4H×4W×C。最后,我们使用Layer Normalization对 x ′ x' x′进行归一化[7],然后将其输入编码器进行特征提取。
编码器由四个阶段组成,在前三个阶段的末尾应用 patch merging操作,以降低输入特征的高度和宽度,同时增加通道数量。我们在四个阶段使用[2,2,2,2]个VSS区块,每个阶段的通道计数为[C, 2C, 4C, 8C]。
同样,解码器被组织成四个阶段。在最后三个阶段的开始,采用patch expanding 操作来减少特征通道的数量,增加特征通道的高度和宽度。在四个阶段中,我们使用[2,2,2,1]VSS块,每个阶段的通道计数为[8C, 4C, 2C, C]。
在解码器之后,使用Final Projection层来恢复特征的大小以匹配分割目标。具体来说,通过patch expanding进行4次上采样来恢复特征的高度和宽度,然后通过投影层来恢复通道的数量。
对于跳跃连接,采用直接的加法操作,不需要附加参数,因此不会引入任何额外参数。
3.2 VSS block
VMamaba[20]衍生的VSS块是VM-UNet的核心模块,如图1 (b)所示。经过Layer Normalization后,输入被分成两个分支。在第一个分支中,输入经过一个线性层,然后是一个激活函数。在第二个分支中,输入通过线性层、深度可分离卷积和激活函数进行处理,然后输入到2D选择性扫描(SS2D)模块中进行进一步的特征提取。随后,使用Layer Normalization对特征进行归一化,然后使用第一个分支的输出执行逐元素的生成,以合并两条路径。最后,使用线性层混合特征,并将此结果与残差连接相结合,形成VSS块的输出。本文默认采用SiLU[14]作为激活函数。
SS2D由三个部分组成:扫描expanding操作、S6块操作和扫描merging操作。如图2(a)所示,扫描expanding操作沿着四个不同的方向(左上到右下、左下到右上、右下到左上、右上到左下)将输入图像展开成序列。然后通过S6块对这些序列进行特征提取,确保各个方向的信息被彻底扫描,从而捕获不同的特征。随后,如图2(b)所示,扫描merging操作将来自四个方向的序列相加并合并,将输出图像恢复为与输入相同的大小。源自Mamba[16]的S6块在S4[17]之上引入了一种选择机制,通过根据输入调整SSM的参数。这使模型能够区分并保留相关信息,同时过滤掉不相关的信息。算法1给出了S6块的伪代码。
3.3 损失函数
VM-UNet的引入旨在验证纯SSM模型在医学图像分割任务中的应用潜力。因此,我们专门利用最基本的二元交叉熵和骰子损失(BceDice loss)和交叉熵和骰子损失(CeDice loss)分别作为二元和多类分割任务的损失函数,如公式5和6所示。
L B c e D i c e = λ 1 L B c e + λ 2 L D i c e L C e D i c e = λ 1 L C e + λ 2 L D i c e { L B c e = − 1 N ∑ i = 1 N [ y i log ( y ^ i ) + ( 1 − y i ) log ( 1 − y ^ i ) ] L C e = − 1 N ∑ i = 1 N ∑ c = 1 C y i , c log ( y ^ i , c ) L D i c e = 1 − 2 ∣ X ∩ Y ∣ ∣ X ∣ + ∣ Y ∣ \begin{align} &L_{\mathrm{BceDice}}=\lambda_1L_{\mathrm{Bce}}+\lambda_2L_{\mathrm{Dice}}\tag{5}\\\\ &L_{\mathrm{CeDice}}=\lambda_1L_{\mathrm{Ce}}+\lambda_2L_{\mathrm{Dice}}\tag{6}\\\\ &\left.\left\{\begin{array}{l}L_{\mathrm{Bce}}=-\frac{1}{N}\sum\limits_{i=1}^N\left[y_i\log(\hat{y}_i)+(1-y_i)\log(1-\hat{y}_i)\right]\\L_{\mathrm{Ce}}=-\frac{1}{N}\sum\limits_{i=1}^N\sum\limits_{c=1}^Cy_{i,c}\log(\hat{y}_{i,c})\\L_{\mathrm{Dice}}=1-\frac{2|X\cap Y|}{|X|+|Y|}\end{array}\right.\right. \tag{7} \end{align} LBceDice=λ1LBce+λ2LDiceLCeDice=λ1LCe+λ2LDice⎩ ⎨ ⎧LBce=−N1i=1∑N[yilog(y^i)+(1−yi)log(1−y^i)]LCe=−N1i=1∑Nc=1∑Cyi,clog(y^i,c)LDice=1−∣X∣+∣Y∣2∣X∩Y∣(5)(6)(7)
式中,N为样本总数,C为类别总数。 Y i , y i Y_i, y_i Yi,yi分别表示真实标签和预测。 Y i , c Y_{i,c} Yi,c是一个指标,如果样本I属于c类,则等于1,否则等于0。 Y i , c Y_{i,c} Yi,c为模型预测样本i属于类别c的概率。|X|和|Y|分别代表真实值和预测值。 λ 1 λ_1 λ1, λ 2 λ_2 λ2为损失函数的权值,默认值均为1。
4 Experiments
在本节中,我们对VM-UNet进行皮肤病变和器官分割任务的综合实验。具体来说,我们在ISIC17、ISIC18和Synapse数据集上评估了VM-UNet在医学图像分割任务上的性能。
4.1 数据集
ISIC17和ISIC18数据集:国际皮肤成像协作2017年和2018年挑战数据集(ISIC17和ISIC18)[8,1,12,2]是两个公开可用的皮肤病变分割数据集,分别包含2,150和2,694张带有分割面具标签的皮肤镜图像。根据之前的工作[28],我们将数据集以7:3的比例分割作为训练集和测试集。具体来说,对于ISIC17数据集,训练集由1500张图像组成,测试集由650张图像组成。对于ISIC18数据集,训练集包含1886张图像,而测试集包含808张图像。对于这两个数据集,我们提供了几个指标的详细评估,包括平均交联(mIoU),骰子相似系数(DSC),准确性(Acc),灵敏度(Sen)和特异性(Spe)。
Synapse多器官分割数据集(Synapse):Synapse[19,3]是一个公开的多器官分割数据集,包含30例腹部CT病例3,779张轴向腹部临床CT图像,包括8种腹部器官(主动脉、胆囊、左肾、右肾、肝脏、胰腺、脾脏、胃)。按照之前作品[10,9]的设置,训练用18个案例,测试用12个案例。对于这个数据集,我们报告了骰子相似系数(DSC)和95%豪斯多夫距离(HD95)作为评估指标。
4.2 实现细节
根据之前的工作[28,9],我们将ISIC17和ISIC18数据集中的图像调整为256×256,将Synapse数据集中的图像调整为224×224。为了防止过拟合,采用了随机翻转和随机旋转等数据增强技术。ISIC17和ISIC18数据集采用BceDice损失函数,Synapse数据集采用CeDice损失函数。我们将批大小设置为32,并使用AdamW[23]优化器,初始学习率为1e-3。使用CosineAnnealingLR[22]作为调度程序,最大迭代次数为50次,最小学习率为1e-5。训练周期设置为300。对于VM-UNet,我们使用在ImageNet-1k上预训练的VMamba-S[20]的权重初始化编码器和解码器的权重。所有实验均在单个NVIDIA RTX A6000 GPU上进行。
4.3 主要结果
我们将VM-UNet与一些最先进的模型进行比较,给出了表1和表2中的实验结果。对于ISIC17和ISIC18数据集,我们的VM-UNet在mIoU, DSC和Acc指标方面优于其他模型。对于Synapse数据集,VM-UNet也取得了具有竞争力的性能。例如,我们的模型超过Swin-Unet,这是第一个纯基于Transformer的模型,在DSC和HD95指标上分别高出1.95%和2.34mm。结果证明了基于SSM的模型在医学图像分割任务中的优越性。
4.4 消融研究
在本节中,我们使用ISIC17和ISIC18数据集对VMUNet的初始化进行消融实验。我们分别用VMamba-T和VMamba-s预训练的权值初始化VM-UNet。实验结果如表3所示,预训练的权值越强,VM-UNet的下游性能就越好,说明预训练的权值对VM-UNet的影响很大。
5 Conclusions and Future works
结论:在本文中,我们首次以VM-UNet为基线,引入了一个纯粹基于SSM的医学图像分割模型。为了利用基于SSM的模型的功能,我们使用VSS块构建VM-UNet,并使用预训练的VMamba-S初始化其权重。在皮肤病变和多器官分割数据集上进行的综合实验表明,纯SSM模型在医学图像分割任务中具有很强的竞争力,值得未来深入探索。
未来的工作:1)基于SSM的机制,设计更适合分割任务的模块。2) VM-UNet的参数数约为30M,为通过手工设计或其他压缩策略精简SSM提供了机会,从而增强了SSM在现实医疗场景中的适用性。3)考虑到SSM在捕获长序列信息方面的优势,进一步研究SSM在更高分辨率下的分割性能是有价值的。4)探索SSM在其他医学成像任务中的应用,如检测、配准、重建等。