|
23 | 23 | import monai |
24 | 24 | from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor |
25 | 25 | from monai.data.meta_obj import _DEFAULT_SPATIAL_NDIM, MetaObj, get_track_meta |
26 | | -from monai.data.utils import affine_to_spacing, decollate_batch, is_no_channel, list_data_collate, remove_extra_metadata |
| 26 | +from monai.data.utils import affine_to_spacing, decollate_batch, list_data_collate, remove_extra_metadata |
27 | 27 | from monai.utils import look_up_option |
28 | 28 | from monai.utils.enums import LazyAttr, MetaKeys, PostFix, SpaceKeys |
29 | 29 | from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_numpy, convert_to_tensor |
30 | 30 |
|
31 | 31 | __all__ = ["MetaTensor", "get_spatial_ndim"] |
32 | 32 |
|
33 | 33 |
|
34 | | -def _normalize_spatial_ndim(spatial_ndim: int, tensor_ndim: int, no_channel: bool = False) -> int: |
| 34 | +def _normalize_spatial_ndim(spatial_ndim: int, tensor_ndim: int) -> int: |
35 | 35 | """Clamp spatial dims to a valid range for the current tensor shape.""" |
36 | | - limit = max(int(tensor_ndim), 1) if no_channel else max(int(tensor_ndim) - 1, 1) |
| 36 | + limit = max(int(tensor_ndim) - 1, 1) |
37 | 37 | return max(1, min(int(spatial_ndim), limit)) |
38 | 38 |
|
39 | 39 |
|
40 | | -def _has_explicit_no_channel(meta: Mapping | None) -> bool: |
41 | | - return ( |
42 | | - isinstance(meta, Mapping) |
43 | | - and MetaKeys.ORIGINAL_CHANNEL_DIM in meta |
44 | | - and is_no_channel(meta[MetaKeys.ORIGINAL_CHANNEL_DIM]) |
45 | | - ) |
46 | | - |
47 | | - |
48 | 40 | def get_spatial_ndim(img: NdarrayOrTensor) -> int: |
49 | 41 | """Return the number of spatial dimensions assuming channel-first layout. |
50 | 42 |
|
51 | 43 | Uses ``MetaTensor.spatial_ndim`` when available, otherwise falls back to |
52 | 44 | ``img.ndim - 1``. |
53 | 45 | """ |
54 | 46 | if isinstance(img, MetaTensor): |
55 | | - no_channel = _has_explicit_no_channel(img.meta) |
56 | | - return _normalize_spatial_ndim(img.spatial_ndim, img.ndim, no_channel=no_channel) |
| 47 | + return _normalize_spatial_ndim(img.spatial_ndim, img.ndim) |
57 | 48 | return img.ndim - 1 |
58 | 49 |
|
59 | 50 |
|
@@ -201,11 +192,10 @@ def __init__( |
201 | 192 | self.affine = self.get_default_affine() |
202 | 193 | # Initialize spatial_ndim from affine matrix (source of truth), clamped by tensor shape. |
203 | 194 | # This cached value is kept in sync via the affine setter for hot-path performance. |
204 | | - no_channel = _has_explicit_no_channel(self.meta) |
205 | 195 | if spatial_ndim is not None: |
206 | | - self.spatial_ndim = _normalize_spatial_ndim(spatial_ndim, self.ndim, no_channel=no_channel) |
| 196 | + self.spatial_ndim = _normalize_spatial_ndim(spatial_ndim, self.ndim) |
207 | 197 | elif self.affine.ndim == 2: |
208 | | - self.spatial_ndim = _normalize_spatial_ndim(self.affine.shape[-1] - 1, self.ndim, no_channel=no_channel) |
| 198 | + self.spatial_ndim = _normalize_spatial_ndim(self.affine.shape[-1] - 1, self.ndim) |
209 | 199 |
|
210 | 200 | # applied_operations |
211 | 201 | if applied_operations is not None: |
@@ -309,7 +299,7 @@ def _handle_batched(cls, ret, idx, metas, func, args, kwargs): |
309 | 299 | if hasattr(ret_meta, "__dict__"): |
310 | 300 | ret.__dict__ = ret_meta.__dict__.copy() |
311 | 301 | if _is_batch_only_index(full_idx): |
312 | | - ret.spatial_ndim = _normalize_spatial_ndim(ret.spatial_ndim, ret.ndim, no_channel=False) |
| 302 | + ret.spatial_ndim = _normalize_spatial_ndim(ret.spatial_ndim, ret.ndim) |
313 | 303 | # `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`. |
314 | 304 | # But we only want to split the batch if the `unbind` is along the 0th dimension. |
315 | 305 | elif func == torch.Tensor.unbind: |
@@ -528,8 +518,7 @@ def affine(self, d: NdarrayTensor) -> None: |
528 | 518 | a = torch.as_tensor(d, device=torch.device("cpu"), dtype=torch.float64) |
529 | 519 | self.meta[MetaKeys.AFFINE] = a |
530 | 520 | if a.ndim == 2: # non-batched: sync spatial_ndim from affine (source of truth) |
531 | | - no_channel = _has_explicit_no_channel(self.meta) |
532 | | - self.spatial_ndim = _normalize_spatial_ndim(a.shape[-1] - 1, self.ndim, no_channel=no_channel) |
| 521 | + self.spatial_ndim = _normalize_spatial_ndim(a.shape[-1] - 1, self.ndim) |
533 | 522 |
|
534 | 523 | @property |
535 | 524 | def spatial_ndim(self) -> int: |
|
0 commit comments