@@ -332,18 +332,20 @@ cdef class GraphDef:
332332 set of GraphNode
333333 All nodes in the graph.
334334 """
335- cdef size_t num_nodes = 0
335+ cdef vector[cydriver.CUgraphNode] nodes_vec
336+ nodes_vec.resize(128)
337+ cdef size_t num_nodes = 128
336338
337339 with nogil:
338- HANDLE_RETURN(cydriver.cuGraphGetNodes(as_cu(self._h_graph), NULL , &num_nodes))
340+ HANDLE_RETURN(cydriver.cuGraphGetNodes(as_cu(self._h_graph), nodes_vec.data() , &num_nodes))
339341
340342 if num_nodes == 0:
341343 return set()
342344
343- cdef vector[cydriver.CUgraphNode] nodes_vec
344- nodes_vec.resize(num_nodes)
345- with nogil:
346- HANDLE_RETURN(cydriver.cuGraphGetNodes(as_cu(self._h_graph), nodes_vec.data(), &num_nodes))
345+ if num_nodes > 128:
346+ nodes_vec.resize(num_nodes)
347+ with nogil:
348+ HANDLE_RETURN(cydriver.cuGraphGetNodes(as_cu(self._h_graph), nodes_vec.data(), &num_nodes))
347349
348350 return set(GraphNode._create(self._h_graph, nodes_vec[i]) for i in range(num_nodes))
349351
@@ -356,21 +358,12 @@ cdef class GraphDef:
356358 Each element is a (from_node, to_node) pair representing
357359 a dependency edge in the graph.
358360 """
359- cdef size_t num_edges = 0
360-
361- with nogil:
362- IF CUDA_CORE_BUILD_MAJOR >= 13:
363- HANDLE_RETURN(cydriver.cuGraphGetEdges(as_cu(self._h_graph), NULL, NULL, NULL, &num_edges))
364- ELSE:
365- HANDLE_RETURN(cydriver.cuGraphGetEdges(as_cu(self._h_graph), NULL, NULL, &num_edges))
366-
367- if num_edges == 0:
368- return set()
369-
370361 cdef vector[cydriver.CUgraphNode] from_nodes
371362 cdef vector[cydriver.CUgraphNode] to_nodes
372- from_nodes.resize(num_edges)
373- to_nodes.resize(num_edges)
363+ from_nodes.resize(128)
364+ to_nodes.resize(128)
365+ cdef size_t num_edges = 128
366+
374367 with nogil:
375368 IF CUDA_CORE_BUILD_MAJOR >= 13:
376369 HANDLE_RETURN(cydriver.cuGraphGetEdges(
@@ -379,6 +372,20 @@ cdef class GraphDef:
379372 HANDLE_RETURN(cydriver.cuGraphGetEdges(
380373 as_cu(self._h_graph), from_nodes.data(), to_nodes.data(), &num_edges))
381374
375+ if num_edges == 0:
376+ return set()
377+
378+ if num_edges > 128:
379+ from_nodes.resize(num_edges)
380+ to_nodes.resize(num_edges)
381+ with nogil:
382+ IF CUDA_CORE_BUILD_MAJOR >= 13:
383+ HANDLE_RETURN(cydriver.cuGraphGetEdges(
384+ as_cu(self._h_graph), from_nodes.data(), to_nodes.data(), NULL, &num_edges))
385+ ELSE:
386+ HANDLE_RETURN(cydriver.cuGraphGetEdges(
387+ as_cu(self._h_graph), from_nodes.data(), to_nodes.data(), &num_edges))
388+
382389 return set(
383390 (GraphNode._create(self._h_graph, from_nodes[i]),
384391 GraphNode._create(self._h_graph, to_nodes[i]))
0 commit comments