Skip to content

Commit 024ede1

Browse files
committed
Remove _driver_ver from _linker.pyx; use _use_nvjitlink_backend as guard
Initialize _use_nvjitlink_backend to None so it can serve as its own "already decided" sentinel, eliminating the redundant _driver_ver variable and the driver_version() call that was only used to set it. Made-with: Cursor
1 parent 110d6de commit 024ede1

2 files changed

Lines changed: 5 additions & 15 deletions

File tree

cuda_core/cuda/core/_linker.pyx

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ from cuda.core._utils.cuda_utils import (
3939
driver,
4040
is_sequence,
4141
)
42-
from cuda.core._utils.version import driver_version
4342

4443
ctypedef const char* const_char_ptr
4544
ctypedef void* void_ptr
@@ -620,9 +619,8 @@ cdef inline void Linker_annotate_error_log(Linker self, object e):
620619

621620
# TODO: revisit this treatment for py313t builds
622621
_driver = None # populated if nvJitLink cannot be used
623-
_driver_ver = None
624622
_inited = False
625-
_use_nvjitlink_backend = False # set by _decide_nvjitlink_or_driver()
623+
_use_nvjitlink_backend = None # set by _decide_nvjitlink_or_driver()
626624

627625
# Input type mappings populated by _lazy_init() with C-level enum ints.
628626
_nvjitlink_input_types = None
@@ -637,12 +635,10 @@ def _nvjitlink_has_version_symbol(nvjitlink) -> bool:
637635
# Note: this function is reused in the tests
638636
def _decide_nvjitlink_or_driver() -> bool:
639637
"""Return True if falling back to the cuLink* driver APIs."""
640-
global _driver_ver, _driver, _use_nvjitlink_backend
641-
if _driver_ver is not None:
638+
global _driver, _use_nvjitlink_backend
639+
if _use_nvjitlink_backend is not None:
642640
return not _use_nvjitlink_backend
643641

644-
_driver_ver = driver_version()[:2]
645-
646642
warn_txt_common = (
647643
"the driver APIs will be used instead, which do not support"
648644
" minor version compatibility or linking LTO IRs."
@@ -667,6 +663,7 @@ def _decide_nvjitlink_or_driver() -> bool:
667663
)
668664

669665
warn(warn_txt, stacklevel=2, category=RuntimeWarning)
666+
_use_nvjitlink_backend = False
670667
_driver = driver
671668
return True
672669

cuda_core/tests/test_optional_dependency_imports.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,20 @@ def restore_optional_import_state():
1212
saved_nvvm_module = _program._nvvm_module
1313
saved_nvvm_attempted = _program._nvvm_import_attempted
1414
saved_driver = _linker._driver
15-
saved_driver_ver = _linker._driver_ver
1615
saved_inited = _linker._inited
1716
saved_use_nvjitlink = _linker._use_nvjitlink_backend
1817

1918
_program._nvvm_module = None
2019
_program._nvvm_import_attempted = False
2120
_linker._driver = None
22-
_linker._driver_ver = None
2321
_linker._inited = False
24-
_linker._use_nvjitlink_backend = False
22+
_linker._use_nvjitlink_backend = None
2523

2624
yield
2725

2826
_program._nvvm_module = saved_nvvm_module
2927
_program._nvvm_import_attempted = saved_nvvm_attempted
3028
_linker._driver = saved_driver
31-
_linker._driver_ver = saved_driver_ver
3229
_linker._inited = saved_inited
3330
_linker._use_nvjitlink_backend = saved_use_nvjitlink
3431

@@ -79,8 +76,6 @@ def fake__optional_cuda_import(modname, probe_function=None):
7976

8077

8178
def test_decide_nvjitlink_or_driver_reraises_nested_module_not_found(monkeypatch):
82-
monkeypatch.setattr(_linker, "driver_version", lambda: (13, 0, 0))
83-
8479
def fake__optional_cuda_import(modname, probe_function=None):
8580
assert modname == "cuda.bindings.nvjitlink"
8681
assert probe_function is not None
@@ -96,8 +91,6 @@ def fake__optional_cuda_import(modname, probe_function=None):
9691

9792

9893
def test_decide_nvjitlink_or_driver_falls_back_when_module_missing(monkeypatch):
99-
monkeypatch.setattr(_linker, "driver_version", lambda: (13, 0, 0))
100-
10194
def fake__optional_cuda_import(modname, probe_function=None):
10295
assert modname == "cuda.bindings.nvjitlink"
10396
assert probe_function is not None

0 commit comments

Comments
 (0)