Skip to content
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
dc92437
Reorganize graph test files for clarity
Andy-Jost Mar 31, 2026
281ed82
Enhance Graph.update() and add whole-graph update tests
Andy-Jost Mar 31, 2026
7854b76
Add AdjacencySet proxy for pred/succ and GraphNode.remove()
Andy-Jost Mar 31, 2026
5fbd288
Add edge mutation support and MutableSet interface for GraphNode adja…
Andy-Jost Apr 2, 2026
aa84e26
Use requires_module mark for numpy version checks in mutation tests
Andy-Jost Apr 2, 2026
b27dd93
Fix empty-graph return type: return set() instead of () for nodes/edges
Andy-Jost Apr 2, 2026
8554d30
Rename AdjacencySet to AdjacencySetProxy, add bulk ops and safety guards
Andy-Jost Apr 2, 2026
9813c20
Add destroy() method with handle invalidation, remove GRAPH_NODE_SENT…
Andy-Jost Apr 2, 2026
6411881
Add GraphNode identity cache for stable Python object round-trips
Andy-Jost Apr 2, 2026
7a3dbb4
Purge node cache on destroy to prevent stale identity lookups
Andy-Jost Apr 2, 2026
91b3b4e
Skip NULL nodes in graph_node_registry to fix sentinel identity colli…
Andy-Jost Apr 2, 2026
1b7743d
Unregister destroyed nodes from C++ graph_node_registry
Andy-Jost Apr 3, 2026
84f0b30
Add dedicated test for node identity preservation through round-trips
Andy-Jost Apr 3, 2026
64d6c2d
Merge branch 'main' into graph-node-identity
Andy-Jost Apr 3, 2026
6b36e47
Add handle= to all GraphNode subclass __repr__ for debugging
Andy-Jost Apr 3, 2026
a40be9a
Merge branch 'main' into graph-node-identity
Andy-Jost Apr 3, 2026
729af49
Rename _node_cache/_cached to _node_registry/_registered
Andy-Jost Apr 3, 2026
42131b6
Fix unregister_handle and rename invalidate_graph_node_handle
Andy-Jost Apr 3, 2026
9766e54
Merge branch 'graph-node-identity' into graph-node-repr
Andy-Jost Apr 3, 2026
15d0036
Add cheap containment test and early type check for AdjacencySetProxy
Andy-Jost Apr 3, 2026
347693f
Add GraphDef.empty(), stack-buffer query optimization, and registry test
Andy-Jost Apr 3, 2026
641a089
Document the two-level handle and object registry design
Andy-Jost Apr 3, 2026
8370687
Fix import formatting in test_registry_cleanup
Andy-Jost Apr 3, 2026
36527da
Merge origin/main into graph-node-repr
Andy-Jost Apr 3, 2026
f779f30
Optimize GraphDef.nodes() and edges() to try a single driver call
Andy-Jost Apr 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions cuda_core/cuda/core/_cpp/REGISTRY_DESIGN.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Handle and Object Registries

