@@ -39,6 +39,7 @@ from cuda.core._utils.cuda_utils import (
3939 driver,
4040 is_sequence,
4141)
42+ from cuda.core._utils.version import driver_version
4243
4344ctypedef const char * const_char_ptr
4445ctypedef void * void_ptr
@@ -181,9 +182,20 @@ cdef class Linker:
181182class LinkerOptions:
182183 """Customizable options for configuring :class:`Linker`.
183184
184- Since the linker may choose to use nvJitLink or the driver APIs as the linking backend ,
185- not all options are applicable. When the system's installed nvJitLink is too old (<12.3),
186- or not installed , the driver APIs (cuLink ) will be used instead.
185+ Since the linker may choose either nvJitLink or the driver's ``cuLink*``
186+ APIs as the backend , not every option is applicable to both backends. The
187+ backend is decided per-:class:`Linker` instance from the installed CUDA
188+ driver major version , nvJitLink's availability and major version , the input
189+ code types , and whether link-time optimization is requested:
190+
191+ - nvJitLink is used when its major version matches the driver's.
192+ - The driver linker is used when nvJitLink is unavailable or too old
193+ (<12.3), or when its major version differs from the driver's (and no LTO
194+ step is required ).
195+ - Linking LTO IRs , or requesting ``link_time_optimization`` / ``ptx``, with
196+ nvJitLink unavailable or with mismatched nvJitLink and driver majors is
197+ unsupported and raises :class:`RuntimeError` at :class:`Linker`
198+ construction time.
187199
188200 Attributes
189201 ----------
@@ -348,39 +360,39 @@ class LinkerOptions:
348360 formatted_options.extend((bytearray(size ), size , bytearray(size ), size ))
349361 option_keys.extend(
350362 (
351- _driver .CUjit_option.CU_JIT_INFO_LOG_BUFFER ,
352- _driver .CUjit_option.CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES ,
353- _driver .CUjit_option.CU_JIT_ERROR_LOG_BUFFER ,
354- _driver .CUjit_option.CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES ,
363+ driver .CUjit_option.CU_JIT_INFO_LOG_BUFFER ,
364+ driver .CUjit_option.CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES ,
365+ driver .CUjit_option.CU_JIT_ERROR_LOG_BUFFER ,
366+ driver .CUjit_option.CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES ,
355367 )
356368 )
357369
358370 if self.arch is not None:
359371 arch = self .arch.split(" _" )[- 1 ].upper()
360- formatted_options.append(getattr(_driver .CUjit_target , f"CU_TARGET_COMPUTE_{arch}"))
361- option_keys.append(_driver .CUjit_option.CU_JIT_TARGET )
372+ formatted_options.append(getattr(driver .CUjit_target , f"CU_TARGET_COMPUTE_{arch}"))
373+ option_keys.append(driver .CUjit_option.CU_JIT_TARGET )
362374 if self.max_register_count is not None:
363375 formatted_options.append(self.max_register_count )
364- option_keys.append(_driver .CUjit_option.CU_JIT_MAX_REGISTERS )
376+ option_keys.append(driver .CUjit_option.CU_JIT_MAX_REGISTERS )
365377 if self.time is not None:
366378 raise ValueError("time option is not supported by the driver API")
367379 if self.verbose:
368380 formatted_options.append(1)
369- option_keys.append(_driver .CUjit_option.CU_JIT_LOG_VERBOSE )
381+ option_keys.append(driver .CUjit_option.CU_JIT_LOG_VERBOSE )
370382 if self.link_time_optimization:
371383 formatted_options.append(1)
372- option_keys.append(_driver .CUjit_option.CU_JIT_LTO )
384+ option_keys.append(driver .CUjit_option.CU_JIT_LTO )
373385 if self.ptx:
374386 raise ValueError("ptx option is not supported by the driver API")
375387 if self.optimization_level is not None:
376388 formatted_options.append(self.optimization_level )
377- option_keys.append(_driver .CUjit_option.CU_JIT_OPTIMIZATION_LEVEL )
389+ option_keys.append(driver .CUjit_option.CU_JIT_OPTIMIZATION_LEVEL )
378390 if self.debug:
379391 formatted_options.append(1)
380- option_keys.append(_driver .CUjit_option.CU_JIT_GENERATE_DEBUG_INFO )
392+ option_keys.append(driver .CUjit_option.CU_JIT_GENERATE_DEBUG_INFO )
381393 if self.lineinfo:
382394 formatted_options.append(1)
383- option_keys.append(_driver .CUjit_option.CU_JIT_GENERATE_LINE_INFO )
395+ option_keys.append(driver .CUjit_option.CU_JIT_GENERATE_LINE_INFO )
384396 if self.ftz is not None:
385397 warn("ftz option is deprecated in the driver API", DeprecationWarning , stacklevel = 3 )
386398 if self.prec_div is not None:
@@ -402,8 +414,8 @@ class LinkerOptions:
402414 if self.split_compile_extended is not None:
403415 raise ValueError("split_compile_extended option is not supported by the driver API")
404416 if self.no_cache is True:
405- formatted_options.append(_driver .CUjit_cacheMode.CU_JIT_CACHE_OPTION_NONE )
406- option_keys.append(_driver .CUjit_option.CU_JIT_CACHE_MODE )
417+ formatted_options.append(driver .CUjit_cacheMode.CU_JIT_CACHE_OPTION_NONE )
418+ option_keys.append(driver .CUjit_option.CU_JIT_CACHE_MODE )
407419
408420 return formatted_options , option_keys
409421
@@ -430,7 +442,7 @@ class LinkerOptions:
430442 backend = backend.lower()
431443 if backend != "nvjitlink":
432444 raise ValueError(f"as_bytes() only supports 'nvjitlink' backend , got '{backend}'")
433- if not _use_nvjitlink_backend :
445+ if _probe_nvjitlink() is None :
434446 raise RuntimeError("nvJitLink backend is not available")
435447 return self._prepare_nvjitlink_options(as_bytes = True )
436448
@@ -453,7 +465,19 @@ cdef inline int Linker_init(Linker self, tuple object_codes, object options) exc
453465
454466 self._options = options = check_or_create_options(LinkerOptions, options, " Linker options" )
455467
456- if _use_nvjitlink_backend:
468+ # Decide the backend per-instance based on the current environment and this
469+ # Linker's inputs. See _choose_backend() for the full decision matrix.
470+ inputs_have_ltoir = any (
471+ getattr (code, " code_type" , None ) == " ltoir" for code in object_codes
472+ )
473+ lto_requested = bool (options.link_time_optimization) or bool (options.ptx)
474+ nvjitlink_version = _probe_nvjitlink()
475+ driver_major = driver_version()[0 ]
476+ backend = _choose_backend(
477+ driver_major, nvjitlink_version, inputs_have_ltoir, lto_requested
478+ )
479+
480+ if backend == "nvjitlink":
457481 self._use_nvjitlink = True
458482 options_bytes = options._prepare_nvjitlink_options(as_bytes = True )
459483 c_num_opts = len (options_bytes)
@@ -618,9 +642,10 @@ cdef inline void Linker_annotate_error_log(Linker self, object e):
618642# =============================================================================
619643
620644# TODO: revisit this treatment for py313t builds
621- _driver = None # populated if nvJitLink cannot be used
622645_inited = False
623- _use_nvjitlink_backend = None # set by _decide_nvjitlink_or_driver()
646+ _nvjitlink_probed = False
647+ _nvjitlink_version = None # (major, minor) if usable; None if unavailable/too old
648+ _nvjitlink_missing_warned = False
624649
625650# Input type mappings populated by _lazy_init() with C-level enum ints.
626651_nvjitlink_input_types = None
@@ -632,12 +657,15 @@ def _nvjitlink_has_version_symbol(nvjitlink) -> bool:
632657 return bool(nvjitlink._inspect_function_pointer("__nvJitLinkVersion"))
633658
634659
635- # Note: this function is reused in the tests
636- def _decide_nvjitlink_or_driver() -> bool:
637- """Return True if falling back to the cuLink* driver APIs."""
638- global _driver , _use_nvjitlink_backend
639- if _use_nvjitlink_backend is not None:
640- return not _use_nvjitlink_backend
660+ def _probe_nvjitlink() -> tuple | None:
661+ """Return ``(major , minor )`` if nvJitLink is available and >= 12.3, else ``None``.
662+
663+ Emits a ``RuntimeWarning`` at most once when nvJitLink is unavailable or too
664+ old. The result is cached for subsequent calls.
665+ """
666+ global _nvjitlink_probed , _nvjitlink_version , _nvjitlink_missing_warned
667+ if _nvjitlink_probed:
668+ return _nvjitlink_version
641669
642670 warn_txt_common = (
643671 " the driver APIs will be used instead, which do not support"
@@ -649,46 +677,108 @@ def _decide_nvjitlink_or_driver() -> bool:
649677 " cuda.bindings.nvjitlink" ,
650678 probe_function = lambda module : module.version(), # probe triggers nvJitLink runtime load
651679 )
680+ warn_txt = None
652681 if nvjitlink_module is None:
653- warn_txt = f" cuda.bindings.nvjitlink is not available, therefore {warn_txt_common} cuda-bindings."
654- else:
655- from cuda.bindings._internal import nvjitlink
656-
657- if _nvjitlink_has_version_symbol(nvjitlink ):
658- _use_nvjitlink_backend = True
659- return False # Use nvjitlink
660682 warn_txt = (
661- f" {'nvJitLink*.dll' if sys.platform == 'win32' else 'libnvJitLink.so*'} is too old (<12.3)."
662- f" Therefore cuda.bindings.nvjitlink is not usable and {warn_txt_common} nvJitLink."
683+ f" cuda.bindings.nvjitlink is not available, therefore {warn_txt_common} cuda-bindings."
663684 )
685+ else:
686+ from cuda.bindings._internal import nvjitlink as inner_nvjitlink
687+
688+ if _nvjitlink_has_version_symbol(inner_nvjitlink ):
689+ _nvjitlink_version = tuple (nvjitlink_module.version())
690+ else :
691+ warn_txt = (
692+ f" {'nvJitLink*.dll' if sys.platform == 'win32' else 'libnvJitLink.so*'} is too old (<12.3)."
693+ f" Therefore cuda.bindings.nvjitlink is not usable and {warn_txt_common} nvJitLink."
694+ )
664695
665- warn(warn_txt, stacklevel = 2 , category = RuntimeWarning )
666- _use_nvjitlink_backend = False
667- _driver = driver
668- return True
696+ if warn_txt is not None and not _nvjitlink_missing_warned:
697+ warn(warn_txt, stacklevel = 2 , category = RuntimeWarning )
698+ _nvjitlink_missing_warned = True
699+ _nvjitlink_probed = True
700+ return _nvjitlink_version
701+
702+
703+ def _choose_backend (
704+ driver_major: int ,
705+ nvjitlink_version: tuple | None ,
706+ inputs_have_ltoir: bool ,
707+ lto_requested: bool ,
708+ ) -> str:
709+ """Choose the linker backend for a specific Linker invocation.
710+
711+ Parameters
712+ ----------
713+ driver_major : int
714+ Major version of the installed CUDA driver (from ``cuDriverGetVersion``).
715+ nvjitlink_version : tuple[int , int] or None
716+ ``(major , minor )`` if nvJitLink is available and >=12.3; ``None`` otherwise.
717+ inputs_have_ltoir : bool
718+ ``True`` if any input ``ObjectCode`` has ``code_type == "ltoir"``.
719+ lto_requested : bool
720+ ``True`` if ``LinkerOptions.link_time_optimization`` or ``ptx`` is set
721+ (both force the use of nvJitLink; the driver linker cannot emit PTX and
722+ cannot do link-time optimization on LTO IR ).
723+
724+ Returns
725+ -------
726+ str
727+ ``"nvjitlink"`` or ``"driver"``.
728+
729+ Raises
730+ ------
731+ RuntimeError
732+ If the request cannot be satisfied by any backend , for example when
733+ LTO IR inputs or ``link_time_optimization`` are requested but nvJitLink
734+ is unavailable , or when driver and nvJitLink have mismatched major
735+ versions for an LTO link.
736+ """
737+ needs_nvjitlink = inputs_have_ltoir or lto_requested
738+
739+ if nvjitlink_version is None:
740+ if needs_nvjitlink:
741+ raise RuntimeError(
742+ "LTO IR input or link-time optimization was requested , but "
743+ "nvJitLink is not available (driver linker cannot perform LTO ). "
744+ "Install cuda-bindings with a compatible nvJitLink (>=12.3)."
745+ )
746+ return "driver"
747+
748+ nvjitlink_major = nvjitlink_version[0 ]
749+ if nvjitlink_major == driver_major:
750+ return "nvjitlink"
751+
752+ if needs_nvjitlink:
753+ raise RuntimeError(
754+ f"Cannot link with nvJitLink {nvjitlink_major}.x against CUDA driver "
755+ f"{driver_major}.x: LTO IR or link-time optimization requires matching "
756+ f"major versions , and the driver linker cannot perform LTO. "
757+ f"Install an nvJitLink matching the driver major version."
758+ )
759+ # Driver and nvJitLink have different major versions. nvJitLink output may
760+ # target an architecture or format that the driver cannot load , so fall back
761+ # to the driver's own linker for non-LTO linking.
762+ return "driver"
669763
670764
671765def _lazy_init():
672766 global _inited, _nvjitlink_input_types, _driver_input_types
673767 if _inited:
674768 return
675-
676- _decide_nvjitlink_or_driver()
677- if _use_nvjitlink_backend:
678- _nvjitlink_input_types = {
679- " ptx" : < int > cynvjitlink.NVJITLINK_INPUT_PTX,
680- " cubin" : < int > cynvjitlink.NVJITLINK_INPUT_CUBIN,
681- " fatbin" : < int > cynvjitlink.NVJITLINK_INPUT_FATBIN,
682- " ltoir" : < int > cynvjitlink.NVJITLINK_INPUT_LTOIR,
683- " object" : < int > cynvjitlink.NVJITLINK_INPUT_OBJECT,
684- " library" : < int > cynvjitlink.NVJITLINK_INPUT_LIBRARY,
685- }
686- else :
687- _driver_input_types = {
688- " ptx" : < int > cydriver.CU_JIT_INPUT_PTX,
689- " cubin" : < int > cydriver.CU_JIT_INPUT_CUBIN,
690- " fatbin" : < int > cydriver.CU_JIT_INPUT_FATBINARY,
691- " object" : < int > cydriver.CU_JIT_INPUT_OBJECT,
692- " library" : < int > cydriver.CU_JIT_INPUT_LIBRARY,
693- }
769+ _nvjitlink_input_types = {
770+ " ptx" : < int > cynvjitlink.NVJITLINK_INPUT_PTX,
771+ " cubin" : < int > cynvjitlink.NVJITLINK_INPUT_CUBIN,
772+ " fatbin" : < int > cynvjitlink.NVJITLINK_INPUT_FATBIN,
773+ " ltoir" : < int > cynvjitlink.NVJITLINK_INPUT_LTOIR,
774+ " object" : < int > cynvjitlink.NVJITLINK_INPUT_OBJECT,
775+ " library" : < int > cynvjitlink.NVJITLINK_INPUT_LIBRARY,
776+ }
777+ _driver_input_types = {
778+ " ptx" : < int > cydriver.CU_JIT_INPUT_PTX,
779+ " cubin" : < int > cydriver.CU_JIT_INPUT_CUBIN,
780+ " fatbin" : < int > cydriver.CU_JIT_INPUT_FATBINARY,
781+ " object" : < int > cydriver.CU_JIT_INPUT_OBJECT,
782+ " library" : < int > cydriver.CU_JIT_INPUT_LIBRARY,
783+ }
694784 _inited = True
0 commit comments