Skip to content

Commit 613be11

Browse files
authored
Add CPU callbacks for stream capture & Cythonize GraphBuilder (#1814)
* Cythonize _graph/_graph_builder (move from pure Python to .pyx) Move the GraphBuilder/Graph/GraphCompleteOptions/GraphDebugPrintOptions implementation out of _graph/__init__.py into _graph/_graph_builder.pyx so it is compiled by Cython. A thin __init__.py re-exports the public names so all existing import sites continue to work unchanged. Cython compatibility adjustments: - Remove `from __future__ import annotations` (unsupported by Cython) - Remove TYPE_CHECKING guard; quote annotations that reference Stream (circular import), forward-reference GraphBuilder/Graph, or use X | None union syntax - Update _graphdef.pyx lazy imports to point directly at _graph_builder No build_hooks.py changes needed — the build system auto-discovers .pyx files via glob. Ref: #1076 Made-with: Cursor * Remove _lazy_init from _graph_builder; add cached get_driver_version Replace the per-module _lazy_init / _inited / _driver_ver / _py_major_minor pattern in _graph_builder.pyx with direct calls to centralized cached functions in cuda_utils: - Add get_driver_version() with @functools.cache alongside get_binding_version - Switch get_binding_version from @functools.lru_cache to @functools.cache (cleaner for nullary functions) - Fix split() to return tuple(result) — Cython enforces return type annotations unlike pure Python - Fix _cond_with_params annotation from -> GraphBuilder to -> tuple to match actual return value Made-with: Cursor * Add CPU callbacks for stream capture (GraphBuilder.callback) Implements #1328: host callbacks during stream capture via cuLaunchHostFunc, mirroring the existing GraphDef.callback API. Extracts shared callback infrastructure (_attach_user_object, _attach_host_callback_to_graph, trampoline/destructor) into a new _graph/_utils.pyx module to avoid circular imports between _graph_builder and _graphdef. Made-with: Cursor
1 parent fc5babd commit 613be11

File tree

7 files changed

+1026
-871
lines changed

7 files changed

+1026
-871
lines changed

cuda_core/cuda/core/_graph/__init__.py

Lines changed: 14 additions & 791 deletions
Large diffs are not rendered by default.

cuda_core/cuda/core/_graph/_graph_builder.pyx

Lines changed: 831 additions & 0 deletions
Large diffs are not rendered by default.

cuda_core/cuda/core/_graph/_graphdef.pyx

Lines changed: 12 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@ GraphNode hierarchy:
3030

3131
from __future__ import annotations
3232

33-
from cpython.ref cimport Py_INCREF
34-
3533
from libc.stddef cimport size_t
3634
from libc.stdint cimport uintptr_t
3735
from libc.stdlib cimport malloc, free
@@ -102,16 +100,11 @@ cdef bint _check_node_get_params():
102100
return _has_cuGraphNodeGetParams
103101

104102

105-
cdef extern from "Python.h":
106-
void _py_decref "Py_DECREF" (void*)
107-
108-
109-
cdef void _py_host_trampoline(void* data) noexcept with gil:
110-
(<object>data)()
111-
112-
113-
cdef void _py_host_destructor(void* data) noexcept with gil:
114-
_py_decref(data)
103+
from cuda.core._graph._utils cimport (
104+
_attach_host_callback_to_graph,
105+
_attach_user_object,
106+
_is_py_host_trampoline,
107+
)
115108

116109

117110
cdef void _destroy_event_handle_copy(void* ptr) noexcept nogil:
@@ -124,30 +117,6 @@ cdef void _destroy_kernel_handle_copy(void* ptr) noexcept nogil:
124117
del p
125118

126119

127-
cdef void _attach_user_object(
128-
cydriver.CUgraph graph, void* ptr,
129-
cydriver.CUhostFn destroy) except *:
130-
"""Create a CUDA user object and transfer ownership to the graph.
131-
132-
On success the graph owns the resource (via MOVE semantics).
133-
On failure the destroy callback is invoked to clean up ptr,
134-
then a CUDAError is raised — callers need no try/except.
135-
"""
136-
cdef cydriver.CUuserObject user_obj = NULL
137-
cdef cydriver.CUresult ret
138-
with nogil:
139-
ret = cydriver.cuUserObjectCreate(
140-
&user_obj, ptr, destroy, 1,
141-
cydriver.CU_USER_OBJECT_NO_DESTRUCTOR_SYNC)
142-
if ret == cydriver.CUDA_SUCCESS:
143-
ret = cydriver.cuGraphRetainUserObject(
144-
graph, user_obj, 1, cydriver.CU_GRAPH_USER_OBJECT_MOVE)
145-
if ret != cydriver.CUDA_SUCCESS:
146-
cydriver.cuUserObjectRelease(user_obj, 1)
147-
if ret != cydriver.CUDA_SUCCESS:
148-
if user_obj == NULL:
149-
destroy(ptr)
150-
HANDLE_RETURN(ret)
151120

152121

153122
cdef class Condition:
@@ -470,7 +439,7 @@ cdef class GraphDef:
470439
Graph
471440
An executable graph that can be launched on a stream.
472441
"""
473-
from cuda.core._graph import _instantiate_graph
442+
from cuda.core._graph._graph_builder import _instantiate_graph
474443

475444
return _instantiate_graph(
476445
driver.CUgraph(as_intptr(self._h_graph)), options)
@@ -485,7 +454,7 @@ cdef class GraphDef:
485454
options : GraphDebugPrintOptions, optional
486455
Customizable options for the debug print.
487456
"""
488-
from cuda.core._graph import GraphDebugPrintOptions
457+
from cuda.core._graph._graph_builder import GraphDebugPrintOptions
489458

490459
cdef unsigned int flags = 0
491460
if options is not None:
@@ -1270,56 +1239,20 @@ cdef class GraphNode:
12701239
cdef cydriver.CUgraphNode pred_node = as_cu(self._h_node)
12711240
cdef cydriver.CUgraphNode* deps = NULL
12721241
cdef size_t num_deps = 0
1273-
cdef void* c_user_data = NULL
1274-
cdef object callable_obj = None
1275-
cdef void* fn_pyobj = NULL
12761242

12771243
if pred_node != NULL:
12781244
deps = &pred_node
12791245
num_deps = 1
12801246

1281-
if isinstance(fn, ct._CFuncPtr):
1282-
Py_INCREF(fn)
1283-
fn_pyobj = <void*>fn
1284-
_attach_user_object(
1285-
as_cu(h_graph), fn_pyobj,
1286-
<cydriver.CUhostFn>_py_host_destructor)
1287-
node_params.fn = <cydriver.CUhostFn><uintptr_t>ct.cast(
1288-
fn, ct.c_void_p).value
1289-
1290-
if user_data is not None:
1291-
if isinstance(user_data, int):
1292-
c_user_data = <void*><uintptr_t>user_data
1293-
else:
1294-
buf = bytes(user_data)
1295-
c_user_data = malloc(len(buf))
1296-
if c_user_data == NULL:
1297-
raise MemoryError(
1298-
"failed to allocate user_data buffer")
1299-
c_memcpy(c_user_data, <const char*>buf, len(buf))
1300-
_attach_user_object(
1301-
as_cu(h_graph), c_user_data,
1302-
<cydriver.CUhostFn>free)
1303-
1304-
node_params.userData = c_user_data
1305-
else:
1306-
if user_data is not None:
1307-
raise ValueError(
1308-
"user_data is only supported with ctypes "
1309-
"function pointers")
1310-
callable_obj = fn
1311-
Py_INCREF(fn)
1312-
fn_pyobj = <void*>fn
1313-
node_params.fn = <cydriver.CUhostFn>_py_host_trampoline
1314-
node_params.userData = fn_pyobj
1315-
_attach_user_object(
1316-
as_cu(h_graph), fn_pyobj,
1317-
<cydriver.CUhostFn>_py_host_destructor)
1247+
_attach_host_callback_to_graph(
1248+
as_cu(h_graph), fn, user_data,
1249+
&node_params.fn, &node_params.userData)
13181250

13191251
with nogil:
13201252
HANDLE_RETURN(cydriver.cuGraphAddHostNode(
13211253
&new_node, as_cu(h_graph), deps, num_deps, &node_params))
13221254

1255+
cdef object callable_obj = fn if not isinstance(fn, ct._CFuncPtr) else None
13231256
self._succ_cache = None
13241257
return HostCallbackNode._create_with_params(
13251258
create_graph_node_handle(new_node, h_graph), callable_obj,
@@ -1947,7 +1880,7 @@ cdef class HostCallbackNode(GraphNode):
19471880
HANDLE_RETURN(cydriver.cuGraphHostNodeGetParams(node, &params))
19481881

19491882
cdef object callable_obj = None
1950-
if params.fn == <cydriver.CUhostFn>_py_host_trampoline:
1883+
if _is_py_host_trampoline(params.fn):
19511884
callable_obj = <object>params.userData
19521885

19531886
return HostCallbackNode._create_with_params(
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from cuda.bindings cimport cydriver
6+
7+
8+
cdef bint _is_py_host_trampoline(cydriver.CUhostFn fn) noexcept nogil
9+
10+
cdef void _attach_user_object(
11+
cydriver.CUgraph graph, void* ptr,
12+
cydriver.CUhostFn destroy) except *
13+
14+
cdef void _attach_host_callback_to_graph(
15+
cydriver.CUgraph graph, object fn, object user_data,
16+
cydriver.CUhostFn* out_fn, void** out_user_data) except *
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from cpython.ref cimport Py_INCREF
6+
7+
from libc.stdint cimport uintptr_t
8+
from libc.stdlib cimport malloc, free
9+
from libc.string cimport memcpy as c_memcpy
10+
11+
from cuda.bindings cimport cydriver
12+
13+
from cuda.core._utils.cuda_utils cimport HANDLE_RETURN
14+
15+
16+
cdef extern from "Python.h":
17+
void _py_decref "Py_DECREF" (void*)
18+
19+
20+
cdef void _py_host_trampoline(void* data) noexcept with gil:
21+
(<object>data)()
22+
23+
24+
cdef void _py_host_destructor(void* data) noexcept with gil:
25+
_py_decref(data)
26+
27+
28+
cdef bint _is_py_host_trampoline(cydriver.CUhostFn fn) noexcept nogil:
29+
return fn == <cydriver.CUhostFn>_py_host_trampoline
30+
31+
32+
cdef void _attach_user_object(
33+
cydriver.CUgraph graph, void* ptr,
34+
cydriver.CUhostFn destroy) except *:
35+
"""Create a CUDA user object and transfer ownership to the graph.
36+
37+
On success the graph owns the resource (via MOVE semantics).
38+
On failure the destroy callback is invoked to clean up ptr,
39+
then a CUDAError is raised — callers need no try/except.
40+
"""
41+
cdef cydriver.CUuserObject user_obj = NULL
42+
cdef cydriver.CUresult ret
43+
with nogil:
44+
ret = cydriver.cuUserObjectCreate(
45+
&user_obj, ptr, destroy, 1,
46+
cydriver.CU_USER_OBJECT_NO_DESTRUCTOR_SYNC)
47+
if ret == cydriver.CUDA_SUCCESS:
48+
ret = cydriver.cuGraphRetainUserObject(
49+
graph, user_obj, 1, cydriver.CU_GRAPH_USER_OBJECT_MOVE)
50+
if ret != cydriver.CUDA_SUCCESS:
51+
cydriver.cuUserObjectRelease(user_obj, 1)
52+
if ret != cydriver.CUDA_SUCCESS:
53+
if user_obj == NULL:
54+
destroy(ptr)
55+
HANDLE_RETURN(ret)
56+
57+
58+
cdef void _attach_host_callback_to_graph(
59+
cydriver.CUgraph graph, object fn, object user_data,
60+
cydriver.CUhostFn* out_fn, void** out_user_data) except *:
61+
"""Resolve a Python callable or ctypes CFuncPtr into a C callback pair.
62+
63+
Handles Py_INCREF, user-object attachment for lifetime management,
64+
and user_data copying. On return, *out_fn and *out_user_data are
65+
ready to pass to cuGraphAddHostNode or cuLaunchHostFunc.
66+
"""
67+
import ctypes as ct
68+
69+
cdef void* fn_pyobj = NULL
70+
71+
if isinstance(fn, ct._CFuncPtr):
72+
Py_INCREF(fn)
73+
fn_pyobj = <void*>fn
74+
_attach_user_object(
75+
graph, fn_pyobj,
76+
<cydriver.CUhostFn>_py_host_destructor)
77+
out_fn[0] = <cydriver.CUhostFn><uintptr_t>ct.cast(
78+
fn, ct.c_void_p).value
79+
80+
if user_data is not None:
81+
if isinstance(user_data, int):
82+
out_user_data[0] = <void*><uintptr_t>user_data
83+
else:
84+
buf = bytes(user_data)
85+
out_user_data[0] = malloc(len(buf))
86+
if out_user_data[0] == NULL:
87+
raise MemoryError(
88+
"failed to allocate user_data buffer")
89+
c_memcpy(out_user_data[0], <const char*>buf, len(buf))
90+
_attach_user_object(
91+
graph, out_user_data[0],
92+
<cydriver.CUhostFn>free)
93+
else:
94+
out_user_data[0] = NULL
95+
else:
96+
if user_data is not None:
97+
raise ValueError(
98+
"user_data is only supported with ctypes "
99+
"function pointers")
100+
Py_INCREF(fn)
101+
fn_pyobj = <void*>fn
102+
out_fn[0] = <cydriver.CUhostFn>_py_host_trampoline
103+
out_user_data[0] = fn_pyobj
104+
_attach_user_object(
105+
graph, fn_pyobj,
106+
<cydriver.CUhostFn>_py_host_destructor)

cuda_core/cuda/core/_utils/cuda_utils.pyx

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,14 +298,18 @@ def is_nested_sequence(obj):
298298
return is_sequence(obj) and any(is_sequence(elem) for elem in obj)
299299

300300

301-
@functools.lru_cache
301+
@functools.cache
302302
def get_binding_version():
303303
try:
304304
major_minor = importlib.metadata.version("cuda-bindings").split(".")[:2]
305305
except importlib.metadata.PackageNotFoundError:
306306
major_minor = importlib.metadata.version("cuda-python").split(".")[:2]
307307
return tuple(int(v) for v in major_minor)
308308

309+
@functools.cache
310+
def get_driver_version():
311+
return handle_return(driver.cuDriverGetVersion())
312+
309313

310314
class Transaction:
311315
"""

cuda_core/tests/graph/test_basic.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,45 @@ def test_graph_capture_errors(init_cuda):
163163
with pytest.raises(RuntimeError, match="^Graph has not finished building."):
164164
gb.complete()
165165
gb.end_building().complete()
166+
167+
168+
def test_graph_capture_callback_python(init_cuda):
169+
results = []
170+
171+
def my_callback():
172+
results.append(42)
173+
174+
launch_stream = Device().create_stream()
175+
gb = launch_stream.create_graph_builder().begin_building()
176+
177+
with pytest.raises(ValueError, match="user_data is only supported"):
178+
gb.callback(my_callback, user_data=b"hello")
179+
180+
gb.callback(my_callback)
181+
graph = gb.end_building().complete()
182+
183+
graph.launch(launch_stream)
184+
launch_stream.sync()
185+
186+
assert results == [42]
187+
188+
189+
def test_graph_capture_callback_ctypes(init_cuda):
190+
import ctypes
191+
192+
CALLBACK = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
193+
result = [0]
194+
195+
@CALLBACK
196+
def read_byte(data):
197+
result[0] = ctypes.cast(data, ctypes.POINTER(ctypes.c_uint8))[0]
198+
199+
launch_stream = Device().create_stream()
200+
gb = launch_stream.create_graph_builder().begin_building()
201+
gb.callback(read_byte, user_data=bytes([0xAB]))
202+
graph = gb.end_building().complete()
203+
204+
graph.launch(launch_stream)
205+
launch_stream.sync()
206+
207+
assert result[0] == 0xAB

0 commit comments

Comments
 (0)