When Python-managed objects round-trip through the CUDA driver (e.g.,
querying a graph's nodes and getting back raw `CUgraphNode` pointers),
we need to recover the original Python object rather than creating a
duplicate.

This document describes the approach used to achieve this. The pattern
is driven mainly by needs arising in the context of CUDA graphs, but
it is general and can be extended to other object types as needs arise.

This solves the same problem as pybind11's `registered_instances` map
and is sometimes called the Identity Map pattern. Two registries work
together to map a raw driver handle all the way back to the original
Python object. Both use weak references so they
do not prevent cleanup. Entries are removed either explicitly (via
`destroy()` or a Box destructor) or implicitly when the weak reference
expires.

## Level 1: Driver Handle -> Resource Handle (C++)

`HandleRegistry` in `resource_handles.cpp` maps a raw CUDA handle
(e.g., `CUevent`, `CUkernel`, `CUgraphNode`) to the `weak_ptr` that
owns it. When a `_ref` constructor receives a raw handle, it
checks the registry first. If found, it returns the existing
`shared_ptr`, preserving the Box and its metadata (e.g., `EventBox`
carries timing/IPC flags, `KernelBox` carries the library dependency).

Without this level, a round-tripped handle would produce a new Box
with default metadata, losing information that was set at creation.

Instances: `event_registry`, `kernel_registry`, `graph_node_registry`.

## Level 2: Resource Handle -> Python Object (Cython)

`_node_registry` in `_graph_node.pyx` is a `WeakValueDictionary`
mapping a resource address (`shared_ptr::get()`) to a Python
`GraphNode` object. When `GraphNode._create` receives a handle from
Level 1, it checks this registry. If found, it returns the existing
Python object.

Without this level, each driver round-trip would produce a distinct
Python object for the same logical node, resulting in surprising
behavior:

```python
a = g.empty()
a.succ = {b}
b2, = a.succ # queries driver, gets back CUgraphNode for b
assert b2 is b # fails without Level 2 registry
```
3 changes: 3 additions & 0 deletions cuda_core/cuda/core/_cpp/resource_handles.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ ContextHandle get_event_context(const EventHandle& h) noexcept {
return h ? get_box(h)->h_context : ContextHandle{};
}

// See REGISTRY_DESIGN.md (Level 1: Driver Handle -> Resource Handle)
static HandleRegistry<CUevent, EventHandle> event_registry;

EventHandle create_event_handle(const ContextHandle& h_ctx, unsigned int flags,
Expand Down Expand Up @@ -894,6 +895,7 @@ static const KernelBox* get_box(const KernelHandle& h) {
);
}

// See REGISTRY_DESIGN.md (Level 1: Driver Handle -> Resource Handle)
static HandleRegistry<CUkernel, KernelHandle> kernel_registry;

KernelHandle create_kernel_handle(const LibraryHandle& h_library, const char* name) {
Expand Down Expand Up @@ -964,6 +966,7 @@ static const GraphNodeBox* get_box(const GraphNodeHandle& h) {
);
}

// See REGISTRY_DESIGN.md (Level 1: Driver Handle -> Resource Handle)
static HandleRegistry<CUgraphNode, GraphNodeHandle> graph_node_registry;

GraphNodeHandle create_graph_node_handle(CUgraphNode node, const GraphHandle& h_graph) {
Expand Down
52 changes: 42 additions & 10 deletions cuda_core/cuda/core/_graph/_graph_def/_adjacency_set_proxy.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class AdjacencySetProxy(MutableSet):
def __contains__(self, x):
if not isinstance(x, GraphNode):
return False
return x in (<_AdjacencySetCore>self._core).query()
return (<_AdjacencySetCore>self._core).contains(<GraphNode>x)

def __iter__(self):
return iter((<_AdjacencySetCore>self._core).query())
Expand Down Expand Up @@ -87,13 +87,13 @@ class AdjacencySetProxy(MutableSet):
if isinstance(other, GraphNode):
nodes.append(other)
else:
nodes.extend(other)
for n in other:
if not isinstance(n, GraphNode):
raise TypeError(
f"expected GraphNode, got {type(n).__name__}")
nodes.append(n)
if not nodes:
return
for n in nodes:
if not isinstance(n, GraphNode):
raise TypeError(
f"expected GraphNode, got {type(n).__name__}")
new = [n for n in nodes if n not in self]
if new:
(<_AdjacencySetCore>self._core).add_edges(new)
Expand Down Expand Up @@ -143,11 +143,14 @@ cdef class _AdjacencySetCore:
cdef cydriver.CUgraphNode c_node = as_cu(self._h_node)
if c_node == NULL:
return []
cdef size_t count = 0
cdef cydriver.CUgraphNode buf[16]
cdef size_t count = 16
cdef size_t i
with nogil:
HANDLE_RETURN(self._query_fn(c_node, NULL, &count))
if count == 0:
return []
HANDLE_RETURN(self._query_fn(c_node, buf, &count))
if count <= 16:
return [GraphNode._create(self._h_graph, buf[i])
for i in range(count)]
cdef vector[cydriver.CUgraphNode] nodes_vec
nodes_vec.resize(count)
with nogil:
Expand All @@ -156,6 +159,35 @@ cdef class _AdjacencySetCore:
return [GraphNode._create(self._h_graph, nodes_vec[i])
for i in range(count)]
Comment on lines 142 to 160
Copy link
Copy Markdown
Contributor Author

@Andy-Jost Andy-Jost Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't able to return a generator as suggested here because nodes_vec is a stack object and it would lead to a use-after-free unless we define a cdef class for the iterator. I added an optimized contains method that avoids reconstructing Python objects for containment checks.


cdef bint contains(self, GraphNode other):
cdef cydriver.CUgraphNode c_node = as_cu(self._h_node)
cdef cydriver.CUgraphNode target = as_cu(other._h_node)
if c_node == NULL or target == NULL:
return False
cdef cydriver.CUgraphNode buf[16]
cdef size_t count = 16
cdef size_t i
with nogil:
HANDLE_RETURN(self._query_fn(c_node, buf, &count))

# Fast path for small sets.
if count <= 16:
for i in range(count):
if buf[i] == target:
return True
return False

# Fallback for large sets.
cdef vector[cydriver.CUgraphNode] nodes_vec
nodes_vec.resize(count)
with nogil:
HANDLE_RETURN(self._query_fn(c_node, nodes_vec.data(), &count))
assert count == nodes_vec.size()
for i in range(count):
if nodes_vec[i] == target:
return True
return False

cdef Py_ssize_t count(self):
cdef cydriver.CUgraphNode c_node = as_cu(self._h_node)
if c_node == NULL:
Expand Down
10 changes: 10 additions & 0 deletions cuda_core/cuda/core/_graph/_graph_def/_graph_def.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,16 @@ cdef class GraphDef:
"""
return self._entry.launch(config, kernel, *args)

def empty(self) -> "EmptyNode":
"""Add an entry-point empty node (no dependencies).

Returns
-------
EmptyNode
A new EmptyNode with no dependencies.
"""
return self._entry.join()

def join(self, *nodes) -> "EmptyNode":
"""Create an empty node that depends on all given nodes.

Expand Down
1 change: 1 addition & 0 deletions cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ from cuda.core import Device
from cuda.core._graph._graph_def._adjacency_set_proxy import AdjacencySetProxy
from cuda.core._utils.cuda_utils import driver, handle_return

# See _cpp/REGISTRY_DESIGN.md (Level 2: Resource Handle -> Python Object)
_node_registry = weakref.WeakValueDictionary()


Expand Down
55 changes: 29 additions & 26 deletions cuda_core/cuda/core/_graph/_graph_def/_subclasses.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ cdef class EmptyNode(GraphNode):
return n

def __repr__(self) -> str:
cdef Py_ssize_t n = len(self.pred)
return f"<EmptyNode with {n} {'pred' if n == 1 else 'preds'}>"
return f"<EmptyNode handle=0x{as_intptr(self._h_node):x}>"


cdef class KernelNode(GraphNode):
Expand Down Expand Up @@ -108,7 +107,8 @@ cdef class KernelNode(GraphNode):
h_kernel)

def __repr__(self) -> str:
return (f"<KernelNode grid={self._grid} block={self._block}>")
return (f"<KernelNode handle=0x{as_intptr(self._h_node):x}"
f" kernel=0x{as_intptr(self._h_kernel):x}>")

@property
def grid(self) -> tuple:
Expand Down Expand Up @@ -207,7 +207,8 @@ cdef class AllocNode(GraphNode):
<int>params.poolProps.location.id, memory_type, tuple(peer_ids))

