Skip to content

Commit 52bb1e7

Browse files
committed
refactor(core): per-instance Linker backend dispatch
Replace module-level "decide once" backend selection with per-Linker instance dispatch at __init__ time. Factor the decision into a pure _choose_backend() helper so it can be unit-tested without a GPU. Handle nvJitLink/driver major-version skew: fall back to the driver linker for non-LTO linking, raise RuntimeError for LTO when the backends are incompatible. Probe driver_version() lazily so environments with nvJitLink but no driver (e.g., build containers) still work; only CUDAError from handle_return is treated as "driver unknown" so other exceptions propagate. _probe_nvjitlink() is cached and warns at most once when nvJitLink is absent. Validate each input as ObjectCode before the code_type pre-scan so invalid inputs surface a TypeError instead of a backend-dispatch RuntimeError. Breaking change: options.link_time_optimization=True with nvJitLink absent now raises RuntimeError instead of silently passing CU_JIT_LTO to the driver (which was not real LTO linking). Closes #712
1 parent a18022c commit 52bb1e7

5 files changed

Lines changed: 361 additions & 79 deletions

File tree

cuda_core/cuda/core/_linker.pyx

Lines changed: 165 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ from cuda.core._utils.cuda_utils import (
3939
driver,
4040
is_sequence,
4141
)
42+
from cuda.core._utils.version import driver_version
4243

