Skip to content

Commit edbc361

Browse files
committed
Add CPU callbacks for stream capture (GraphBuilder.callback)
Implements #1328: host callbacks during stream capture via cuLaunchHostFunc, mirroring the existing GraphDef.callback API. Extracts shared callback infrastructure (_attach_user_object, _attach_host_callback_to_graph, trampoline/destructor) into a new _graph/_utils.pyx module to avoid circular imports between _graph_builder and _graphdef. Made-with: Cursor
1 parent 9b19dbf commit edbc361

5 files changed

Lines changed: 230 additions & 77 deletions

File tree

cuda_core/cuda/core/_graph/_graph_builder.pyx

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,12 @@
55
import weakref
66
from dataclasses import dataclass
77

8+
from cuda.bindings cimport cydriver
9+
10+
from cuda.core._graph._utils cimport _attach_host_callback_to_graph
11+
from cuda.core._resource_handles cimport as_cu
812
from cuda.core._stream cimport Stream
13+
from cuda.core._utils.cuda_utils cimport HANDLE_RETURN
914
from cuda.core._utils.cuda_utils import (
1015
driver,
1116
get_binding_version,
@@ -682,6 +687,57 @@ class GraphBuilder:
682687
)
683688
)
684689

690+
def callback(self, fn, *, user_data=None):
691+
"""Add a host callback to the graph during stream capture.
692+
693+
The callback runs on the host CPU when the graph reaches this point
694+
in execution. Two modes are supported:
695+
696+
- **Python callable**: Pass any callable. The GIL is acquired
697+
automatically. The callable must take no arguments; use closures
698+
or ``functools.partial`` to bind state.
699+
- **ctypes function pointer**: Pass a ``ctypes.CFUNCTYPE`` instance.
700+
The function receives a single ``void*`` argument (the
701+
``user_data``). The caller must keep the ctypes wrapper alive
702+
for the lifetime of the graph.
703+
704+
.. warning::
705+
706+
Callbacks must not call CUDA API functions. Doing so may
707+
deadlock or corrupt driver state.
708+
709+
Parameters
710+
----------
711+
fn : callable or ctypes function pointer
712+
The callback function.
713+
user_data : int or bytes-like, optional
714+
Only for ctypes function pointers. If ``int``, passed as a raw
715+
pointer (caller manages lifetime). If bytes-like, the data is
716+
copied and its lifetime is tied to the graph.
717+
"""
718+
cdef Stream stream = <Stream>self._mnff.stream
719+
cdef cydriver.CUstream c_stream = as_cu(stream._h_stream)
720+
cdef cydriver.CUstreamCaptureStatus capture_status
721+
cdef cydriver.CUgraph c_graph = NULL
722+
723+
with nogil:
724+
IF CUDA_CORE_BUILD_MAJOR >= 13:
725+
HANDLE_RETURN(cydriver.cuStreamGetCaptureInfo(
726+
c_stream, &capture_status, NULL, &c_graph, NULL, NULL, NULL))
727+
ELSE:
728+
HANDLE_RETURN(cydriver.cuStreamGetCaptureInfo(
729+
c_stream, &capture_status, NULL, &c_graph, NULL, NULL))
730+
731+
if capture_status != cydriver.CU_STREAM_CAPTURE_STATUS_ACTIVE:
732+
raise RuntimeError("Cannot add callback when graph is not being built")
733+
734+
cdef cydriver.CUhostFn c_fn
735+
cdef void* c_user_data = NULL
736+
_attach_host_callback_to_graph(c_graph, fn, user_data, &c_fn, &c_user_data)
737+
738+
with nogil:
739+
HANDLE_RETURN(cydriver.cuLaunchHostFunc(c_stream, c_fn, c_user_data))
740+
685741