def __repr__(self) -> str:
return f"<AllocNode dptr=0x{self._dptr:x} size={self._bytesize}>"
return (f"<AllocNode handle=0x{as_intptr(self._h_node):x}"
f" dptr=0x{self._dptr:x} size={self._bytesize}>")

@property
def dptr(self) -> int:
Expand Down Expand Up @@ -273,7 +274,7 @@ cdef class FreeNode(GraphNode):
return FreeNode._create_with_params(h_node, dptr)

def __repr__(self) -> str:
return f"<FreeNode dptr=0x{self._dptr:x}>"
return f"<FreeNode handle=0x{as_intptr(self._h_node):x} dptr=0x{self._dptr:x}>"

@property
def dptr(self) -> int:
Expand Down Expand Up @@ -328,8 +329,8 @@ cdef class MemsetNode(GraphNode):
params.elementSize, params.width, params.height, params.pitch)

def __repr__(self) -> str:
return (f"<MemsetNode dptr=0x{self._dptr:x} "
f"value={self._value} elem={self._element_size}>")
return (f"<MemsetNode handle=0x{as_intptr(self._h_node):x}"
f" dptr=0x{self._dptr:x} value={self._value}>")

@property
def dptr(self) -> int:
Expand Down Expand Up @@ -416,8 +417,8 @@ cdef class MemcpyNode(GraphNode):
def __repr__(self) -> str:
cdef str dt = "H" if self._dst_type == cydriver.CU_MEMORYTYPE_HOST else "D"
cdef str st = "H" if self._src_type == cydriver.CU_MEMORYTYPE_HOST else "D"
return (f"<MemcpyNode dst=0x{self._dst:x}({dt}) "
f"src=0x{self._src:x}({st}) size={self._size}>")
return (f"<MemcpyNode handle=0x{as_intptr(self._h_node):x}"
f" dst=0x{self._dst:x}({dt}) src=0x{self._src:x}({st}) size={self._size}>")

@property
def dst(self) -> int:
Expand Down Expand Up @@ -465,12 +466,8 @@ cdef class ChildGraphNode(GraphNode):
return ChildGraphNode._create_with_params(h_node, h_child)

