Skip to content

Commit f779f30

Browse files
committed
Optimize GraphDef.nodes() and edges() to try a single driver call
Pre-allocate vectors to 128 entries and pass them on the first call. Only fall back to a second call if the graph exceeds 128 nodes/edges. Made-with: Cursor
1 parent 36527da commit f779f30

File tree

2 files changed

+27
-20
lines changed

2 files changed

+27
-20
lines changed

cuda_core/cuda/core/_graph/_graph_def/_graph_def.pyx

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -332,18 +332,20 @@ cdef class GraphDef:
332332
set of GraphNode
333333
All nodes in the graph.
334334
"""
335-
cdef size_t num_nodes = 0
335+
cdef vector[cydriver.CUgraphNode] nodes_vec
336+
nodes_vec.resize(128)
337+
cdef size_t num_nodes = 128
336338
337339
with nogil:
338-
HANDLE_RETURN(cydriver.cuGraphGetNodes(as_cu(self._h_graph), NULL, &num_nodes))
340+
HANDLE_RETURN(cydriver.cuGraphGetNodes(as_cu(self._h_graph), nodes_vec.data(), &num_nodes))
339341
340342
if num_nodes == 0:
341343
return set()
342344
343-
cdef vector[cydriver.CUgraphNode] nodes_vec
344-
nodes_vec.resize(num_nodes)
345-
with nogil:
346-
HANDLE_RETURN(cydriver.cuGraphGetNodes(as_cu(self._h_graph), nodes_vec.data(), &num_nodes))
345+
if num_nodes > 128:
346+
nodes_vec.resize(num_nodes)
347+
with nogil:
348+
HANDLE_RETURN(cydriver.cuGraphGetNodes(as_cu(self._h_graph), nodes_vec.data(), &num_nodes))
347349
348350
return set(GraphNode._create(self._h_graph, nodes_vec[i]) for i in range(num_nodes))
349351
@@ -356,21 +358,12 @@ cdef class GraphDef:
356358
Each element is a (from_node, to_node) pair representing
357359
a dependency edge in the graph.
358360
"""
359-
cdef size_t num_edges = 0
360-
361-
with nogil:
362-
IF CUDA_CORE_BUILD_MAJOR >= 13:
363-
HANDLE_RETURN(cydriver.cuGraphGetEdges(as_cu(self._h_graph), NULL, NULL, NULL, &num_edges))
364-
ELSE:
365-
HANDLE_RETURN(cydriver.cuGraphGetEdges(as_cu(self._h_graph), NULL, NULL, &num_edges))
366-
367-
if num_edges == 0:
368-
return set()
369-
370361
cdef vector[cydriver.CUgraphNode] from_nodes
371362
cdef vector[cydriver.CUgraphNode] to_nodes
372-
from_nodes.resize(num_edges)
373-
to_nodes.resize(num_edges)
363+
from_nodes.resize(128)
364+
to_nodes.resize(128)
365+
cdef size_t num_edges = 128
366+
374367
with nogil:
375368
IF CUDA_CORE_BUILD_MAJOR >= 13:
376369
HANDLE_RETURN(cydriver.cuGraphGetEdges(
@@ -379,6 +372,20 @@ cdef class GraphDef:
379372
HANDLE_RETURN(cydriver.cuGraphGetEdges(
380373
as_cu(self._h_graph), from_nodes.data(), to_nodes.data(), &num_edges))
381374
375+
if num_edges == 0:
376+
return set()
377+
378+
if num_edges > 128:
379+
from_nodes.resize(num_edges)
380+
to_nodes.resize(num_edges)
381+
with nogil:
382+
IF CUDA_CORE_BUILD_MAJOR >= 13:
383+
HANDLE_RETURN(cydriver.cuGraphGetEdges(
384+
as_cu(self._h_graph), from_nodes.data(), to_nodes.data(), NULL, &num_edges))
385+
ELSE:
386+
HANDLE_RETURN(cydriver.cuGraphGetEdges(
387+
as_cu(self._h_graph), from_nodes.data(), to_nodes.data(), &num_edges))
388+
382389
return set(
383390
(GraphNode._create(self._h_graph, from_nodes[i]),
384391
GraphNode._create(self._h_graph, to_nodes[i]))

cuda_core/tests/graph/test_graphdef.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -704,7 +704,7 @@ def test_identity_preservation(init_cuda):
704704
objects rather than duplicates."""
705705
g = GraphDef()
706706
a = g.empty()
707-
b = a.join()
707+
b = g.empty()
708708

709709
# nodes()
710710
assert any(x is a for x in g.nodes())

0 commit comments

Comments
 (0)