Skip to content

Commit 8674e36

Browse files
cpcloudclaude
andcommitted
refactor(core): per-instance Linker backend dispatch (#712)
Replace the module-level "decide once, use everywhere" nvJitLink-vs-driver choice with a per-Linker-instance decision that considers the CUDA driver major version, nvJitLink's availability and major version, the input code types, and whether link-time optimization is requested. The dispatch is factored into a pure helper `_choose_backend()` that is fully unit-testable without a GPU. Its decision matrix: - no nvJitLink, no LTO -> driver - matching majors -> nvJitLink - cross-major, no LTO -> driver (nvJitLink output may not be loadable) - LTO + no nvJitLink -> RuntimeError - LTO + cross-major -> RuntimeError This resolves the cross-major-driver scenario described in #712, where an nvJitLink 12.x may produce a CUBIN the driver 13.x (or vice versa) cannot load. The previous code committed to nvJitLink unconditionally when it was importable. Tests: - `tests/test_linker_dispatch.py` parametrizes the entire matrix against `_choose_backend()` with mocked versions (no GPU, no driver required). - `tests/test_linker.py::TestLinkerDispatch` drives the same decision through the real `Linker` constructor via monkeypatched version probes. - `tests/test_optional_dependency_imports.py` is updated to exercise the new `_probe_nvjitlink()` helper in place of the removed `_decide_nvjitlink_or_driver()`. - `tests/test_program.py` and `tests/test_linker.py` use a small local helper to compute the effective backend for the current environment. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 67be918 commit 8674e36

File tree

5 files changed

+337
-77
lines changed

5 files changed

+337
-77
lines changed

cuda_core/cuda/core/_linker.pyx

Lines changed: 149 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ from cuda.core._utils.cuda_utils import (
3939
driver,
4040
is_sequence,
4141
)
42+
from cuda.core._utils.version import driver_version
4243

4344
ctypedef const char* const_char_ptr
4445
ctypedef void* void_ptr
@@ -181,9 +182,20 @@ cdef class Linker:
181182
class LinkerOptions:
182183
"""Customizable options for configuring :class:`Linker`.
183184

184-
Since the linker may choose to use nvJitLink or the driver APIs as the linking backend,
185-
not all options are applicable. When the system's installed nvJitLink is too old (<12.3),
186-
or not installed, the driver APIs (cuLink) will be used instead.
185+
Since the linker may choose either nvJitLink or the driver's ``cuLink*``
186+
APIs as the backend, not every option is applicable to both backends. The
187+
backend is decided per-:class:`Linker` instance from the installed CUDA
188+
driver major version, nvJitLink's availability and major version, the input
189+
code types, and whether link-time optimization is requested:
190+
191+
- nvJitLink is used when its major version matches the driver's.
192+
- The driver linker is used when nvJitLink is unavailable or too old
193+
(<12.3), or when its major version differs from the driver's (and no LTO
194+
step is required).
195+
- Linking LTO IRs, or requesting ``link_time_optimization`` / ``ptx``, with
196+
nvJitLink unavailable or with mismatched nvJitLink and driver majors is
197+
unsupported and raises :class:`RuntimeError` at :class:`Linker`
198+
construction time.
187199

188200
Attributes
189201
----------
@@ -348,39 +360,39 @@ class LinkerOptions:
348360
formatted_options.extend((bytearray(size), size, bytearray(size), size))
349361
option_keys.extend(
350362
(
351-
_driver.CUjit_option.CU_JIT_INFO_LOG_BUFFER,
352-
_driver.CUjit_option.CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES,
353-
_driver.CUjit_option.CU_JIT_ERROR_LOG_BUFFER,
354-
_driver.CUjit_option.CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES,
363+
driver.CUjit_option.CU_JIT_INFO_LOG_BUFFER,
364+
driver.CUjit_option.CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES,
365+
driver.CUjit_option.CU_JIT_ERROR_LOG_BUFFER,
366+
driver.CUjit_option.CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES,
355367
)
356368
)
357369

358370
if self.arch is not None:
359371
arch = self.arch.split("_")[-1].upper()
360-
formatted_options.append(getattr(_driver.CUjit_target, f"CU_TARGET_COMPUTE_{arch}"))
361-
option_keys.append(_driver.CUjit_option.CU_JIT_TARGET)
372+
formatted_options.append(getattr(driver.CUjit_target, f"CU_TARGET_COMPUTE_{arch}"))
373+
option_keys.append(driver.CUjit_option.CU_JIT_TARGET)
362374
if self.max_register_count is not None:
363375
formatted_options.append(self.max_register_count)
364-
option_keys.append(_driver.CUjit_option.CU_JIT_MAX_REGISTERS)
376+
option_keys.append(driver.CUjit_option.CU_JIT_MAX_REGISTERS)
365377
if self.time is not None:
366378
raise ValueError("time option is not supported by the driver API")
367379
if self.verbose:
368380
formatted_options.append(1)
369-
option_keys.append(_driver.CUjit_option.CU_JIT_LOG_VERBOSE)
381+
option_keys.append(driver.CUjit_option.CU_JIT_LOG_VERBOSE)
370382
if self.link_time_optimization:
371383
formatted_options.append(1)
372-
option_keys.append(_driver.CUjit_option.CU_JIT_LTO)
384+
option_keys.append(driver.CUjit_option.CU_JIT_LTO)
373385
if self.ptx:
374386
raise ValueError("ptx option is not supported by the driver API")
375387
if self.optimization_level is not None:
376388
formatted_options.append(self.optimization_level)
377-
option_keys.append(_driver.CUjit_option.CU_JIT_OPTIMIZATION_LEVEL)
389+
option_keys.append(driver.CUjit_option.CU_JIT_OPTIMIZATION_LEVEL)
378390
if self.debug:
379391
formatted_options.append(1)
380-
option_keys.append(_driver.CUjit_option.CU_JIT_GENERATE_DEBUG_INFO)
392+
option_keys.append(driver.CUjit_option.CU_JIT_GENERATE_DEBUG_INFO)
381393
if self.lineinfo:
382394
formatted_options.append(1)
383-
option_keys.append(_driver.CUjit_option.CU_JIT_GENERATE_LINE_INFO)
395+
option_keys.append(driver.CUjit_option.CU_JIT_GENERATE_LINE_INFO)
384396
if self.ftz is not None:
385397
warn("ftz option is deprecated in the driver API", DeprecationWarning, stacklevel=3)
386398
if self.prec_div is not None:
@@ -402,8 +414,8 @@ class LinkerOptions:
402414
if self.split_compile_extended is not None:
403415
raise ValueError("split_compile_extended option is not supported by the driver API")
404416
if self.no_cache is True:
405-
formatted_options.append(_driver.CUjit_cacheMode.CU_JIT_CACHE_OPTION_NONE)
406-
option_keys.append(_driver.CUjit_option.CU_JIT_CACHE_MODE)
417+
formatted_options.append(driver.CUjit_cacheMode.CU_JIT_CACHE_OPTION_NONE)
418+
option_keys.append(driver.CUjit_option.CU_JIT_CACHE_MODE)
407419

408420
return formatted_options, option_keys
409421

@@ -430,7 +442,7 @@ class LinkerOptions:
430442
backend = backend.lower()
431443
if backend != "nvjitlink":
432444
raise ValueError(f"as_bytes() only supports 'nvjitlink' backend, got '{backend}'")
433-
if not _use_nvjitlink_backend:
445+
if _probe_nvjitlink() is None:
434446
raise RuntimeError("nvJitLink backend is not available")
435447
return self._prepare_nvjitlink_options(as_bytes=True)
436448

@@ -453,7 +465,19 @@ cdef inline int Linker_init(Linker self, tuple object_codes, object options) exc
453465

454466
self._options = options = check_or_create_options(LinkerOptions, options, "Linker options")
455467

456-
if _use_nvjitlink_backend:
468+
# Decide the backend per-instance based on the current environment and this
469+
# Linker's inputs. See _choose_backend() for the full decision matrix.
470+
inputs_have_ltoir = any(
471+
getattr(code, "code_type", None) == "ltoir" for code in object_codes
472+
)
473+
lto_requested = bool(options.link_time_optimization) or bool(options.ptx)
474+
nvjitlink_version = _probe_nvjitlink()
475+
driver_major = driver_version()[0]
476+
backend = _choose_backend(
477+
driver_major, nvjitlink_version, inputs_have_ltoir, lto_requested
478+
)
479+
480+
if backend == "nvjitlink":
457481
self._use_nvjitlink = True
458482
options_bytes = options._prepare_nvjitlink_options(as_bytes=True)
459483
c_num_opts = len(options_bytes)
@@ -618,9 +642,10 @@ cdef inline void Linker_annotate_error_log(Linker self, object e):
618642
# =============================================================================
619643

620644
# TODO: revisit this treatment for py313t builds
621-
_driver = None # populated if nvJitLink cannot be used
622645
_inited = False
623-
_use_nvjitlink_backend = None # set by _decide_nvjitlink_or_driver()
646+
_nvjitlink_probed = False
647+
_nvjitlink_version = None # (major, minor) if usable; None if unavailable/too old
648+
_nvjitlink_missing_warned = False
624649

625650
# Input type mappings populated by _lazy_init() with C-level enum ints.
626651
_nvjitlink_input_types = None
@@ -632,12 +657,15 @@ def _nvjitlink_has_version_symbol(nvjitlink) -> bool:
632657
return bool(nvjitlink._inspect_function_pointer("__nvJitLinkVersion"))
633658

634659

635-
# Note: this function is reused in the tests
636-
def _decide_nvjitlink_or_driver() -> bool:
637-
"""Return True if falling back to the cuLink* driver APIs."""
638-
global _driver, _use_nvjitlink_backend
639-
if _use_nvjitlink_backend is not None:
640-
return not _use_nvjitlink_backend
660+
def _probe_nvjitlink() -> tuple | None:
661+
"""Return ``(major, minor)`` if nvJitLink is available and >= 12.3, else ``None``.
662+
663+
Emits a ``RuntimeWarning`` at most once when nvJitLink is unavailable or too
664+
old. The result is cached for subsequent calls.
665+
"""
666+
global _nvjitlink_probed, _nvjitlink_version, _nvjitlink_missing_warned
667+
if _nvjitlink_probed:
668+
return _nvjitlink_version
641669

642670
warn_txt_common = (
643671
"the driver APIs will be used instead, which do not support"
@@ -649,46 +677,108 @@ def _decide_nvjitlink_or_driver() -> bool:
649677
"cuda.bindings.nvjitlink",
650678
probe_function=lambda module: module.version(), # probe triggers nvJitLink runtime load
651679
)
680+
warn_txt = None
652681
if nvjitlink_module is None:
653-
warn_txt = f"cuda.bindings.nvjitlink is not available, therefore {warn_txt_common} cuda-bindings."
654-
else:
655-
from cuda.bindings._internal import nvjitlink
656-
657-
if _nvjitlink_has_version_symbol(nvjitlink):
658-
_use_nvjitlink_backend = True
659-
return False # Use nvjitlink
660682
warn_txt = (
661-
f"{'nvJitLink*.dll' if sys.platform == 'win32' else 'libnvJitLink.so*'} is too old (<12.3)."
662-
f" Therefore cuda.bindings.nvjitlink is not usable and {warn_txt_common} nvJitLink."
683+
f"cuda.bindings.nvjitlink is not available, therefore {warn_txt_common} cuda-bindings."
663684
)
685+
else:
686+
from cuda.bindings._internal import nvjitlink as inner_nvjitlink
687+
688+
if _nvjitlink_has_version_symbol(inner_nvjitlink):
689+
_nvjitlink_version = tuple(nvjitlink_module.version())
690+
else:
691+
warn_txt = (
692+
f"{'nvJitLink*.dll' if sys.platform == 'win32' else 'libnvJitLink.so*'} is too old (<12.3)."
693+
f" Therefore cuda.bindings.nvjitlink is not usable and {warn_txt_common} nvJitLink."
694+
)
664695

665-
warn(warn_txt, stacklevel=2, category=RuntimeWarning)
666-
_use_nvjitlink_backend = False
667-
_driver = driver
668-
return True
696+
if warn_txt is not None and not _nvjitlink_missing_warned:
697+
warn(warn_txt, stacklevel=2, category=RuntimeWarning)
698+
_nvjitlink_missing_warned = True
699+
_nvjitlink_probed = True
700+
return _nvjitlink_version
701+
702+
703+
def _choose_backend(
704+
driver_major: int,
705+
nvjitlink_version: tuple | None,
706+
inputs_have_ltoir: bool,
707+
lto_requested: bool,
708+
) -> str:
709+
"""Choose the linker backend for a specific Linker invocation.
710+
711+
Parameters
712+
----------
713+
driver_major : int
714+
Major version of the installed CUDA driver (from ``cuDriverGetVersion``).
715+
nvjitlink_version : tuple[int, int] or None
716+
``(major, minor)`` if nvJitLink is available and >=12.3; ``None`` otherwise.
717+
inputs_have_ltoir : bool
718+
``True`` if any input ``ObjectCode`` has ``code_type == "ltoir"``.
719+
lto_requested : bool
720+
``True`` if ``LinkerOptions.link_time_optimization`` or ``ptx`` is set
721+
(both force the use of nvJitLink; the driver linker cannot emit PTX and
722+
cannot do link-time optimization on LTO IR).
723+
724+
Returns
725+
-------
726+
str
727+
``"nvjitlink"`` or ``"driver"``.
728+
729+
Raises
730+
------
731+
RuntimeError
732+
If the request cannot be satisfied by any backend, for example when
733+
LTO IR inputs or ``link_time_optimization`` are requested but nvJitLink
734+
is unavailable, or when driver and nvJitLink have mismatched major
735+
versions for an LTO link.
736+
"""
737+
needs_nvjitlink = inputs_have_ltoir or lto_requested
738+
739+
if nvjitlink_version is None:
740+
if needs_nvjitlink:
741+
raise RuntimeError(
742+
"LTO IR input or link-time optimization was requested, but "
743+
"nvJitLink is not available (driver linker cannot perform LTO). "
744+
"Install cuda-bindings with a compatible nvJitLink (>=12.3)."
745+
)
746+
return "driver"
747+
748+
nvjitlink_major = nvjitlink_version[0]
749+
if nvjitlink_major == driver_major:
750+
return "nvjitlink"
751+
752+
if needs_nvjitlink:
753+
raise RuntimeError(
754+
f"Cannot link with nvJitLink {nvjitlink_major}.x against CUDA driver "
755+
f"{driver_major}.x: LTO IR or link-time optimization requires matching "
756+
f"major versions, and the driver linker cannot perform LTO. "
757+
f"Install an nvJitLink matching the driver major version."
758+
)
759+
# Driver and nvJitLink have different major versions. nvJitLink output may
760+
# target an architecture or format that the driver cannot load, so fall back
761+
# to the driver's own linker for non-LTO linking.
762+
return "driver"
669763

670764

671765
def _lazy_init():
672766
global _inited, _nvjitlink_input_types, _driver_input_types
673767
if _inited:
674768
return
675-
676-
_decide_nvjitlink_or_driver()
677-
if _use_nvjitlink_backend:
678-
_nvjitlink_input_types = {
679-
"ptx": <int>cynvjitlink.NVJITLINK_INPUT_PTX,
680-
"cubin": <int>cynvjitlink.NVJITLINK_INPUT_CUBIN,
681-
"fatbin": <int>cynvjitlink.NVJITLINK_INPUT_FATBIN,
682-
"ltoir": <int>cynvjitlink.NVJITLINK_INPUT_LTOIR,
683-
"object": <int>cynvjitlink.NVJITLINK_INPUT_OBJECT,
684-
"library": <int>cynvjitlink.NVJITLINK_INPUT_LIBRARY,
685-
}
686-
else:
687-
_driver_input_types = {
688-
"ptx": <int>cydriver.CU_JIT_INPUT_PTX,
689-
"cubin": <int>cydriver.CU_JIT_INPUT_CUBIN,
690-
"fatbin": <int>cydriver.CU_JIT_INPUT_FATBINARY,
691-
"object": <int>cydriver.CU_JIT_INPUT_OBJECT,
692-
"library": <int>cydriver.CU_JIT_INPUT_LIBRARY,
693-
}
769+
_nvjitlink_input_types = {
770+
"ptx": <int>cynvjitlink.NVJITLINK_INPUT_PTX,
771+
"cubin": <int>cynvjitlink.NVJITLINK_INPUT_CUBIN,
772+
"fatbin": <int>cynvjitlink.NVJITLINK_INPUT_FATBIN,
773+
"ltoir": <int>cynvjitlink.NVJITLINK_INPUT_LTOIR,
774+
"object": <int>cynvjitlink.NVJITLINK_INPUT_OBJECT,
775+
"library": <int>cynvjitlink.NVJITLINK_INPUT_LIBRARY,
776+
}
777+
_driver_input_types = {
778+
"ptx": <int>cydriver.CU_JIT_INPUT_PTX,
779+
"cubin": <int>cydriver.CU_JIT_INPUT_CUBIN,
780+
"fatbin": <int>cydriver.CU_JIT_INPUT_FATBINARY,
781+
"object": <int>cydriver.CU_JIT_INPUT_OBJECT,
782+
"library": <int>cydriver.CU_JIT_INPUT_LIBRARY,
783+
}
694784
_inited = True

0 commit comments

Comments
 (0)