1414from __future__ import annotations
1515
1616import unittest
17+ from unittest .mock import MagicMock , patch
1718
1819import torch
1920
@@ -28,23 +29,16 @@ def test_cpu_device_always_supported(self):
2829 device = torch .device ("cpu" )
2930 self .assertFalse (_compiled_unsupported (device ))
3031
31- def test_non_cuda_device_always_supported (self ):
32- """Non-CUDA devices should always be supported."""
33- device = torch .device ("cpu" )
34- self .assertFalse (_compiled_unsupported (device ))
35-
3632 @unittest .skipIf (not torch .cuda .is_available (), reason = "CUDA not available" )
3733 def test_cuda_device_detection (self ):
3834 """Verify CUDA compute capability detection."""
39- device = torch .device ("cuda:0" if torch .cuda .is_available () else "cpu" )
40- if device .type == "cuda" :
41- cc_major = torch .cuda .get_device_properties (device ).major
42- unsupported = _compiled_unsupported (device )
43- # Device is unsupported if cc_major >= 12
44- if cc_major >= 12 :
45- self .assertTrue (unsupported )
46- else :
47- self .assertFalse (unsupported )
35+ device = torch .device ("cuda:0" )
36+ cc_major = torch .cuda .get_device_properties (device ).major
37+ unsupported = _compiled_unsupported (device )
38+ if cc_major >= 12 :
39+ self .assertTrue (unsupported )
40+ else :
41+ self .assertFalse (unsupported )
4842
4943 def test_compiled_unsupported_return_type (self ):
5044 """Verify return type is bool."""
@@ -56,19 +50,24 @@ def test_compiled_unsupported_return_type(self):
5650class TestResampleFallback (unittest .TestCase ):
5751 """Test Resample fallback behavior on unsupported devices."""
5852
59- @unittest .skipIf (not torch .cuda .is_available (), reason = "CUDA not available" )
6053 def test_resample_compilation_flag_respected (self ):
61- """Verify Resample respects _compiled_unsupported check."""
62- # This would require internal inspection or output verification
63- # Could test with mock device properties or actual Blackwell GPU
54+ """Verify _compiled_unsupported identifies Blackwell (cc>=12) and supported (cc<12) devices."""
55+ mock_props = MagicMock ()
56+ cuda_device = torch .device ("cuda:0" )
57+
58+ mock_props .major = 12 # Blackwell – unsupported
59+ with patch ("torch.cuda.get_device_properties" , return_value = mock_props ):
60+ self .assertTrue (_compiled_unsupported (cuda_device ))
61+
62+ mock_props .major = 9 # Hopper – supported
63+ with patch ("torch.cuda.get_device_properties" , return_value = mock_props ):
64+ self .assertFalse (_compiled_unsupported (cuda_device ))
6465
6566 def test_compiled_unsupported_logic (self ):
6667 """Test that unsupported devices are correctly detected."""
67- # CPU should be supported
6868 cpu_device = torch .device ("cpu" )
6969 self .assertFalse (_compiled_unsupported (cpu_device ))
7070
71- # Verify logic: return True if CUDA and cc_major >= 12
7271 cuda_device = torch .device ("cuda:0" if torch .cuda .is_available () else "cpu" )
7372 if cuda_device .type == "cuda" :
7473 cc_major = torch .cuda .get_device_properties (cuda_device ).major
@@ -79,5 +78,3 @@ def test_compiled_unsupported_logic(self):
7978
8079if __name__ == "__main__" :
8180 unittest .main ()
82- if __name__ == "__main__" :
83- unittest .main ()
0 commit comments