Skip to content

Commit 214ab53

Browse files
authored
Merge branch 'dev' into docs/dints-shape-constraints
2 parents 3420a85 + 252d26e commit 214ab53

12 files changed

Lines changed: 154 additions & 19 deletions

File tree

MANIFEST.in

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,5 @@ include monai/_version.py
33

44
include README.md
55
include LICENSE
6+
7+
prune tests

monai/networks/nets/swin_unetr.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,19 @@ class SwinUNETR(nn.Module):
4747
Swin UNETR based on: "Hatamizadeh et al.,
4848
Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images
4949
<https://arxiv.org/abs/2201.01266>"
50+
51+
Spatial Shape Constraints:
52+
Each spatial dimension of the input must be divisible by ``patch_size ** 5``.
53+
With the default ``patch_size=2``, this means each spatial dimension must be divisible by **32**
54+
(i.e., 2^5 = 32). This requirement comes from the patch embedding step followed by 4 stages
55+
of PatchMerging downsampling, each halving the spatial resolution.
56+
57+
For a custom ``patch_size``, the divisibility requirement is ``patch_size ** 5``.
58+
59+
Examples of valid 3D input sizes (with default ``patch_size=2``):
60+
``(32, 32, 32)``, ``(64, 64, 64)``, ``(96, 96, 96)``, ``(128, 128, 128)``, ``(64, 32, 192)``.
61+
62+
A ``ValueError`` is raised in ``forward()`` if the input spatial shape violates this constraint.
5063
"""
5164

5265
def __init__(
@@ -76,7 +89,8 @@ def __init__(
7689
Args:
7790
in_channels: dimension of input channels.
7891
out_channels: dimension of output channels.
79-
patch_size: size of the patch token.
92+
patch_size: size of the patch token. Input spatial dimensions must be divisible by
93+
``patch_size ** 5`` (e.g., divisible by 32 when ``patch_size=2``).
8094
feature_size: dimension of network feature size.
8195
depths: number of layers in each stage.
8296
num_heads: number of attention heads.
@@ -108,6 +122,10 @@ def __init__(
108122
# for 2D single channel input with size (96,96), 2-channel output and gradient checkpointing.
109123
>>> net = SwinUNETR(in_channels=3, out_channels=2, use_checkpoint=True, spatial_dims=2)
110124
125+
Raises:
126+
ValueError: When a spatial dimension of the input is not divisible by ``patch_size ** 5``.
127+
Use ``net._check_input_size(spatial_shape)`` to validate a shape before inference.
128+
111129
"""
112130

113131
super().__init__()

monai/transforms/signal/array.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def __call__(self, signal: NdarrayOrTensor) -> NdarrayOrTensor:
273273
data = convert_to_tensor(self.freqs * time_partial)
274274
sine_partial = self.magnitude * torch.sin(data)
275275

276-
loc = np.random.choice(range(length))
276+
loc = self.R.choice(range(length))
277277
signal = paste(signal, sine_partial, (loc,))
278278

279279
return signal
@@ -354,7 +354,7 @@ def __call__(self, signal: NdarrayOrTensor) -> NdarrayOrTensor:
354354
time_partial = np.arange(0, round(self.fracs * length), 1)
355355
squaredpulse_partial = self.magnitude * squarepulse(self.freqs * time_partial)
356356

357-
loc = np.random.choice(range(length))
357+
loc = self.R.choice(range(length))
358358
signal = paste(signal, squaredpulse_partial, (loc,))
359359

360360
return signal

