Mamba的发展和在医学图像领域的应用综述

这篇文章上次修改于 8 个月前,可能部分内容已经不适用,如有疑问可询问作者。

Mamba的发展和在医学图像领域的应用综述

Introduction

Recently, State Space Models (SSMs), exemplified by Mamba, have emerged as a promising approach.

They not only excel in modeling long-range interactions but also maintain a linear computational complexity(compare with quadratic computational complexity in Transformer).

Mamba, a recent selective structured state space model, excels in long sequence modeling, which is vital in the large model era.

Since January 2024, Mamba has been actively applied to diverse computer vision tasks, yielding numerous contributions.

Reference: https://export.arxiv.org/pdf/2404.18861

Related Work

CV:

CNN:

CNNs have linear computational complexity, while they have restricted receptive field. This restriction limits their capability to capture larger spatial contexts, which is essential for comprehensively understanding scenes or complex spatial relations in tasks that demand global information.

ViTs:

Vision Transformers, utilize a self-attention to process of image patches, have demonstrated remarkable modeling capabilities across various visual tasks. However, the self-attention mechanism involves a quadratic computational cost to the number of patches.

SSM:

The state space model is a concept that is widely adopted in various disciplines. Its core idea is connecting the input and output sequences using a latent state.

It takes different forms in different disciplines, such as Markov decision process in reinforcement learning (Hafner et al, 2020), dynamic causal modeling in computational neuroscience (Friston et al, 2003) and Kalman filters in controls (Kalman, 1960).

Recently, the state space model (SSM) has been introduced to deep learning for sequence modeling and its parameters or mappings are learned by gradient descent (Gu et al, 2021). (Combining recurrent, convolutional, and continuous-time models with linear state space layers)

development:

SSM needs extensive computational and memory requirement >> (reparameterizing the state matrices) >> constant sequence transitions restrict their context-based reasoning ability >> (integrating a selection mechanism) >> selectively propagate or forget information along the sequence or scan path based on the current token.

Furthermore, to efficiently compute these selective SSMs, the authors develop a hardware-aware algorithm. Subsequently, the authors integrate these selective SSMs into a simplified neural network architecture, termed Mamba.


Mamba

Before the nature of Mamba, I may propose these blogs:

I most prefer: 一文通透想颠覆Transformer的Mamba:从SSM、HiPPO、S4到Mamba

Video: Mamba and S4 Explained: Architecture, Parallel Scan, Kernel Fusion, Recurrent, Convolution, Math

力压Transformer算法?首篇Mamba综述来了!

Mamba详细介绍和RNN、Transformer的架构可视化对比

Formulation

SSM (S4)

one of the most classical SSM transformation is S4: which is mathematically formulated as linear ordinary differential equations (ODEs):

where A ∈ RN×N is the evolution parameter and B ∈ RN×1, C ∈ R1×N are the projection parameters of neural networks in deep learning.

The term h′(t) denotes the derivative of h(t) with respect to time t. Input signal x(t) ∈ R to a 1D output signal y(t) ∈ R through an N-D latent state h(t) ∈ RN .

(2) represented as the discrete (1). After discretizing A, B to A, B, the Eq. (1) can be reformulated as (3).

In cv field, we use convolution kernel to capture the feature, also SSM can be present as convolutional form like (4):

for each x, we can use a convolution kernel to parallel compute them.

x can be refered as input signal
A can be refered as stored historical imformation
B can be refered as handle the input
C can be refered as transformimg station to output

As the figure shows above, easily speaking, D is the skip-connection. The systerm is like a linear time-invariant(LTI) signal feedback system, which is like RNNs. A reserves some hidden information, then how can we generate matrix A that store long sequence memories? ——the ans is HiPPO: HiPPO:Recurrent Memory with Optimal Polynomial Projections.

Selective SSM (S6)

paper: [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](Mamba: Linear-Time Sequence Modeling with Selective State Spaces)

the Eq above as we seen, remain invariant with respect to the input or temporal dynamics, which may confine the long sequence problem. So Mamba introduce a selection mechanism:

Here matrix A remains the same because the state itself is expected to remain static, but the way it is affected (via B and C) is dynamic. Modifying B and C to be selective allows finer-grained control over whether to let an input 𝑥𝑡 into the state ℎ𝑡 or the state into the output 𝑦𝑡. But it leads a new question: we cannot train these dynamic matrixs parallelly.

In fact, we can segmented calculation sequences with selective scan algorithm. This procress may consider the SRAM and HBM in GPU. Keeping the h in SRAM(small memory but high speed), while A, B and C in DRAM.

Backbone_Network

Reproduction

Reference:Mamba神经网络架构~从0构建

I wrote a S4 frame mamba, different from the reference above, but may have some other studying value:

PYTHON
import torch
import torch.nn as nn
import math

