Skip to content

Commit 7fbde8a

Browse files
committed
Fix 2D inverse transform shape mismatch (4x4 vs 3x3)
Use affine matrix rank instead of spatial_ndim in Rotate, Affine, and RandAffine inverse methods to avoid RuntimeError when spatial_ndim=2 but the stored affine is 4x4. Add regression tests for all three. Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
1 parent d30a73e commit 7fbde8a

File tree

2 files changed

+36
-4
lines changed

2 files changed

+36
-4
lines changed

monai/transforms/spatial/array.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,7 +1036,7 @@ def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor:
10361036
out = convert_to_dst_type(out, dst=data, dtype=out.dtype)[0]
10371037
if isinstance(out, MetaTensor):
10381038
affine = convert_to_tensor(out.peek_pending_affine(), track_meta=False)
1039-
mat = to_affine_nd(out.spatial_ndim, transform_t)
1039+
mat = to_affine_nd(len(affine) - 1, transform_t)
10401040
out.affine @= convert_to_dst_type(mat, affine)[0]
10411041
return out
10421042

@@ -2353,7 +2353,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor:
23532353
out.meta = data.meta # type: ignore
23542354
affine = convert_data_type(out.peek_pending_affine(), torch.Tensor)[0]
23552355
xform, *_ = convert_to_dst_type(
2356-
Affine.compute_w_affine(out.spatial_ndim, inv_affine, data.shape[1:], orig_size), affine
2356+
Affine.compute_w_affine(len(affine) - 1, inv_affine, data.shape[1:], orig_size), affine
23572357
)
23582358
out.affine @= xform
23592359
return out
@@ -2622,7 +2622,7 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor:
26222622
out.meta = data.meta # type: ignore
26232623
affine = convert_data_type(out.peek_pending_affine(), torch.Tensor)[0]
26242624
xform, *_ = convert_to_dst_type(
2625-
Affine.compute_w_affine(out.spatial_ndim, inv_affine, data.shape[1:], orig_size), affine
2625+
Affine.compute_w_affine(len(affine) - 1, inv_affine, data.shape[1:], orig_size), affine
26262626
)
26272627
out.affine @= xform
26282628
return out

tests/data/meta_tensor/test_spatial_ndim.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@
1515
from copy import deepcopy
1616
from unittest import skipUnless
1717

18+
import numpy as np
1819
import torch
1920
from parameterized import parameterized
2021

2122
from monai.data import MetaTensor
2223
from monai.data.utils import collate_meta_tensor_fn, decollate_batch
23-
from monai.transforms import Resize, SqueezeDim
24+
from monai.transforms import Affine, RandAffine, Resize, Rotate, SqueezeDim
2425
from monai.transforms.utility.array import SplitDim
2526
from monai.utils import optional_import
2627

@@ -134,6 +135,37 @@ def test_einops_rearrange_then_resize(self):
134135
out = Resize(spatial_size=(32, 32, 3), mode="trilinear", align_corners=True)(x_)
135136
self.assertEqual(out.shape[-3:], (32, 32, 3))
136137

138+
def test_affine_inverse_2d_metatensor(self):
139+
"""Affine.inverse on 2D data: 4x4 affine with spatial_ndim=2."""
140+
img = MetaTensor(torch.randn(1, 32, 32), affine=torch.eye(4))
141+
self.assertEqual(img.spatial_ndim, 2)
142+
xform = Affine(rotate_params=(np.pi / 6,), padding_mode="zeros", image_only=True)
143+
result = xform(img)
144+
inv = xform.inverse(result)
145+
self.assertEqual(inv.shape, img.shape)
146+
self.assertEqual(len(inv.applied_operations), 0)
147+
148+
def test_rotate_inverse_2d_metatensor(self):
149+
"""Rotate.inverse on 2D data: 4x4 affine with spatial_ndim=2."""
150+
img = MetaTensor(torch.randn(1, 32, 32), affine=torch.eye(4))
151+
self.assertEqual(img.spatial_ndim, 2)
152+
xform = Rotate(angle=(np.pi / 4,), padding_mode="zeros")
153+
result = xform(img)
154+
inv = xform.inverse(result)
155+
self.assertEqual(inv.shape, img.shape)
156+
self.assertEqual(len(inv.applied_operations), 0)
157+
158+
def test_rand_affine_inverse_2d_metatensor(self):
159+
"""RandAffine.inverse on 2D data: 4x4 affine with spatial_ndim=2."""
160+
img = MetaTensor(torch.randn(1, 32, 32), affine=torch.eye(4))
161+
self.assertEqual(img.spatial_ndim, 2)
162+
xform = RandAffine(prob=1.0, rotate_range=(np.pi / 6,), padding_mode="zeros")
163+
xform.set_random_state(seed=42)
164+
result = xform(img)
165+
inv = xform.inverse(result)
166+
self.assertEqual(inv.shape, img.shape)
167+
self.assertEqual(len(inv.applied_operations), 0)
168+
137169

138170
if __name__ == "__main__":
139171
unittest.main()

0 commit comments

Comments
 (0)