monai/transforms/utility/array.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,19 +1049,34 @@ class ConvertToMultiChannelBasedOnBratsClasses(Transform):
10491049
which include TC (Tumor core), WT (Whole tumor) and ET (Enhancing tumor):
10501050
label 1 is the necrotic and non-enhancing tumor core, which should be counted under TC and WT subregion,
10511051
label 2 is the peritumoral edema, which is counted only under WT subregion,
1052-
label 4 is the GD-enhancing tumor, which should be counted under ET, TC, WT subregions.
1052+
the specified `et_label` (default 4) is the GD-enhancing tumor, which should be counted under ET, TC, WT subregions.
1053+
1054+
Args:
1055+
et_label: the label used for the GD-enhancing tumor (ET).
1056+
- Use 4 for BraTS 2018-2022.
1057+
- Use 3 for BraTS 2023.
1058+
Defaults to 4.
10531059
"""
10541060

10551061
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
10561062

1063+
def __init__(self, et_label: int = 4) -> None:
1064+
if et_label in (1, 2):
1065+
raise ValueError(f"et_label cannot be 1 or 2, as these are reserved. Got {et_label}.")
1066+
self.et_label = et_label
1067+
10571068
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
10581069
# if img has channel dim, squeeze it
10591070
if img.ndim == 4 and img.shape[0] == 1:
10601071
img = img.squeeze(0)
10611072

1062-
result = [(img == 1) | (img == 4), (img == 1) | (img == 4) | (img == 2), img == 4]
1063-
# merge labels 1 (tumor non-enh) and 4 (tumor enh) and 2 (large edema) to WT
1064-
# label 4 is ET
1073+
result = [
1074+
(img == 1) | (img == self.et_label),
1075+
(img == 1) | (img == self.et_label) | (img == 2),
1076+
img == self.et_label,
1077+
]
1078+
# merge labels 1 (tumor non-enh) and self.et_label (tumor enh) and 2 (large edema) to WT
1079+
# self.et_label is ET (4 or 3)
10651080
return torch.stack(result, dim=0) if isinstance(img, torch.Tensor) else np.stack(result, axis=0)
10661081

10671082

monai/transforms/utility/dictionary.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1297,19 +1297,27 @@ def __call__(self, data: Mapping[Hashable, Any]):
12971297
class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
12981298
"""
12991299
Dictionary-based wrapper of :py:class:`monai.transforms.ConvertToMultiChannelBasedOnBratsClasses`.
1300-
Convert labels to multi channels based on brats18 classes:
1300+
Convert labels to multi channels based on brats classes:
13011301
label 1 is the necrotic and non-enhancing tumor core
13021302
label 2 is the peritumoral edema
1303-
label 4 is the GD-enhancing tumor
1303+
the specified `et_label` (default 4) is the GD-enhancing tumor
13041304
The possible classes are TC (Tumor core), WT (Whole tumor)
13051305
and ET (Enhancing tumor).
1306+
1307+
Args:
1308+
keys: keys of the corresponding items to be transformed.
1309+
et_label: the label used for the GD-enhancing tumor (ET).
1310+
- Use 4 for BraTS 2018-2022.
1311+
- Use 3 for BraTS 2023.
1312+
Defaults to 4.
1313+
allow_missing_keys: don't raise exception if key is missing.
13061314
"""
13071315

13081316
backend = ConvertToMultiChannelBasedOnBratsClasses.backend
13091317

1310-
def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False):
1318+
def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False, et_label: int = 4):
13111319
super().__init__(keys, allow_missing_keys)
1312-
self.converter = ConvertToMultiChannelBasedOnBratsClasses()
1320+
self.converter = ConvertToMultiChannelBasedOnBratsClasses(et_label=et_label)
13131321

13141322
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
13151323
d = dict(data)

monai/transforms/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@
8585
cp, has_cp = optional_import("cupy")
8686
cp_ndarray, _ = optional_import("cupy", name="ndarray")
8787
exposure, has_skimage = optional_import("skimage.exposure")
88+
# NOTE: cucim is deliberately NOT imported at module level.
89+
# Module-level cucim imports caused very slow import times and other buggy behaviour.
90+
# Keep cucim imports inside the functions that need them.
8891

8992
__all__ = [
9093
"allow_missing_keys_mode",

monai/utils/module.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import pdb
1818
import re
1919
import sys
20+
import traceback as traceback_mod
2021
import warnings
2122
from collections.abc import Callable, Collection, Hashable, Iterable, Mapping
2223
from functools import partial, wraps
@@ -368,8 +369,9 @@ def optional_import(
368369
OptionalImportError: from torch.nn.functional import conv1d (requires version '42' by 'min_version').
369370
"""
370371

