@@ -39,6 +39,7 @@ from cuda.core._utils.cuda_utils import (
3939 driver,
4040 is_sequence,
4141)
42+ from cuda.core._utils.version import driver_version
4243
4344ctypedef const char * const_char_ptr
4445ctypedef void * void_ptr
@@ -181,9 +182,20 @@ cdef class Linker:
181182class LinkerOptions:
182183 """Customizable options for configuring :class:`Linker`.
183184
184- Since the linker may choose to use nvJitLink or the driver APIs as the linking backend ,
185- not all options are applicable. When the system's installed nvJitLink is too old (<12.3),
186- or not installed , the driver APIs (cuLink ) will be used instead.
185+ Since the linker may choose either nvJitLink or the driver's ``cuLink*``
186+ APIs as the backend , not every option is applicable to both backends. The
187+ backend is decided per-:class:`Linker` instance from the installed CUDA
188+ driver major version , nvJitLink's availability and major version , the input
189+ code types , and whether link-time optimization is requested:
190+
191+ - nvJitLink is used when its major version matches the driver's.
192+ - The driver linker is used when nvJitLink is unavailable or too old
193+ (<12.3), or when its major version differs from the driver's (and no LTO
194+ step is required ).
195+ - Linking LTO IRs , or requesting ``link_time_optimization`` / ``ptx``, with
196+ nvJitLink unavailable or with mismatched nvJitLink and driver majors is
197+ unsupported and raises :class:`RuntimeError` at :class:`Linker`
198+ construction time.
187199
188200 Attributes
189201 ----------
@@ -348,39 +360,39 @@ class LinkerOptions:
348360 formatted_options.extend((bytearray(size ), size , bytearray(size ), size ))
349361 option_keys.extend(
350362 (
351- _driver .CUjit_option.CU_JIT_INFO_LOG_BUFFER ,
352- _driver .CUjit_option.CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES ,
353- _driver .CUjit_option.CU_JIT_ERROR_LOG_BUFFER ,
354- _driver .CUjit_option.CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES ,
363+ driver .CUjit_option.CU_JIT_INFO_LOG_BUFFER ,
364+ driver .CUjit_option.CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES ,
365+ driver .CUjit_option.CU_JIT_ERROR_LOG_BUFFER ,
366+ driver .CUjit_option.CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES ,
355367 )
356368 )
357369
358370 if self.arch is not None:
359371 arch = self .arch.split(" _" )[- 1 ].upper()
360- formatted_options.append(getattr(_driver .CUjit_target , f"CU_TARGET_COMPUTE_{arch}"))
361- option_keys.append(_driver .CUjit_option.CU_JIT_TARGET )
372+ formatted_options.append(getattr(driver .CUjit_target , f"CU_TARGET_COMPUTE_{arch}"))
373+ option_keys.append(driver .CUjit_option.CU_JIT_TARGET )
362374 if self.max_register_count is not None:
363375 formatted_options.append(self.max_register_count )
364- option_keys.append(_driver .CUjit_option.CU_JIT_MAX_REGISTERS )
376+ option_keys.append(driver .CUjit_option.CU_JIT_MAX_REGISTERS )
365377 if self.time is not None:
366378 raise ValueError("time option is not supported by the driver API")
367379 if self.verbose:
368380 formatted_options.append(1)
369- option_keys.append(_driver .CUjit_option.CU_JIT_LOG_VERBOSE )
381+ option_keys.append(driver .CUjit_option.CU_JIT_LOG_VERBOSE )
370382 if self.link_time_optimization:
371383 formatted_options.append(1)
372- option_keys.append(_driver .CUjit_option.CU_JIT_LTO )
384+ option_keys.append(driver .CUjit_option.CU_JIT_LTO )
373385 if self.ptx:
374386 raise ValueError("ptx option is not supported by the driver API")
375387 if self.optimization_level is not None:
376388 formatted_options.append(self.optimization_level )
377- option_keys.append(_driver .CUjit_option.CU_JIT_OPTIMIZATION_LEVEL )
389+ option_keys.append(driver .CUjit_option.CU_JIT_OPTIMIZATION_LEVEL )
378390 if self.debug:
379391 formatted_options.append(1)
380- option_keys.append(_driver .CUjit_option.CU_JIT_GENERATE_DEBUG_INFO )
392+ option_keys.append(driver .CUjit_option.CU_JIT_GENERATE_DEBUG_INFO )
381393 if self.lineinfo:
382394 formatted_options.append(1)
383- option_keys.append(_driver .CUjit_option.CU_JIT_GENERATE_LINE_INFO )
395+ option_keys.append(driver .CUjit_option.CU_JIT_GENERATE_LINE_INFO )
384396 if self.ftz is not None:
385397 warn("ftz option is deprecated in the driver API", DeprecationWarning , stacklevel = 3 )
386398 if self.prec_div is not None:
@@ -402,8 +414,8 @@ class LinkerOptions:
402414 if self.split_compile_extended is not None:
403415 raise ValueError("split_compile_extended option is not supported by the driver API")
404416 if self.no_cache is True:
405- formatted_options.append(_driver .CUjit_cacheMode.CU_JIT_CACHE_OPTION_NONE )
406- option_keys.append(_driver .CUjit_option.CU_JIT_CACHE_MODE )
417+ formatted_options.append(driver .CUjit_cacheMode.CU_JIT_CACHE_OPTION_NONE )
418+ option_keys.append(driver .CUjit_option.CU_JIT_CACHE_MODE )
407419
408420 return formatted_options , option_keys
409421
@@ -430,7 +442,7 @@ class LinkerOptions:
430442 backend = backend.lower()
431443 if backend != "nvjitlink":
432444 raise ValueError(f"as_bytes() only supports 'nvjitlink' backend , got '{backend}'")
433- if not _use_nvjitlink_backend :
445+ if _probe_nvjitlink() is None :
434446 raise RuntimeError("nvJitLink backend is not available")
435447 return self._prepare_nvjitlink_options(as_bytes = True )
436448
@@ -453,7 +465,32 @@ cdef inline int Linker_init(Linker self, tuple object_codes, object options) exc
453465
454466 self._options = options = check_or_create_options(LinkerOptions, options, " Linker options" )
455467
456- if _use_nvjitlink_backend:
468+ # Validate inputs up front so an invalid object (e.g. something that merely
469+ # exposes a ``code_type`` attribute ) gets the ObjectCode TypeError from
470+ # assert_type rather than a backend-dispatch RuntimeError.
471+ for code in object_codes:
472+ assert_type(code , ObjectCode )
473+
474+ # Decide the backend per-instance based on the current environment and this
475+ # Linker's inputs. See _choose_backend() for the full decision matrix.
476+ inputs_have_ltoir = any (code.code_type == " ltoir" for code in object_codes)
477+ lto_requested = bool (options.link_time_optimization) or bool (options.ptx)
478+ nvjitlink_version = _probe_nvjitlink()
479+ # Probe driver version lazily: only needed when comparing majors.
480+ # In environments where nvJitLink is installed but the driver is
481+ # absent (e.g., build containers ), cuDriverGetVersion raises CUDAError
482+ # via handle_return; treat that as "driver unknown" and fall through to
483+ # _choose_backend. Any other exception type is a real bug and should
484+ # propagate rather than silently flip the backend choice.
485+ try:
486+ driver_major = driver_version()[0 ]
487+ except CUDAError:
488+ driver_major = None
489+ backend = _choose_backend(
490+ driver_major, nvjitlink_version, inputs_have_ltoir, lto_requested
491+ )
492+
493+ if backend == "nvjitlink":
457494 self._use_nvjitlink = True
458495 options_bytes = options._prepare_nvjitlink_options(as_bytes = True )
459496 c_num_opts = len (options_bytes)
@@ -490,7 +527,6 @@ cdef inline int Linker_init(Linker self, tuple object_codes, object options) exc
490527 self ._culink_handle = create_culink_handle(c_raw_culink)
491528
492529 for code in object_codes:
493- assert_type(code, ObjectCode)
494530 Linker_add_code_object(self , code)
495531 return 0
496532
@@ -618,9 +654,10 @@ cdef inline void Linker_annotate_error_log(Linker self, object e):
618654# =============================================================================
619655
620656# TODO: revisit this treatment for py313t builds
621- _driver = None # populated if nvJitLink cannot be used
622657_inited = False
623- _use_nvjitlink_backend = None # set by _decide_nvjitlink_or_driver()
658+ _nvjitlink_probed = False
659+ _nvjitlink_version = None # (major, minor) if usable; None if unavailable/too old
660+ _nvjitlink_missing_warned = False
624661
625662# Input type mappings populated by _lazy_init() with C-level enum ints.
626663_nvjitlink_input_types = None
@@ -632,12 +669,15 @@ def _nvjitlink_has_version_symbol(nvjitlink) -> bool:
632669 return bool(nvjitlink._inspect_function_pointer("__nvJitLinkVersion"))
633670
634671
635- # Note: this function is reused in the tests
636- def _decide_nvjitlink_or_driver() -> bool:
637- """Return True if falling back to the cuLink* driver APIs."""
638- global _driver , _use_nvjitlink_backend
639- if _use_nvjitlink_backend is not None:
640- return not _use_nvjitlink_backend
672+ def _probe_nvjitlink() -> tuple | None:
673+ """Return ``(major , minor )`` if nvJitLink is available and >= 12.3, else ``None``.
674+
675+ Emits a ``RuntimeWarning`` at most once when nvJitLink is unavailable or too
676+ old. The result is cached for subsequent calls.
677+ """
678+ global _nvjitlink_probed , _nvjitlink_version , _nvjitlink_missing_warned
679+ if _nvjitlink_probed:
680+ return _nvjitlink_version
641681
642682 warn_txt_common = (
643683 " the driver APIs will be used instead, which do not support"
@@ -649,46 +689,111 @@ def _decide_nvjitlink_or_driver() -> bool:
649689 " cuda.bindings.nvjitlink" ,
650690 probe_function = lambda module : module.version(), # probe triggers nvJitLink runtime load
651691 )
692+ warn_txt = None
652693 if nvjitlink_module is None:
653- warn_txt = f" cuda.bindings.nvjitlink is not available, therefore {warn_txt_common} cuda-bindings."
654- else:
655- from cuda.bindings._internal import nvjitlink
656-
657- if _nvjitlink_has_version_symbol(nvjitlink ):
658- _use_nvjitlink_backend = True
659- return False # Use nvjitlink
660694 warn_txt = (
661- f" {'nvJitLink*.dll' if sys.platform == 'win32' else 'libnvJitLink.so*'} is too old (<12.3)."
662- f" Therefore cuda.bindings.nvjitlink is not usable and {warn_txt_common} nvJitLink."
695+ f" cuda.bindings.nvjitlink is not available, therefore {warn_txt_common} cuda-bindings."
663696 )
697+ else:
698+ from cuda.bindings._internal import nvjitlink as inner_nvjitlink
699+
700+ if _nvjitlink_has_version_symbol(inner_nvjitlink ):
701+ _nvjitlink_version = tuple (nvjitlink_module.version())
702+ else :
703+ warn_txt = (
704+ f" {'nvJitLink*.dll' if sys.platform == 'win32' else 'libnvJitLink.so*'} is too old (<12.3)."
705+ f" Therefore cuda.bindings.nvjitlink is not usable and {warn_txt_common} nvJitLink."
706+ )
707+
708+ if warn_txt is not None and not _nvjitlink_missing_warned:
709+ warn(warn_txt, stacklevel = 2 , category = RuntimeWarning )
710+ _nvjitlink_missing_warned = True
711+ _nvjitlink_probed = True
712+ return _nvjitlink_version
664713
665- warn(warn_txt, stacklevel = 2 , category = RuntimeWarning )
666- _use_nvjitlink_backend = False
667- _driver = driver
668- return True
714+
715+ def _choose_backend (
716+ driver_major: int | None ,
717+ nvjitlink_version: tuple | None ,
718+ inputs_have_ltoir: bool ,
719+ lto_requested: bool ,
720+ ) -> str:
721+ """Choose the linker backend for a specific Linker invocation.
722+
723+ Parameters
724+ ----------
725+ driver_major : int or None
726+ Major version of the installed CUDA driver (from ``cuDriverGetVersion``).
727+ ``None`` when the driver cannot be queried (e.g., no driver installed ).
728+ nvjitlink_version : tuple[int , int] or None
729+ ``(major , minor )`` if nvJitLink is available and >=12.3; ``None`` otherwise.
730+ inputs_have_ltoir : bool
731+ ``True`` if any input ``ObjectCode`` has ``code_type == "ltoir"``.
732+ lto_requested : bool
733+ ``True`` if ``LinkerOptions.link_time_optimization`` or ``ptx`` is set
734+ (both force the use of nvJitLink; the driver linker cannot emit PTX and
735+ cannot do link-time optimization on LTO IR ).
736+
737+ Returns
738+ -------
739+ str
740+ ``"nvjitlink"`` or ``"driver"``.
741+
742+ Raises
743+ ------
744+ RuntimeError
745+ If the request cannot be satisfied by any backend , for example when
746+ LTO IR inputs or ``link_time_optimization`` are requested but nvJitLink
747+ is unavailable , or when driver and nvJitLink have mismatched major
748+ versions for an LTO link.
749+ """
750+ needs_nvjitlink = inputs_have_ltoir or lto_requested
751+
752+ if nvjitlink_version is None:
753+ if needs_nvjitlink:
754+ raise RuntimeError(
755+ "LTO IR input or link-time optimization was requested , but "
756+ "nvJitLink is not available (driver linker cannot perform LTO ). "
757+ "Install cuda-bindings with a compatible nvJitLink (>=12.3)."
758+ )
759+ return "driver"
760+
761+ nvjitlink_major = nvjitlink_version[0 ]
762+ # If driver version is unknown , optimistically use nvJitLink
763+ # (common in build containers with nvJitLink but no driver ).
764+ if driver_major is None or nvjitlink_major == driver_major:
765+ return "nvjitlink"
766+
767+ if needs_nvjitlink:
768+ raise RuntimeError(
769+ f"Cannot link with nvJitLink {nvjitlink_major}.x against CUDA driver "
770+ f"{driver_major}.x: LTO IR or link-time optimization requires matching "
771+ f"major versions , and the driver linker cannot perform LTO. "
772+ f"Install an nvJitLink matching the driver major version."
773+ )
774+ # Driver and nvJitLink have different major versions. nvJitLink output may
775+ # target an architecture or format that the driver cannot load , so fall back
776+ # to the driver's own linker for non-LTO linking.
777+ return "driver"
669778
670779
671780def _lazy_init():
672781 global _inited, _nvjitlink_input_types, _driver_input_types
673782 if _inited:
674783 return
675-
676- _decide_nvjitlink_or_driver()
677- if _use_nvjitlink_backend:
678- _nvjitlink_input_types = {
679- " ptx" : < int > cynvjitlink.NVJITLINK_INPUT_PTX,
680- " cubin" : < int > cynvjitlink.NVJITLINK_INPUT_CUBIN,
681- " fatbin" : < int > cynvjitlink.NVJITLINK_INPUT_FATBIN,
682- " ltoir" : < int > cynvjitlink.NVJITLINK_INPUT_LTOIR,
683- " object" : < int > cynvjitlink.NVJITLINK_INPUT_OBJECT,
684- " library" : < int > cynvjitlink.NVJITLINK_INPUT_LIBRARY,
685- }
686- else :
687- _driver_input_types = {
688- " ptx" : < int > cydriver.CU_JIT_INPUT_PTX,
689- " cubin" : < int > cydriver.CU_JIT_INPUT_CUBIN,
690- " fatbin" : < int > cydriver.CU_JIT_INPUT_FATBINARY,
691- " object" : < int > cydriver.CU_JIT_INPUT_OBJECT,
692- " library" : < int > cydriver.CU_JIT_INPUT_LIBRARY,
693- }
784+ _nvjitlink_input_types = {
785+ " ptx" : < int > cynvjitlink.NVJITLINK_INPUT_PTX,
786+ " cubin" : < int > cynvjitlink.NVJITLINK_INPUT_CUBIN,
787+ " fatbin" : < int > cynvjitlink.NVJITLINK_INPUT_FATBIN,
788+ " ltoir" : < int > cynvjitlink.NVJITLINK_INPUT_LTOIR,
789+ " object" : < int > cynvjitlink.NVJITLINK_INPUT_OBJECT,
790+ " library" : < int > cynvjitlink.NVJITLINK_INPUT_LIBRARY,
791+ }
792+ _driver_input_types = {
793+ " ptx" : < int > cydriver.CU_JIT_INPUT_PTX,
794+ " cubin" : < int > cydriver.CU_JIT_INPUT_CUBIN,
795+ " fatbin" : < int > cydriver.CU_JIT_INPUT_FATBINARY,
796+ " object" : < int > cydriver.CU_JIT_INPUT_OBJECT,
797+ " library" : < int > cydriver.CU_JIT_INPUT_LIBRARY,
798+ }
694799 _inited = True
0 commit comments