Skip to content

Commit ab2be3a

Browse files
committed
Fix get_spatial_ndim for 2D post-EnsureChannelFirst without regressing 3D
The previous commit removed no_channel from all call sites, which broke 3D no-channel tensors (spatial_ndim clamped to ndim-1=2 instead of 3). Targeted fix: only remove no_channel from get_spatial_ndim, since it runs after EnsureChannelFirst has already added a channel dim. The constructor and affine setter keep no_channel to correctly handle pre-channel tensors. Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
1 parent dc16dda commit ab2be3a

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

monai/data/meta_tensor.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,25 +23,34 @@
2323
import monai
2424
from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor
2525
from monai.data.meta_obj import _DEFAULT_SPATIAL_NDIM, MetaObj, get_track_meta
26-
from monai.data.utils import affine_to_spacing, decollate_batch, list_data_collate, remove_extra_metadata
26+
from monai.data.utils import affine_to_spacing, decollate_batch, is_no_channel, list_data_collate, remove_extra_metadata
2727
from monai.utils import look_up_option
2828
from monai.utils.enums import LazyAttr, MetaKeys, PostFix, SpaceKeys
2929
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_numpy, convert_to_tensor
3030

3131
__all__ = ["MetaTensor", "get_spatial_ndim"]
3232

3333

34-
def _normalize_spatial_ndim(spatial_ndim: int, tensor_ndim: int) -> int:
34+
def _normalize_spatial_ndim(spatial_ndim: int, tensor_ndim: int, no_channel: bool = False) -> int:
3535
"""Clamp spatial dims to a valid range for the current tensor shape."""
36-
limit = max(int(tensor_ndim) - 1, 1)
36+
limit = max(int(tensor_ndim), 1) if no_channel else max(int(tensor_ndim) - 1, 1)
3737
return max(1, min(int(spatial_ndim), limit))
3838

3939

40+
def _has_explicit_no_channel(meta: Mapping | None) -> bool:
41+
return (
42+
isinstance(meta, Mapping)
43+
and MetaKeys.ORIGINAL_CHANNEL_DIM in meta
44+
and is_no_channel(meta[MetaKeys.ORIGINAL_CHANNEL_DIM])
45+
)
46+
47+
4048
def get_spatial_ndim(img: NdarrayOrTensor) -> int:
4149
"""Return the number of spatial dimensions assuming channel-first layout.
4250
4351
Uses ``MetaTensor.spatial_ndim`` when available, otherwise falls back to
44-
``img.ndim - 1``.
52+
``img.ndim - 1``. Always assumes channel-first (``no_channel=False``)
53+
because callers run after ``EnsureChannelFirst`` has already added one.
4554
"""
4655
if isinstance(img, MetaTensor):
4756
return _normalize_spatial_ndim(img.spatial_ndim, img.ndim)
@@ -192,10 +201,11 @@ def __init__(
192201
self.affine = self.get_default_affine()
193202
# Initialize spatial_ndim from affine matrix (source of truth), clamped by tensor shape.
194203
# This cached value is kept in sync via the affine setter for hot-path performance.
204+
no_channel = _has_explicit_no_channel(self.meta)
195205
if spatial_ndim is not None:
196-
self.spatial_ndim = _normalize_spatial_ndim(spatial_ndim, self.ndim)
206+
self.spatial_ndim = _normalize_spatial_ndim(spatial_ndim, self.ndim, no_channel=no_channel)
197207
elif self.affine.ndim == 2:
198-
self.spatial_ndim = _normalize_spatial_ndim(self.affine.shape[-1] - 1, self.ndim)
208+
self.spatial_ndim = _normalize_spatial_ndim(self.affine.shape[-1] - 1, self.ndim, no_channel=no_channel)
199209

200210
# applied_operations
201211
if applied_operations is not None:
@@ -518,7 +528,8 @@ def affine(self, d: NdarrayTensor) -> None:
518528
a = torch.as_tensor(d, device=torch.device("cpu"), dtype=torch.float64)
519529
self.meta[MetaKeys.AFFINE] = a
520530
if a.ndim == 2: # non-batched: sync spatial_ndim from affine (source of truth)
521-
self.spatial_ndim = _normalize_spatial_ndim(a.shape[-1] - 1, self.ndim)
531+
no_channel = _has_explicit_no_channel(self.meta)
532+
self.spatial_ndim = _normalize_spatial_ndim(a.shape[-1] - 1, self.ndim, no_channel=no_channel)
522533

523534
@property
524535
def spatial_ndim(self) -> int:

0 commit comments

Comments
 (0)