@@ -30,8 +30,6 @@ GraphNode hierarchy:
3030
3131from __future__ import annotations
3232
33- from cpython.ref cimport Py_INCREF
34-
3533from libc.stddef cimport size_t
3634from libc.stdint cimport uintptr_t
3735from libc.stdlib cimport malloc, free
@@ -102,16 +100,11 @@ cdef bint _check_node_get_params():
102100 return _has_cuGraphNodeGetParams
103101
104102
105- cdef extern from " Python.h" :
106- void _py_decref " Py_DECREF" (void * )
107-
108-
109- cdef void _py_host_trampoline(void * data) noexcept with gil:
110- (< object > data)()
111-
112-
113- cdef void _py_host_destructor(void * data) noexcept with gil:
114- _py_decref(data)
103+ from cuda.core._graph._utils cimport (
104+ _attach_host_callback_to_graph,
105+ _attach_user_object,
106+ _is_py_host_trampoline,
107+ )
115108
116109
117110cdef void _destroy_event_handle_copy(void * ptr) noexcept nogil:
@@ -124,30 +117,6 @@ cdef void _destroy_kernel_handle_copy(void* ptr) noexcept nogil:
124117 del p
125118
126119
127- cdef void _attach_user_object(
128- cydriver.CUgraph graph, void * ptr,
129- cydriver.CUhostFn destroy) except * :
130- """ Create a CUDA user object and transfer ownership to the graph.
131-
132- On success the graph owns the resource (via MOVE semantics).
133- On failure the destroy callback is invoked to clean up ptr,
134- then a CUDAError is raised — callers need no try/except.
135- """
136- cdef cydriver.CUuserObject user_obj = NULL
137- cdef cydriver.CUresult ret
138- with nogil:
139- ret = cydriver.cuUserObjectCreate(
140- & user_obj, ptr, destroy, 1 ,
141- cydriver.CU_USER_OBJECT_NO_DESTRUCTOR_SYNC)
142- if ret == cydriver.CUDA_SUCCESS:
143- ret = cydriver.cuGraphRetainUserObject(
144- graph, user_obj, 1 , cydriver.CU_GRAPH_USER_OBJECT_MOVE)
145- if ret != cydriver.CUDA_SUCCESS:
146- cydriver.cuUserObjectRelease(user_obj, 1 )
147- if ret != cydriver.CUDA_SUCCESS:
148- if user_obj == NULL :
149- destroy(ptr)
150- HANDLE_RETURN(ret)
151120
152121
153122cdef class Condition:
@@ -1270,56 +1239,20 @@ cdef class GraphNode:
12701239 cdef cydriver.CUgraphNode pred_node = as_cu(self ._h_node)
12711240 cdef cydriver.CUgraphNode* deps = NULL
12721241 cdef size_t num_deps = 0
1273- cdef void* c_user_data = NULL
1274- cdef object callable_obj = None
1275- cdef void* fn_pyobj = NULL
12761242
12771243 if pred_node != NULL:
12781244 deps = & pred_node
12791245 num_deps = 1
12801246
1281- if isinstance(fn , ct._CFuncPtr ):
1282- Py_INCREF(fn)
1283- fn_pyobj = < void * > fn
1284- _attach_user_object(
1285- as_cu(h_graph), fn_pyobj,
1286- < cydriver.CUhostFn> _py_host_destructor)
1287- node_params.fn = < cydriver.CUhostFn>< uintptr_t> ct.cast(
1288- fn, ct.c_void_p).value
1289-
1290- if user_data is not None :
1291- if isinstance (user_data, int ):
1292- c_user_data = < void * >< uintptr_t> user_data
1293- else :
1294- buf = bytes(user_data)
1295- c_user_data = malloc(len (buf))
1296- if c_user_data == NULL :
1297- raise MemoryError (
1298- " failed to allocate user_data buffer" )
1299- c_memcpy(c_user_data, < const char * > buf, len (buf))
1300- _attach_user_object(
1301- as_cu(h_graph), c_user_data,
1302- < cydriver.CUhostFn> free)
1303-
1304- node_params.userData = c_user_data
1305- else :
1306- if user_data is not None :
1307- raise ValueError (
1308- " user_data is only supported with ctypes "
1309- " function pointers" )
1310- callable_obj = fn
1311- Py_INCREF(fn)
1312- fn_pyobj = < void * > fn
1313- node_params.fn = < cydriver.CUhostFn> _py_host_trampoline
1314- node_params.userData = fn_pyobj
1315- _attach_user_object(
1316- as_cu(h_graph), fn_pyobj,
1317- < cydriver.CUhostFn> _py_host_destructor)
1247+ _attach_host_callback_to_graph(
1248+ as_cu(h_graph ), fn , user_data ,
1249+ &node_params.fn , &node_params.userData )
13181250
13191251 with nogil:
13201252 HANDLE_RETURN(cydriver.cuGraphAddHostNode(
13211253 &new_node , as_cu(h_graph ), deps , num_deps , &node_params ))
13221254
1255+ cdef object callable_obj = fn if not isinstance (fn, ct._CFuncPtr) else None
13231256 self._succ_cache = None
13241257 return HostCallbackNode._create_with_params(
13251258 create_graph_node_handle(new_node , h_graph ), callable_obj ,
@@ -1947,7 +1880,7 @@ cdef class HostCallbackNode(GraphNode):
19471880 HANDLE_RETURN(cydriver.cuGraphHostNodeGetParams(node, & params))
19481881
19491882 cdef object callable_obj = None
1950- if params.fn == < cydriver.CUhostFn > _py_host_trampoline :
1883+ if _is_py_host_trampoline( params.fn) :
19511884 callable_obj = < object > params.userData
19521885
19531886 return HostCallbackNode._create_with_params(
0 commit comments