Skip to content

Commit 811ec27

Browse files
authored
Add explicit CUDA graph construction API (GraphDef, GraphNode) (#1772)
* Add explicit CUDA graph construction API (GraphDef, GraphNode) Introduces GraphDef and GraphNode types for explicit CUDA graph construction, with a full node hierarchy, shared instantiation helper with GraphCompleteOptions support, and comprehensive tests. Made-with: Cursor * Skip alloc_managed tests on devices without managed memory pool support Made-with: Cursor
1 parent 3f92376 commit 811ec27

16 files changed

+5475
-141
lines changed

cuda_core/cuda/core/_cpp/resource_handles.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ decltype(&cuLibraryLoadData) p_cuLibraryLoadData = nullptr;
5656
decltype(&cuLibraryUnload) p_cuLibraryUnload = nullptr;
5757
decltype(&cuLibraryGetKernel) p_cuLibraryGetKernel = nullptr;
5858

59+
// Graph
60+
decltype(&cuGraphDestroy) p_cuGraphDestroy = nullptr;
61+
5962
// Linker
6063
decltype(&cuLinkDestroy) p_cuLinkDestroy = nullptr;
6164

@@ -919,6 +922,57 @@ LibraryHandle get_kernel_library(const KernelHandle& h) noexcept {
919922
return get_box(h)->h_library;
920923
}
921924

925+
// ============================================================================
926+
// Graph Handles
927+
// ============================================================================
928+
929+
namespace {
930+
struct GraphBox {
931+
CUgraph resource;
932+
GraphHandle h_parent; // Keeps parent alive for child/branch graphs
933+
};
934+
} // namespace
935+
936+
GraphHandle create_graph_handle(CUgraph graph) {
937+
auto box = std::shared_ptr<const GraphBox>(
938+
new GraphBox{graph, {}},
939+
[](const GraphBox* b) {
940+
GILReleaseGuard gil;
941+
p_cuGraphDestroy(b->resource);
942+
delete b;
943+
}
944+
);
945+
return GraphHandle(box, &box->resource);
946+
}
947+
948+
GraphHandle create_graph_handle_ref(CUgraph graph, const GraphHandle& h_parent) {
949+
auto box = std::make_shared<const GraphBox>(GraphBox{graph, h_parent});
950+
return GraphHandle(box, &box->resource);
951+
}
952+
953+
namespace {
954+
struct GraphNodeBox {
955+
CUgraphNode resource;
956+
GraphHandle h_graph;
957+
};
958+
} // namespace
959+
960+
static const GraphNodeBox* get_box(const GraphNodeHandle& h) {
961+
const CUgraphNode* p = h.get();
962+
return reinterpret_cast<const GraphNodeBox*>(
963+
reinterpret_cast<const char*>(p) - offsetof(GraphNodeBox, resource)
964+
);
965+
}
966+
967+
GraphNodeHandle create_graph_node_handle(CUgraphNode node, const GraphHandle& h_graph) {
968+
auto box = std::make_shared<const GraphNodeBox>(GraphNodeBox{node, h_graph});
969+
return GraphNodeHandle(box, &box->resource);
970+
}
971+
972+
GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept {
973+
return h ? get_box(h)->h_graph : GraphHandle{};
974+
}
975+
922976
// ============================================================================
923977
// Graphics Resource Handles
924978
// ============================================================================

cuda_core/cuda/core/_cpp/resource_handles.hpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ extern decltype(&cuLibraryLoadData) p_cuLibraryLoadData;
9292
extern decltype(&cuLibraryUnload) p_cuLibraryUnload;
9393
extern decltype(&cuLibraryGetKernel) p_cuLibraryGetKernel;
9494

95+
// Graph
96+
extern decltype(&cuGraphDestroy) p_cuGraphDestroy;
97+
9598
// Linker
9699
extern decltype(&cuLinkDestroy) p_cuLinkDestroy;
97100

@@ -144,6 +147,8 @@ using EventHandle = std::shared_ptr<const CUevent>;
144147
using MemoryPoolHandle = std::shared_ptr<const CUmemoryPool>;
145148
using LibraryHandle = std::shared_ptr<const CUlibrary>;
146149
using KernelHandle = std::shared_ptr<const CUkernel>;
150+
using GraphHandle = std::shared_ptr<const CUgraph>;
151+
using GraphNodeHandle = std::shared_ptr<const CUgraphNode>;
147152
using GraphicsResourceHandle = std::shared_ptr<const CUgraphicsResource>;
148153
using NvrtcProgramHandle = std::shared_ptr<const nvrtcProgram>;
149154
using NvvmProgramHandle = std::shared_ptr<const NvvmProgramValue>;
@@ -382,6 +387,33 @@ KernelHandle create_kernel_handle_ref(CUkernel kernel);
382387
// Returns empty handle if the kernel has no library dependency.
383388
LibraryHandle get_kernel_library(const KernelHandle& h) noexcept;
384389

390+
// ============================================================================
391+
// Graph handle functions
392+
// ============================================================================
393+
394+
// Wrap an externally-created CUgraph with RAII cleanup.
395+
// When the last reference is released, cuGraphDestroy is called automatically.
396+
// The caller must have already created the graph via cuGraphCreate.
397+
GraphHandle create_graph_handle(CUgraph graph);
398+
399+
// Create a non-owning graph handle that keeps h_parent alive.
400+
// Use for graphs owned by a child/conditional node in a parent graph.
401+
// The child graph will NOT be destroyed when this handle is released,
402+
// but h_parent will be prevented from destruction while this handle exists.
403+
GraphHandle create_graph_handle_ref(CUgraph graph, const GraphHandle& h_parent);
404+
405+
// ============================================================================
406+
// Graph node handle functions
407+
// ============================================================================
408+
409+
// Create a node handle. Nodes are owned by their parent graph (not
410+
// independently destroyable). The GraphHandle dependency ensures the
411+
// graph outlives any node reference.
412+
GraphNodeHandle create_graph_node_handle(CUgraphNode node, const GraphHandle& h_graph);
413+
414+
// Extract the owning graph handle from a node handle.
415+
GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept;
416+
385417
// ============================================================================
386418
// Graphics resource handle functions
387419
// ============================================================================
@@ -478,6 +510,14 @@ inline CUkernel as_cu(const KernelHandle& h) noexcept {
478510
return h ? *h : nullptr;
479511
}
480512

513+
inline CUgraph as_cu(const GraphHandle& h) noexcept {
514+
return h ? *h : nullptr;
515+
}
516+
517+
inline CUgraphNode as_cu(const GraphNodeHandle& h) noexcept {
518+
return h ? *h : nullptr;
519+
}
520+
481521
inline CUgraphicsResource as_cu(const GraphicsResourceHandle& h) noexcept {
482522
return h ? *h : nullptr;
483523
}
@@ -528,6 +568,14 @@ inline std::intptr_t as_intptr(const KernelHandle& h) noexcept {
528568
return reinterpret_cast<std::intptr_t>(as_cu(h));
529569
}
530570

571+
inline std::intptr_t as_intptr(const GraphHandle& h) noexcept {
572+
return reinterpret_cast<std::intptr_t>(as_cu(h));
573+
}
574+
575+
inline std::intptr_t as_intptr(const GraphNodeHandle& h) noexcept {
576+
return reinterpret_cast<std::intptr_t>(as_cu(h));
577+
}
578+
531579
inline std::intptr_t as_intptr(const GraphicsResourceHandle& h) noexcept {
532580
return reinterpret_cast<std::intptr_t>(as_cu(h));
533581
}
@@ -606,6 +654,17 @@ inline PyObject* as_py(const KernelHandle& h) noexcept {
606654
return detail::make_py("cuda.bindings.driver", "CUkernel", as_intptr(h));
607655
}
608656

657+
inline PyObject* as_py(const GraphHandle& h) noexcept {
658+
return detail::make_py("cuda.bindings.driver", "CUgraph", as_intptr(h));
659+
}
660+
661+
inline PyObject* as_py(const GraphNodeHandle& h) noexcept {
662+
if (!as_intptr(h)) {
663+
Py_RETURN_NONE;
664+
}
665+
return detail::make_py("cuda.bindings.driver", "CUgraphNode", as_intptr(h));
666+
}
667+
609668
inline PyObject* as_py(const NvrtcProgramHandle& h) noexcept {
610669
return detail::make_py("cuda.bindings.nvrtc", "nvrtcProgram", as_intptr(h));
611670
}

cuda_core/cuda/core/_graph/__init__.py

Lines changed: 77 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,43 @@ class GraphDebugPrintOptions:
9191
extra_topo_info: bool = False
9292
conditional_node_params: bool = False
9393

94+
def _to_flags(self) -> int:
95+
"""Convert options to CUDA driver API flags (internal use)."""
96+
flags = 0
97+
if self.verbose:
98+
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_VERBOSE
99+
if self.runtime_types:
100+
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_RUNTIME_TYPES
101+
if self.kernel_node_params:
102+
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_PARAMS
103+
if self.memcpy_node_params:
104+
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMCPY_NODE_PARAMS
105+
if self.memset_node_params:
106+
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMSET_NODE_PARAMS
107+
if self.host_node_params:
108+
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HOST_NODE_PARAMS
109+
if self.event_node_params:
110+
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EVENT_NODE_PARAMS
111+
if self.ext_semas_signal_node_params:
112+
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_SIGNAL_NODE_PARAMS
113+
if self.ext_semas_wait_node_params:
114+
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_WAIT_NODE_PARAMS
115+
if self.kernel_node_attributes:
116+
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_ATTRIBUTES
117+
if self.handles:
118+
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HANDLES
119+
if self.mem_alloc_node_params:
120+
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_ALLOC_NODE_PARAMS
121+
if self.mem_free_node_params:
122+
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_FREE_NODE_PARAMS
123+
if self.batch_mem_op_node_params:
124+
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_BATCH_MEM_OP_NODE_PARAMS
125+
if self.extra_topo_info:
126+
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXTRA_TOPO_INFO
127+
if self.conditional_node_params:
128+
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_CONDITIONAL_NODE_PARAMS
129+
return flags
130+
94131

95132
@dataclass
96133
class GraphCompleteOptions:
@@ -118,6 +155,44 @@ class GraphCompleteOptions:
118155
use_node_priority: bool = False
119156

120157

158+
def _instantiate_graph(h_graph, options: GraphCompleteOptions | None = None) -> Graph:
159+
params = driver.CUDA_GRAPH_INSTANTIATE_PARAMS()
160+
if options:
161+
flags = 0
162+
if options.auto_free_on_launch:
163+
flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH
164+
if options.upload_stream:
165+
flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_UPLOAD
166+
params.hUploadStream = options.upload_stream.handle
167+
if options.device_launch:
168+
flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_DEVICE_LAUNCH
169+
if options.use_node_priority:
170+
flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY
171+
params.flags = flags
172+
173+
graph = Graph._init(handle_return(driver.cuGraphInstantiateWithParams(h_graph, params)))
174+
if params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_ERROR:
175+
raise RuntimeError(
176+
"Instantiation failed for an unexpected reason which is described in the return value of the function."
177+
)
178+
elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_INVALID_STRUCTURE:
179+
raise RuntimeError("Instantiation failed due to invalid structure, such as cycles.")
180+
elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_NODE_OPERATION_NOT_SUPPORTED:
181+
raise RuntimeError(
182+
"Instantiation for device launch failed because the graph contained an unsupported operation."
183+
)
184+
elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_MULTIPLE_CTXS_NOT_SUPPORTED:
185+
raise RuntimeError("Instantiation for device launch failed due to the nodes belonging to different contexts.")
186+
elif (
187+
_py_major_minor >= (12, 8)
188+
and params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_CONDITIONAL_HANDLE_UNUSED
189+
):
190+
raise RuntimeError("One or more conditional handles are not associated with conditional builders.")
191+
elif params.result_out != driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_SUCCESS:
192+
raise RuntimeError(f"Graph instantiation failed with unexpected error code: {params.result_out}")
193+
return graph
194+
195+
121196
class GraphBuilder:
122197
"""Represents a graph under construction.
123198
@@ -280,53 +355,7 @@ def complete(self, options: GraphCompleteOptions | None = None) -> Graph:
280355
if not self._building_ended:
281356
raise RuntimeError("Graph has not finished building.")
282357

283-
if (_driver_ver < 12000) or (_py_major_minor < (12, 0)):
284-
flags = 0
285-
if options:
286-
if options.auto_free_on_launch:
287-
flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH
288-
if options.use_node_priority:
289-
flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY
290-
return Graph._init(handle_return(driver.cuGraphInstantiateWithFlags(self._mnff.graph, flags)))
291-
292-
params = driver.CUDA_GRAPH_INSTANTIATE_PARAMS()
293-
if options:
294-
flags = 0
295-
if options.auto_free_on_launch:
296-
flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_AUTO_FREE_ON_LAUNCH
297-
if options.upload_stream:
298-
flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_UPLOAD
299-
params.hUploadStream = options.upload_stream.handle
300-
if options.device_launch:
301-
flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_DEVICE_LAUNCH
302-
if options.use_node_priority:
303-
flags |= driver.CUgraphInstantiate_flags.CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY
304-
params.flags = flags
305-
306-
graph = Graph._init(handle_return(driver.cuGraphInstantiateWithParams(self._mnff.graph, params)))
307-
if params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_ERROR:
308-
# NOTE: Should never get here since the handle_return should have caught this case
309-
raise RuntimeError(
310-
"Instantiation failed for an unexpected reason which is described in the return value of the function."
311-
)
312-
elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_INVALID_STRUCTURE:
313-
raise RuntimeError("Instantiation failed due to invalid structure, such as cycles.")
314-
elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_NODE_OPERATION_NOT_SUPPORTED:
315-
raise RuntimeError(
316-
"Instantiation for device launch failed because the graph contained an unsupported operation."
317-
)
318-
elif params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_MULTIPLE_CTXS_NOT_SUPPORTED:
319-
raise RuntimeError(
320-
"Instantiation for device launch failed due to the nodes belonging to different contexts."
321-
)
322-
elif (
323-
_py_major_minor >= (12, 8)
324-
and params.result_out == driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_CONDITIONAL_HANDLE_UNUSED
325-
):
326-
raise RuntimeError("One or more conditional handles are not associated with conditional builders.")
327-
elif params.result_out != driver.CUgraphInstantiateResult.CUDA_GRAPH_INSTANTIATE_SUCCESS:
328-
raise RuntimeError(f"Graph instantiation failed with unexpected error code: {params.result_out}")
329-
return graph
358+
return _instantiate_graph(self._mnff.graph, options)
330359

331360
def debug_dot_print(self, path, options: GraphDebugPrintOptions | None = None):
332361
"""Generates a DOT debug file for the graph builder.
@@ -341,41 +370,7 @@ def debug_dot_print(self, path, options: GraphDebugPrintOptions | None = None):
341370
"""
342371
if not self._building_ended:
343372
raise RuntimeError("Graph has not finished building.")
344-
flags = 0
345-
if options:
346-
if options.verbose:
347-
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_VERBOSE
348-
if options.runtime_types:
349-
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_RUNTIME_TYPES
350-
if options.kernel_node_params:
351-
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_PARAMS
352-
if options.memcpy_node_params:
353-
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMCPY_NODE_PARAMS
354-
if options.memset_node_params:
355-
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEMSET_NODE_PARAMS
356-
if options.host_node_params:
357-
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HOST_NODE_PARAMS
358-
if options.event_node_params:
359-
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EVENT_NODE_PARAMS
360-
if options.ext_semas_signal_node_params:
361-
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_SIGNAL_NODE_PARAMS
362-
if options.ext_semas_wait_node_params:
363-
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXT_SEMAS_WAIT_NODE_PARAMS
364-
if options.kernel_node_attributes:
365-
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_KERNEL_NODE_ATTRIBUTES
366-
if options.handles:
367-
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_HANDLES
368-
if options.mem_alloc_node_params:
369-
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_ALLOC_NODE_PARAMS
370-
if options.mem_free_node_params:
371-
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_MEM_FREE_NODE_PARAMS
372-
if options.batch_mem_op_node_params:
373-
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_BATCH_MEM_OP_NODE_PARAMS
374-
if options.extra_topo_info:
375-
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_EXTRA_TOPO_INFO
376-
if options.conditional_node_params:
377-
flags |= driver.CUgraphDebugDot_flags.CU_GRAPH_DEBUG_DOT_FLAGS_CONDITIONAL_NODE_PARAMS
378-
373+
flags = options._to_flags() if options else 0
379374
handle_return(driver.cuGraphDebugDotPrint(self._mnff.graph, path, flags))
380375

381376
def split(self, count: int) -> tuple[GraphBuilder, ...]:

0 commit comments

Comments
 (0)