|
1 | 1 | import unittest |
2 | 2 | import torch |
3 | | -from monai.networks.nets import UMamba |
| 3 | +from monai.networks.nets import UMambaUNet |
4 | 4 |
|
5 | 5 | class TestUMamba(unittest.TestCase): |
6 | 6 | def test_forward_shape(self): |
7 | 7 | # Set up input dimensions and model |
8 | | - input_tensor = torch.randn(2, 1, 64, 64) # (batch_size, channels, H, W) |
9 | | - model = UMamba(in_channels=1, out_channels=2) # example args |
10 | | - |
11 | | - # Forward pass |
| 8 | + input_tensor = torch.randn(2, 1, 16, 64, 64) |
| 9 | + model = UMambaUNet(in_channels=1, out_channels=2) |
12 | 10 | output = model(input_tensor) |
13 | | - |
14 | | - # Assert output shape matches expectation |
15 | | - self.assertEqual(output.shape, (2, 2, 64, 64)) # adjust if necessary |
| 11 | + self.assertEqual(output.shape, (2, 2, 16, 64, 64)) |
16 | 12 |
|
17 | 13 | def test_script(self): |
18 | 14 | # Test JIT scripting if supported |
19 | | - model = UMamba(in_channels=1, out_channels=2) |
| 15 | + model = UMambaUNet(in_channels=1, out_channels=2) |
20 | 16 | scripted = torch.jit.script(model) |
21 | 17 | x = torch.randn(1, 1, 64, 64) |
22 | 18 | out = scripted(x) |
|
0 commit comments