Skip to content

Commit 545184b

Browse files
committed
address coderabbit
1 parent 29cf6ac commit 545184b

File tree

3 files changed

+10
-9
lines changed

3 files changed

+10
-9
lines changed

monai/transforms/spatial/functional.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,10 @@ def spatial_resample(
9999
src_affine: torch.Tensor = img.peek_pending_affine() if isinstance(img, MetaTensor) else torch.eye(4)
100100
img = convert_to_tensor(data=img, track_meta=get_track_meta())
101101
# ensure spatial rank is <= 3
102-
spatial_rank = min(get_spatial_ndim(img), 3)
102+
max_rank = max(int(img.ndim) - 1, 1)
103+
spatial_rank = min(get_spatial_ndim(img), max_rank, 3)
103104
if (not isinstance(spatial_size, int) or spatial_size != -1) and spatial_size is not None:
104-
spatial_rank = min(len(ensure_tuple(spatial_size)), 3) # infer spatial rank based on spatial_size
105+
spatial_rank = min(len(ensure_tuple(spatial_size)), max_rank, 3) # infer spatial rank based on spatial_size
105106
src_affine = to_affine_nd(spatial_rank, src_affine).to(torch.float64)
106107
dst_affine = to_affine_nd(spatial_rank, dst_affine) if dst_affine is not None else src_affine
107108
dst_affine = convert_to_dst_type(dst_affine, src_affine)[0]

tests/data/meta_tensor/test_spatial_ndim.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
PRESERVATION_CASES = [
3838
("reshape", lambda t: t.reshape(1, 100), 2),
3939
("unsqueeze", lambda t: t.unsqueeze(0), 2),
40-
("squeeze", lambda t: MetaTensor(torch.randn(1, 1, 10, 10), affine=torch.eye(3)).squeeze(1), 2),
40+
("squeeze", lambda t: t.unsqueeze(1).squeeze(1), 2),
4141
("clone", lambda t: t.clone(), 2),
4242
("deepcopy", lambda t: deepcopy(t), 2),
4343
]

tests/transforms/utility/test_splitdim.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ def test_spatial_ndim_decremented(self):
5656
self.assertEqual(arr.spatial_ndim, 3)
5757
out = SplitDim(dim=1, keepdim=False)(arr)
5858
for item in out:
59-
if isinstance(item, MetaTensor):
60-
self.assertEqual(item.spatial_ndim, 2)
59+
self.assertIsInstance(item, MetaTensor)
60+
self.assertEqual(item.spatial_ndim, 2)
6161

6262
def test_spatial_ndim_negative_dim(self):
6363
"""spatial_ndim decremented for keepdim=False with negative dim."""
@@ -67,8 +67,8 @@ def test_spatial_ndim_negative_dim(self):
6767
self.assertEqual(arr.spatial_ndim, 3)
6868
out = SplitDim(dim=-1, keepdim=False)(arr)
6969
for item in out:
70-
if isinstance(item, MetaTensor):
71-
self.assertEqual(item.spatial_ndim, 2)
70+
self.assertIsInstance(item, MetaTensor)
71+
self.assertEqual(item.spatial_ndim, 2)
7272

7373
def test_spatial_ndim_channel_dim_no_decrement(self):
7474
"""spatial_ndim not decremented for keepdim=False on channel dim (dim=0)."""
@@ -78,8 +78,8 @@ def test_spatial_ndim_channel_dim_no_decrement(self):
7878
self.assertEqual(arr.spatial_ndim, 2)
7979
out = SplitDim(dim=0, keepdim=False)(arr)
8080
for item in out:
81-
if isinstance(item, MetaTensor):
82-
self.assertEqual(item.spatial_ndim, 2)
81+
self.assertIsInstance(item, MetaTensor)
82+
self.assertEqual(item.spatial_ndim, 2)
8383

8484

8585
if __name__ == "__main__":

0 commit comments

Comments
 (0)