Mamba的发展和在医学图像领域的应用综述
Introduction
Recently, State Space Models (SSMs), exemplified by Mamba, have emerged as a promising approach.
They not only excel in modeling long-range interactions but also maintain a linear computational complexity(compare with quadratic computational complexity in Transformer).
Mamba, a recent selective structured state space model, excels in long sequence modeling, which is vital in the large model era.
Since January 2024, Mamba has been actively applied to diverse computer vision tasks, yielding numerous contributions.
Reference: https://export.arxiv.org/pdf/2404.18861
Related Work
CV:
CNN:
CNNs have linear computational complexity, while they have restricted receptive field. This restriction limits their capability to capture larger spatial contexts, which is essential for comprehensively understanding scenes or complex spatial relations in tasks that demand global information.
ViTs:
Vision Transformers, utilize a self-attention to process of image patches, have demonstrated remarkable modeling capabilities across various visual tasks. However, the self-attention mechanism involves a quadratic computational cost to the number of patches.
SSM:
The state space model is a concept that is widely adopted in various disciplines. Its core idea is connecting the input and output sequences using a latent state.
It takes different forms in different disciplines, such as Markov decision process in reinforcement learning (Hafner et al, 2020), dynamic causal modeling in computational neuroscience (Friston et al, 2003) and Kalman filters in controls (Kalman, 1960).
Recently, the state space model (SSM) has been introduced to deep learning for sequence modeling and its parameters or mappings are learned by gradient descent (Gu et al, 2021). (Combining recurrent, convolutional, and continuous-time models with linear state space layers)
development:
SSM needs extensive computational and memory requirement >> (reparameterizing the state matrices) >> constant sequence transitions restrict their context-based reasoning ability >> (integrating a selection mechanism) >> selectively propagate or forget information along the sequence or scan path based on the current token.
Furthermore, to efficiently compute these selective SSMs, the authors develop a hardware-aware algorithm. Subsequently, the authors integrate these selective SSMs into a simplified neural network architecture, termed Mamba.
Mamba
Before the nature of Mamba, I may propose these blogs:
I most prefer: 一文通透想颠覆Transformer的Mamba:从SSM、HiPPO、S4到Mamba
Video: Mamba and S4 Explained: Architecture, Parallel Scan, Kernel Fusion, Recurrent, Convolution, Math
Mamba详细介绍和RNN、Transformer的架构可视化对比
Formulation
SSM (S4)
one of the most classical SSM transformation is S4: which is mathematically formulated as linear ordinary differential equations (ODEs):
where A ∈ RN×N is the evolution parameter and B ∈ RN×1, C ∈ R1×N are the projection parameters of neural networks in deep learning.
The term h′(t) denotes the derivative of h(t) with respect to time t. Input signal x(t) ∈ R to a 1D output signal y(t) ∈ R through an N-D latent state h(t) ∈ RN .
(2) represented as the discrete (1). After discretizing A, B to A, B, the Eq. (1) can be reformulated as (3).
In cv field, we use convolution kernel to capture the feature, also SSM can be present as convolutional form like (4):
for each x, we can use a convolution kernel to parallel compute them.

