Skip to content

Commit 8d7610b

Browse files
committed
Add explicit spatial_ndim tracking to MetaTensor (Fixes #6397)
Fixes dimension-mismatch crashes when einops.rearrange() or other reshape operations change tensor ndim by decoupling spatial rank from tensor shape. - Add _spatial_ndim attribute to MetaObj, derived from affine in MetaTensor - Expose spatial_ndim property with getter/setter and validation - Sync spatial_ndim on affine assignment and propagate through collation - Update transforms to use spatial_ndim instead of ndim-1 heuristic - Add 18 new tests for spatial_ndim behavior Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
1 parent d53eb00 commit 8d7610b

File tree

16 files changed

+281
-53
lines changed

16 files changed

+281
-53
lines changed

monai/data/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
monai_to_itk_ddf,
7272
)
7373
from .meta_obj import MetaObj, get_track_meta, set_track_meta
74-
from .meta_tensor import MetaTensor
74+
from .meta_tensor import MetaTensor, get_spatial_ndim
7575
from .samplers import DistributedSampler, DistributedWeightedRandomSampler
7676
from .synthetic import create_test_image_2d, create_test_image_3d
7777
from .test_time_augmentation import TestTimeAugmentation

monai/data/meta_obj.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def __init__(self) -> None:
8484
self._applied_operations: list = MetaObj.get_default_applied_operations()
8585
self._pending_operations: list = MetaObj.get_default_applied_operations() # the same default as applied_ops
8686
self._is_batch: bool = False
87+
self._spatial_ndim: int = 3 # default: 3 spatial dimensions
8788

8889
@staticmethod
8990
def flatten_meta_objs(*args: Iterable):

monai/data/meta_tensor.py

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,25 @@
2121
import torch
2222

2323
import monai
24-
from monai.config.type_definitions import NdarrayTensor
24+
from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor
2525
from monai.data.meta_obj import MetaObj, get_track_meta
2626
from monai.data.utils import affine_to_spacing, decollate_batch, list_data_collate, remove_extra_metadata
2727
from monai.utils import look_up_option
2828
from monai.utils.enums import LazyAttr, MetaKeys, PostFix, SpaceKeys
2929
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_numpy, convert_to_tensor
3030

31-
__all__ = ["MetaTensor"]
31+
__all__ = ["MetaTensor", "get_spatial_ndim"]
32+
33+
34+
def get_spatial_ndim(img: NdarrayOrTensor) -> int:
35+
"""Return the number of spatial dimensions assuming channel-first layout.
36+
37+
Uses ``MetaTensor.spatial_ndim`` when available, otherwise falls back to
38+
``img.ndim - 1``.
39+
"""
40+
if isinstance(img, MetaTensor):
41+
return img.spatial_ndim
42+
return img.ndim - 1
3243

3344

