Skip to content

Commit dd9108e

Browse files
committed
Revert "Publish the graph API as cuda.core.graph (#1858)"
This reverts commit c6aea12.
1 parent d393729 commit dd9108e

31 files changed

+184
-238
lines changed

cuda_core/cuda/core/__init__.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

@@ -31,6 +31,12 @@
3131
from cuda.core import system, utils
3232
from cuda.core._device import Device
3333
from cuda.core._event import Event, EventOptions
34+
from cuda.core._graph import (
35+
Graph,
36+
GraphBuilder,
37+
GraphCompleteOptions,
38+
GraphDebugPrintOptions,
39+
)
3440
from cuda.core._graphics import GraphicsResource
3541
from cuda.core._launch_config import LaunchConfig
3642
from cuda.core._launcher import launch
@@ -63,12 +69,3 @@
6369
StreamOptions,
6470
)
6571
from cuda.core._tensor_map import TensorMapDescriptor, TensorMapDescriptorOptions
66-
from cuda.core.graph import (
67-
Condition,
68-
Graph,
69-
GraphAllocOptions,
70-
GraphBuilder,
71-
GraphCompleteOptions,
72-
GraphDebugPrintOptions,
73-
GraphDef,
74-
)

cuda_core/cuda/core/_device.pyx

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

@@ -23,6 +23,7 @@ from cuda.core._resource_handles cimport (
2323
as_cu,
2424
)
2525