class MambaBlock(nn.Module):
    def __init__(self, d_model, d_state, d_conv=None, use_conv=True):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.use_conv = use_conv

        # 投影层
        self.in_proj = nn.Linear(d_model, 2 * d_model + d_state)
        self.out_proj = nn.Linear(d_model, d_model)

        # 卷积层(可选)
        if use_conv and d_conv is not None:
            self.conv = nn.Conv1d(d_model, d_model, kernel_size=d_conv, padding=d_conv-1, groups=d_model)
        else:
            self.conv = None

        # S4D核心参数
        self.A = nn.Parameter(torch.randn(d_state, d_state))
        self.B = nn.Parameter(torch.randn(d_state, d_model))
        self.C = nn.Parameter(torch.randn(d_model, d_state))
        self.D = nn.Parameter(torch.randn(d_model))

    def forward(self, x):
        # x shape: (batch_size, seq_len, d_model)
        B, L, D = x.shape
        
        # 投影
        x_proj = self.in_proj(x)
        x, delta, B = x_proj.split([self.d_model, self.d_model, self.d_state], dim=-1)

        # 卷积(如果使用)
        if self.conv is not None:
            x = x.transpose(1, 2)
            x = self.conv(x)[:, :, :L]
            x = x.transpose(1, 2)

        # S4D核心计算
        dA = torch.diag(delta.squeeze(-1))
        dB = B
        dC = self.C
        y = torch.zeros(B, L, D, device=x.device)
        h = torch.zeros(B, self.d_state, device=x.device)
        for t in range(L):
            u = x[:, t]
            h = h + u @ dB.T
            y[:, t] = h @ dC.T + u * self.D
            h = h @ torch.matrix_exp(dA)

        # 输出投影
        return self.out_proj(y)

class MambaModel(nn.Module):
    def __init__(self, d_model, d_state, d_conv=None, num_layers=1, use_conv=True):
        super().__init__()
        self.layers = nn.ModuleList([MambaBlock(d_model, d_state, d_conv, use_conv) for _ in range(num_layers)])

    def forward(self, x):
        for layer in self.layers:
            x = x + layer(x)
        return x
Copy

Challenges

Mamba-based models demonstrate adherence to scaling laws. Their advantages are evident in their efficiency and performance in dense prediction tasks, indicating their potential as next-generation visual networks. However, scaling visual Mamba to larger configurations requires further investigation. Additionally, there remains a need for more exploration of Mamba-based models specifically tailored for dense prediction tasks.


VMamba

Additionally, I would introduce VMamba, The S6 used in CV.

As the fighure shows, the VMamba was constructed by SS2D and S6 block, then merged as final output. one of the main question is the scanning method: how can we handle the 2D data without preventing the weights from being input-independent, both can we capture the space information without fixed convs?

As reference of authors words: By adopting complementary 1D traversal paths, SS2D enables each pixel in the image to effectively integrate information from all other pixels in different directions, thereby facilitating the establishment of global receptive fields in the 2D space.


Application In Medical

MambaIR

paper: MambaIR: A Simple Baseline for Image Restoration with State-Space Model

Guo H.etc introduced a simple but effective baseline, named MambaIR, which introduces both local enhancement and channel attention to improve the vanilla Mamba, target for image resolution work.

Framework

However, the standard Mamba, which was designed for 1D sequential data in NLP, is not a natural fit for image restoration scenarios. First, since Mamba processes flattened 1D image sequences in a recursive manner, it can result in spatially close pixels being found at very distant locations in the flattened sequences, resulting in the problem of local pixel forgetting. Second, due to the requirement to memorize the long sequence dependencies, the number of hidden states in the state space equations is typically large, which can lead to channel redundancy, thus hindering the learning of critical channel representations.

Reproduction

Github of the paper: MambaIR

environment

I suggest this approach:

BASH
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
pip install causal_conv1d==1.0.0
pip install mamba_ssm==1.0.1
pip install timm==0.9.16
Copy

while I tried on my server with pip==21.2.3 that can build the mamba_ssm. If you have something wrong with building it, try adjust your pip version.

To get Dataset

As the author upload the dataset into google drive, we can download dataset with python package gdown, which can helps us easily get the data on our server(especially overseas).

BASH
pip install --no-index gdown
gdown https://drive.google.com/uc?id=ID # replace the ID with your data share ID
Copy

eg, when you click DF2K download link, you can see the URL: https://drive.google.com/file/d/1TubDkirxl4qAWelfOnpwaSKoj3KLAIG4/view , where 1TubDkirxl4qAWelfOnpwaSKoj3KLAIG4 between /d/ and /view is the ID we need. Then you can download dataset DF2K with following command:

BASH
gdown https://drive.google.com/uc?id=1TubDkirxl4qAWelfOnpwaSKoj3KLAIG4
Copy

option

open file "/options/train/train_MambaIR_lightSR_x2.yml" or others, you need to revise the PATH of dataset.

then you can train it follow their readme on github.

VM-UNet

paper: [VM-UNet: Vision Mamba UNet](vm-unet: Vision Mamba UNet for Medical Image Segmentation)

Ruan JC .etc propose a Ushape architecture model for medical image segmentation.

Framework

Specifically, VM-UNet comprises a Patch Embedding layer, an encoder, a decoder, a Final Projection layer, and skip connections. They have not adopted a symmetrical structure but instead utilized an asymmetric design. For the skip connections, a straightforward addition operation is adopted without bells and whistles, thereby not introducing any additional parameters.

As we seen, the mamba module was used as the VSS Block and merged in UNet.

Summary