3445
@functools.lru_cache(None)
@@ -111,6 +122,7 @@ def __new__(
111122
meta: dict | None = None,
112123
applied_operations: list | None = None,
113124
*args,
125+
spatial_ndim: int | None = None,
114126
**kwargs,
115127
) -> MetaTensor:
116128
_kwargs = {"device": kwargs.pop("device", None), "dtype": kwargs.pop("dtype", None)} if kwargs else {}
@@ -123,6 +135,7 @@ def __init__(
123135
meta: dict | None = None,
124136
applied_operations: list | None = None,
125137
*_args,
138+
spatial_ndim: int | None = None,
126139
**_kwargs,
127140
) -> None:
128141
"""
@@ -134,6 +147,8 @@ def __init__(
134147
the list is typically maintained by `monai.transforms.TraceableTransform`.
135148
See also: :py:class:`monai.transforms.TraceableTransform`
136149
_args: additional args (currently not in use in this constructor).
150+
spatial_ndim: optional number of spatial dimensions. If ``None``, derived
151+
from the affine matrix clamped by the tensor shape.
137152
_kwargs: additional kwargs (currently not in use in this constructor).
138153
139154
Note:
@@ -158,6 +173,12 @@ def __init__(
158173
self.affine = self.meta[MetaKeys.AFFINE]
159174
else:
160175
self.affine = self.get_default_affine()
176+
# derive spatial_ndim from affine, clamped by tensor shape
177+
if spatial_ndim is not None:
178+
self.spatial_ndim = spatial_ndim
179+
elif self.affine.ndim == 2:
180+
self.spatial_ndim = min(self.affine.shape[-1] - 1, max(self.ndim - 1, 1))
181+
161182
# applied_operations
162183
if applied_operations is not None:
163184
self.applied_operations = applied_operations
@@ -468,14 +489,29 @@ def affine(self) -> torch.Tensor:
468489
@affine.setter
469490
def affine(self, d: NdarrayTensor) -> None:
470491
"""Set the affine."""
471-
self.meta[MetaKeys.AFFINE] = torch.as_tensor(d, device=torch.device("cpu"), dtype=torch.float64)
492+
a = torch.as_tensor(d, device=torch.device("cpu"), dtype=torch.float64)
493+
self.meta[MetaKeys.AFFINE] = a
494+
if a.ndim == 2: # non-batched: sync spatial_ndim
495+
self.spatial_ndim = a.shape[-1] - 1
496+
497+
@property
498+
def spatial_ndim(self) -> int:
499+
"""Get the number of spatial dimensions."""
500+
return getattr(self, "_spatial_ndim", 3)
501+
502+
@spatial_ndim.setter
503+
def spatial_ndim(self, val: int) -> None:
504+
"""Set the number of spatial dimensions."""
505+
if val < 1:
506+
raise ValueError(f"spatial_ndim must be >= 1, got {val}")
507+
self._spatial_ndim = val
472508

473509
@property
474510
def pixdim(self):
475511
"""Get the spacing"""
476512
if self.is_batch:
477-
return [affine_to_spacing(a) for a in self.affine]
478-
return affine_to_spacing(self.affine)
513+
return [affine_to_spacing(a, r=self.spatial_ndim) for a in self.affine]
514+
return affine_to_spacing(self.affine, r=self.spatial_ndim)
479515

480516
def peek_pending_shape(self):
481517
"""
@@ -490,7 +526,7 @@ def peek_pending_shape(self):
490526

491527
def peek_pending_affine(self):
492528
res = self.affine
493-
r = len(res) - 1
529+
r = res.shape[-1] - 1 if res.ndim >= 2 else self.spatial_ndim
494530
if r not in (2, 3):
495531
warnings.warn(f"Only 2d and 3d affine are supported, got {r}d input.")
496532
for p in self.pending_operations:
@@ -503,8 +539,10 @@ def peek_pending_affine(self):
503539
return res
504540

505541
def peek_pending_rank(self):
506-
a = self.pending_operations[-1].get(LazyAttr.AFFINE, None) if self.pending_operations else self.affine
507-
return 1 if a is None else int(max(1, len(a) - 1))
542+
if self.pending_operations:
543+
a = self.pending_operations[-1].get(LazyAttr.AFFINE, None)
544+
return 1 if a is None else int(max(1, len(a) - 1))
545+
return self.spatial_ndim
508546

509547
def new_empty(self, size, dtype=None, device=None, requires_grad=False): # type: ignore[override]
510548
"""

monai/data/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,7 @@ def collate_meta_tensor_fn(batch, *, collate_fn_map=None):
432432
collated.meta = default_collate(meta_dicts)
433433
collated.applied_operations = [i.applied_operations or TraceKeys.NONE for i in batch]
434434
collated.is_batch = True
435+
collated.spatial_ndim = getattr(batch[0], "spatial_ndim", 3) # assumes uniform spatial_ndim
435436
return collated
436437

437438

monai/transforms/croppad/functional.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
from monai.config.type_definitions import NdarrayTensor
2424
from monai.data.meta_obj import get_track_meta
25-
from monai.data.meta_tensor import MetaTensor
25+
from monai.data.meta_tensor import MetaTensor, get_spatial_ndim
2626
from monai.data.utils import to_affine_nd
2727
from monai.transforms.inverse import TraceableTransform
2828
from monai.transforms.utils import convert_pad_mode, create_translate
@@ -132,7 +132,7 @@ def crop_or_pad_nd(img: torch.Tensor, translation_mat, spatial_size: tuple[int,
132132
mode: the padding mode.
133133
kwargs: other arguments for the `np.pad` or `torch.pad` function.
134134
"""
135-
ndim = len(img.shape) - 1
135+
ndim = get_spatial_ndim(img)
136136
matrix_np = np.round(to_affine_nd(ndim, convert_to_numpy(translation_mat, wrap_sequence=True).copy()))
137137
matrix_np = to_affine_nd(len(spatial_size), matrix_np)
138138
cc = np.asarray(np.meshgrid(*[[0.5, x - 0.5] for x in spatial_size], indexing="ij"))