def __repr__(self) -> str:
cdef cydriver.CUgraph g = as_cu(self._h_child_graph)
cdef size_t num_nodes = 0
with nogil:
HANDLE_RETURN(cydriver.cuGraphGetNodes(g, NULL, &num_nodes))
cdef Py_ssize_t n = <Py_ssize_t>num_nodes
return f"<ChildGraphNode with {n} {'subnode' if n == 1 else 'subnodes'}>"
return (f"<ChildGraphNode handle=0x{as_intptr(self._h_node):x}"
f" child=0x{as_intptr(self._h_child_graph):x}>")

@property
def child_graph(self) -> "GraphDef":
Expand Down Expand Up @@ -507,7 +504,8 @@ cdef class EventRecordNode(GraphNode):
return EventRecordNode._create_with_params(h_node, h_event)

def __repr__(self) -> str:
return f"<EventRecordNode event=0x{as_intptr(self._h_event):x}>"
return (f"<EventRecordNode handle=0x{as_intptr(self._h_node):x}"
f" event=0x{as_intptr(self._h_event):x}>")

@property
def event(self) -> Event:
Expand Down Expand Up @@ -544,7 +542,8 @@ cdef class EventWaitNode(GraphNode):
return EventWaitNode._create_with_params(h_node, h_event)

def __repr__(self) -> str:
return f"<EventWaitNode event=0x{as_intptr(self._h_event):x}>"
return (f"<EventWaitNode handle=0x{as_intptr(self._h_node):x}"
f" event=0x{as_intptr(self._h_event):x}>")

@property
def event(self) -> Event:
Expand Down Expand Up @@ -591,8 +590,10 @@ cdef class HostCallbackNode(GraphNode):
def __repr__(self) -> str:
if self._callable is not None:
name = getattr(self._callable, '__name__', '?')
return f"<HostCallbackNode callback={name}>"
return f"<HostCallbackNode cfunc=0x{<uintptr_t>self._fn:x}>"
return (f"<HostCallbackNode handle=0x{as_intptr(self._h_node):x}"
f" callback={name}>")
return (f"<HostCallbackNode handle=0x{as_intptr(self._h_node):x}"
f" cfunc=0x{<uintptr_t>self._fn:x}>")

@property
def callback_fn(self):
Expand Down Expand Up @@ -672,7 +673,7 @@ cdef class ConditionalNode(GraphNode):
return n

def __repr__(self) -> str:
return "<ConditionalNode>"
return f"<ConditionalNode handle=0x{as_intptr(self._h_node):x}>"

@property
def condition(self) -> Condition | None:
Expand Down Expand Up @@ -709,7 +710,8 @@ cdef class IfNode(ConditionalNode):
"""An if-conditional node (1 branch, executes when condition is non-zero)."""

def __repr__(self) -> str:
return f"<IfNode condition=0x{<unsigned long long>self._condition._c_handle:x}>"
return (f"<IfNode handle=0x{as_intptr(self._h_node):x}"
f" condition=0x{<unsigned long long>self._condition._c_handle:x}>")

@property
def then(self) -> "GraphDef":
Expand All @@ -721,7 +723,8 @@ cdef class IfElseNode(ConditionalNode):
"""An if-else conditional node (2 branches)."""

def __repr__(self) -> str:
return f"<IfElseNode condition=0x{<unsigned long long>self._condition._c_handle:x}>"
return (f"<IfElseNode handle=0x{as_intptr(self._h_node):x}"
f" condition=0x{<unsigned long long>self._condition._c_handle:x}>")

@property
def then(self) -> "GraphDef":
Expand All @@ -738,7 +741,8 @@ cdef class WhileNode(ConditionalNode):
"""A while-loop conditional node (1 branch, repeats while condition is non-zero)."""

def __repr__(self) -> str:
return f"<WhileNode condition=0x{<unsigned long long>self._condition._c_handle:x}>"
return (f"<WhileNode handle=0x{as_intptr(self._h_node):x}"
f" condition=0x{<unsigned long long>self._condition._c_handle:x}>")

@property
def body(self) -> "GraphDef":
Expand All @@ -750,6 +754,5 @@ cdef class SwitchNode(ConditionalNode):
"""A switch conditional node (N branches, selected by condition value)."""

def __repr__(self) -> str:
cdef Py_ssize_t n = len(self._branches)
return (f"<SwitchNode condition=0x{<unsigned long long>self._condition._c_handle:x}"
f" with {n} {'branch' if n == 1 else 'branches'}>")
return (f"<SwitchNode handle=0x{as_intptr(self._h_node):x}"
f" condition=0x{<unsigned long long>self._condition._c_handle:x}>")
Loading
Loading