@@ -70,81 +70,25 @@ def _has_nvrtc_pch_apis_for_tests():
7070)
7171
7272
73- _libnvvm_version = None
74- _libnvvm_version_attempted = False
75-
76- precheck_nvvm_ir = """target triple = "nvptx64-unknown-cuda"
77- target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-i128:128:128-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"
78-
79- define void @dummy_kernel() {{
80- entry:
81- ret void
82- }}
83-
84- !nvvm.annotations = !{{!0}}
85- !0 = !{{void ()* @dummy_kernel, !"kernel", i32 1}}
86-
87- !nvvmir.version = !{{!1}}
88- !1 = !{{i32 {major}, i32 {minor}, i32 {debug_major}, i32 {debug_minor}}}
89- """
90-
91-
92- def _get_libnvvm_version_for_tests ():
93- """
94- Detect libNVVM version by compiling dummy IR and analyzing the PTX output.
95-
96- Workaround for the lack of direct libNVVM version API (nvbugs 5312315).
97- The approach:
98- - Compile a small dummy NVVM IR to PTX
99- - Use PTX version analysis APIs if available to infer libNVVM version
100- - Cache the result for future use
101- """
102- global _libnvvm_version , _libnvvm_version_attempted
73+ def _has_check_nvvm_compiler_options ():
74+ try :
75+ import cuda .bindings .utils as utils
76+ except ModuleNotFoundError :
77+ return False
78+ return hasattr (utils , "check_nvvm_compiler_options" )
10379
104- if _libnvvm_version_attempted :
105- return _libnvvm_version
10680
107- _libnvvm_version_attempted = True
81+ has_nvvm_option_checker = pytest .mark .skipif (
82+ not _has_check_nvvm_compiler_options (),
83+ reason = "cuda.bindings.utils.check_nvvm_compiler_options not available (cuda-bindings too old?)" ,
84+ )
10885
109- try :
110- from cuda .core ._program import _get_nvvm_module
11186
112- nvvm = _get_nvvm_module ()
113-
114- try :
115- from cuda .bindings .utils import get_minimal_required_cuda_ver_from_ptx_ver , get_ptx_ver
116- except ImportError :
117- _libnvvm_version = None
118- return _libnvvm_version
119-
120- program = nvvm .create_program ()
121- try :
122- major , minor , debug_major , debug_minor = nvvm .ir_version ()
123- global precheck_nvvm_ir
124- precheck_nvvm_ir = precheck_nvvm_ir .format (
125- major = major , minor = minor , debug_major = debug_major , debug_minor = debug_minor
126- )
127- precheck_ir_bytes = precheck_nvvm_ir .encode ("utf-8" )
128- nvvm .add_module_to_program (program , precheck_ir_bytes , len (precheck_ir_bytes ), "precheck.ll" )
129-
130- options = ["-arch=compute_90" ]
131- nvvm .verify_program (program , len (options ), options )
132- nvvm .compile_program (program , len (options ), options )
133-
134- ptx_size = nvvm .get_compiled_result_size (program )
135- ptx_data = bytearray (ptx_size )
136- nvvm .get_compiled_result (program , ptx_data )
137- ptx_str = ptx_data .decode ("utf-8" )
138- ptx_version = get_ptx_ver (ptx_str )
139- cuda_version = get_minimal_required_cuda_ver_from_ptx_ver (ptx_version )
140- _libnvvm_version = cuda_version
141- return _libnvvm_version
142- finally :
143- nvvm .destroy_program (program )
87+ def _check_nvvm_arch (arch : str ) -> bool :
88+ """Check if the given NVVM arch is supported by the installed libNVVM."""
89+ from cuda .bindings .utils import check_nvvm_compiler_options
14490
145- except Exception :
146- _libnvvm_version = None
147- return _libnvvm_version
91+ return check_nvvm_compiler_options ([f"-arch={ arch } " ])
14892
14993
15094@pytest .fixture (scope = "session" )
@@ -524,10 +468,13 @@ def test_nvvm_compile_invalid_ir():
524468 ),
525469 pytest .param (
526470 ProgramOptions (name = "test_sm110_1" , arch = "sm_110" , device_code_optimize = False ),
527- marks = pytest .mark .skipif (
528- (_get_libnvvm_version_for_tests () or 0 ) < 13000 ,
529- reason = "Compute capability 110 requires libNVVM >= 13.0" ,
530- ),
471+ marks = [
472+ has_nvvm_option_checker ,
473+ pytest .mark .skipif (
474+ _has_check_nvvm_compiler_options () and not _check_nvvm_arch ("compute_110" ),
475+ reason = "Compute capability 110 not supported by installed libNVVM" ,
476+ ),
477+ ],
531478 ),
532479 pytest .param (
533480 ProgramOptions (
@@ -539,17 +486,23 @@ def test_nvvm_compile_invalid_ir():
539486 fma = True ,
540487 device_code_optimize = True ,
541488 ),
542- marks = pytest .mark .skipif (
543- (_get_libnvvm_version_for_tests () or 0 ) < 13000 ,
544- reason = "Compute capability 110 requires libNVVM >= 13.0" ,
545- ),
489+ marks = [
490+ has_nvvm_option_checker ,
491+ pytest .mark .skipif (
492+ _has_check_nvvm_compiler_options () and not _check_nvvm_arch ("compute_110" ),
493+ reason = "Compute capability 110 not supported by installed libNVVM" ,
494+ ),
495+ ],
546496 ),
547497 pytest .param (
548498 ProgramOptions (name = "test_sm110_3" , arch = "sm_110" , link_time_optimization = True ),
549- marks = pytest .mark .skipif (
550- (_get_libnvvm_version_for_tests () or 0 ) < 13000 ,
551- reason = "Compute capability 110 requires libNVVM >= 13.0" ,
552- ),
499+ marks = [
500+ has_nvvm_option_checker ,
501+ pytest .mark .skipif (
502+ _has_check_nvvm_compiler_options () and not _check_nvvm_arch ("compute_110" ),
503+ reason = "Compute capability 110 not supported by installed libNVVM" ,
504+ ),
505+ ],
553506 ),
554507 ],
555508)
@@ -729,12 +682,8 @@ def test_program_options_as_bytes_nvrtc():
729682 """Test ProgramOptions.as_bytes() for NVRTC backend"""
730683 options = ProgramOptions (arch = "sm_80" , debug = True , lineinfo = True , ftz = True )
731684 nvrtc_options = options .as_bytes ("nvrtc" )
732-
733- # Should return list of bytes
734685 assert isinstance (nvrtc_options , list )
735686 assert all (isinstance (opt , bytes ) for opt in nvrtc_options )
736-
737- # Decode to check content
738687 options_str = [opt .decode () for opt in nvrtc_options ]
739688 assert "-arch=sm_80" in options_str
740689 assert "--device-debug" in options_str
@@ -747,12 +696,8 @@ def test_program_options_as_bytes_nvvm():
747696 """Test ProgramOptions.as_bytes() for NVVM backend"""
748697 options = ProgramOptions (arch = "sm_80" , debug = True , ftz = True , device_code_optimize = True )
749698 nvvm_options = options .as_bytes ("nvvm" )
750-
751- # Should return list of bytes (same as other backends)
752699 assert isinstance (nvvm_options , list )
753700 assert all (isinstance (opt , bytes ) for opt in nvvm_options )
754-
755- # Decode to check content
756701 options_str = [opt .decode () for opt in nvvm_options ]
757702 assert "-arch=compute_80" in options_str
758703 assert "-g" in options_str
0 commit comments