Skip to content

Commit 4964850

Browse files
committed
Register UMambaUNet in monai.networks.nets.__init__
1 parent 9b190cf commit 4964850

2 files changed

Lines changed: 111 additions & 0 deletions

File tree

monai/networks/nets/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,3 +144,4 @@
144144
from .vnet import VNet
145145
from .voxelmorph import VoxelMorph, VoxelMorphUNet
146146
from .vqvae import VQVAE
147+
from .u_mamba import UMambaUNet

monai/networks/nets/u_mamba.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import torch
13+
import torch.nn as nn
14+
import torch.nn.functional as F
15+
16+
# Simple placeholder for the SSM (Mamba-like block)
17+
class SSMBlock(nn.Module):
18+
def __init__(self, dim):
19+
super().__init__()
20+
self.linear1 = nn.Linear(dim, dim)
21+
self.linear2 = nn.Linear(dim, dim)
22+
23+
def forward(self, x):
24+
# x: (B, L, C)
25+
return self.linear2(torch.silu(self.linear1(x)))
26+
27+
class UMambaBlock(nn.Module):
28+
def __init__(self, in_channels, hidden_channels):
29+
super().__init__()
30+
self.conv_res1 = nn.Sequential(
31+
nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1),
32+
nn.InstanceNorm3d(in_channels),
33+
nn.LeakyReLU(),
34+
)
35+
self.conv_res2 = nn.Sequential(
36+
nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1),
37+
nn.InstanceNorm3d(in_channels),
38+
nn.LeakyReLU(),
39+
)
40+
41+
self.layernorm = nn.LayerNorm(hidden_channels)
42+
self.linear1 = nn.Linear(in_channels, hidden_channels)
43+
self.linear2 = nn.Linear(hidden_channels, in_channels)
44+
self.conv1d = nn.Conv1d(hidden_channels, hidden_channels, kernel_size=3, padding=1)
45+
self.ssm = SSMBlock(hidden_channels)
46+
47+
def forward(self, x):
48+
# x: (B, C, H, W, D)
49+
residual = x
50+
x = self.conv_res1(x)
51+
x = self.conv_res2(x) + residual
52+
53+
B, C, H, W, D = x.shape
54+
x_flat = x.view(B, C, -1).permute(0, 2, 1) # (B, L, C)
55+
x_norm = self.layernorm(x_flat)
56+
x_proj = self.linear1(x_norm)
57+
58+
x_silu = torch.silu(x_proj)
59+
x_ssm = self.ssm(x_silu)
60+
x_conv1d = self.conv1d(x_proj.permute(0, 2, 1)).permute(0, 2, 1)
61+
62+
x_combined = torch.silu(x_conv1d) * torch.silu(x_ssm)
63+
x_out = self.linear2(x_combined)
64+
x_out = x_out.permute(0, 2, 1).view(B, C, H, W, D)
65+
66+
return x + x_out # Residual connection
67+
68+
class ResidualBlock(nn.Module):
69+
def __init__(self, channels):
70+
super().__init__()
71+
self.block = nn.Sequential(
72+
nn.Conv3d(channels, channels, kernel_size=3, padding=1),
73+
nn.BatchNorm3d(channels),
74+
nn.ReLU(),
75+
nn.Conv3d(channels, channels, kernel_size=3, padding=1),
76+
nn.BatchNorm3d(channels),
77+
)
78+
79+
def forward(self, x):
80+
return F.relu(x + self.block(x))
81+
82+
class UMambaUNet(nn.Module):
83+
def __init__(self, in_channels=1, out_channels=1, base_channels=32):
84+
super().__init__()
85+
self.enc1 = UMambaBlock(in_channels, base_channels)
86+
self.down1 = nn.Conv3d(base_channels, base_channels*2, kernel_size=3, stride=2, padding=1)
87+
88+
self.enc2 = UMambaBlock(base_channels*2, base_channels*2)
89+
self.down2 = nn.Conv3d(base_channels*2, base_channels*4, kernel_size=3, stride=2, padding=1)
90+
91+
self.bottleneck = UMambaBlock(base_channels*4, base_channels*4)
92+
93+
self.up2 = nn.ConvTranspose3d(base_channels*4, base_channels*2, kernel_size=2, stride=2)
94+
self.dec2 = ResidualBlock(base_channels*4)
95+
96+
self.up1 = nn.ConvTranspose3d(base_channels*2, base_channels, kernel_size=2, stride=2)
97+
self.dec1 = ResidualBlock(base_channels*2)
98+
99+
self.final = nn.Conv3d(base_channels, out_channels, kernel_size=1)
100+
101+
def forward(self, x):
102+
x1 = self.enc1(x)
103+
x2 = self.enc2(self.down1(x1))
104+
x3 = self.bottleneck(self.down2(x2))
105+
106+
x = self.up2(x3)
107+
x = self.dec2(torch.cat([x, x2], dim=1))
108+
x = self.up1(x)
109+
x = self.dec1(torch.cat([x, x1], dim=1))
110+
return self.final(x)

0 commit comments

Comments
 (0)