Skip to content

Commit 4724769

Browse files
committed
Factor out _nvjitlink_has_version_symbol() for clarity and testability
This aids unit testing by allowing localized stubbing of the version-symbol check, without needing to patch the full inner nvjitlink module.
1 parent 7e9fce4 commit 4724769

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

cuda_core/cuda/core/experimental/_linker.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@
2929
_nvjitlink_input_types = None # populated if nvJitLink cannot be used
3030

3131

32+
def _nvjitlink_has_version_symbol(inner_nvjitlink) -> bool:
33+
# This condition is equivalent to testing for version >= 12.3
34+
return bool(inner_nvjitlink._inspect_function_pointer("__nvJitLinkVersion"))
35+
36+
3237
# Note: this function is reused in the tests
3338
def _decide_nvjitlink_or_driver() -> bool:
3439
"""Returns True if falling back to the cuLink* driver APIs."""
@@ -53,7 +58,7 @@ def _decide_nvjitlink_or_driver() -> bool:
5358
from cuda.bindings._internal import nvjitlink as inner_nvjitlink
5459

5560
try:
56-
if inner_nvjitlink._inspect_function_pointer("__nvJitLinkVersion"):
61+
if _nvjitlink_has_version_symbol(inner_nvjitlink):
5762
return False # Use nvjitlink
5863
except RuntimeError:
5964
warn_detail = "not available"

0 commit comments

Comments
 (0)