26+
from cuda.core._graph import GraphBuilder
2627
from cuda.core._stream import IsStreamT, Stream, StreamOptions
2728
from cuda.core._utils.clear_error_support import assert_type
2829
from cuda.core._utils.cuda_utils import (
@@ -1361,17 +1362,15 @@ class Device:
13611362
self._check_context_initialized()
13621363
handle_return(runtime.cudaDeviceSynchronize())
13631364

1364-
def create_graph_builder(self) -> "GraphBuilder":
1365-
"""Create a new :obj:`~graph.GraphBuilder` object.
1365+
def create_graph_builder(self) -> GraphBuilder:
1366+
"""Create a new :obj:`~_graph.GraphBuilder` object.
13661367

13671368
Returns
13681369
-------
1369-
:obj:`~graph.GraphBuilder`
1370+
:obj:`~_graph.GraphBuilder`
13701371
Newly created graph builder object.
13711372

13721373
"""
1373-
from cuda.core.graph._graph_builder import GraphBuilder
1374-
13751374
self._check_context_initialized()
13761375
return GraphBuilder._init(stream=self.create_stream(), is_stream_owner=True)
13771376

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from cuda.core._graph._graph_builder import (
6+
Graph,
7+
GraphBuilder,
8+
GraphCompleteOptions,
9+
GraphDebugPrintOptions,
10+
_instantiate_graph,
11+
)
12+
13+
__all__ = [
14+
"Graph",
15+
"GraphBuilder",
16+
"GraphCompleteOptions",
17+
"GraphDebugPrintOptions",
18+
"_instantiate_graph",
19+
]

cuda_core/cuda/core/graph/_graph_builder.pyx renamed to cuda_core/cuda/core/_graph/_graph_builder.pyx

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ from libc.stdint cimport intptr_t
99

1010
from cuda.bindings cimport cydriver
1111

12-
from cuda.core.graph._utils cimport _attach_host_callback_to_graph
12+
from cuda.core._graph._utils cimport _attach_host_callback_to_graph
1313
from cuda.core._resource_handles cimport as_cu
1414
from cuda.core._stream cimport Stream
1515
from cuda.core._utils.cuda_utils cimport HANDLE_RETURN
@@ -23,7 +23,7 @@ from cuda.core._utils.cuda_utils import (
2323

2424
@dataclass
2525
class GraphDebugPrintOptions:
26-
"""Options for debug_dot_print().
26+
"""Customizable options for :obj:`_graph.GraphBuilder.debug_dot_print()`
2727
2828
Attributes
2929
----------
@@ -119,7 +119,7 @@ class GraphDebugPrintOptions:
119119

120120
@dataclass
121121
class GraphCompleteOptions:
122-
"""Options for graph instantiation.
122+
"""Customizable options for :obj:`_graph.GraphBuilder.complete()`
123123

124124
Attributes
125125
----------
@@ -182,13 +182,13 @@ def _instantiate_graph(h_graph, options: GraphCompleteOptions | None = None) ->
182182

183183

184184
class GraphBuilder:
185-
"""A graph under construction by stream capture.
185+
"""Represents a graph under construction.
186186
187187
A graph groups a set of CUDA kernels and other CUDA operations together and executes
188188
them with a specified dependency tree. It speeds up the workflow by combining the
189189
driver activities associated with CUDA kernel launches and CUDA API calls.
190190
191-
Directly creating a :obj:`~graph.GraphBuilder` is not supported due
191+
Directly creating a :obj:`~_graph.GraphBuilder` is not supported due
192192
to ambiguity. New graph builders should instead be created through a
193193
:obj:`~_device.Device`, or a :obj:`~_stream.stream` object.
194194
@@ -326,16 +326,16 @@ class GraphBuilder:
326326
return self
327327

328328
def complete(self, options: GraphCompleteOptions | None = None) -> "Graph":
329-
"""Completes the graph builder and returns the built :obj:`~graph.Graph` object.
329+
"""Completes the graph builder and returns the built :obj:`~_graph.Graph` object.
330330

331331
Parameters
332332
----------
333-
options : :obj:`~graph.GraphCompleteOptions`, optional
333+
options : :obj:`~_graph.GraphCompleteOptions`, optional
334334
Customizable dataclass for the graph builder completion options.
335335

336336
Returns
337337
-------
338-
graph : :obj:`~graph.Graph`
338+
graph : :obj:`~_graph.Graph`
339339
The newly built graph.
340340

341341
"""
@@ -351,7 +351,7 @@ class GraphBuilder:
351351
----------
352352
path : str
353353
File path to use for writting debug DOT output
354-
options : :obj:`~graph.GraphDebugPrintOptions`, optional
354+
options : :obj:`~_graph.GraphDebugPrintOptions`, optional
355355
Customizable dataclass for the debug print options.
356356
357357
"""
@@ -373,7 +373,7 @@ class GraphBuilder:
373373

374374
Returns
375375
-------
376-
graph_builders : tuple[:obj:`~graph.GraphBuilder`, ...]
376+
graph_builders : tuple[:obj:`~_graph.GraphBuilder`, ...]
377377
A tuple of split graph builders. The first graph builder in the tuple
378378
is always the original graph builder.
379379

@@ -400,12 +400,12 @@ class GraphBuilder:
400400

401401
Parameters
402402
----------
403-
*graph_builders : :obj:`~graph.GraphBuilder`
403+
*graph_builders : :obj:`~_graph.GraphBuilder`
404404
The graph builders to join.
405405

406406
Returns
407407
-------
408-
graph_builder : :obj:`~graph.GraphBuilder`
408+
graph_builder : :obj:`~_graph.GraphBuilder`
409409
The newly joined graph builder.
410410

411411
"""
@@ -521,7 +521,7 @@ class GraphBuilder:
521521

522522
Returns
523523
-------
524-
graph_builder : :obj:`~graph.GraphBuilder`
524+
graph_builder : :obj:`~_graph.GraphBuilder`
525525
The newly created conditional graph builder.
526526

527527
"""
@@ -552,7 +552,7 @@ class GraphBuilder:
552552

553553
Returns
554554
-------
555-
graph_builders : tuple[:obj:`~graph.GraphBuilder`, :obj:`~graph.GraphBuilder`]
555+
graph_builders : tuple[:obj:`~_graph.GraphBuilder`, :obj:`~_graph.GraphBuilder`]
556556
A tuple of two new graph builders, one for the if branch and one for the else branch.
557557

558558
"""
@@ -586,7 +586,7 @@ class GraphBuilder:
586586

587587
Returns
588588
-------
589-
graph_builders : tuple[:obj:`~graph.GraphBuilder`, ...]
589+
graph_builders : tuple[:obj:`~_graph.GraphBuilder`, ...]
590590
A tuple of new graph builders, one for each branch.
591591

592592
"""
@@ -617,7 +617,7 @@ class GraphBuilder:
617617

618618
Returns
619619
-------
620-
graph_builder : :obj:`~graph.GraphBuilder`
620+
graph_builder : :obj:`~_graph.GraphBuilder`
621621
The newly created while loop graph builder.
622622

623623
"""
@@ -643,13 +643,13 @@ class GraphBuilder:
643643
self._mnff.close()
644644

645645
def add_child(self, child_graph: GraphBuilder):
646-
"""Adds the child :obj:`~graph.GraphBuilder` builder into self.
646+
"""Adds the child :obj:`~_graph.GraphBuilder` builder into self.
647647
648648
The child graph builder will be added as a child node to the parent graph builder.
649649
650650
Parameters
651651
----------
652-
child_graph : :obj:`~graph.GraphBuilder`
652+
child_graph : :obj:`~_graph.GraphBuilder`
653653
The child graph builder. Must have finished building.
654654
"""
655655
if not child_graph._building_ended:
@@ -737,13 +737,13 @@ class GraphBuilder:
737737

738738

739739
class Graph:
740-
"""An executable graph.
740+
"""Represents an executable graph.
741741
742742
A graph groups a set of CUDA kernels and other CUDA operations together and executes
743743
them with a specified dependency tree. It speeds up the workflow by combining the
744744
driver activities associated with CUDA kernel launches and CUDA API calls.
745745
746-
Graphs must be built using a :obj:`~graph.GraphBuilder` object.
746+
Graphs must be built using a :obj:`~_graph.GraphBuilder` object.
747747
748748
"""
749749

@@ -793,12 +793,12 @@ class Graph:
793793

794794
Parameters
795795
----------
796-
source : :obj:`~graph.GraphBuilder` or :obj:`~graph.GraphDef`
796+
source : :obj:`~_graph.GraphBuilder` or :obj:`~_graph._graph_def.GraphDef`
797797
The graph definition to update from. A GraphBuilder must have
798798
finished building.
799799

800800
"""
801-
from cuda.core.graph import GraphDef
801+
from cuda.core._graph._graph_def import GraphDef
802802

803803
cdef cydriver.CUgraph cu_graph
804804
cdef cydriver.CUgraphExec cu_exec = <cydriver.CUgraphExec><intptr_t>int(self._mnff.graph)

cuda_core/cuda/core/graph/__init__.pxd renamed to cuda_core/cuda/core/_graph/_graph_def/__init__.pxd

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
from cuda.core.graph._graph_def cimport Condition, GraphDef
6-
from cuda.core.graph._graph_node cimport GraphNode
7-
from cuda.core.graph._subclasses cimport (
5+
from cuda.core._graph._graph_def._graph_def cimport Condition, GraphDef
6+
from cuda.core._graph._graph_def._graph_node cimport GraphNode
7+
from cuda.core._graph._graph_def._subclasses cimport (
88
AllocNode,
99
ChildGraphNode,
1010
ConditionalNode,
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""Explicit CUDA graph construction — GraphDef, GraphNode, and node subclasses."""
6+
7+
from cuda.core._graph._graph_def._graph_def import (
8+
Condition,
9+
GraphAllocOptions,
10+
GraphDef,
11+
)
12+
from cuda.core._graph._graph_def._graph_node import GraphNode
13+
from cuda.core._graph._graph_def._subclasses import (
14+
AllocNode,
15+
ChildGraphNode,
16+
ConditionalNode,
17+
EmptyNode,
18+
EventRecordNode,
19+
EventWaitNode,
20+
FreeNode,
21+
HostCallbackNode,
22+
IfElseNode,
23+
IfNode,
24+
KernelNode,
25+
MemcpyNode,
26+
MemsetNode,
27+
SwitchNode,
28+
WhileNode,
29+
)
30+
31+
__all__ = [
32+
"AllocNode",
33+
"ChildGraphNode",
34+
"Condition",
35+
"ConditionalNode",
36+
"EmptyNode",
37+
"EventRecordNode",
38+
"EventWaitNode",
39+
"FreeNode",
40+
"GraphAllocOptions",
41+
"GraphDef",
42+
"GraphNode",
43+
"HostCallbackNode",
44+
"IfElseNode",
45+
"IfNode",
46+
"KernelNode",
47+
"MemcpyNode",
48+
"MemsetNode",
49+
"SwitchNode",
50+
"WhileNode",
51+
]

cuda_core/cuda/core/graph/_adjacency_set_proxy.pyx renamed to cuda_core/cuda/core/_graph/_graph_def/_adjacency_set_proxy.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from libc.stddef cimport size_t
88
from libcpp.vector cimport vector
99
from cuda.bindings cimport cydriver
10-
from cuda.core.graph._graph_node cimport GraphNode
10+
from cuda.core._graph._graph_def._graph_node cimport GraphNode
1111
from cuda.core._resource_handles cimport (
1212
GraphHandle,
1313
GraphNodeHandle,
File renamed without changes.

cuda_core/cuda/core/graph/_graph_def.pyx renamed to cuda_core/cuda/core/_graph/_graph_def/_graph_def.pyx

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ from libcpp.vector cimport vector
1212

1313
from cuda.bindings cimport cydriver
1414

15-
from cuda.core.graph._graph_node cimport GraphNode
15+
from cuda.core._graph._graph_def._graph_node cimport GraphNode
1616
from cuda.core._resource_handles cimport (
1717
GraphHandle,
1818
as_cu,
@@ -29,7 +29,7 @@ from cuda.core._utils.cuda_utils import driver
2929

3030

3131
cdef class Condition:
32-
"""A condition variable for conditional graph nodes.
32+
"""Wraps a CUgraphConditionalHandle.
3333
3434
Created by :meth:`GraphDef.create_condition` and passed to
3535
conditional-node builder methods (``if_cond``, ``if_else``,
@@ -91,7 +91,7 @@ class GraphAllocOptions:
9191
9292
9393
cdef class GraphDef:
94-
"""A graph definition.
94+
"""Represents a CUDA graph definition (CUgraph).
9595

9696
A GraphDef is used to construct a graph explicitly by adding nodes
9797
and specifying dependencies. Once construction is complete, call
@@ -287,15 +287,15 @@ cdef class GraphDef:
287287

288288
Parameters
289289
----------
290-
options : :obj:`~graph.GraphCompleteOptions`, optional
290+
options : :obj:`~_graph.GraphCompleteOptions`, optional
291291
Customizable dataclass for graph instantiation options.
292292

293293
Returns
294294
-------
295295
Graph
296296
An executable graph that can be launched on a stream.
297297
"""
298-
from cuda.core.graph._graph_builder import _instantiate_graph
298+
from cuda.core._graph._graph_builder import _instantiate_graph
299299
300300
return _instantiate_graph(
301301
driver.CUgraph(as_intptr(self._h_graph)), options)
@@ -310,7 +310,7 @@ cdef class GraphDef:
310310
options : GraphDebugPrintOptions, optional
311311
Customizable options for the debug print.
312312
"""
313-
from cuda.core.graph._graph_builder import GraphDebugPrintOptions
313+
from cuda.core._graph._graph_builder import GraphDebugPrintOptions
314314
315315
cdef unsigned int flags = 0
316316
if options is not None:
File renamed without changes.

0 commit comments

Comments
 (0)