|
31 | 31 | __all__ = ["MetaTensor", "get_spatial_ndim"] |
32 | 32 |
|
33 | 33 |
|
| 34 | +def _normalize_spatial_ndim(spatial_ndim: int, tensor_ndim: int) -> int: |
| 35 | + """Clamp spatial dims to a valid range for the current tensor shape.""" |
| 36 | + return max(1, min(int(spatial_ndim), max(int(tensor_ndim) - 1, 1))) |
| 37 | + |
| 38 | + |
34 | 39 | def get_spatial_ndim(img: NdarrayOrTensor) -> int: |
35 | 40 | """Return the number of spatial dimensions assuming channel-first layout. |
36 | 41 |
|
37 | 42 | Uses ``MetaTensor.spatial_ndim`` when available, otherwise falls back to |
38 | 43 | ``img.ndim - 1``. |
39 | 44 | """ |
40 | 45 | if isinstance(img, MetaTensor): |
41 | | - return img.spatial_ndim |
| 46 | + inferred = _normalize_spatial_ndim(img.spatial_ndim, img.ndim) |
| 47 | + shape_spatial = max(img.ndim - 1, 1) |
| 48 | + # For non-batched tensors, preserve explicit higher-rank shape information |
| 49 | + # (e.g., invalid 4D spatial inputs should still be reported as rank 4). |
| 50 | + if not img.is_batch and shape_spatial > inferred: |
| 51 | + return shape_spatial |
| 52 | + return inferred |
42 | 53 | return img.ndim - 1 |
43 | 54 |
|
44 | 55 |
|
@@ -175,9 +186,9 @@ def __init__( |
175 | 186 | self.affine = self.get_default_affine() |
176 | 187 | # derive spatial_ndim from affine, clamped by tensor shape |
177 | 188 | if spatial_ndim is not None: |
178 | | - self.spatial_ndim = spatial_ndim |
| 189 | + self.spatial_ndim = _normalize_spatial_ndim(spatial_ndim, self.ndim) |
179 | 190 | elif self.affine.ndim == 2: |
180 | | - self.spatial_ndim = min(self.affine.shape[-1] - 1, max(self.ndim - 1, 1)) |
| 191 | + self.spatial_ndim = _normalize_spatial_ndim(self.affine.shape[-1] - 1, self.ndim) |
181 | 192 |
|
182 | 193 | # applied_operations |
183 | 194 | if applied_operations is not None: |
@@ -243,6 +254,8 @@ def update_meta(rets: Sequence, func, args, kwargs) -> Sequence: |
243 | 254 | # raise NotImplementedError("torch.cat is not implemented for batch of MetaTensors.") |
244 | 255 | if is_batch: |
245 | 256 | ret = MetaTensor._handle_batched(ret, idx, metas, func, args, kwargs) |
| 257 | + if func == torch.Tensor.__getitem__: |
| 258 | + ret.spatial_ndim = _normalize_spatial_ndim(ret.spatial_ndim, ret.ndim) |
246 | 259 | out.append(ret) |
247 | 260 | # if the input was a tuple, then return it as a tuple |
248 | 261 | return tuple(out) if isinstance(rets, tuple) else out |
@@ -492,7 +505,7 @@ def affine(self, d: NdarrayTensor) -> None: |
492 | 505 | a = torch.as_tensor(d, device=torch.device("cpu"), dtype=torch.float64) |
493 | 506 | self.meta[MetaKeys.AFFINE] = a |
494 | 507 | if a.ndim == 2: # non-batched: sync spatial_ndim |
495 | | - self.spatial_ndim = a.shape[-1] - 1 |
| 508 | + self.spatial_ndim = _normalize_spatial_ndim(a.shape[-1] - 1, self.ndim) |
496 | 509 |
|
497 | 510 | @property |
498 | 511 | def spatial_ndim(self) -> int: |
|
0 commit comments