Skip to content

Commit 2b5b367

Browse files
committed
fix(warp): fix mode type mismatch in Warp fallback path on Blackwell GPUs
When USE_COMPILED=True, Warp.__init__ stores integer interpolation/padding modes (e.g. 1, 0, 7) for grid_pull. However, when a Blackwell GPU triggers the runtime fallback in forward(), F.grid_sample is called with those integers instead of the required strings ("bilinear", "zeros", etc.), causing a crash. Fix by always storing _interp_mode_native and _padding_mode_native as string attributes in __init__, and using them exclusively in the F.grid_sample call. Also clean up test_spatial_gpu_support.py: - Remove duplicate test_non_cuda_device_always_supported (identical to cpu test) - Implement test_resample_compilation_flag_respected using unittest.mock so it runs without a physical Blackwell device - Remove duplicate if __name__ == "__main__" block https://claude.ai/code/session_015SGxtVTnVKUHk7hWHLSRZK
1 parent 057ff4d commit 2b5b367

File tree

2 files changed

+34
-25
lines changed

2 files changed

+34
-25
lines changed

monai/networks/blocks/warp.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,18 @@ def __init__(self, mode=GridSampleMode.BILINEAR.value, padding_mode=GridSamplePa
6464
super().__init__()
6565
# resolves _interp_mode for different methods
6666

67+
# Native string modes are always stored for the PyTorch fallback path.
68+
# When USE_COMPILED=True but the device is unsupported at runtime (e.g. Blackwell GPU),
69+
# forward() falls back to F.grid_sample which requires string, not integer, modes.
70+
self._interp_mode_native = (
71+
GridSampleMode(mode).value if mode in (m.value for m in GridSampleMode) else GridSampleMode.BILINEAR.value
72+
)
73+
self._padding_mode_native = (
74+
GridSamplePadMode(padding_mode).value
75+
if padding_mode in (p.value for p in GridSamplePadMode)
76+
else GridSamplePadMode.BORDER.value
77+
)
78+
6779
if USE_COMPILED:
6880
if mode in (inter.value for inter in GridSampleMode):
6981
mode = GridSampleMode(mode)
@@ -78,7 +90,7 @@ def __init__(self, mode=GridSampleMode.BILINEAR.value, padding_mode=GridSamplePa
7890
self._interp_mode = mode
7991
else:
8092
warnings.warn("monai.networks.blocks.Warp: Using PyTorch native grid_sample.")
81-
self._interp_mode = GridSampleMode(mode).value
93+
self._interp_mode = self._interp_mode_native
8294

8395
# resolves _padding_mode for different methods
8496
if USE_COMPILED:
@@ -94,7 +106,7 @@ def __init__(self, mode=GridSampleMode.BILINEAR.value, padding_mode=GridSamplePa
94106
padding_mode = 0 # default to nearest
95107
self._padding_mode = padding_mode
96108
else:
97-
self._padding_mode = GridSamplePadMode(padding_mode).value
109+
self._padding_mode = self._padding_mode_native
98110

99111
self.ref_grid = None
100112
self.jitter = jitter
@@ -147,7 +159,7 @@ def forward(self, image: torch.Tensor, ddf: torch.Tensor):
147159
index_ordering: list[int] = list(range(spatial_dims - 1, -1, -1))
148160
grid = grid[..., index_ordering] # z, y, x -> x, y, z
149161
return F.grid_sample(
150-
image, grid, mode=self._interp_mode, padding_mode=f"{self._padding_mode}", align_corners=True
162+
image, grid, mode=self._interp_mode_native, padding_mode=self._padding_mode_native, align_corners=True
151163
)
152164

153165
# using csrc resampling

tests/transforms/test_spatial_gpu_support.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import annotations
1515

1616
import unittest
17+
from unittest.mock import MagicMock, patch
1718

1819
import torch
1920

@@ -28,23 +29,16 @@ def test_cpu_device_always_supported(self):
2829
device = torch.device("cpu")
2930
self.assertFalse(_compiled_unsupported(device))
3031

31-
def test_non_cuda_device_always_supported(self):
32-
"""Non-CUDA devices should always be supported."""
33-
device = torch.device("cpu")
34-
self.assertFalse(_compiled_unsupported(device))
35-
3632
@unittest.skipIf(not torch.cuda.is_available(), reason="CUDA not available")
3733
def test_cuda_device_detection(self):
3834
"""Verify CUDA compute capability detection."""
39-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
40-
if device.type == "cuda":
41-
cc_major = torch.cuda.get_device_properties(device).major
42-
unsupported = _compiled_unsupported(device)
43-
# Device is unsupported if cc_major >= 12
44-
if cc_major >= 12:
45-
self.assertTrue(unsupported)
46-
else:
47-
self.assertFalse(unsupported)
35+
device = torch.device("cuda:0")
36+
cc_major = torch.cuda.get_device_properties(device).major
37+
unsupported = _compiled_unsupported(device)
38+
if cc_major >= 12:
39+
self.assertTrue(unsupported)
40+
else:
41+
self.assertFalse(unsupported)
4842

4943
def test_compiled_unsupported_return_type(self):
5044
"""Verify return type is bool."""
@@ -56,19 +50,24 @@ def test_compiled_unsupported_return_type(self):
5650
class TestResampleFallback(unittest.TestCase):
5751
"""Test Resample fallback behavior on unsupported devices."""
5852

59-
@unittest.skipIf(not torch.cuda.is_available(), reason="CUDA not available")
6053
def test_resample_compilation_flag_respected(self):
61-
"""Verify Resample respects _compiled_unsupported check."""
62-
# This would require internal inspection or output verification
63-
# Could test with mock device properties or actual Blackwell GPU
54+
"""Verify _compiled_unsupported identifies Blackwell (cc>=12) and supported (cc<12) devices."""
55+
mock_props = MagicMock()
56+
cuda_device = torch.device("cuda:0")
57+
58+
mock_props.major = 12 # Blackwell – unsupported
59+
with patch("torch.cuda.get_device_properties", return_value=mock_props):
60+
self.assertTrue(_compiled_unsupported(cuda_device))
61+
62+
mock_props.major = 9 # Hopper – supported
63+
with patch("torch.cuda.get_device_properties", return_value=mock_props):
64+
self.assertFalse(_compiled_unsupported(cuda_device))
6465

6566
def test_compiled_unsupported_logic(self):
6667
"""Test that unsupported devices are correctly detected."""
67-
# CPU should be supported
6868
cpu_device = torch.device("cpu")
6969
self.assertFalse(_compiled_unsupported(cpu_device))
7070

71-
# Verify logic: return True if CUDA and cc_major >= 12
7271
cuda_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
7372
if cuda_device.type == "cuda":
7473
cc_major = torch.cuda.get_device_properties(cuda_device).major
@@ -79,5 +78,3 @@ def test_compiled_unsupported_logic(self):
7978

8079
if __name__ == "__main__":
8180
unittest.main()
82-
if __name__ == "__main__":
83-
unittest.main()

0 commit comments

Comments
 (0)