Skip to content

Commit 4cea0a1

Browse files
authored
Standardize internal version checks in cuda.core (#1825)
* Cythonize _graph/_graph_builder (move from pure Python to .pyx) Move the GraphBuilder/Graph/GraphCompleteOptions/GraphDebugPrintOptions implementation out of _graph/__init__.py into _graph/_graph_builder.pyx so it is compiled by Cython. A thin __init__.py re-exports the public names so all existing import sites continue to work unchanged. Cython compatibility adjustments: - Remove `from __future__ import annotations` (unsupported by Cython) - Remove TYPE_CHECKING guard; quote annotations that reference Stream (circular import), forward-reference GraphBuilder/Graph, or use X | None union syntax - Update _graphdef.pyx lazy imports to point directly at _graph_builder No build_hooks.py changes needed — the build system auto-discovers .pyx files via glob. Ref: #1076 Made-with: Cursor * Remove _lazy_init from _graph_builder; add cached get_driver_version Replace the per-module _lazy_init / _inited / _driver_ver / _py_major_minor pattern in _graph_builder.pyx with direct calls to centralized cached functions in cuda_utils: - Add get_driver_version() with @functools.cache alongside get_binding_version - Switch get_binding_version from @functools.lru_cache to @functools.cache (cleaner for nullary functions) - Fix split() to return tuple(result) — Cython enforces return type annotations unlike pure Python - Fix _cond_with_params annotation from -> GraphBuilder to -> tuple to match actual return value Made-with: Cursor * Add CPU callbacks for stream capture (GraphBuilder.callback) Implements #1328: host callbacks during stream capture via cuLaunchHostFunc, mirroring the existing GraphDef.callback API. Extracts shared callback infrastructure (_attach_user_object, _attach_host_callback_to_graph, trampoline/destructor) into a new _graph/_utils.pyx module to avoid circular imports between _graph_builder and _graphdef. Made-with: Cursor * Standardize internal version checks in cuda.core Move binding and driver version queries into a dedicated cuda/core/_utils/version.{pyx,pxd} module, providing both Python (binding_version, driver_version) and Cython (cy_binding_version, cy_driver_version) entry points. All functions return version tuples ((major, minor, patch)) and are cached—Python via @functools.cache, Cython via module-level globals. Remove get_binding_version / get_driver_version from cuda_utils.pyx and update all internal call sites and tests to import from the new module. Remove version checks for CUDA < 12.0 (now the minimum) and eliminate dead code exposed by the migration: _lazy_init / _use_ex / _kernel_ctypes / _is_cukernel_get_library_supported machinery in _module.pyx, _launcher.pyx, and _launch_config.pyx. The public NVML-based system.get_driver_version API is unrelated and left unchanged. Made-with: Cursor * Fix unused imports after merge with main Remove unused imports flagged by cython-lint and ruff after resolving merge conflicts with origin/main. Made-with: Cursor * Replace _reduce_3_tuple with math.prod in _launcher.pyx Remove the now-dead _reduce_3_tuple helper from cuda_utils.pyx. Made-with: Cursor * Remove _driver_ver from _linker.pyx; use _use_nvjitlink_backend as guard Initialize _use_nvjitlink_backend to None so it can serve as its own "already decided" sentinel, eliminating the redundant _driver_ver variable and the driver_version() call that was only used to set it. Made-with: Cursor * Add return type annotations to version.pyx; fix minor arithmetic Add -> tuple[int, int, int] annotations to binding_version and driver_version. Align driver_version arithmetic with _system.pyx. Made-with: Cursor
1 parent ad7c96b commit 4cea0a1

17 files changed

+128
-328
lines changed

cuda_core/cuda/core/_graph/_graph_builder.pyx

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ from cuda.core._graph._utils cimport _attach_host_callback_to_graph
1111
from cuda.core._resource_handles cimport as_cu
1212
from cuda.core._stream cimport Stream
1313
from cuda.core._utils.cuda_utils cimport HANDLE_RETURN
14+
from cuda.core._utils.version cimport cy_binding_version, cy_driver_version
15+
1416
from cuda.core._utils.cuda_utils import (
1517
driver,
16-
get_binding_version,
17-
get_driver_version,
1818
handle_return,
1919
)
2020

@@ -169,7 +169,7 @@ def _instantiate_graph(h_graph, options: GraphCompleteOptions | None = None) ->
169169
elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_MULTIPLE_CTXS_NOT_SUPPORTED:
170170
raise RuntimeError("Instantiation for device launch failed due to the nodes belonging to different contexts.")
171171
elif (
172-
get_binding_version() >= (12, 8)
172+
cy_binding_version() >= (12, 8, 0)
173173
and params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_CONDITIONAL_HANDLE_UNUSED
174174
):
175175
raise RuntimeError("One or more conditional handles are not associated with conditional builders.")
@@ -449,10 +449,10 @@ class GraphBuilder:
449449
The newly created conditional handle.
450450

451451
"""
452-
if get_driver_version() < 12030:
453-
raise RuntimeError(f"Driver version {get_driver_version()} does not support conditional handles")
454-
if get_binding_version() < (12, 3):
455-
raise RuntimeError(f"Binding version {get_binding_version()} does not support conditional handles")
452+
if cy_driver_version() < (12, 3, 0):
453+
raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional handles")
454+
if cy_binding_version() < (12, 3, 0):
455+
raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional handles")
456456
if default_value is not None:
457457
flags = driver.CU_GRAPH_COND_ASSIGN_DEFAULT
458458
else:
@@ -522,10 +522,10 @@ class GraphBuilder:
522522
The newly created conditional graph builder.
523523

524524
"""
525-
if get_driver_version() < 12030:
526-
raise RuntimeError(f"Driver version {get_driver_version()} does not support conditional if")
527-
if get_binding_version() < (12, 3):
528-
raise RuntimeError(f"Binding version {get_binding_version()} does not support conditional if")
525+
if cy_driver_version() < (12, 3, 0):
526+
raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional if")
527+
if cy_binding_version() < (12, 3, 0):
528+
raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional if")
529529
node_params = driver.CUgraphNodeParams()
530530
node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL
531531
node_params.conditional.handle = handle
@@ -553,10 +553,10 @@ class GraphBuilder:
553553
A tuple of two new graph builders, one for the if branch and one for the else branch.
554554

555555
"""
556-
if get_driver_version() < 12080:
557-
raise RuntimeError(f"Driver version {get_driver_version()} does not support conditional if-else")
558-
if get_binding_version() < (12, 8):
559-
raise RuntimeError(f"Binding version {get_binding_version()} does not support conditional if-else")
556+
if cy_driver_version() < (12, 8, 0):
557+
raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional if-else")
558+
if cy_binding_version() < (12, 8, 0):
559+
raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional if-else")
560560
node_params = driver.CUgraphNodeParams()
561561
node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL
562562
node_params.conditional.handle = handle
@@ -587,10 +587,10 @@ class GraphBuilder:
587587
A tuple of new graph builders, one for each branch.
588588

589589
"""
590-
if get_driver_version() < 12080:
591-
raise RuntimeError(f"Driver version {get_driver_version()} does not support conditional switch")
592-
if get_binding_version() < (12, 8):
593-
raise RuntimeError(f"Binding version {get_binding_version()} does not support conditional switch")
590+
if cy_driver_version() < (12, 8, 0):
591+
raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional switch")
592+
if cy_binding_version() < (12, 8, 0):
593+
raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional switch")
594594
node_params = driver.CUgraphNodeParams()
595595
node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL
596596
node_params.conditional.handle = handle
@@ -618,10 +618,10 @@ class GraphBuilder:
618618
The newly created while loop graph builder.
619619

620620
"""
621-
if get_driver_version() < 12030:
622-
raise RuntimeError(f"Driver version {get_driver_version()} does not support conditional while loop")
623-
if get_binding_version() < (12, 3):
624-
raise RuntimeError(f"Binding version {get_binding_version()} does not support conditional while loop")
621+
if cy_driver_version() < (12, 3, 0):
622+
raise RuntimeError(f"Driver version {'.'.join(map(str, cy_driver_version()))} does not support conditional while loop")
623+
if cy_binding_version() < (12, 3, 0):
624+
raise RuntimeError(f"Binding version {'.'.join(map(str, cy_binding_version()))} does not support conditional while loop")
625625
node_params = driver.CUgraphNodeParams()
626626
node_params.type = driver.CUgraphNodeType.CU_GRAPH_NODE_TYPE_CONDITIONAL
627627
node_params.conditional.handle = handle
@@ -649,12 +649,6 @@ class GraphBuilder:
649649
child_graph : :obj:`~_graph.GraphBuilder`
650650
The child graph builder. Must have finished building.
651651
"""
652-
if (get_driver_version() < 12000) or (get_binding_version() < (12, 0)):
653-
raise NotImplementedError(
654-
f"Launching child graphs is not implemented for versions older than CUDA 12."
655-
f"Found driver version is {get_driver_version()} and binding version is {get_binding_version()}"
656-
)
657-
658652
if not child_graph._building_ended:
659653
raise ValueError("Child graph has not finished building.")
660654

cuda_core/cuda/core/_graph/_graphdef.pyx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ cdef bint _version_checked = False
9494
cdef bint _check_node_get_params():
9595
global _has_cuGraphNodeGetParams, _version_checked
9696
if not _version_checked:
97-
ver = handle_return(driver.cuDriverGetVersion())
98-
_has_cuGraphNodeGetParams = ver >= 13020
97+
from cuda.core._utils.version import driver_version
98+
_has_cuGraphNodeGetParams = driver_version() >= (13, 2, 0)
9999
_version_checked = True
100100
return _has_cuGraphNodeGetParams
101101

cuda_core/cuda/core/_launch_config.pyx

Lines changed: 0 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -4,49 +4,16 @@
44

55
from libc.string cimport memset
66

7-
from cuda.core._utils.cuda_utils cimport (
8-
HANDLE_RETURN,
9-
)
10-
11-
import threading
12-
137
from cuda.core._device import Device
148
from cuda.core._utils.cuda_utils import (
159
CUDAError,
1610
cast_to_3_tuple,
1711
driver,
18-
get_binding_version,
1912
)
2013

21-
22-
cdef bint _inited = False
23-
cdef bint _use_ex = False
24-
cdef object _lock = threading.Lock()
25-
26-
# Attribute names for identity comparison and representation
2714
_LAUNCH_CONFIG_ATTRS = ('grid', 'cluster', 'block', 'shmem_size', 'cooperative_launch')
2815

2916

30-
cdef int _lazy_init() except?-1:
31-
global _inited, _use_ex
32-
if _inited:
33-
return 0
34-
35-
cdef tuple _py_major_minor
36-
cdef int _driver_ver
37-
with _lock:
38-
if _inited:
39-
return 0
40-
41-
# binding availability depends on cuda-python version
42-
_py_major_minor = get_binding_version()
43-
HANDLE_RETURN(cydriver.cuDriverGetVersion(&_driver_ver))
44-
_use_ex = (_driver_ver >= 11080) and (_py_major_minor >= (11, 8))
45-
_inited = True
46-
47-
return 0
48-
49-
5017
cdef class LaunchConfig:
5118
"""Customizable launch options.
5219
@@ -99,8 +66,6 @@ cdef class LaunchConfig:
9966
cooperative_launch : bool, optional
10067
Whether to launch as cooperative kernel (default: False)
10168
"""
102-
_lazy_init()
103-
10469
# Convert and validate grid and block dimensions
10570
self.grid = cast_to_3_tuple("LaunchConfig.grid", grid)
10671
self.block = cast_to_3_tuple("LaunchConfig.block", block)
@@ -110,10 +75,6 @@ cdef class LaunchConfig:
11075
# device compute capability or attributes.
11176
# thread block clusters are supported starting H100
11277
if cluster is not None:
113-
if not _use_ex:
114-
err, drvers = driver.cuDriverGetVersion()
115-
drvers_fmt = f" (got driver version {drvers})" if err == driver.CUresult.CUDA_SUCCESS else ""
116-
raise CUDAError(f"thread block clusters require cuda.bindings & driver 11.8+{drvers_fmt}")
11778
cc = Device().compute_capability
11879
if cc < (9, 0):
11980
raise CUDAError(
@@ -153,7 +114,6 @@ cdef class LaunchConfig:
153114
return hash(self._identity())
154115

155116
cdef cydriver.CUlaunchConfig _to_native_launch_config(self):
156-
_lazy_init()
157117
cdef cydriver.CUlaunchConfig drv_cfg
158118
cdef cydriver.CUlaunchAttribute attr
159119
memset(&drv_cfg, 0, sizeof(drv_cfg))
@@ -201,8 +161,6 @@ cpdef object _to_native_launch_config(LaunchConfig config):
201161
driver.CUlaunchConfig
202162
Native CUDA driver launch configuration
203163
"""
204-
_lazy_init()
205-
206164
cdef object drv_cfg = driver.CUlaunchConfig()
207165
cdef list attrs
208166
cdef object attr

cuda_core/cuda/core/_launcher.pyx

Lines changed: 9 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -15,39 +15,9 @@ from cuda.core._utils.cuda_utils cimport (
1515
check_or_create_options,
1616
HANDLE_RETURN,
1717
)
18-
19-
import threading
20-
2118
from cuda.core._module import Kernel
2219
from cuda.core._stream import Stream
23-
from cuda.core._utils.cuda_utils import (
24-
_reduce_3_tuple,
25-
get_binding_version,
26-
)
27-
28-
29-
cdef bint _inited = False
30-
cdef bint _use_ex = False
31-
cdef object _lock = threading.Lock()
32-
33-
34-
cdef int _lazy_init() except?-1:
35-
global _inited, _use_ex
36-
if _inited:
37-
return 0
38-
39-
cdef int _driver_ver
40-
with _lock:
41-
if _inited:
42-
return 0
43-
44-
# binding availability depends on cuda-python version
45-
_py_major_minor = get_binding_version()
46-
HANDLE_RETURN(cydriver.cuDriverGetVersion(&_driver_ver))
47-
_use_ex = (_driver_ver >= 11080) and (_py_major_minor >= (11, 8))
48-
_inited = True
49-
50-
return 0
20+
from math import prod
5121

5222

5323
def launch(stream: Stream | GraphBuilder | IsStreamT, config: LaunchConfig, kernel: Kernel, *kernel_args):
@@ -70,49 +40,31 @@ def launch(stream: Stream | GraphBuilder | IsStreamT, config: LaunchConfig, kern
7040
7141
"""
7242
cdef Stream s = Stream_accept(stream, allow_stream_protocol=True)
73-
_lazy_init()
7443
cdef LaunchConfig conf = check_or_create_options(LaunchConfig, config, "launch config")
7544

7645
# TODO: can we ensure kernel_args is valid/safe to use here?
7746
# TODO: merge with HelperKernelParams?
7847
cdef ParamHolder ker_args = ParamHolder(kernel_args)
7948
cdef void** args_ptr = <void**><uintptr_t>(ker_args.ptr)
8049

81-
# Note: We now use CUkernel handles exclusively (CUDA 12+), but they can be cast to
82-
# CUfunction for use with cuLaunchKernel, as both handle types are interchangeable
83-
# for kernel launch purposes.
8450
cdef Kernel ker = <Kernel>kernel
8551
cdef cydriver.CUfunction func_handle = <cydriver.CUfunction>as_cu(ker._h_kernel)
8652

87-
# Note: CUkernel can still be launched via cuLaunchKernel (not just cuLaunchKernelEx).
88-
# We check both binding & driver versions here mainly to see if the "Ex" API is
89-
# available and if so we use it, as it's more feature rich.
90-
if _use_ex:
91-
drv_cfg = conf._to_native_launch_config()
92-
drv_cfg.hStream = as_cu(s._h_stream)
93-
if conf.cooperative_launch:
94-
_check_cooperative_launch(kernel, conf, s)
95-
with nogil:
96-
HANDLE_RETURN(cydriver.cuLaunchKernelEx(&drv_cfg, func_handle, args_ptr, NULL))
97-
else:
98-
# TODO: check if config has any unsupported attrs
99-
HANDLE_RETURN(
100-
cydriver.cuLaunchKernel(
101-
func_handle,
102-
conf.grid[0], conf.grid[1], conf.grid[2],
103-
conf.block[0], conf.block[1], conf.block[2],
104-
conf.shmem_size, as_cu(s._h_stream), args_ptr, NULL
105-
)
106-
)
53+
drv_cfg = conf._to_native_launch_config()
54+
drv_cfg.hStream = as_cu(s._h_stream)
55+
if conf.cooperative_launch:
56+
_check_cooperative_launch(kernel, conf, s)
57+
with nogil:
58+
HANDLE_RETURN(cydriver.cuLaunchKernelEx(&drv_cfg, func_handle, args_ptr, NULL))
10759

10860

10961
cdef _check_cooperative_launch(kernel: Kernel, config: LaunchConfig, stream: Stream):
11062
dev = stream.device
11163
num_sm = dev.properties.multiprocessor_count
11264
max_grid_size = (
113-
kernel.occupancy.max_active_blocks_per_multiprocessor(_reduce_3_tuple(config.block), config.shmem_size) * num_sm
65+
kernel.occupancy.max_active_blocks_per_multiprocessor(prod(config.block), config.shmem_size) * num_sm
11466
)
115-
if _reduce_3_tuple(config.grid) > max_grid_size:
67+
if prod(config.grid) > max_grid_size:
11668
# For now let's try not to be smart and adjust the grid size behind users' back.
11769
# We explicitly ask users to adjust.
11870
x, y, z = config.grid

cuda_core/cuda/core/_linker.pyx

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ from cuda.core._utils.cuda_utils import (
3737
CUDAError,
3838
check_or_create_options,
3939
driver,
40-
handle_return,
4140
is_sequence,
4241
)
4342

@@ -620,9 +619,8 @@ cdef inline void Linker_annotate_error_log(Linker self, object e):
620619

621620
# TODO: revisit this treatment for py313t builds
622621
_driver = None # populated if nvJitLink cannot be used
623-
_driver_ver = None
624622
_inited = False
625-
_use_nvjitlink_backend = False # set by _decide_nvjitlink_or_driver()
623+
_use_nvjitlink_backend = None # set by _decide_nvjitlink_or_driver()
626624

627625
# Input type mappings populated by _lazy_init() with C-level enum ints.
628626
_nvjitlink_input_types = None
@@ -637,13 +635,10 @@ def _nvjitlink_has_version_symbol(nvjitlink) -> bool:
637635
# Note: this function is reused in the tests
638636
def _decide_nvjitlink_or_driver() -> bool:
639637
"""Return True if falling back to the cuLink* driver APIs."""
640-
global _driver_ver, _driver, _use_nvjitlink_backend
641-
if _driver_ver is not None:
638+
global _driver, _use_nvjitlink_backend
639+
if _use_nvjitlink_backend is not None:
642640
return not _use_nvjitlink_backend
643641

644-
_driver_ver = handle_return(driver.cuDriverGetVersion())
645-
_driver_ver = (_driver_ver // 1000, (_driver_ver % 1000) // 10)
646-
647642
warn_txt_common = (
648643
"the driver APIs will be used instead, which do not support"
649644
" minor version compatibility or linking LTO IRs."
@@ -668,6 +663,7 @@ def _decide_nvjitlink_or_driver() -> bool:
668663
)
669664

670665
warn(warn_txt, stacklevel=2, category=RuntimeWarning)
666+
_use_nvjitlink_backend = False
671667
_driver = driver
672668
return True
673669

cuda_core/cuda/core/_memory/_virtual_memory_resource.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616
Transaction,
1717
check_or_create_options,
1818
driver,
19-
get_binding_version,
2019
)
2120
from cuda.core._utils.cuda_utils import (
2221
_check_driver_error as raise_if_driver_error,
2322
)
23+
from cuda.core._utils.version import binding_version
2424

2525
__all__ = ["VirtualMemoryResource", "VirtualMemoryResourceOptions"]
2626

@@ -99,8 +99,7 @@ class VirtualMemoryResourceOptions:
9999
_t = driver.CUmemAllocationType
100100
# CUDA 13+ exposes MANAGED in CUmemAllocationType; older 12.x does not
101101
_allocation_type = {"pinned": _t.CU_MEM_ALLOCATION_TYPE_PINNED} # noqa: RUF012
102-
ver_major, ver_minor = get_binding_version()
103-
if ver_major >= 13:
102+
if binding_version() >= (13, 0, 0):
104103
_allocation_type["managed"] = _t.CU_MEM_ALLOCATION_TYPE_MANAGED
105104

106105
@staticmethod

0 commit comments

Comments
 (0)