diff --git a/cuda_core/cuda/core/graph/__init__.py b/cuda_core/cuda/core/graph/__init__.py index 57a6988d861..3f810986282 100644 --- a/cuda_core/cuda/core/graph/__init__.py +++ b/cuda_core/cuda/core/graph/__init__.py @@ -2,32 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from cuda.core.graph._graph_builder import ( - Graph, - GraphBuilder, - GraphCompleteOptions, - GraphDebugPrintOptions, -) -from cuda.core.graph._graph_definition import ( - GraphAllocOptions, - GraphCondition, - GraphDefinition, -) -from cuda.core.graph._graph_node import GraphNode -from cuda.core.graph._subclasses import ( - AllocNode, - ChildGraphNode, - ConditionalNode, - EmptyNode, - EventRecordNode, - EventWaitNode, - FreeNode, - HostCallbackNode, - IfElseNode, - IfNode, - KernelNode, - MemcpyNode, - MemsetNode, - SwitchNode, - WhileNode, -) +from ._graph_builder import * +from ._graph_def import * +from ._graph_node import * +from ._subclasses import * diff --git a/cuda_core/cuda/core/graph/_graph_builder.pyx b/cuda_core/cuda/core/graph/_graph_builder.pyx index 42370029706..5b304a24d38 100644 --- a/cuda_core/cuda/core/graph/_graph_builder.pyx +++ b/cuda_core/cuda/core/graph/_graph_builder.pyx @@ -21,6 +21,9 @@ from cuda.core._utils.cuda_utils import ( handle_return, ) +__all__ = ['Graph', 'GraphBuilder', 'GraphCompleteOptions', 'GraphDebugPrintOptions'] + + @dataclass class GraphDebugPrintOptions: """Options for debug_dot_print(). diff --git a/cuda_core/cuda/core/graph/_graph_definition.pyx b/cuda_core/cuda/core/graph/_graph_definition.pyx index 195a1e300b0..a185d39685a 100644 --- a/cuda_core/cuda/core/graph/_graph_definition.pyx +++ b/cuda_core/cuda/core/graph/_graph_definition.pyx @@ -27,6 +27,8 @@ from dataclasses import dataclass from cuda.core._utils.cuda_utils import driver +__all__ = ['GraphCondition', 'GraphAllocOptions', 'GraphDef'] + cdef class GraphCondition: """A condition variable for conditional graph nodes. diff --git a/cuda_core/cuda/core/graph/_graph_node.pyx b/cuda_core/cuda/core/graph/_graph_node.pyx index 553304885be..bd10bfa007f 100644 --- a/cuda_core/cuda/core/graph/_graph_node.pyx +++ b/cuda_core/cuda/core/graph/_graph_node.pyx @@ -61,6 +61,8 @@ import weakref from cuda.core.graph._adjacency_set_proxy import AdjacencySetProxy from cuda.core._utils.cuda_utils import driver +__all__ = ['GraphNode'] + # See _cpp/REGISTRY_DESIGN.md (Level 2: Resource Handle -> Python Object) _node_registry = weakref.WeakValueDictionary() diff --git a/cuda_core/cuda/core/graph/_subclasses.pyx b/cuda_core/cuda/core/graph/_subclasses.pyx index ef08bb30856..86cf9eea53e 100644 --- a/cuda_core/cuda/core/graph/_subclasses.pyx +++ b/cuda_core/cuda/core/graph/_subclasses.pyx @@ -34,6 +34,24 @@ from cuda.core.graph._utils cimport _is_py_host_trampoline from cuda.core._utils.cuda_utils import driver, handle_return +__all__ = [ + 'AllocNode', + 'ChildGraphNode', + 'ConditionalNode', + 'EmptyNode', + 'EventRecordNode', + 'EventWaitNode', + 'FreeNode', + 'HostCallbackNode', + 'IfElseNode', + 'IfNode', + 'KernelNode', + 'MemcpyNode', + 'MemsetNode', + 'SwitchNode', + 'WhileNode', +] + cdef bint _has_cuGraphNodeGetParams = False cdef bint _version_checked = False