|
23 | 23 | import monai |
24 | 24 | from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor |
25 | 25 | from monai.data.meta_obj import _DEFAULT_SPATIAL_NDIM, MetaObj, get_track_meta |
26 | | -from monai.data.utils import affine_to_spacing, decollate_batch, list_data_collate, remove_extra_metadata |
| 26 | +from monai.data.utils import affine_to_spacing, decollate_batch, is_no_channel, list_data_collate, remove_extra_metadata |
27 | 27 | from monai.utils import look_up_option |
28 | 28 | from monai.utils.enums import LazyAttr, MetaKeys, PostFix, SpaceKeys |
29 | 29 | from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_numpy, convert_to_tensor |
30 | 30 |
|
31 | 31 | __all__ = ["MetaTensor", "get_spatial_ndim"] |
32 | 32 |
|
33 | 33 |
|
34 | | -def _normalize_spatial_ndim(spatial_ndim: int, tensor_ndim: int) -> int: |
| 34 | +def _normalize_spatial_ndim(spatial_ndim: int, tensor_ndim: int, no_channel: bool = False) -> int: |
35 | 35 | """Clamp spatial dims to a valid range for the current tensor shape.""" |
36 | | - limit = max(int(tensor_ndim) - 1, 1) |
| 36 | + limit = max(int(tensor_ndim), 1) if no_channel else max(int(tensor_ndim) - 1, 1) |
37 | 37 | return max(1, min(int(spatial_ndim), limit)) |
38 | 38 |
|
39 | 39 |
|
| 40 | +def _has_explicit_no_channel(meta: Mapping | None) -> bool: |
| 41 | + return ( |
| 42 | + isinstance(meta, Mapping) |
| 43 | + and MetaKeys.ORIGINAL_CHANNEL_DIM in meta |
| 44 | + and is_no_channel(meta[MetaKeys.ORIGINAL_CHANNEL_DIM]) |
| 45 | + ) |
| 46 | + |
| 47 | + |
40 | 48 | def get_spatial_ndim(img: NdarrayOrTensor) -> int: |
41 | 49 | """Return the number of spatial dimensions assuming channel-first layout. |
42 | 50 |
|
43 | 51 | Uses ``MetaTensor.spatial_ndim`` when available, otherwise falls back to |
44 | | - ``img.ndim - 1``. |
| 52 | + ``img.ndim - 1``. Always assumes channel-first (``no_channel=False``) |
| 53 | + because callers run after ``EnsureChannelFirst`` has already added one. |
45 | 54 | """ |
46 | 55 | if isinstance(img, MetaTensor): |
47 | 56 | return _normalize_spatial_ndim(img.spatial_ndim, img.ndim) |
@@ -192,10 +201,11 @@ def __init__( |
192 | 201 | self.affine = self.get_default_affine() |
193 | 202 | # Initialize spatial_ndim from affine matrix (source of truth), clamped by tensor shape. |
194 | 203 | # This cached value is kept in sync via the affine setter for hot-path performance. |
| 204 | + no_channel = _has_explicit_no_channel(self.meta) |
195 | 205 | if spatial_ndim is not None: |
196 | | - self.spatial_ndim = _normalize_spatial_ndim(spatial_ndim, self.ndim) |
| 206 | + self.spatial_ndim = _normalize_spatial_ndim(spatial_ndim, self.ndim, no_channel=no_channel) |
197 | 207 | elif self.affine.ndim == 2: |
198 | | - self.spatial_ndim = _normalize_spatial_ndim(self.affine.shape[-1] - 1, self.ndim) |
| 208 | + self.spatial_ndim = _normalize_spatial_ndim(self.affine.shape[-1] - 1, self.ndim, no_channel=no_channel) |
199 | 209 |
|
200 | 210 | # applied_operations |
201 | 211 | if applied_operations is not None: |
@@ -518,7 +528,8 @@ def affine(self, d: NdarrayTensor) -> None: |
518 | 528 | a = torch.as_tensor(d, device=torch.device("cpu"), dtype=torch.float64) |
519 | 529 | self.meta[MetaKeys.AFFINE] = a |
520 | 530 | if a.ndim == 2: # non-batched: sync spatial_ndim from affine (source of truth) |
521 | | - self.spatial_ndim = _normalize_spatial_ndim(a.shape[-1] - 1, self.ndim) |
| 531 | + no_channel = _has_explicit_no_channel(self.meta) |
| 532 | + self.spatial_ndim = _normalize_spatial_ndim(a.shape[-1] - 1, self.ndim, no_channel=no_channel) |
522 | 533 |
|
523 | 534 | @property |
524 | 535 | def spatial_ndim(self) -> int: |
|
0 commit comments