Skip to content

Commit aec93bf

Browse files
committed
Merge branch 'graph-node-repr' into graph-public-api
2 parents 47276b4 + f779f30 commit aec93bf

13 files changed

+263
-118
lines changed
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Handle and Object Registries
2+
3+
When Python-managed objects round-trip through the CUDA driver (e.g.,
4+
querying a graph's nodes and getting back raw `CUgraphNode` pointers),
5+
we need to recover the original Python object rather than creating a
6+
duplicate.
7+
8+
This document describes the approach used to achieve this. The pattern
9+
is driven mainly by needs arising in the context of CUDA graphs, but
10+
it is general and can be extended to other object types as needs arise.
11+
12+
This solves the same problem as pybind11's `registered_instances` map
13+
and is sometimes called the Identity Map pattern. Two registries work
14+
together to map a raw driver handle all the way back to the original
15+
Python object. Both use weak references so they
16+
do not prevent cleanup. Entries are removed either explicitly (via
17+
`destroy()` or a Box destructor) or implicitly when the weak reference
18+
expires.
19+
20+
## Level 1: Driver Handle -> Resource Handle (C++)
21+
22+
`HandleRegistry` in `resource_handles.cpp` maps a raw CUDA handle
23+
(e.g., `CUevent`, `CUkernel`, `CUgraphNode`) to the `weak_ptr` that
24+
owns it. When a `_ref` constructor receives a raw handle, it
25+
checks the registry first. If found, it returns the existing
26+
`shared_ptr`, preserving the Box and its metadata (e.g., `EventBox`
27+
carries timing/IPC flags, `KernelBox` carries the library dependency).
28+
29+
Without this level, a round-tripped handle would produce a new Box
30+
with default metadata, losing information that was set at creation.
31+
32+
Instances: `event_registry`, `kernel_registry`, `graph_node_registry`.
33+
34+
## Level 2: Resource Handle -> Python Object (Cython)
35+
36+
`_node_registry` in `_graph_node.pyx` is a `WeakValueDictionary`
37+
mapping a resource address (`shared_ptr::get()`) to a Python
38+
`GraphNode` object. When `GraphNode._create` receives a handle from
39+
Level 1, it checks this registry. If found, it returns the existing
40+
Python object.
41+
42+
Without this level, each driver round-trip would produce a distinct
43+
Python object for the same logical node, resulting in surprising
44+
behavior:
45+
46+
```python
47+
a = g.empty()
48+
a.succ = {b}
49+
b2, = a.succ # queries driver, gets back CUgraphNode for b
50+
assert b2 is b # fails without Level 2 registry
51+
```

cuda_core/cuda/core/_cpp/resource_handles.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -174,13 +174,8 @@ class HandleRegistry {
174174
}
175175

176176
void unregister_handle(const Key& key) noexcept {
177-
try {
178-
std::lock_guard<std::mutex> lock(mutex_);
179-
auto it = map_.find(key);
180-
if (it != map_.end() && it->second.expired()) {
181-
map_.erase(it);
182-
}
183-
} catch (...) {}
177+
std::lock_guard<std::mutex> lock(mutex_);
178+
map_.erase(key);
184179
}
185180

186181
Handle lookup(const Key& key) {
@@ -393,6 +388,7 @@ ContextHandle get_event_context(const EventHandle& h) noexcept {
393388
return h ? get_box(h)->h_context : ContextHandle{};
394389
}
395390

391+
// See REGISTRY_DESIGN.md (Level 1: Driver Handle -> Resource Handle)
396392
static HandleRegistry<CUevent, EventHandle> event_registry;
397393

398394
EventHandle create_event_handle(const ContextHandle& h_ctx, unsigned int flags,
@@ -899,6 +895,7 @@ static const KernelBox* get_box(const KernelHandle& h) {
899895
);
900896
}
901897

898+
// See REGISTRY_DESIGN.md (Level 1: Driver Handle -> Resource Handle)
902899
static HandleRegistry<CUkernel, KernelHandle> kernel_registry;
903900

904901
KernelHandle create_kernel_handle(const LibraryHandle& h_library, const char* name) {
@@ -969,6 +966,7 @@ static const GraphNodeBox* get_box(const GraphNodeHandle& h) {
969966
);
970967
}
971968

969+
// See REGISTRY_DESIGN.md (Level 1: Driver Handle -> Resource Handle)
972970
static HandleRegistry<CUgraphNode, GraphNodeHandle> graph_node_registry;
973971

974972
GraphNodeHandle create_graph_node_handle(CUgraphNode node, const GraphHandle& h_graph) {
@@ -989,7 +987,7 @@ GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept {
989987
return h ? get_box(h)->h_graph : GraphHandle{};
990988
}
991989

992-
void invalidate_graph_node_handle(const GraphNodeHandle& h) noexcept {
990+
void invalidate_graph_node(const GraphNodeHandle& h) noexcept {
993991
if (h) {
994992
CUgraphNode node = get_box(h)->resource;
995993
if (node) {

cuda_core/cuda/core/_cpp/resource_handles.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ GraphNodeHandle create_graph_node_handle(CUgraphNode node, const GraphHandle& h_
416416
GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept;
417417

418418
// Zero the CUgraphNode resource inside the handle, marking it invalid.
419-
void invalidate_graph_node_handle(const GraphNodeHandle& h) noexcept;
419+
void invalidate_graph_node(const GraphNodeHandle& h) noexcept;
420420

421421
// ============================================================================
422422
// Graphics resource handle functions

cuda_core/cuda/core/_resource_handles.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ cdef GraphHandle create_graph_handle_ref(cydriver.CUgraph graph, const GraphHand
186186
# Graph node handles
187187
cdef GraphNodeHandle create_graph_node_handle(cydriver.CUgraphNode node, const GraphHandle& h_graph) except+ nogil
188188
cdef GraphHandle graph_node_get_graph(const GraphNodeHandle& h) noexcept nogil
189-
cdef void invalidate_graph_node_handle(const GraphNodeHandle& h) noexcept nogil
189+
cdef void invalidate_graph_node(const GraphNodeHandle& h) noexcept nogil
190190

191191
# Graphics resource handles
192192
cdef GraphicsResourceHandle create_graphics_resource_handle(

cuda_core/cuda/core/_resource_handles.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ cdef extern from "_cpp/resource_handles.hpp" namespace "cuda_core":
159159
cydriver.CUgraphNode node, const GraphHandle& h_graph) except+ nogil
160160
GraphHandle graph_node_get_graph "cuda_core::graph_node_get_graph" (
161161
const GraphNodeHandle& h) noexcept nogil
162-
void invalidate_graph_node_handle "cuda_core::invalidate_graph_node_handle" (
162+
void invalidate_graph_node "cuda_core::invalidate_graph_node" (
163163
const GraphNodeHandle& h) noexcept nogil
164164

165165
# Graphics resource handles

cuda_core/cuda/core/graph/_adjacency_set_proxy.pyx

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class AdjacencySetProxy(MutableSet):
3939
def __contains__(self, x):
4040
if not isinstance(x, GraphNode):
4141
return False
42-
return x in (<_AdjacencySetCore>self._core).query()
42+
return (<_AdjacencySetCore>self._core).contains(<GraphNode>x)
4343

4444
def __iter__(self):
4545
return iter((<_AdjacencySetCore>self._core).query())
@@ -87,13 +87,13 @@ class AdjacencySetProxy(MutableSet):
8787
if isinstance(other, GraphNode):
8888
nodes.append(other)
8989
else:
90-
nodes.extend(other)
90+
for n in other:
91+
if not isinstance(n, GraphNode):
92+
raise TypeError(
93+
f"expected GraphNode, got {type(n).__name__}")
94+
nodes.append(n)
9195
if not nodes:
9296
return
93-
for n in nodes:
94-
if not isinstance(n, GraphNode):
95-
raise TypeError(
96-
f"expected GraphNode, got {type(n).__name__}")
9797
new = [n for n in nodes if n not in self]
9898
if new:
9999
(<_AdjacencySetCore>self._core).add_edges(new)
@@ -143,11 +143,14 @@ cdef class _AdjacencySetCore:
143143
cdef cydriver.CUgraphNode c_node = as_cu(self._h_node)
144144
if c_node == NULL:
145145
return []
146-
cdef size_t count = 0
146+
cdef cydriver.CUgraphNode buf[16]
147+
cdef size_t count = 16
148+
cdef size_t i
147149
with nogil:
148-
HANDLE_RETURN(self._query_fn(c_node, NULL, &count))
149-
if count == 0:
150-
return []
150+
HANDLE_RETURN(self._query_fn(c_node, buf, &count))
151+
if count <= 16:
152+
return [GraphNode._create(self._h_graph, buf[i])
153+
for i in range(count)]
151154
cdef vector[cydriver.CUgraphNode] nodes_vec
152155
nodes_vec.resize(count)
153156
with nogil:
@@ -156,6 +159,35 @@ cdef class _AdjacencySetCore:
156159
return [GraphNode._create(self._h_graph, nodes_vec[i])
157160
for i in range(count)]
158161

162+
cdef bint contains(self, GraphNode other):
163+
cdef cydriver.CUgraphNode c_node = as_cu(self._h_node)
164+
cdef cydriver.CUgraphNode target = as_cu(other._h_node)
165+
if c_node == NULL or target == NULL:
166+
return False
167+
cdef cydriver.CUgraphNode buf[16]
168+
cdef size_t count = 16
169+
cdef size_t i
170+
with nogil:
171+
HANDLE_RETURN(self._query_fn(c_node, buf, &count))
172+
173+
# Fast path for small sets.
174+
if count <= 16:
175+
for i in range(count):
176+
if buf[i] == target:
177+
return True
178+
return False
179+
180+
# Fallback for large sets.
181+
cdef vector[cydriver.CUgraphNode] nodes_vec
182+
nodes_vec.resize(count)
183+
with nogil:
184+
HANDLE_RETURN(self._query_fn(c_node, nodes_vec.data(), &count))
185+
assert count == nodes_vec.size()
186+
for i in range(count):
187+
if nodes_vec[i] == target:
188+
return True
189+
return False
190+
159191
cdef Py_ssize_t count(self):
160192
cdef cydriver.CUgraphNode c_node = as_cu(self._h_node)
161193
if c_node == NULL:

cuda_core/cuda/core/graph/_graph_def.pyx

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,16 @@ cdef class GraphDef:
159159
"""
160160
return self._entry.launch(config, kernel, *args)
161161
162+
def empty(self) -> "EmptyNode":
163+
"""Add an entry-point empty node (no dependencies).
164+
165+
Returns
166+
-------
167+
EmptyNode
168+
A new EmptyNode with no dependencies.
169+
"""
170+
return self._entry.join()
171+
162172
def join(self, *nodes) -> "EmptyNode":
163173
"""Create an empty node that depends on all given nodes.
164174

@@ -322,18 +332,20 @@ cdef class GraphDef:
322332
set of GraphNode
323333
All nodes in the graph.
324334
"""
325-
cdef size_t num_nodes = 0
335+
cdef vector[cydriver.CUgraphNode] nodes_vec
336+
nodes_vec.resize(128)
337+
cdef size_t num_nodes = 128
326338
327339
with nogil:
328-
HANDLE_RETURN(cydriver.cuGraphGetNodes(as_cu(self._h_graph), NULL, &num_nodes))
340+
HANDLE_RETURN(cydriver.cuGraphGetNodes(as_cu(self._h_graph), nodes_vec.data(), &num_nodes))
329341
330342
if num_nodes == 0:
331343
return set()
332344
333-
cdef vector[cydriver.CUgraphNode] nodes_vec
334-
nodes_vec.resize(num_nodes)
335-
with nogil:
336-
HANDLE_RETURN(cydriver.cuGraphGetNodes(as_cu(self._h_graph), nodes_vec.data(), &num_nodes))
345+
if num_nodes > 128:
346+
nodes_vec.resize(num_nodes)
347+
with nogil:
348+
HANDLE_RETURN(cydriver.cuGraphGetNodes(as_cu(self._h_graph), nodes_vec.data(), &num_nodes))
337349
338350
return set(GraphNode._create(self._h_graph, nodes_vec[i]) for i in range(num_nodes))
339351
@@ -346,21 +358,12 @@ cdef class GraphDef:
346358
Each element is a (from_node, to_node) pair representing
347359
a dependency edge in the graph.
348360
"""
349-
cdef size_t num_edges = 0
350-
351-
with nogil:
352-
IF CUDA_CORE_BUILD_MAJOR >= 13:
353-
HANDLE_RETURN(cydriver.cuGraphGetEdges(as_cu(self._h_graph), NULL, NULL, NULL, &num_edges))
354-
ELSE:
355-
HANDLE_RETURN(cydriver.cuGraphGetEdges(as_cu(self._h_graph), NULL, NULL, &num_edges))
356-
357-
if num_edges == 0:
358-
return set()
359-
360361
cdef vector[cydriver.CUgraphNode] from_nodes
361362
cdef vector[cydriver.CUgraphNode] to_nodes
362-
from_nodes.resize(num_edges)
363-
to_nodes.resize(num_edges)
363+
from_nodes.resize(128)
364+
to_nodes.resize(128)
365+
cdef size_t num_edges = 128
366+
364367
with nogil:
365368
IF CUDA_CORE_BUILD_MAJOR >= 13:
366369
HANDLE_RETURN(cydriver.cuGraphGetEdges(
@@ -369,6 +372,20 @@ cdef class GraphDef:
369372
HANDLE_RETURN(cydriver.cuGraphGetEdges(
370373
as_cu(self._h_graph), from_nodes.data(), to_nodes.data(), &num_edges))
371374
375+
if num_edges == 0:
376+
return set()
377+
378+
if num_edges > 128:
379+
from_nodes.resize(num_edges)
380+
to_nodes.resize(num_edges)
381+
with nogil:
382+
IF CUDA_CORE_BUILD_MAJOR >= 13:
383+
HANDLE_RETURN(cydriver.cuGraphGetEdges(
384+
as_cu(self._h_graph), from_nodes.data(), to_nodes.data(), NULL, &num_edges))
385+
ELSE:
386+
HANDLE_RETURN(cydriver.cuGraphGetEdges(
387+
as_cu(self._h_graph), from_nodes.data(), to_nodes.data(), &num_edges))
388+
372389
return set(
373390
(GraphNode._create(self._h_graph, from_nodes[i]),
374391
GraphNode._create(self._h_graph, to_nodes[i]))

0 commit comments

Comments
 (0)