monai/transforms/intensity/array.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from monai.config import DtypeLike
2727
from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor
2828
from monai.data.meta_obj import get_track_meta
29+
from monai.data.meta_tensor import get_spatial_ndim
2930
from monai.data.ultrasound_confidence_map import UltrasoundConfidenceMap
3031
from monai.data.utils import get_random_patch, get_valid_patch_size
3132
from monai.networks.layers import GaussianFilter, HilbertTransform, MedianFilter, SavitzkyGolayFilter
@@ -1580,7 +1581,7 @@ def __init__(self, radius: Sequence[int] | int = 1) -> None:
15801581
def __call__(self, img: NdarrayTensor) -> NdarrayTensor:
15811582
img = convert_to_tensor(img, track_meta=get_track_meta())
15821583
img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float)
1583-
spatial_dims = img_t.ndim - 1
1584+
spatial_dims = get_spatial_ndim(img)
15841585
r = ensure_tuple_rep(self.radius, spatial_dims)
15851586
median_filter_instance = MedianFilter(r, spatial_dims=spatial_dims)
15861587
out_t: torch.Tensor = median_filter_instance(img_t)
@@ -1616,7 +1617,7 @@ def __call__(self, img: NdarrayTensor) -> NdarrayTensor:
16161617
sigma = [torch.as_tensor(s, device=img_t.device) for s in self.sigma]
16171618
else:
16181619
sigma = torch.as_tensor(self.sigma, device=img_t.device)
1619-
gaussian_filter = GaussianFilter(img_t.ndim - 1, sigma, approx=self.approx)
1620+
gaussian_filter = GaussianFilter(get_spatial_ndim(img), sigma, approx=self.approx)
16201621
out_t: torch.Tensor = gaussian_filter(img_t.unsqueeze(0)).squeeze(0)
16211622
out, *_ = convert_to_dst_type(out_t, dst=img, dtype=out_t.dtype)
16221623

@@ -1673,7 +1674,7 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen
16731674
if not self._do_transform:
16741675
return img
16751676

1676-
sigma = ensure_tuple_size(vals=(self.x, self.y, self.z), dim=img.ndim - 1)
1677+
sigma = ensure_tuple_size(vals=(self.x, self.y, self.z), dim=get_spatial_ndim(img))
16771678
return GaussianSmooth(sigma=sigma, approx=self.approx)(img)
16781679

16791680

@@ -1723,7 +1724,7 @@ def __call__(self, img: NdarrayTensor) -> NdarrayTensor:
17231724
img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float32)
17241725

17251726
gf1, gf2 = (
1726-
GaussianFilter(img_t.ndim - 1, sigma, approx=self.approx).to(img_t.device)
1727+
GaussianFilter(get_spatial_ndim(img), sigma, approx=self.approx).to(img_t.device)
17271728
for sigma in (self.sigma1, self.sigma2)
17281729
)
17291730
blurred_f = gf1(img_t.unsqueeze(0))
@@ -1811,8 +1812,9 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen
18111812

18121813
if self.x2 is None or self.y2 is None or self.z2 is None or self.a is None:
18131814
raise RuntimeError("please call the `randomize()` function first.")
1814-
sigma1 = ensure_tuple_size(vals=(self.x1, self.y1, self.z1), dim=img.ndim - 1)
1815-
sigma2 = ensure_tuple_size(vals=(self.x2, self.y2, self.z2), dim=img.ndim - 1)
1815+
_sp = get_spatial_ndim(img)
1816+
sigma1 = ensure_tuple_size(vals=(self.x1, self.y1, self.z1), dim=_sp)
1817+
sigma2 = ensure_tuple_size(vals=(self.x2, self.y2, self.z2), dim=_sp)
18161818
return GaussianSharpen(sigma1=sigma1, sigma2=sigma2, alpha=self.a, approx=self.approx)(img)
18171819

18181820