4344
ctypedef const char* const_char_ptr
4445
ctypedef void* void_ptr
@@ -181,9 +182,20 @@ cdef class Linker:
181182
class LinkerOptions:
182183
"""Customizable options for configuring :class:`Linker`.
183184

184-
Since the linker may choose to use nvJitLink or the driver APIs as the linking backend,
185-
not all options are applicable. When the system's installed nvJitLink is too old (<12.3),
186-
or not installed, the driver APIs (cuLink) will be used instead.
185+
Since the linker may choose either nvJitLink or the driver's ``cuLink*``
186+
APIs as the backend, not every option is applicable to both backends. The
187+
backend is decided per-:class:`Linker` instance from the installed CUDA
188+
driver major version, nvJitLink's availability and major version, the input
189+
code types, and whether link-time optimization is requested:
190+
191+
- nvJitLink is used when its major version matches the driver's.
192+
- The driver linker is used when nvJitLink is unavailable or too old
193+
(<12.3), or when its major version differs from the driver's (and no LTO
194+
step is required).
195+
- Linking LTO IRs, or requesting ``link_time_optimization`` / ``ptx``, with
196+
nvJitLink unavailable or with mismatched nvJitLink and driver majors is
197+
unsupported and raises :class:`RuntimeError` at :class:`Linker`
198+
construction time.
187199

188200
Attributes
189201
----------
@@ -348,39 +360,39 @@ class LinkerOptions:
348360
formatted_options.extend((bytearray(size), size, bytearray(size), size))
349361
option_keys.extend(
350362
(
351-
_driver.CUjit_option.CU_JIT_INFO_LOG_BUFFER,
352-
_driver.CUjit_option.CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES,
353-
_driver.CUjit_option.CU_JIT_ERROR_LOG_BUFFER,
354-
_driver.CUjit_option.CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES,
363+
driver.CUjit_option.CU_JIT_INFO_LOG_BUFFER,
364+
driver.CUjit_option.CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES,
365+
driver.CUjit_option.CU_JIT_ERROR_LOG_BUFFER,
366+
driver.CUjit_option.CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES,
355367
)
356368
)
357369

358370
if self.arch is not None:
359371
arch = self.arch.split("_")[-1].upper()
360-
formatted_options.append(getattr(_driver.CUjit_target, f"CU_TARGET_COMPUTE_{arch}"))
361-
option_keys.append(_driver.CUjit_option.CU_JIT_TARGET)
372+
formatted_options.append(getattr(driver.CUjit_target, f"CU_TARGET_COMPUTE_{arch}"))
373+
option_keys.append(driver.CUjit_option.CU_JIT_TARGET)
362374
if self.max_register_count is not None:
363375
formatted_options.append(self.max_register_count)
364-
option_keys.append(_driver.CUjit_option.CU_JIT_MAX_REGISTERS)
376+
option_keys.append(driver.CUjit_option.CU_JIT_MAX_REGISTERS)
365377
if self.time is not None:
366378
raise ValueError("time option is not supported by the driver API")
367379
if self.verbose:
368380
formatted_options.append(1)
369-
option_keys.append(_driver.CUjit_option.CU_JIT_LOG_VERBOSE)
381+
option_keys.append(driver.CUjit_option.CU_JIT_LOG_VERBOSE)
370382
if self.link_time_optimization:
371383
formatted_options.append(1)
372-
option_keys.append(_driver.CUjit_option.CU_JIT_LTO)
384+
option_keys.append(driver.CUjit_option.CU_JIT_LTO)
373385
if self.ptx:
374386
raise ValueError("ptx option is not supported by the driver API")
375387
if self.optimization_level is not None:
376388
formatted_options.append(self.optimization_level)
377-
option_keys.append(_driver.CUjit_option.CU_JIT_OPTIMIZATION_LEVEL)
389+
option_keys.append(driver.CUjit_option.CU_JIT_OPTIMIZATION_LEVEL)
378390
if self.debug:
379391
formatted_options.append(1)
380-
option_keys.append(_driver.CUjit_option.CU_JIT_GENERATE_DEBUG_INFO)
392+
option_keys.append(driver.CUjit_option.CU_JIT_GENERATE_DEBUG_INFO)
381393
if self.lineinfo:
382394
formatted_options.append(1)
383-
option_keys.append(_driver.CUjit_option.CU_JIT_GENERATE_LINE_INFO)
395+
option_keys.append(driver.CUjit_option.CU_JIT_GENERATE_LINE_INFO)
384396
if self.ftz is not None:
385397
warn("ftz option is deprecated in the driver API", DeprecationWarning, stacklevel=3)
386398
if self.prec_div is not None:
@@ -402,8 +414,8 @@ class LinkerOptions:
402414
if self.split_compile_extended is not None:
403415
raise ValueError("split_compile_extended option is not supported by the driver API")
404416
if self.no_cache is True:
405-
formatted_options.append(_driver.CUjit_cacheMode.CU_JIT_CACHE_OPTION_NONE)
406-
option_keys.append(_driver.CUjit_option.CU_JIT_CACHE_MODE)
417+
formatted_options.append(driver.CUjit_cacheMode.CU_JIT_CACHE_OPTION_NONE)
418+
option_keys.append(driver.CUjit_option.CU_JIT_CACHE_MODE)
407419

408420
return formatted_options, option_keys
409421

@@ -430,7 +442,7 @@ class LinkerOptions:
430442
backend = backend.lower()
431443
if backend != "nvjitlink":
432444
raise ValueError(f"as_bytes() only supports 'nvjitlink' backend, got '{backend}'")
433-
if not _use_nvjitlink_backend:
445+
if _probe_nvjitlink() is None:
434446
raise RuntimeError("nvJitLink backend is not available")
435447
return self._prepare_nvjitlink_options(as_bytes=True)
436448

@@ -453,7 +465,32 @@ cdef inline int Linker_init(Linker self, tuple object_codes, object options) exc
453465

454466
self._options = options = check_or_create_options(LinkerOptions, options, "Linker options")
455467

456-
if _use_nvjitlink_backend:
468+
# Validate inputs up front so an invalid object (e.g. something that merely
469+
# exposes a ``code_type`` attribute) gets the ObjectCode TypeError from
470+
# assert_type rather than a backend-dispatch RuntimeError.
471+
for code in object_codes:
472+
assert_type(code, ObjectCode)
473+
474+
# Decide the backend per-instance based on the current environment and this
475+
# Linker's inputs. See _choose_backend() for the full decision matrix.
476+
inputs_have_ltoir = any(code.code_type == "ltoir" for code in object_codes)
477+
lto_requested = bool(options.link_time_optimization) or bool(options.ptx)
478+
nvjitlink_version = _probe_nvjitlink()
479+
# Probe driver version lazily: only needed when comparing majors.
480+
# In environments where nvJitLink is installed but the driver is
481+
# absent (e.g., build containers), cuDriverGetVersion raises CUDAError
482+
# via handle_return; treat that as "driver unknown" and fall through to
483+
# _choose_backend. Any other exception type is a real bug and should
484+
# propagate rather than silently flip the backend choice.
485+
try:
486+
driver_major = driver_version()[0]
487+
except CUDAError:
488+
driver_major = None
489+
backend = _choose_backend(
490+
driver_major, nvjitlink_version, inputs_have_ltoir, lto_requested
491+
)
492+
493+
if backend == "nvjitlink":
457494
self._use_nvjitlink = True
458495
options_bytes = options._prepare_nvjitlink_options(as_bytes=True)
459496
c_num_opts = len(options_bytes)
@@ -490,7 +527,6 @@ cdef inline int Linker_init(Linker self, tuple object_codes, object options) exc
490527
self._culink_handle = create_culink_handle(c_raw_culink)
491528

492529
for code in object_codes:
493-
assert_type(code, ObjectCode)
494530
Linker_add_code_object(self, code)
495531
return 0
496532

@@ -618,9 +654,10 @@ cdef inline void Linker_annotate_error_log(Linker self, object e):
618654
# =============================================================================
619655

620656
# TODO: revisit this treatment for py313t builds
621-
_driver = None # populated if nvJitLink cannot be used
622657
_inited = False
623-
_use_nvjitlink_backend = None # set by _decide_nvjitlink_or_driver()
658+
_nvjitlink_probed = False
659+
_nvjitlink_version = None # (major, minor) if usable; None if unavailable/too old
660+
_nvjitlink_missing_warned = False
624661

625662
# Input type mappings populated by _lazy_init() with C-level enum ints.
626663
_nvjitlink_input_types = None
@@ -632,12 +669,15 @@ def _nvjitlink_has_version_symbol(nvjitlink) -> bool:
632669
return bool(nvjitlink._inspect_function_pointer("__nvJitLinkVersion"))
633670

634671

635-
# Note: this function is reused in the tests
636-
def _decide_nvjitlink_or_driver() -> bool:
637-
"""Return True if falling back to the cuLink* driver APIs."""
638-
global _driver, _use_nvjitlink_backend
639-
if _use_nvjitlink_backend is not None:
640-
return not _use_nvjitlink_backend
672+
def _probe_nvjitlink() -> tuple | None:
673+
"""Return ``(major, minor)`` if nvJitLink is available and >= 12.3, else ``None``.
674+
675+
Emits a ``RuntimeWarning`` at most once when nvJitLink is unavailable or too
676+
old. The result is cached for subsequent calls.
677+
"""
678+
global _nvjitlink_probed, _nvjitlink_version, _nvjitlink_missing_warned
679+
if _nvjitlink_probed:
680+
return _nvjitlink_version
641681

642682
warn_txt_common = (
643683
"the driver APIs will be used instead, which do not support"
@@ -649,46 +689,111 @@ def _decide_nvjitlink_or_driver() -> bool:
649689
"cuda.bindings.nvjitlink",
650690
probe_function=lambda module: module.version(), # probe triggers nvJitLink runtime load
651691
)
692+
warn_txt = None
652693
if nvjitlink_module is None:
653-
warn_txt = f"cuda.bindings.nvjitlink is not available, therefore {warn_txt_common} cuda-bindings."
654-
else:
655-
from cuda.bindings._internal import nvjitlink
656-
657-
if _nvjitlink_has_version_symbol(nvjitlink):
658-
_use_nvjitlink_backend = True
659-
return False # Use nvjitlink
660694
warn_txt = (
661-
f"{'nvJitLink*.dll' if sys.platform == 'win32' else 'libnvJitLink.so*'} is too old (<12.3)."
662-
f" Therefore cuda.bindings.nvjitlink is not usable and {warn_txt_common} nvJitLink."
695+
f"cuda.bindings.nvjitlink is not available, therefore {warn_txt_common} cuda-bindings."
663696
)
697+
else:
698+
from cuda.bindings._internal import nvjitlink as inner_nvjitlink
699+
700+
if _nvjitlink_has_version_symbol(inner_nvjitlink):
701+
_nvjitlink_version = tuple(nvjitlink_module.version())
702+
else:
703+
warn_txt = (
704+
f"{'nvJitLink*.dll' if sys.platform == 'win32' else 'libnvJitLink.so*'} is too old (<12.3)."
705+
f" Therefore cuda.bindings.nvjitlink is not usable and {warn_txt_common} nvJitLink."
706+
)
707+
708+
if warn_txt is not None and not _nvjitlink_missing_warned:
709+
warn(warn_txt, stacklevel=2, category=RuntimeWarning)
710+
_nvjitlink_missing_warned = True
711+
_nvjitlink_probed = True
712+
return _nvjitlink_version
664713

665-
warn(warn_txt, stacklevel=2, category=RuntimeWarning)
666-
_use_nvjitlink_backend = False
667-
_driver = driver
668-
return True
714+
715+
def _choose_backend(
716+
driver_major: int | None,
717+
nvjitlink_version: tuple | None,
718+
inputs_have_ltoir: bool,
719+
lto_requested: bool,
720+
) -> str:
721+
"""Choose the linker backend for a specific Linker invocation.
722+
723+
Parameters
724+
----------
725+
driver_major : int or None
726+
Major version of the installed CUDA driver (from ``cuDriverGetVersion``).
727+
``None`` when the driver cannot be queried (e.g., no driver installed).
728+
nvjitlink_version : tuple[int, int] or None
729+
``(major, minor)`` if nvJitLink is available and >=12.3; ``None`` otherwise.
730+
inputs_have_ltoir : bool
731+
``True`` if any input ``ObjectCode`` has ``code_type == "ltoir"``.
732+
lto_requested : bool
733+
``True`` if ``LinkerOptions.link_time_optimization`` or ``ptx`` is set
734+
(both force the use of nvJitLink; the driver linker cannot emit PTX and
735+
cannot do link-time optimization on LTO IR).
736+
737+
Returns
738+
-------
739+
str
740+
``"nvjitlink"`` or ``"driver"``.
741+
742+
Raises
743+
------
744+
RuntimeError
745+
If the request cannot be satisfied by any backend, for example when
746+
LTO IR inputs or ``link_time_optimization`` are requested but nvJitLink
747+
is unavailable, or when driver and nvJitLink have mismatched major
748+
versions for an LTO link.
749+
"""
750+
needs_nvjitlink = inputs_have_ltoir or lto_requested
751+
752+
if nvjitlink_version is None:
753+
if needs_nvjitlink:
754+
raise RuntimeError(
755+
"LTO IR input or link-time optimization was requested, but "
756+
"nvJitLink is not available (driver linker cannot perform LTO). "
757+
"Install cuda-bindings with a compatible nvJitLink (>=12.3)."
758+
)
759+
return "driver"
760+
761+
nvjitlink_major = nvjitlink_version[0]
762+
# If driver version is unknown, optimistically use nvJitLink
763+
# (common in build containers with nvJitLink but no driver).
764+
if driver_major is None or nvjitlink_major == driver_major:
765+
return "nvjitlink"
766+
767+
if needs_nvjitlink:
768+
raise RuntimeError(
769+
f"Cannot link with nvJitLink {nvjitlink_major}.x against CUDA driver "
770+
f"{driver_major}.x: LTO IR or link-time optimization requires matching "
771+
f"major versions, and the driver linker cannot perform LTO. "
772+
f"Install an nvJitLink matching the driver major version."
773+
)
774+
# Driver and nvJitLink have different major versions. nvJitLink output may
775+
# target an architecture or format that the driver cannot load, so fall back
776+
# to the driver's own linker for non-LTO linking.
777+
return "driver"
669778

670779

671780
def _lazy_init():
672781
global _inited, _nvjitlink_input_types, _driver_input_types
673782
if _inited:
674783
return
675-
676-
_decide_nvjitlink_or_driver()
677-
if _use_nvjitlink_backend:
678-
_nvjitlink_input_types = {
679-
"ptx": <int>cynvjitlink.NVJITLINK_INPUT_PTX,
680-
"cubin": <int>cynvjitlink.NVJITLINK_INPUT_CUBIN,
681-
"fatbin": <int>cynvjitlink.NVJITLINK_INPUT_FATBIN,
682-
"ltoir": <int>cynvjitlink.NVJITLINK_INPUT_LTOIR,
683-
"object": <int>cynvjitlink.NVJITLINK_INPUT_OBJECT,
684-
"library": <int>cynvjitlink.NVJITLINK_INPUT_LIBRARY,
685-
}
686-
else:
687-
_driver_input_types = {
688-
"ptx": <int>cydriver.CU_JIT_INPUT_PTX,
689-
"cubin": <int>cydriver.CU_JIT_INPUT_CUBIN,
690-
"fatbin": <int>cydriver.CU_JIT_INPUT_FATBINARY,
691-
"object": <int>cydriver.CU_JIT_INPUT_OBJECT,
692-
"library": <int>cydriver.CU_JIT_INPUT_LIBRARY,
693-
}
784+
_nvjitlink_input_types = {
785+
"ptx": <int>cynvjitlink.NVJITLINK_INPUT_PTX,
786+
"cubin": <int>cynvjitlink.NVJITLINK_INPUT_CUBIN,
787+
"fatbin": <int>cynvjitlink.NVJITLINK_INPUT_FATBIN,
788+
"ltoir": <int>cynvjitlink.NVJITLINK_INPUT_LTOIR,
789+
"object": <int>cynvjitlink.NVJITLINK_INPUT_OBJECT,
790+
"library": <int>cynvjitlink.NVJITLINK_INPUT_LIBRARY,
791+
}
792+
_driver_input_types = {
793+
"ptx": <int>cydriver.CU_JIT_INPUT_PTX,
794+
"cubin": <int>cydriver.CU_JIT_INPUT_CUBIN,
795+
"fatbin": <int>cydriver.CU_JIT_INPUT_FATBINARY,
796+
"object": <int>cydriver.CU_JIT_INPUT_OBJECT,
797+
"library": <int>cydriver.CU_JIT_INPUT_LIBRARY,
798+
}
694799
_inited = True

0 commit comments

Comments
 (0)