Skip to content

Commit 90e4873

Browse files
committed
Fix spatial_ndim drift for sliced MetaTensor 2D paths
Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
1 parent 7fbde8a commit 90e4873

File tree

3 files changed

+44
-7
lines changed

3 files changed

+44
-7
lines changed

monai/data/meta_tensor.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,25 @@
3131
__all__ = ["MetaTensor", "get_spatial_ndim"]
3232

3333

34+
def _normalize_spatial_ndim(spatial_ndim: int, tensor_ndim: int) -> int:
35+
"""Clamp spatial dims to a valid range for the current tensor shape."""
36+
return max(1, min(int(spatial_ndim), max(int(tensor_ndim) - 1, 1)))
37+
38+
3439
def get_spatial_ndim(img: NdarrayOrTensor) -> int:
3540
"""Return the number of spatial dimensions assuming channel-first layout.
3641
3742
Uses ``MetaTensor.spatial_ndim`` when available, otherwise falls back to
3843
``img.ndim - 1``.
3944
"""
4045
if isinstance(img, MetaTensor):
41-
return img.spatial_ndim
46+
inferred = _normalize_spatial_ndim(img.spatial_ndim, img.ndim)
47+
shape_spatial = max(img.ndim - 1, 1)
48+
# For non-batched tensors, preserve explicit higher-rank shape information
49+
# (e.g., invalid 4D spatial inputs should still be reported as rank 4).
50+
if not img.is_batch and shape_spatial > inferred:
51+
return shape_spatial
52+
return inferred
4253
return img.ndim - 1
4354

4455

@@ -175,9 +186,9 @@ def __init__(
175186
self.affine = self.get_default_affine()
176187
# derive spatial_ndim from affine, clamped by tensor shape
177188
if spatial_ndim is not None:
178-
self.spatial_ndim = spatial_ndim
189+
self.spatial_ndim = _normalize_spatial_ndim(spatial_ndim, self.ndim)
179190
elif self.affine.ndim == 2:
180-
self.spatial_ndim = min(self.affine.shape[-1] - 1, max(self.ndim - 1, 1))
191+
self.spatial_ndim = _normalize_spatial_ndim(self.affine.shape[-1] - 1, self.ndim)
181192

182193
# applied_operations
183194
if applied_operations is not None:
@@ -243,6 +254,8 @@ def update_meta(rets: Sequence, func, args, kwargs) -> Sequence:
243254
# raise NotImplementedError("torch.cat is not implemented for batch of MetaTensors.")
244255
if is_batch:
245256
ret = MetaTensor._handle_batched(ret, idx, metas, func, args, kwargs)
257+
if func == torch.Tensor.__getitem__:
258+
ret.spatial_ndim = _normalize_spatial_ndim(ret.spatial_ndim, ret.ndim)
246259
out.append(ret)
247260
# if the input was a tuple, then return it as a tuple
248261
return tuple(out) if isinstance(rets, tuple) else out
@@ -492,7 +505,7 @@ def affine(self, d: NdarrayTensor) -> None:
492505
a = torch.as_tensor(d, device=torch.device("cpu"), dtype=torch.float64)
493506
self.meta[MetaKeys.AFFINE] = a
494507
if a.ndim == 2: # non-batched: sync spatial_ndim
495-
self.spatial_ndim = a.shape[-1] - 1
508+
self.spatial_ndim = _normalize_spatial_ndim(a.shape[-1] - 1, self.ndim)
496509

497510
@property
498511
def spatial_ndim(self) -> int:

monai/data/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,7 @@ def collate_meta_tensor_fn(batch, *, collate_fn_map=None):
432432
collated.meta = default_collate(meta_dicts)
433433
collated.applied_operations = [i.applied_operations or TraceKeys.NONE for i in batch]
434434
collated.is_batch = True
435-
collated.spatial_ndim = getattr(batch[0], "spatial_ndim", 3) # assumes uniform spatial_ndim
435+
collated.spatial_ndim = min(getattr(batch[0], "spatial_ndim", 3), max(collated.ndim - 1, 1))
436436
return collated
437437

438438

tests/data/meta_tensor/test_spatial_ndim.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
import torch
2020
from parameterized import parameterized
2121

22-
from monai.data import MetaTensor
22+
from monai.data import MetaTensor, get_spatial_ndim
2323
from monai.data.utils import collate_meta_tensor_fn, decollate_batch
24-
from monai.transforms import Affine, RandAffine, Resize, Rotate, SqueezeDim
24+
from monai.transforms import Affine, LabelToContour, RandAffine, RandZoom, Resize, Rotate, SqueezeDim
2525
from monai.transforms.utility.array import SplitDim
2626
from monai.utils import optional_import
2727

@@ -121,6 +121,30 @@ def test_lazy_apply_pending_2d(self):
121121
self.assertIsInstance(result, MetaTensor)
122122
self.assertEqual(len(applied), 1)
123123

124+
def test_batch_slice_clamps_spatial_ndim(self):
125+
t = MetaTensor(torch.randn(10, 6, 5, 7), affine=torch.eye(4))
126+
self.assertEqual(t.spatial_ndim, 3)
127+
sliced = t[0]
128+
self.assertEqual(sliced.shape, (6, 5, 7))
129+
self.assertEqual(sliced.spatial_ndim, 2)
130+
self.assertEqual(get_spatial_ndim(sliced), 2)
131+
132+
def test_label_to_contour_batch_slice_2d(self):
133+
t = MetaTensor(torch.randint(0, 2, (10, 6, 5, 7)).float(), affine=torch.eye(4))
134+
sliced = t[0]
135+
out = LabelToContour()(sliced)
136+
self.assertEqual(out.shape, sliced.shape)
137+
138+
def test_rand_zoom_batch_slice_2d(self):
139+
t = MetaTensor(torch.randn(10, 1, 64, 64), affine=torch.eye(4))
140+
sliced = t[0]
141+
zoom = RandZoom(prob=1.0, min_zoom=0.6, max_zoom=1.2)
142+
zoom.set_random_state(seed=0)
143+
zoom.randomize(sliced)
144+
self.assertEqual(len(zoom._zoom), 2)
145+
out = zoom(sliced)
146+
self.assertEqual(out.ndim, sliced.ndim)
147+
124148
@skipUnless(has_einops, "Requires einops")
125149
def test_einops_rearrange_then_resize(self):
126150
"""Reproduce the exact #6397 bug: einops.rearrange -> Resize."""

0 commit comments

Comments
 (0)