|
27 | 27 | from monai.config.type_definitions import NdarrayOrTensor |
28 | 28 | from monai.data.box_utils import BoxMode, StandardMode |
29 | 29 | from monai.data.meta_obj import get_track_meta, set_track_meta |
30 | | -from monai.data.meta_tensor import MetaTensor |
| 30 | +from monai.data.meta_tensor import MetaTensor, get_spatial_ndim |
31 | 31 | from monai.data.utils import AFFINE_TOL, affine_to_spacing, compute_shape_offset, iter_patch, to_affine_nd, zoom_affine |
32 | 32 | from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull |
33 | 33 | from monai.networks.utils import meshgrid_ij |
@@ -848,12 +848,14 @@ def __call__( |
848 | 848 | anti_aliasing = self.anti_aliasing if anti_aliasing is None else anti_aliasing |
849 | 849 | anti_aliasing_sigma = self.anti_aliasing_sigma if anti_aliasing_sigma is None else anti_aliasing_sigma |
850 | 850 |
|
851 | | - input_ndim = img.ndim - 1 # spatial ndim |
| 851 | + input_ndim = get_spatial_ndim(img) |
852 | 852 | if self.size_mode == "all": |
853 | 853 | output_ndim = len(ensure_tuple(self.spatial_size)) |
854 | 854 | if output_ndim > input_ndim: |
855 | 855 | input_shape = ensure_tuple_size(img.shape, output_ndim + 1, 1) |
856 | 856 | img = img.reshape(input_shape) |
| 857 | + if isinstance(img, MetaTensor): |
| 858 | + img.spatial_ndim = output_ndim |
857 | 859 | elif output_ndim < input_ndim: |
858 | 860 | raise ValueError( |
859 | 861 | "len(spatial_size) must be greater or equal to img spatial dimensions, " |
@@ -1034,7 +1036,7 @@ def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor: |
1034 | 1036 | out = convert_to_dst_type(out, dst=data, dtype=out.dtype)[0] |
1035 | 1037 | if isinstance(out, MetaTensor): |
1036 | 1038 | affine = convert_to_tensor(out.peek_pending_affine(), track_meta=False) |
1037 | | - mat = to_affine_nd(len(affine) - 1, transform_t) |
| 1039 | + mat = to_affine_nd(out.spatial_ndim, transform_t) |
1038 | 1040 | out.affine @= convert_to_dst_type(mat, affine)[0] |
1039 | 1041 | return out |
1040 | 1042 |
|
@@ -1131,7 +1133,7 @@ def __call__( |
1131 | 1133 | during initialization for this call. Defaults to None. |
1132 | 1134 | """ |
1133 | 1135 | img = convert_to_tensor(img, track_meta=get_track_meta()) |
1134 | | - _zoom = ensure_tuple_rep(self.zoom, img.ndim - 1) # match the spatial image dim |
| 1136 | + _zoom = ensure_tuple_rep(self.zoom, get_spatial_ndim(img)) |
1135 | 1137 | _mode = self.mode if mode is None else mode |
1136 | 1138 | _padding_mode = padding_mode or self.padding_mode |
1137 | 1139 | _align_corners = self.align_corners if align_corners is None else align_corners |
@@ -1519,7 +1521,7 @@ def randomize(self, data: NdarrayOrTensor) -> None: |
1519 | 1521 | super().randomize(None) |
1520 | 1522 | if not self._do_transform: |
1521 | 1523 | return None |
1522 | | - self._axis = self.R.randint(data.ndim - 1) |
| 1524 | + self._axis = self.R.randint(get_spatial_ndim(data)) |
1523 | 1525 |
|
1524 | 1526 | def __call__(self, img: torch.Tensor, randomize: bool = True, lazy: bool | None = None) -> torch.Tensor: |
1525 | 1527 | """ |
@@ -1629,13 +1631,14 @@ def randomize(self, img: NdarrayOrTensor) -> None: |
1629 | 1631 | super().randomize(None) |
1630 | 1632 | if not self._do_transform: |
1631 | 1633 | return None |
| 1634 | + _sp = get_spatial_ndim(img) |
1632 | 1635 | self._zoom = [self.R.uniform(l, h) for l, h in zip(self.min_zoom, self.max_zoom)] |
1633 | 1636 | if len(self._zoom) == 1: |
1634 | 1637 | # to keep the spatial shape ratio, use same random zoom factor for all dims |
1635 | | - self._zoom = ensure_tuple_rep(self._zoom[0], img.ndim - 1) |
1636 | | - elif len(self._zoom) == 2 and img.ndim > 3: |
| 1638 | + self._zoom = ensure_tuple_rep(self._zoom[0], _sp) |
| 1639 | + elif len(self._zoom) == 2 and _sp > 2: |
1637 | 1640 | # if 2 zoom factors provided for 3D data, use the first factor for H and W dims, second factor for D dim |
1638 | | - self._zoom = ensure_tuple_rep(self._zoom[0], img.ndim - 2) + ensure_tuple(self._zoom[-1]) |
| 1641 | + self._zoom = ensure_tuple_rep(self._zoom[0], _sp - 1) + ensure_tuple(self._zoom[-1]) |
1639 | 1642 |
|
1640 | 1643 | def __call__( |
1641 | 1644 | self, |
@@ -2350,7 +2353,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: |
2350 | 2353 | out.meta = data.meta # type: ignore |
2351 | 2354 | affine = convert_data_type(out.peek_pending_affine(), torch.Tensor)[0] |
2352 | 2355 | xform, *_ = convert_to_dst_type( |
2353 | | - Affine.compute_w_affine(len(affine) - 1, inv_affine, data.shape[1:], orig_size), affine |
| 2356 | + Affine.compute_w_affine(out.spatial_ndim, inv_affine, data.shape[1:], orig_size), affine |
2354 | 2357 | ) |
2355 | 2358 | out.affine @= xform |
2356 | 2359 | return out |
@@ -2619,7 +2622,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: |
2619 | 2622 | out.meta = data.meta # type: ignore |
2620 | 2623 | affine = convert_data_type(out.peek_pending_affine(), torch.Tensor)[0] |
2621 | 2624 | xform, *_ = convert_to_dst_type( |
2622 | | - Affine.compute_w_affine(len(affine) - 1, inv_affine, data.shape[1:], orig_size), affine |
| 2625 | + Affine.compute_w_affine(out.spatial_ndim, inv_affine, data.shape[1:], orig_size), affine |
2623 | 2626 | ) |
2624 | 2627 | out.affine @= xform |
2625 | 2628 | return out |
@@ -3032,7 +3035,7 @@ def __call__( |
3032 | 3035 | raise ValueError("the spatial size of `img` does not match with the length of `distort_steps`") |
3033 | 3036 |
|
3034 | 3037 | all_ranges = [] |
3035 | | - num_cells = ensure_tuple_rep(self.num_cells, len(img.shape) - 1) |
| 3038 | + num_cells = ensure_tuple_rep(self.num_cells, get_spatial_ndim(img)) |
3036 | 3039 | if isinstance(img, MetaTensor) and img.pending_operations: |
3037 | 3040 | warnings.warn("MetaTensor img has pending operations, transform may return incorrect results.") |
3038 | 3041 | for dim_idx, dim_size in enumerate(img.shape[1:]): |
|
0 commit comments