monai/transforms/inverse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def track_transform_meta(
213213
orig_affine = data_t.peek_pending_affine()
214214
orig_affine = convert_to_dst_type(orig_affine, affine, dtype=torch.float64)[0]
215215
try:
216-
affine = orig_affine @ to_affine_nd(len(orig_affine) - 1, affine, dtype=torch.float64)
216+
affine = orig_affine @ to_affine_nd(orig_affine.shape[-1] - 1, affine, dtype=torch.float64)
217217
except RuntimeError as e:
218218
if orig_affine.ndim > 2:
219219
if data_t.is_batch:

monai/transforms/lazy/functional.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -256,9 +256,11 @@ def apply_pending(data: torch.Tensor | MetaTensor, pending: list | None = None,
256256
if not pending:
257257
return data, []
258258

259+
_rank = data.spatial_ndim if isinstance(data, MetaTensor) else 3
260+
259261
cumulative_xform = affine_from_pending(pending[0])
260-
if cumulative_xform.shape[0] == 3:
261-
cumulative_xform = to_affine_nd(3, cumulative_xform)
262+
if cumulative_xform.shape[0] < _rank + 1:
263+
cumulative_xform = to_affine_nd(_rank, cumulative_xform)
262264

263265
cur_kwargs = kwargs_from_pending(pending[0])
264266
override_kwargs: dict[str, Any] = {}
@@ -283,8 +285,8 @@ def apply_pending(data: torch.Tensor | MetaTensor, pending: list | None = None,
283285
data = resample(data.to(device), cumulative_xform, _cur_kwargs)
284286

285287
next_matrix = affine_from_pending(p)
286-
if next_matrix.shape[0] == 3:
287-
next_matrix = to_affine_nd(3, next_matrix)
288+
if next_matrix.shape[0] < _rank + 1:
289+
next_matrix = to_affine_nd(_rank, next_matrix)
288290

289291
cumulative_xform = combine_transforms(cumulative_xform, next_matrix)
290292
cur_kwargs.update(new_kwargs)

monai/transforms/post/array.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from monai.config.type_definitions import NdarrayOrTensor
2525
from monai.data.meta_obj import get_track_meta
26-
from monai.data.meta_tensor import MetaTensor
26+
from monai.data.meta_tensor import MetaTensor, get_spatial_ndim
2727
from monai.networks import one_hot
2828
from monai.networks.layers import GaussianFilter, apply_filter, separable_filtering
2929
from monai.transforms.inverse import InvertibleTransform
@@ -624,7 +624,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
624624
"""
625625
img = convert_to_tensor(img, track_meta=get_track_meta())
626626
img_: torch.Tensor = convert_to_tensor(img, track_meta=False)
627-
spatial_dims = len(img_.shape) - 1
627+
spatial_dims = get_spatial_ndim(img)
628628
img_ = img_.unsqueeze(0) # adds a batch dim
629629
if spatial_dims == 2:
630630
kernel = torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], dtype=torch.float32)
@@ -1104,7 +1104,7 @@ def __call__(self, image: NdarrayOrTensor) -> torch.Tensor:
11041104
image_tensor = convert_to_tensor(image, track_meta=get_track_meta())
11051105

11061106
# Check/set spatial axes
1107-
n_spatial_dims = image_tensor.ndim - 1 # excluding the channel dimension
1107+
n_spatial_dims = get_spatial_ndim(image_tensor)
11081108
valid_spatial_axes = list(range(n_spatial_dims)) + list(range(-n_spatial_dims, 0))
11091109

11101110
# Check gradient axes to be valid

monai/transforms/spatial/array.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from monai.config.type_definitions import NdarrayOrTensor
2828
from monai.data.box_utils import BoxMode, StandardMode
2929
from monai.data.meta_obj import get_track_meta, set_track_meta
30-
from monai.data.meta_tensor import MetaTensor
30+
from monai.data.meta_tensor import MetaTensor, get_spatial_ndim
3131
from monai.data.utils import AFFINE_TOL, affine_to_spacing, compute_shape_offset, iter_patch, to_affine_nd, zoom_affine
3232
from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull
3333
from monai.networks.utils import meshgrid_ij
@@ -848,12 +848,14 @@ def __call__(
848848
anti_aliasing = self.anti_aliasing if anti_aliasing is None else anti_aliasing
849849
anti_aliasing_sigma = self.anti_aliasing_sigma if anti_aliasing_sigma is None else anti_aliasing_sigma
850850

851-
input_ndim = img.ndim - 1 # spatial ndim
851+
input_ndim = get_spatial_ndim(img)
852852
if self.size_mode == "all":
853853
output_ndim = len(ensure_tuple(self.spatial_size))
854854
if output_ndim > input_ndim:
855855
input_shape = ensure_tuple_size(img.shape, output_ndim + 1, 1)
856856
img = img.reshape(input_shape)
857+
if isinstance(img, MetaTensor):
858+
img.spatial_ndim = output_ndim
857859
elif output_ndim < input_ndim:
858860
raise ValueError(
859861
"len(spatial_size) must be greater or equal to img spatial dimensions, "
@@ -1034,7 +1036,7 @@ def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor:
10341036
out = convert_to_dst_type(out, dst=data, dtype=out.dtype)[0]
10351037
if isinstance(out, MetaTensor):
10361038
affine = convert_to_tensor(out.peek_pending_affine(), track_meta=False)
1037-
mat = to_affine_nd(len(affine) - 1, transform_t)
1039+
mat = to_affine_nd(out.spatial_ndim, transform_t)
10381040
out.affine @= convert_to_dst_type(mat, affine)[0]
10391041
return out
10401042

@@ -1131,7 +1133,7 @@ def __call__(
11311133
during initialization for this call. Defaults to None.
11321134
"""
11331135
img = convert_to_tensor(img, track_meta=get_track_meta())
1134-
_zoom = ensure_tuple_rep(self.zoom, img.ndim - 1) # match the spatial image dim
1136+
_zoom = ensure_tuple_rep(self.zoom, get_spatial_ndim(img))
11351137
_mode = self.mode if mode is None else mode
11361138
_padding_mode = padding_mode or self.padding_mode
11371139
_align_corners = self.align_corners if align_corners is None else align_corners
@@ -1519,7 +1521,7 @@ def randomize(self, data: NdarrayOrTensor) -> None:
15191521
super().randomize(None)
15201522
if not self._do_transform:
15211523
return None
1522-
self._axis = self.R.randint(data.ndim - 1)
1524+
self._axis = self.R.randint(get_spatial_ndim(data))
15231525

15241526
def __call__(self, img: torch.Tensor, randomize: bool = True, lazy: bool | None = None) -> torch.Tensor:
15251527
"""
@@ -1629,13 +1631,14 @@ def randomize(self, img: NdarrayOrTensor) -> None:
16291631
super().randomize(None)
16301632
if not self._do_transform:
16311633
return None
1634+
_sp = get_spatial_ndim(img)
16321635
self._zoom = [self.R.uniform(l, h) for l, h in zip(self.min_zoom, self.max_zoom)]
16331636
if len(self._zoom) == 1:
16341637
# to keep the spatial shape ratio, use same random zoom factor for all dims
1635-
self._zoom = ensure_tuple_rep(self._zoom[0], img.ndim - 1)
1636-
elif len(self._zoom) == 2 and img.ndim > 3:
1638+
self._zoom = ensure_tuple_rep(self._zoom[0], _sp)
1639+
elif len(self._zoom) == 2 and _sp > 2:
16371640
# if 2 zoom factors provided for 3D data, use the first factor for H and W dims, second factor for D dim
1638-
self._zoom = ensure_tuple_rep(self._zoom[0], img.ndim - 2) + ensure_tuple(self._zoom[-1])
1641+
self._zoom = ensure_tuple_rep(self._zoom[0], _sp - 1) + ensure_tuple(self._zoom[-1])
16391642

16401643
def __call__(
16411644
self,
@@ -2350,7 +2353,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor:
23502353
out.meta = data.meta # type: ignore
23512354
affine = convert_data_type(out.peek_pending_affine(), torch.Tensor)[0]
23522355
xform, *_ = convert_to_dst_type(
2353-
Affine.compute_w_affine(len(affine) - 1, inv_affine, data.shape[1:], orig_size), affine
2356+
Affine.compute_w_affine(out.spatial_ndim, inv_affine, data.shape[1:], orig_size), affine
23542357
)
23552358
out.affine @= xform
23562359
return out
@@ -2619,7 +2622,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor:
26192622
out.meta = data.meta # type: ignore
26202623
affine = convert_data_type(out.peek_pending_affine(), torch.Tensor)[0]
26212624
xform, *_ = convert_to_dst_type(
2622-
Affine.compute_w_affine(len(affine) - 1, inv_affine, data.shape[1:], orig_size), affine
2625+
Affine.compute_w_affine(out.spatial_ndim, inv_affine, data.shape[1:], orig_size), affine
26232626
)
26242627
out.affine @= xform
26252628
return out
@@ -3032,7 +3035,7 @@ def __call__(
30323035
raise ValueError("the spatial size of `img` does not match with the length of `distort_steps`")
30333036

30343037
all_ranges = []
3035-
num_cells = ensure_tuple_rep(self.num_cells, len(img.shape) - 1)
3038+
num_cells = ensure_tuple_rep(self.num_cells, get_spatial_ndim(img))
30363039
if isinstance(img, MetaTensor) and img.pending_operations:
30373040
warnings.warn("MetaTensor img has pending operations, transform may return incorrect results.")
30383041
for dim_idx, dim_size in enumerate(img.shape[1:]):

0 commit comments

Comments
 (0)