Skip to content

Commit 5064470

Browse files
authored
Add GraphNode identity cache for stable object round-trips (#1853)
* Reorganize graph test files for clarity Rename test files to reflect what they actually test: - test_basic -> test_graph_builder (stream capture tests) - test_conditional -> test_graph_builder_conditional - test_advanced -> test_graph_update (moved child_graph and stream_lifetime tests into test_graph_builder) - test_capture_alloc -> test_graph_memory_resource - test_explicit* -> test_graphdef* Made-with: Cursor * Enhance Graph.update() and add whole-graph update tests - Extend Graph.update() to accept both GraphBuilder and GraphDef sources - Surface CUgraphExecUpdateResultInfo details on update failure instead of a generic CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE message - Release the GIL during cuGraphExecUpdate via nogil block - Add parametrized happy-path test covering both GraphBuilder and GraphDef - Add error-case tests: unfinished builder, topology mismatch, wrong type Made-with: Cursor * Add AdjacencySet proxy for pred/succ and GraphNode.remove() Replace cached tuple-based pred/succ with mutable AdjacencySet backed by direct CUDA driver calls. Add GraphNode.remove() wrapping cuGraphDestroyNode. Made-with: Cursor * Add edge mutation support and MutableSet interface for GraphNode adjacencies Enable adding/removing edges between graph nodes via AdjacencySet (a MutableSet proxy on GraphNode.pred/succ), node removal via discard(), and property setters for bulk edge replacement. Includes comprehensive mutation and interface tests. Closes part of #1330 (step 2: edge mutation on GraphDef). Made-with: Cursor * Use requires_module mark for numpy version checks in mutation tests Replace inline skipif version check with requires_module(np, "2.1") from the shared test helpers, consistent with other test files. Made-with: Cursor * Fix empty-graph return type: return set() instead of () for nodes/edges Made-with: Cursor * Rename AdjacencySet to AdjacencySetProxy, add bulk ops and safety guards Rename class and file to AdjacencySetProxy to clarify write-through semantics. Add bulk-efficient clear(), __isub__(), __ior__() overrides and remove_edges() on the Cython core. Guard GraphNode.discard() against double-destroy via membership check. Filter duplicates in update(). Add error-path tests for wrong types, cross-graph edges, and self-edges. Made-with: Cursor * Add destroy() method with handle invalidation, remove GRAPH_NODE_SENTINEL Replace discard() with destroy() which calls cuGraphDestroyNode and then zeroes the CUgraphNode resource in the handle box via invalidate_graph_node_handle. This prevents stale memory access on destroyed nodes. Properties (type, pred, succ, handle) degrade gracefully to None/empty for destroyed nodes. Remove the GRAPH_NODE_SENTINEL (0x1) approach in favor of using NULL for both sentinels and destroyed nodes, which is simpler and avoids the risk of passing 0x1 to driver APIs that treat it as a valid pointer. Made-with: Cursor * Add GraphNode identity cache for stable Python object round-trips Nodes retrieved via GraphDef.nodes(), edges(), or pred/succ traversal now return the same Python object that was originally created, enabling identity checks with `is`. A C++ HandleRegistry deduplicates CUgraphNode handles, and a Cython WeakValueDictionary caches the Python wrapper objects. Made-with: Cursor * Purge node cache on destroy to prevent stale identity lookups Made-with: Cursor * Skip NULL nodes in graph_node_registry to fix sentinel identity collision Sentinel (entry) nodes use NULL as their CUgraphNode, so caching them under a NULL key caused all sentinels across different graphs to share the same handle. This made nodes built from the wrong graph's entry point, causing CUDA_ERROR_INVALID_VALUE for conditional nodes and hash collisions in equality tests. Made-with: Cursor * Unregister destroyed nodes from C++ graph_node_registry When a node is destroyed, the driver may reuse its CUgraphNode pointer for a new node. Without unregistering the old entry, the registry returns a stale handle pointing to the wrong node type and graph. Made-with: Cursor * Add dedicated test for node identity preservation through round-trips Made-with: Cursor * Rename _node_cache/_cached to _node_registry/_registered Aligns Python-side terminology with the C++ graph_node_registry. Made-with: Cursor * Fix unregister_handle and rename invalidate_graph_node_handle unregister_handle: remove the expired() guard that prevented erasure when the shared_ptr was still alive. This caused stale registry entries after destroy(), leading to CUDA_ERROR_INVALID_VALUE when the driver reused CUgraphNode pointer values. Rename invalidate_graph_node_handle -> invalidate_graph_node for consistency with the rest of the graph node API. Made-with: Cursor
1 parent 5777275 commit 5064470

File tree

7 files changed

+95
-37
lines changed

7 files changed

+95
-37
lines changed

cuda_core/cuda/core/_cpp/resource_handles.cpp

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -174,13 +174,8 @@ class HandleRegistry {
174174
}
175175

176176
void unregister_handle(const Key& key) noexcept {
177-
try {
178-
std::lock_guard<std::mutex> lock(mutex_);
179-
auto it = map_.find(key);
180-
if (it != map_.end() && it->second.expired()) {
181-
map_.erase(it);
182-
}
183-
} catch (...) {}
177+
std::lock_guard<std::mutex> lock(mutex_);
178+
map_.erase(key);
184179
}
185180

186181
Handle lookup(const Key& key) {
@@ -969,17 +964,32 @@ static const GraphNodeBox* get_box(const GraphNodeHandle& h) {
969964
);
970965
}
971966

967+
static HandleRegistry<CUgraphNode, GraphNodeHandle> graph_node_registry;
968+
972969
GraphNodeHandle create_graph_node_handle(CUgraphNode node, const GraphHandle& h_graph) {
970+
if (node) {
971+
if (auto h = graph_node_registry.lookup(node)) {
972+
return h;
973+
}
974+
}
973975
auto box = std::make_shared<const GraphNodeBox>(GraphNodeBox{node, h_graph});
974-
return GraphNodeHandle(box, &box->resource);
976+
GraphNodeHandle h(box, &box->resource);
977+
if (node) {
978+
graph_node_registry.register_handle(node, h);
979+
}
980+
return h;
975981
}
976982

977983
GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept {
978984
return h ? get_box(h)->h_graph : GraphHandle{};
979985
}
980986

981-
void invalidate_graph_node_handle(const GraphNodeHandle& h) noexcept {
987+
void invalidate_graph_node(const GraphNodeHandle& h) noexcept {
982988
if (h) {
989+
CUgraphNode node = get_box(h)->resource;
990+
if (node) {
991+
graph_node_registry.unregister_handle(node);
992+
}
983993
get_box(h)->resource = nullptr;
984994
}
985995
}

cuda_core/cuda/core/_cpp/resource_handles.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ GraphNodeHandle create_graph_node_handle(CUgraphNode node, const GraphHandle& h_
416416
GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept;
417417

418418
// Zero the CUgraphNode resource inside the handle, marking it invalid.
419-
void invalidate_graph_node_handle(const GraphNodeHandle& h) noexcept;
419+
void invalidate_graph_node(const GraphNodeHandle& h) noexcept;
420420

421421
// ============================================================================
422422
// Graphics resource handle functions

cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx

Lines changed: 45 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ from cuda.core._resource_handles cimport (
4848
create_graph_handle_ref,
4949
create_graph_node_handle,
5050
graph_node_get_graph,
51-
invalidate_graph_node_handle,
51+
invalidate_graph_node,
5252
)
5353
from cuda.core._utils.cuda_utils cimport HANDLE_RETURN, _parse_fill_value
5454

@@ -57,10 +57,19 @@ from cuda.core._graph._utils cimport (
5757
_attach_user_object,
5858
)
5959

60+
import weakref
61+
6062
from cuda.core import Device
6163
from cuda.core._graph._graph_def._adjacency_set_proxy import AdjacencySetProxy
6264
from cuda.core._utils.cuda_utils import driver, handle_return
6365

66+
_node_registry = weakref.WeakValueDictionary()
67+
68+
69+
cdef inline GraphNode _registered(GraphNode n):
70+
_node_registry[<uintptr_t>n._h_node.get()] = n
71+
return n
72+
6473

6574
cdef class GraphNode:
6675
"""Base class for all graph nodes.
@@ -144,7 +153,8 @@ cdef class GraphNode:
144153
return
145154
with nogil:
146155
HANDLE_RETURN(cydriver.cuGraphDestroyNode(node))
147-
invalidate_graph_node_handle(self._h_node)
156+
_node_registry.pop(<uintptr_t>self._h_node.get(), None)
157+
invalidate_graph_node(self._h_node)
148158

149159
@property
150160
def pred(self):
@@ -522,18 +532,29 @@ cdef inline ConditionalNode _make_conditional_node(
522532
n._cond_type = cond_type
523533
n._branches = branches
524534

525-
return n
535+
return _registered(n)
526536

527537
cdef inline GraphNode GN_create(GraphHandle h_graph, cydriver.CUgraphNode node):
538+
cdef GraphNodeHandle h_node = create_graph_node_handle(node, h_graph)
539+
540+
# Sentinel: virtual node to represent the graph entry point.
528541
if node == NULL:
529542
n = GraphNode.__new__(GraphNode)
530-
(<GraphNode>n)._h_node = create_graph_node_handle(node, h_graph)
543+
(<GraphNode>n)._h_node = h_node
531544
return n
532545

533-
cdef GraphNodeHandle h_node = create_graph_node_handle(node, h_graph)
546+
# Return a registered object or create and register a new one.
547+
registered = _node_registry.get(<uintptr_t>h_node.get())
548+
if registered is not None:
549+
return <GraphNode>registered
550+
else:
551+
return _registered(GN_create_impl(h_node))
552+
553+
554+
cdef inline GraphNode GN_create_impl(GraphNodeHandle h_node):
534555
cdef cydriver.CUgraphNodeType node_type
535556
with nogil:
536-
HANDLE_RETURN(cydriver.cuGraphNodeGetType(node, &node_type))
557+
HANDLE_RETURN(cydriver.cuGraphNodeGetType(as_cu(h_node), &node_type))
537558

538559
if node_type == cydriver.CU_GRAPH_NODE_TYPE_EMPTY:
539560
return EmptyNode._create_impl(h_node)
@@ -595,10 +616,10 @@ cdef inline KernelNode GN_launch(GraphNode self, LaunchConfig conf, Kernel ker,
595616
_attach_user_object(as_cu(h_graph), <void*>new KernelHandle(ker._h_kernel),
596617
<cydriver.CUhostFn>_destroy_kernel_handle_copy)
597618

598-
return KernelNode._create_with_params(
619+
return _registered(KernelNode._create_with_params(
599620
create_graph_node_handle(new_node, h_graph),
600621
conf.grid, conf.block, conf.shmem_size,
601-
ker._h_kernel)
622+
ker._h_kernel))
602623

603624

604625
cdef inline EmptyNode GN_join(GraphNode self, tuple nodes):
@@ -624,7 +645,7 @@ cdef inline EmptyNode GN_join(GraphNode self, tuple nodes):
624645
HANDLE_RETURN(cydriver.cuGraphAddEmptyNode(
625646
&new_node, as_cu(h_graph), deps_ptr, num_deps))
626647

627-
return EmptyNode._create_impl(create_graph_node_handle(new_node, h_graph))
648+
return _registered(EmptyNode._create_impl(create_graph_node_handle(new_node, h_graph)))
628649

629650

630651
cdef inline AllocNode GN_alloc(GraphNode self, size_t size, object options):
@@ -700,9 +721,9 @@ cdef inline AllocNode GN_alloc(GraphNode self, size_t size, object options):
700721
HANDLE_RETURN(cydriver.cuGraphAddMemAllocNode(
701722
&new_node, as_cu(h_graph), deps, num_deps, &alloc_params))
702723

703-
return AllocNode._create_with_params(
724+
return _registered(AllocNode._create_with_params(
704725
create_graph_node_handle(new_node, h_graph), alloc_params.dptr, size,
705-
device_id, memory_type, tuple(peer_ids))
726+
device_id, memory_type, tuple(peer_ids)))
706727

707728

708729
cdef inline FreeNode GN_free(GraphNode self, cydriver.CUdeviceptr c_dptr):
@@ -720,7 +741,7 @@ cdef inline FreeNode GN_free(GraphNode self, cydriver.CUdeviceptr c_dptr):
720741
HANDLE_RETURN(cydriver.cuGraphAddMemFreeNode(
721742
&new_node, as_cu(h_graph), deps, num_deps, c_dptr))
722743

723-
return FreeNode._create_with_params(create_graph_node_handle(new_node, h_graph), c_dptr)
744+
return _registered(FreeNode._create_with_params(create_graph_node_handle(new_node, h_graph), c_dptr))
724745

725746

726747
cdef inline MemsetNode GN_memset(
@@ -755,9 +776,9 @@ cdef inline MemsetNode GN_memset(
755776
&new_node, as_cu(h_graph), deps, num_deps,
756777
&memset_params, ctx))
757778

758-
return MemsetNode._create_with_params(
779+
return _registered(MemsetNode._create_with_params(
759780
create_graph_node_handle(new_node, h_graph), c_dst,
760-
val, elem_size, width, height, pitch)
781+
val, elem_size, width, height, pitch))
761782

762783

763784
cdef inline MemcpyNode GN_memcpy(
@@ -816,9 +837,9 @@ cdef inline MemcpyNode GN_memcpy(
816837
HANDLE_RETURN(cydriver.cuGraphAddMemcpyNode(
817838
&new_node, as_cu(h_graph), deps, num_deps, &params, ctx))
818839

819-
return MemcpyNode._create_with_params(
840+
return _registered(MemcpyNode._create_with_params(
820841
create_graph_node_handle(new_node, h_graph), c_dst, c_src, size,
821-
c_dst_type, c_src_type)
842+
c_dst_type, c_src_type))
822843

823844

824845
cdef inline ChildGraphNode GN_embed(GraphNode self, GraphDef child_def):
@@ -843,8 +864,8 @@ cdef inline ChildGraphNode GN_embed(GraphNode self, GraphDef child_def):
843864

844865
cdef GraphHandle h_embedded = create_graph_handle_ref(embedded_graph, h_graph)
845866

846-
return ChildGraphNode._create_with_params(
847-
create_graph_node_handle(new_node, h_graph), h_embedded)
867+
return _registered(ChildGraphNode._create_with_params(
868+
create_graph_node_handle(new_node, h_graph), h_embedded))
848869

849870

850871
cdef inline EventRecordNode GN_record_event(GraphNode self, Event ev):
@@ -865,8 +886,8 @@ cdef inline EventRecordNode GN_record_event(GraphNode self, Event ev):
865886
_attach_user_object(as_cu(h_graph), <void*>new EventHandle(ev._h_event),
866887
<cydriver.CUhostFn>_destroy_event_handle_copy)
867888

868-
return EventRecordNode._create_with_params(
869-
create_graph_node_handle(new_node, h_graph), ev._h_event)
889+
return _registered(EventRecordNode._create_with_params(
890+
create_graph_node_handle(new_node, h_graph), ev._h_event))
870891

871892

872893
cdef inline EventWaitNode GN_wait_event(GraphNode self, Event ev):
@@ -887,8 +908,8 @@ cdef inline EventWaitNode GN_wait_event(GraphNode self, Event ev):
887908
_attach_user_object(as_cu(h_graph), <void*>new EventHandle(ev._h_event),
888909
<cydriver.CUhostFn>_destroy_event_handle_copy)
889910

890-
return EventWaitNode._create_with_params(
891-
create_graph_node_handle(new_node, h_graph), ev._h_event)
911+
return _registered(EventWaitNode._create_with_params(
912+
create_graph_node_handle(new_node, h_graph), ev._h_event))
892913

893914

894915
cdef inline HostCallbackNode GN_callback(GraphNode self, object fn, object user_data):
@@ -914,6 +935,6 @@ cdef inline HostCallbackNode GN_callback(GraphNode self, object fn, object user_
914935
&new_node, as_cu(h_graph), deps, num_deps, &node_params))
915936

916937
cdef object callable_obj = fn if not isinstance(fn, ct._CFuncPtr) else None
917-
return HostCallbackNode._create_with_params(
938+
return _registered(HostCallbackNode._create_with_params(
918939
create_graph_node_handle(new_node, h_graph), callable_obj,
919-
node_params.fn, node_params.userData)
940+
node_params.fn, node_params.userData))

cuda_core/cuda/core/_resource_handles.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ cdef GraphHandle create_graph_handle_ref(cydriver.CUgraph graph, const GraphHand
186186
# Graph node handles
187187
cdef GraphNodeHandle create_graph_node_handle(cydriver.CUgraphNode node, const GraphHandle& h_graph) except+ nogil
188188
cdef GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept nogil
189-
cdef void invalidate_graph_node_handle(const GraphNodeHandle& h) noexcept nogil
189+
cdef void invalidate_graph_node(const GraphNodeHandle& h) noexcept nogil
190190

191191
# Graphics resource handles
192192
cdef GraphicsResourceHandle create_graphics_resource_handle(

cuda_core/cuda/core/_resource_handles.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core":
159159
cydriver.CUgraphNode node, const GraphHandle& h_graph) except+ nogil
160160
GraphHandle graph_node_get_graph "cuda_core::graph_node_get_graph" (
161161
const GraphNodeHandle& h) noexcept nogil
162-
void invalidate_graph_node_handle "cuda_core::invalidate_graph_node_handle" (
162+
void invalidate_graph_node "cuda_core::invalidate_graph_node" (
163163
const GraphNodeHandle& h) noexcept nogil
164164

165165
# Graphics resource handles

cuda_core/tests/graph/test_graphdef.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,7 @@ def test_node_type_preserved_by_nodes(node_spec):
661661
matched = [n for n in all_nodes if n == node]
662662
assert len(matched) == 1
663663
assert isinstance(matched[0], spec.roundtrip_class)
664+
assert matched[0] is node
664665

665666

666667
def test_node_type_preserved_by_pred_succ(node_spec):
@@ -670,6 +671,7 @@ def test_node_type_preserved_by_pred_succ(node_spec):
670671
matched = [s for s in predecessor.succ if s == node]
671672
assert len(matched) == 1
672673
assert isinstance(matched[0], spec.roundtrip_class)
674+
assert matched[0] is node
673675

674676

675677
def test_node_attrs(node_spec):
@@ -697,6 +699,31 @@ def test_node_attrs_preserved_by_nodes(node_spec):
697699
assert getattr(retrieved, attr) == getattr(node, attr), f"{spec.name}.{attr} not preserved by nodes()"
698700

699701

702+
def test_identity_preservation(init_cuda):
703+
"""Round-trips through nodes(), edges(), and pred/succ return extant
704+
objects rather than duplicates."""
705+
g = GraphDef()
706+
a = g.join()
707+
b = a.join()
708+
709+
# nodes()
710+
assert any(x is a for x in g.nodes())
711+
assert any(x is b for x in g.nodes())
712+
713+
# succ/pred
714+
a.succ = {b}
715+
(b2,) = a.succ
716+
assert b2 is b
717+
718+
(a2,) = b.pred
719+
assert a2 is a
720+
721+
# edges()
722+
((a2, b2),) = g.edges()
723+
assert a2 is a
724+
assert b2 is b
725+
726+
700727
# =============================================================================
701728
# GraphDef basics
702729
# =============================================================================

cuda_core/tests/graph/test_graphdef_mutation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ def test_convert_linear_to_fan_in(init_cuda):
380380
for node in g.nodes():
381381
if isinstance(node, MemsetNode):
382382
node.pred.clear()
383-
elif isinstance(node, KernelNode) and node != reduce_node:
383+
elif isinstance(node, KernelNode) and node is not reduce_node:
384384
node.succ.add(reduce_node)
385385

386386
assert len(g.edges()) == 8

0 commit comments

Comments
 (0)