Skip to content

Commit dc16dda

Browse files
committed
Fix 2D inverse transform failures by removing no_channel from spatial_ndim normalization
After EnsureChannelFirst adds a channel dim, ORIGINAL_CHANNEL_DIM="no_channel" refers to the original file, not the current tensor. The no_channel flag caused _normalize_spatial_ndim to treat all dims as spatial (returning 3 instead of 2), breaking Resized_2d, Resized_longest_2d, and Zoomd_2d inverse transforms. Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
1 parent 9f359b0 commit dc16dda

File tree

1 file changed

+8
-19
lines changed

1 file changed

+8
-19
lines changed

monai/data/meta_tensor.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,37 +23,28 @@
2323
import monai
2424
from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor
2525
from monai.data.meta_obj import _DEFAULT_SPATIAL_NDIM, MetaObj, get_track_meta
26-
from monai.data.utils import affine_to_spacing, decollate_batch, is_no_channel, list_data_collate, remove_extra_metadata
26+
from monai.data.utils import affine_to_spacing, decollate_batch, list_data_collate, remove_extra_metadata
2727
from monai.utils import look_up_option
2828
from monai.utils.enums import LazyAttr, MetaKeys, PostFix, SpaceKeys
2929
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_numpy, convert_to_tensor
3030

3131
__all__ = ["MetaTensor", "get_spatial_ndim"]
3232

3333

34-
def _normalize_spatial_ndim(spatial_ndim: int, tensor_ndim: int, no_channel: bool = False) -> int:
34+
def _normalize_spatial_ndim(spatial_ndim: int, tensor_ndim: int) -> int:
3535
"""Clamp spatial dims to a valid range for the current tensor shape."""
36-
limit = max(int(tensor_ndim), 1) if no_channel else max(int(tensor_ndim) - 1, 1)
36+
limit = max(int(tensor_ndim) - 1, 1)
3737
return max(1, min(int(spatial_ndim), limit))
3838

3939

40-
def _has_explicit_no_channel(meta: Mapping | None) -> bool:
41-
return (
42-
isinstance(meta, Mapping)
43-
and MetaKeys.ORIGINAL_CHANNEL_DIM in meta
44-
and is_no_channel(meta[MetaKeys.ORIGINAL_CHANNEL_DIM])
45-
)
46-
47-
4840
def get_spatial_ndim(img: NdarrayOrTensor) -> int:
4941
"""Return the number of spatial dimensions assuming channel-first layout.
5042
5143
Uses ``MetaTensor.spatial_ndim`` when available, otherwise falls back to
5244
``img.ndim - 1``.
5345
"""
5446
if isinstance(img, MetaTensor):
55-
no_channel = _has_explicit_no_channel(img.meta)
56-
return _normalize_spatial_ndim(img.spatial_ndim, img.ndim, no_channel=no_channel)
47+
return _normalize_spatial_ndim(img.spatial_ndim, img.ndim)
5748
return img.ndim - 1
5849

5950

@@ -201,11 +192,10 @@ def __init__(
201192
self.affine = self.get_default_affine()
202193
# Initialize spatial_ndim from affine matrix (source of truth), clamped by tensor shape.
203194
# This cached value is kept in sync via the affine setter for hot-path performance.
204-
no_channel = _has_explicit_no_channel(self.meta)
205195
if spatial_ndim is not None:
206-
self.spatial_ndim = _normalize_spatial_ndim(spatial_ndim, self.ndim, no_channel=no_channel)
196+
self.spatial_ndim = _normalize_spatial_ndim(spatial_ndim, self.ndim)
207197
elif self.affine.ndim == 2:
208-
self.spatial_ndim = _normalize_spatial_ndim(self.affine.shape[-1] - 1, self.ndim, no_channel=no_channel)
198+
self.spatial_ndim = _normalize_spatial_ndim(self.affine.shape[-1] - 1, self.ndim)
209199

210200
# applied_operations
211201
if applied_operations is not None:
@@ -309,7 +299,7 @@ def _handle_batched(cls, ret, idx, metas, func, args, kwargs):
309299
if hasattr(ret_meta, "__dict__"):
310300
ret.__dict__ = ret_meta.__dict__.copy()
311301
if _is_batch_only_index(full_idx):
312-
ret.spatial_ndim = _normalize_spatial_ndim(ret.spatial_ndim, ret.ndim, no_channel=False)
302+
ret.spatial_ndim = _normalize_spatial_ndim(ret.spatial_ndim, ret.ndim)
313303
# `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`.
314304
# But we only want to split the batch if the `unbind` is along the 0th dimension.
315305
elif func == torch.Tensor.unbind:
@@ -528,8 +518,7 @@ def affine(self, d: NdarrayTensor) -> None:
528518
a = torch.as_tensor(d, device=torch.device("cpu"), dtype=torch.float64)
529519
self.meta[MetaKeys.AFFINE] = a
530520
if a.ndim == 2: # non-batched: sync spatial_ndim from affine (source of truth)
531-
no_channel = _has_explicit_no_channel(self.meta)
532-
self.spatial_ndim = _normalize_spatial_ndim(a.shape[-1] - 1, self.ndim, no_channel=no_channel)
521+
self.spatial_ndim = _normalize_spatial_ndim(a.shape[-1] - 1, self.ndim)
533522

534523
@property
535524
def spatial_ndim(self) -> int:

0 commit comments

Comments
 (0)