686742
class Graph:
687743
"""Represents an executable graph.

cuda_core/cuda/core/_graph/_graphdef.pyx

Lines changed: 10 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@ GraphNode hierarchy:
3030

3131
from __future__ import annotations
3232

33-
from cpython.ref cimport Py_INCREF
34-
3533
from libc.stddef cimport size_t
3634
from libc.stdint cimport uintptr_t
3735
from libc.stdlib cimport malloc, free
@@ -102,16 +100,11 @@ cdef bint _check_node_get_params():
102100
return _has_cuGraphNodeGetParams
103101

104102

105-
cdef extern from "Python.h":
106-
void _py_decref "Py_DECREF" (void*)
107-
108-
109-
cdef void _py_host_trampoline(void* data) noexcept with gil:
110-
(<object>data)()
111-
112-
113-
cdef void _py_host_destructor(void* data) noexcept with gil:
114-
_py_decref(data)
103+
from cuda.core._graph._utils cimport (
104+
_attach_host_callback_to_graph,
105+
_attach_user_object,
106+
_is_py_host_trampoline,
107+
)
115108

116109

117110
cdef void _destroy_event_handle_copy(void* ptr) noexcept nogil:
@@ -124,30 +117,6 @@ cdef void _destroy_kernel_handle_copy(void* ptr) noexcept nogil:
124117
del p
125118

126119

127-
cdef void _attach_user_object(
128-
cydriver.CUgraph graph, void* ptr,
129-
cydriver.CUhostFn destroy) except *:
130-
"""Create a CUDA user object and transfer ownership to the graph.
131-
132-
On success the graph owns the resource (via MOVE semantics).
133-
On failure the destroy callback is invoked to clean up ptr,
134-
then a CUDAError is raised — callers need no try/except.
135-
"""
136-
cdef cydriver.CUuserObject user_obj = NULL
137-
cdef cydriver.CUresult ret
138-
with nogil:
139-
ret = cydriver.cuUserObjectCreate(
140-
&user_obj, ptr, destroy, 1,
141-
cydriver.CU_USER_OBJECT_NO_DESTRUCTOR_SYNC)
142-
if ret == cydriver.CUDA_SUCCESS:
143-
ret = cydriver.cuGraphRetainUserObject(
144-
graph, user_obj, 1, cydriver.CU_GRAPH_USER_OBJECT_MOVE)
145-
if ret != cydriver.CUDA_SUCCESS:
146-
cydriver.cuUserObjectRelease(user_obj, 1)
147-
if ret != cydriver.CUDA_SUCCESS:
148-
if user_obj == NULL:
149-
destroy(ptr)
150-
HANDLE_RETURN(ret)
151120

152121

153122
cdef class Condition:
@@ -1270,56 +1239,20 @@ cdef class GraphNode:
12701239
cdef cydriver.CUgraphNode pred_node = as_cu(self._h_node)
12711240
cdef cydriver.CUgraphNode* deps = NULL
12721241
cdef size_t num_deps = 0
1273-
cdef void* c_user_data = NULL
1274-
cdef object callable_obj = None
1275-
cdef void* fn_pyobj = NULL
12761242

12771243
if pred_node != NULL:
12781244
deps = &pred_node
12791245
num_deps = 1
12801246

1281-
if isinstance(fn, ct._CFuncPtr):
1282-
Py_INCREF(fn)
1283-
fn_pyobj = <void*>fn
1284-
_attach_user_object(
1285-
as_cu(h_graph), fn_pyobj,
1286-
<cydriver.CUhostFn>_py_host_destructor)
1287-
node_params.fn = <cydriver.CUhostFn><uintptr_t>ct.cast(
1288-
fn, ct.c_void_p).value
1289-
1290-
if user_data is not None:
1291-
if isinstance(user_data, int):
1292-
c_user_data = <void*><uintptr_t>user_data
1293-
else:
1294-
buf = bytes(user_data)
1295-
c_user_data = malloc(len(buf))
1296-
if c_user_data == NULL:
1297-
raise MemoryError(
1298-
"failed to allocate user_data buffer")
1299-
c_memcpy(c_user_data, <const char*>buf, len(buf))
1300-
_attach_user_object(
1301-
as_cu(h_graph), c_user_data,
1302-
<cydriver.CUhostFn>free)
1303-
1304-
node_params.userData = c_user_data
1305-
else:
1306-
if user_data is not None:
1307-
raise ValueError(
1308-
"user_data is only supported with ctypes "
1309-
"function pointers")
1310-
callable_obj = fn
1311-
Py_INCREF(fn)
1312-
fn_pyobj = <void*>fn
1313-
node_params.fn = <cydriver.CUhostFn>_py_host_trampoline
1314-
node_params.userData = fn_pyobj
1315-
_attach_user_object(
1316-
as_cu(h_graph), fn_pyobj,
1317-
<cydriver.CUhostFn>_py_host_destructor)
1247+
_attach_host_callback_to_graph(
1248+
as_cu(h_graph), fn, user_data,
1249+
&node_params.fn, &node_params.userData)
13181250

13191251
with nogil:
13201252
HANDLE_RETURN(cydriver.cuGraphAddHostNode(
13211253
&new_node, as_cu(h_graph), deps, num_deps, &node_params))
13221254

1255+
cdef object callable_obj = fn if not isinstance(fn, ct._CFuncPtr) else None
13231256
self._succ_cache = None
13241257
return HostCallbackNode._create_with_params(
13251258
create_graph_node_handle(new_node, h_graph), callable_obj,
@@ -1947,7 +1880,7 @@ cdef class HostCallbackNode(GraphNode):
19471880
HANDLE_RETURN(cydriver.cuGraphHostNodeGetParams(node, &params))
19481881

19491882
cdef object callable_obj = None
1950-
if params.fn == <cydriver.CUhostFn>_py_host_trampoline:
1883+
if _is_py_host_trampoline(params.fn):
19511884
callable_obj = <object>params.userData
19521885

19531886
return HostCallbackNode._create_with_params(
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from cuda.bindings cimport cydriver
6+
7+
8+
cdef bint _is_py_host_trampoline(cydriver.CUhostFn fn) noexcept nogil
9+
10+
cdef void _attach_user_object(
11+
cydriver.CUgraph graph, void* ptr,
12+
cydriver.CUhostFn destroy) except *
13+
14+
cdef void _attach_host_callback_to_graph(
15+
cydriver.CUgraph graph, object fn, object user_data,
16+
cydriver.CUhostFn* out_fn, void** out_user_data) except *
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from cpython.ref cimport Py_INCREF
6+
7+
from libc.stdint cimport uintptr_t
8+
from libc.stdlib cimport malloc, free
9+
from libc.string cimport memcpy as c_memcpy
10+
11+
from cuda.bindings cimport cydriver
12+
13+
from cuda.core._utils.cuda_utils cimport HANDLE_RETURN
14+
15+
16+
cdef extern from "Python.h":
17+
void _py_decref "Py_DECREF" (void*)
18+
19+
20+
cdef void _py_host_trampoline(void* data) noexcept with gil:
21+
(<object>data)()
22+
23+
24+
cdef void _py_host_destructor(void* data) noexcept with gil:
25+
_py_decref(data)
26+
27+
28+
cdef bint _is_py_host_trampoline(cydriver.CUhostFn fn) noexcept nogil:
29+
return fn == <cydriver.CUhostFn>_py_host_trampoline
30+
31+
32+
cdef void _attach_user_object(
33+
cydriver.CUgraph graph, void* ptr,
34+
cydriver.CUhostFn destroy) except *:
35+
"""Create a CUDA user object and transfer ownership to the graph.
36+
37+
On success the graph owns the resource (via MOVE semantics).
38+
On failure the destroy callback is invoked to clean up ptr,
39+
then a CUDAError is raised — callers need no try/except.
40+
"""
41+
cdef cydriver.CUuserObject user_obj = NULL
42+
cdef cydriver.CUresult ret
43+
with nogil:
44+
ret = cydriver.cuUserObjectCreate(
45+
&user_obj, ptr, destroy, 1,
46+
cydriver.CU_USER_OBJECT_NO_DESTRUCTOR_SYNC)
47+
if ret == cydriver.CUDA_SUCCESS:
48+
ret = cydriver.cuGraphRetainUserObject(
49+
graph, user_obj, 1, cydriver.CU_GRAPH_USER_OBJECT_MOVE)
50+
if ret != cydriver.CUDA_SUCCESS:
51+
cydriver.cuUserObjectRelease(user_obj, 1)
52+
if ret != cydriver.CUDA_SUCCESS:
53+
if user_obj == NULL:
54+
destroy(ptr)
55+
HANDLE_RETURN(ret)
56+
57+
58+
cdef void _attach_host_callback_to_graph(
59+
cydriver.CUgraph graph, object fn, object user_data,
60+
cydriver.CUhostFn* out_fn, void** out_user_data) except *:
61+
"""Resolve a Python callable or ctypes CFuncPtr into a C callback pair.
62+
63+
Handles Py_INCREF, user-object attachment for lifetime management,
64+
and user_data copying. On return, *out_fn and *out_user_data are
65+
ready to pass to cuGraphAddHostNode or cuLaunchHostFunc.
66+
"""
67+
import ctypes as ct
68+
69+
cdef void* fn_pyobj = NULL
70+
71+
if isinstance(fn, ct._CFuncPtr):
72+
Py_INCREF(fn)
73+
fn_pyobj = <void*>fn
74+
_attach_user_object(
75+
graph, fn_pyobj,
76+
<cydriver.CUhostFn>_py_host_destructor)
77+
out_fn[0] = <cydriver.CUhostFn><uintptr_t>ct.cast(
78+
fn, ct.c_void_p).value
79+
80+
if user_data is not None:
81+
if isinstance(user_data, int):
82+
out_user_data[0] = <void*><uintptr_t>user_data
83+
else:
84+
buf = bytes(user_data)
85+
out_user_data[0] = malloc(len(buf))
86+
if out_user_data[0] == NULL:
87+
raise MemoryError(
88+
"failed to allocate user_data buffer")
89+
c_memcpy(out_user_data[0], <const char*>buf, len(buf))
90+
_attach_user_object(
91+
graph, out_user_data[0],
92+
<cydriver.CUhostFn>free)
93+
else:
94+
out_user_data[0] = NULL
95+
else:
96+
if user_data is not None:
97+
raise ValueError(
98+
"user_data is only supported with ctypes "
99+
"function pointers")
100+
Py_INCREF(fn)
101+
fn_pyobj = <void*>fn
102+
out_fn[0] = <cydriver.CUhostFn>_py_host_trampoline
103+
out_user_data[0] = fn_pyobj
104+
_attach_user_object(
105+
graph, fn_pyobj,
106+
<cydriver.CUhostFn>_py_host_destructor)

cuda_core/tests/graph/test_basic.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,45 @@ def test_graph_capture_errors(init_cuda):
163163
with pytest.raises(RuntimeError, match="^Graph has not finished building."):
164164
gb.complete()
165165
gb.end_building().complete()
166+
167+
168+
def test_graph_capture_callback_python(init_cuda):
169+
results = []
170+
171+
def my_callback():
172+
results.append(42)
173+
174+
launch_stream = Device().create_stream()
175+
gb = launch_stream.create_graph_builder().begin_building()
176+
177+
with pytest.raises(ValueError, match="user_data is only supported"):
178+
gb.callback(my_callback, user_data=b"hello")
179+
180+
gb.callback(my_callback)
181+
graph = gb.end_building().complete()
182+
183+
graph.launch(launch_stream)
184+
launch_stream.sync()
185+
186+
assert results == [42]
187+
188+
189+
def test_graph_capture_callback_ctypes(init_cuda):
190+
import ctypes
191+
192+
CALLBACK = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
193+
result = [0]
194+
195+
@CALLBACK
196+
def read_byte(data):
197+
result[0] = ctypes.cast(data, ctypes.POINTER(ctypes.c_uint8))[0]
198+
199+
launch_stream = Device().create_stream()
200+
gb = launch_stream.create_graph_builder().begin_building()
201+
gb.callback(read_byte, user_data=bytes([0xAB]))
202+
graph = gb.end_building().complete()
203+
204+
graph.launch(launch_stream)
205+
launch_stream.sync()
206+
207+
assert result[0] == 0xAB

0 commit comments

Comments
 (0)