371-
tb = None
372+
had_exception = False
372373
exception_str = ""
374+
tb_str = ""
373375
if name:
374376
actual_cmd = f"from {module} import {name}"
375377
else:
@@ -384,8 +386,12 @@ def optional_import(
384386
if name: # user specified to load class/function/... from the module
385387
the_module = getattr(the_module, name)
386388
except Exception as import_exception: # any exceptions during import
387-
tb = import_exception.__traceback__
389+
tb_str = "".join(
390+
traceback_mod.format_exception(type(import_exception), import_exception, import_exception.__traceback__)
391+
)
392+
import_exception.__traceback__ = None
388393
exception_str = f"{import_exception}"
394+
had_exception = True
389395
else: # found the module
390396
if version_args and version_checker(pkg, f"{version}", version_args):
391397
return the_module, True
@@ -394,7 +400,7 @@ def optional_import(
394400

395401
# preparing lazy error message
396402
msg = descriptor.format(actual_cmd)
397-
if version and tb is None: # a pure version issue
403+
if version and not had_exception: # a pure version issue
398404
msg += f" (requires '{module} {version}' by '{version_checker.__name__}')"
399405
if exception_str:
400406
msg += f" ({exception_str})"
@@ -407,10 +413,9 @@ def __init__(self, *_args, **_kwargs):
407413
+ "\n\nFor details about installing the optional dependencies, please visit:"
408414
+ "\n https://monai.readthedocs.io/en/latest/installation.html#installing-the-recommended-dependencies"
409415
)
410-
if tb is None:
411-
self._exception = OptionalImportError(_default_msg)
412-
else:
413-
self._exception = OptionalImportError(_default_msg).with_traceback(tb)
416+
if tb_str:
417+
_default_msg += f"\n\nOriginal traceback:\n{tb_str}"
418+
self._exception = OptionalImportError(_default_msg)
414419

415420
def __getattr__(self, name):
416421
"""

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def get_cmds():
144144
setup(
145145
version=versioneer.get_version(),
146146
cmdclass=get_cmds(),
147-
packages=find_packages(exclude=("docs", "examples", "tests")),
147+
packages=find_packages(exclude=("docs", "examples", "tests", "tests.*")),
148148
zip_safe=False,
149149
package_data={"monai": ["py.typed", *jit_extension_source]}, # type: ignore[arg-type]
150150
ext_modules=get_extensions(),

tests/networks/nets/test_swin_unetr.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,17 @@ def test_ill_arg(self):
9090
with self.assertRaises(ValueError):
9191
SwinUNETR(in_channels=1, out_channels=3, feature_size=24, norm_name="instance", drop_rate=-1)
9292

93+
@skipUnless(has_einops, "Requires einops")
94+
def test_invalid_input_shape(self):
95+
# spatial dims not divisible by patch_size**5 (default patch_size=2, so must be divisible by 32)
96+
net = SwinUNETR(in_channels=1, out_channels=2, feature_size=24, spatial_dims=3)
97+
with self.assertRaises(ValueError):
98+
net(torch.randn(1, 1, 33, 64, 64)) # 33 is not divisible by 32
99+
100+
net_2d = SwinUNETR(in_channels=1, out_channels=2, feature_size=24, spatial_dims=2)
101+
with self.assertRaises(ValueError):
102+
net_2d(torch.randn(1, 1, 48, 33)) # 33 is not divisible by 32
103+
93104
def test_patch_merging(self):
94105
dim = 10
95106
t = PatchMerging(dim)(torch.zeros((1, 21, 20, 20, dim)))

tests/transforms/test_convert_to_multi_channel.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
from tests.test_utils import TEST_NDARRAYS, assert_allclose
2121

2222
TESTS = []
23+
TESTS_ET_LABEL_3 = []
24+
25+
# Tests for default et_label = 4
2326
for p in TEST_NDARRAYS:
2427
TESTS.extend(
2528
[
@@ -46,6 +49,23 @@
4649
]
4750
)
4851

52+
# Tests for et_label = 3
53+
for p in TEST_NDARRAYS:
54+
TESTS_ET_LABEL_3.extend(
55+
[
56+
[
57+
p([[0, 1, 2], [1, 2, 3], [0, 1, 3]]),
58+
p(
59+
[
60+
[[0, 1, 0], [1, 0, 1], [0, 1, 1]],
61+
[[0, 1, 1], [1, 1, 1], [0, 1, 1]],
62+
[[0, 0, 0], [0, 0, 1], [0, 0, 1]],
63+
]
64+
),
65+
]
66+
]
67+
)
68+
4969

5070
class TestConvertToMultiChannel(unittest.TestCase):
5171
@parameterized.expand(TESTS)
@@ -54,6 +74,18 @@ def test_type_shape(self, data, expected_result):
5474
assert_allclose(result, expected_result)
5575
self.assertTrue(result.dtype in (bool, torch.bool))
5676

77+
@parameterized.expand(TESTS_ET_LABEL_3)
78+
def test_type_shape_et_label_3(self, data, expected_result):
79+
result = ConvertToMultiChannelBasedOnBratsClasses(et_label=3)(data)
80+
assert_allclose(result, expected_result)
81+
self.assertTrue(result.dtype in (bool, torch.bool))
82+
83+
def test_invalid_et_label(self):
84+
with self.assertRaises(ValueError):
85+
ConvertToMultiChannelBasedOnBratsClasses(et_label=1)
86+
with self.assertRaises(ValueError):
87+
ConvertToMultiChannelBasedOnBratsClasses(et_label=2)
88+
5789

5890
if __name__ == "__main__":
5991
unittest.main()

0 commit comments

Comments
 (0)