x can be refered as input signal
A can be refered as stored historical imformation
B can be refered as handle the input
C can be refered as transformimg station to output
As the figure shows above, easily speaking, D is the skip-connection. The systerm is like a linear time-invariant(LTI) signal feedback system, which is like RNNs. A reserves some hidden information, then how can we generate matrix A that store long sequence memories? ——the ans is HiPPO: HiPPO:Recurrent Memory with Optimal Polynomial Projections.
Selective SSM (S6)
paper: [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](Mamba: Linear-Time Sequence Modeling with Selective State Spaces)
the Eq above as we seen, remain invariant with respect to the input or temporal dynamics, which may confine the long sequence problem. So Mamba introduce a selection mechanism:
Here matrix A remains the same because the state itself is expected to remain static, but the way it is affected (via B and C) is dynamic. Modifying B and C to be selective allows finer-grained control over whether to let an input 𝑥𝑡 into the state ℎ𝑡 or the state into the output 𝑦𝑡. But it leads a new question: we cannot train these dynamic matrixs parallelly.
In fact, we can segmented calculation sequences with selective scan algorithm. This procress may consider the SRAM and HBM in GPU. Keeping the h in SRAM(small memory but high speed), while A, B and C in DRAM.
Backbone_Network
Reproduction
Reference:Mamba神经网络架构~从0构建
I wrote a S4 frame mamba, different from the reference above, but may have some other studying value:
import torch
import torch.nn as nn
import math
class MambaBlock(nn.Module):
def __init__(self, d_model, d_state, d_conv=None, use_conv=True):
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.d_conv = d_conv
self.use_conv = use_conv
# 投影层
self.in_proj = nn.Linear(d_model, 2 * d_model + d_state)
self.out_proj = nn.Linear(d_model, d_model)
# 卷积层(可选)
if use_conv and d_conv is not None:
self.conv = nn.Conv1d(d_model, d_model, kernel_size=d_conv, padding=d_conv-1, groups=d_model)
else:
self.conv = None
# S4D核心参数
self.A = nn.Parameter(torch.randn(d_state, d_state))
self.B = nn.Parameter(torch.randn(d_state, d_model))
self.C = nn.Parameter(torch.randn(d_model, d_state))
self.D = nn.Parameter(torch.randn(d_model))
def forward(self, x):
# x shape: (batch_size, seq_len, d_model)
B, L, D = x.shape
# 投影
x_proj = self.in_proj(x)
x, delta, B = x_proj.split([self.d_model, self.d_model, self.d_state], dim=-1)
# 卷积(如果使用)
if self.conv is not None:
x = x.transpose(1, 2)
x = self.conv(x)[:, :, :L]
x = x.transpose(1, 2)
# S4D核心计算
dA = torch.diag(delta.squeeze(-1))
dB = B
dC = self.C
y = torch.zeros(B, L, D, device=x.device)
h = torch.zeros(B, self.d_state, device=x.device)
for t in range(L):
u = x[:, t]
h = h + u @ dB.T
y[:, t] = h @ dC.T + u * self.D
h = h @ torch.matrix_exp(dA)
# 输出投影
return self.out_proj(y)
class MambaModel(nn.Module):
def __init__(self, d_model, d_state, d_conv=None, num_layers=1, use_conv=True):
super().__init__()
self.layers = nn.ModuleList([MambaBlock(d_model, d_state, d_conv, use_conv) for _ in range(num_layers)])
def forward(self, x):
for layer in self.layers:
x = x + layer(x)
return x
Challenges
Mamba-based models demonstrate adherence to scaling laws. Their advantages are evident in their efficiency and performance in dense prediction tasks, indicating their potential as next-generation visual networks. However, scaling visual Mamba to larger configurations requires further investigation. Additionally, there remains a need for more exploration of Mamba-based models specifically tailored for dense prediction tasks.
VMamba
Additionally, I would introduce VMamba, The S6 used in CV.
As the fighure shows, the VMamba was constructed by SS2D and S6 block, then merged as final output. one of the main question is the scanning method: how can we handle the 2D data without preventing the weights from being input-independent, both can we capture the space information without fixed convs?
As reference of authors words: By adopting complementary 1D traversal paths, SS2D enables each pixel in the image to effectively integrate information from all other pixels in different directions, thereby facilitating the establishment of global receptive fields in the 2D space.
Application In Medical
MambaIR
paper: MambaIR: A Simple Baseline for Image Restoration with State-Space Model
Guo H.etc introduced a simple but effective baseline, named MambaIR, which introduces both local enhancement and channel attention to improve the vanilla Mamba, target for image resolution work.
Framework
However, the standard Mamba, which was designed for 1D sequential data in NLP, is not a natural fit for image restoration scenarios. First, since Mamba processes flattened 1D image sequences in a recursive manner, it can result in spatially close pixels being found at very distant locations in the flattened sequences, resulting in the problem of local pixel forgetting. Second, due to the requirement to memorize the long sequence dependencies, the number of hidden states in the state space equations is typically large, which can lead to channel redundancy, thus hindering the learning of critical channel representations.
Reproduction
Github of the paper: MambaIR
environment
I suggest this approach:
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia
pip install causal_conv1d==1.0.0
pip install mamba_ssm==1.0.1
pip install timm==0.9.16
while I tried on my server with pip==21.2.3 that can build the mamba_ssm. If you have something wrong with building it, try adjust your pip version.
To get Dataset
As the author upload the dataset into google drive, we can download dataset with python package gdown, which can helps us easily get the data on our server(especially overseas).
pip install --no-index gdown
gdown https://drive.google.com/uc?id=ID # replace the ID with your data share ID
eg, when you click DF2K download link, you can see the URL: https://drive.google.com/file/d/1TubDkirxl4qAWelfOnpwaSKoj3KLAIG4/view , where 1TubDkirxl4qAWelfOnpwaSKoj3KLAIG4 between /d/ and /view is the ID we need. Then you can download dataset DF2K with following command:
gdown https://drive.google.com/uc?id=1TubDkirxl4qAWelfOnpwaSKoj3KLAIG4
option
open file "/options/train/train_MambaIR_lightSR_x2.yml" or others, you need to revise the PATH of dataset.
then you can train it follow their readme on github.
VM-UNet
paper: [VM-UNet: Vision Mamba UNet](vm-unet: Vision Mamba UNet for Medical Image Segmentation)
Ruan JC .etc propose a Ushape architecture model for medical image segmentation.
Framework
Specifically, VM-UNet comprises a Patch Embedding layer, an encoder, a decoder, a Final Projection layer, and skip connections. They have not adopted a symmetrical structure but instead utilized an asymmetric design. For the skip connections, a straightforward addition operation is adopted without bells and whistles, thereby not introducing any additional parameters.
As we seen, the mamba module was used as the VSS Block and merged in UNet.