Skip to content

Commit e09b6e4

Browse files
committed
Add edge mutation support and MutableSet interface for GraphNode adjacencies
Enable adding/removing edges between graph nodes via AdjacencySet (a MutableSet proxy on GraphNode.pred/succ), node removal via discard(), and property setters for bulk edge replacement. Includes comprehensive mutation and interface tests. Closes part of #1330 (step 2: edge mutation on GraphDef). Made-with: Cursor
1 parent 4c1ac10 commit e09b6e4

7 files changed

Lines changed: 512 additions & 9 deletions

File tree

cuda_core/cuda/core/_graph/_graph_def/_adjacency_set.pyx

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ class AdjacencySet(MutableSet):
2828
def __init__(self, node, bint is_fwd):
2929
self._core = _AdjacencySetCore(node, is_fwd)
3030

31+
@classmethod
32+
def _from_iterable(cls, it):
33+
return set(it)
34+
3135
# --- abstract methods required by MutableSet ---
3236

3337
def __contains__(self, x):
@@ -45,6 +49,8 @@ class AdjacencySet(MutableSet):
4549
if not isinstance(value, GraphNode):
4650
raise TypeError(
4751
f"expected GraphNode, got {type(value).__name__}")
52+
if value in self:
53+
return
4854
(<_AdjacencySetCore>self._core).add_edge(<GraphNode>value)
4955

5056
def discard(self, value):

cuda_core/cuda/core/_graph/_graph_def/_graph_def.pyx

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ cdef class GraphDef:
314314
with nogil:
315315
HANDLE_RETURN(cydriver.cuGraphDebugDotPrint(as_cu(self._h_graph), c_path, flags))
316316
317-
def nodes(self) -> tuple:
317+
def nodes(self) -> set:
318318
"""Return all nodes in the graph.
319319

320320
Returns
@@ -335,9 +335,9 @@ cdef class GraphDef:
335335
with nogil:
336336
HANDLE_RETURN(cydriver.cuGraphGetNodes(as_cu(self._h_graph), nodes_vec.data(), &num_nodes))
337337
338-
return tuple(GraphNode._create(self._h_graph, nodes_vec[i]) for i in range(num_nodes))
338+
return set(GraphNode._create(self._h_graph, nodes_vec[i]) for i in range(num_nodes))
339339
340-
def edges(self) -> tuple:
340+
def edges(self) -> set:
341341
"""Return all edges in the graph as (from_node, to_node) pairs.
342342

343343
Returns
@@ -369,7 +369,7 @@ cdef class GraphDef:
369369
HANDLE_RETURN(cydriver.cuGraphGetEdges(
370370
as_cu(self._h_graph), from_nodes.data(), to_nodes.data(), &num_edges))
371371
372-
return tuple(
372+
return set(
373373
(GraphNode._create(self._h_graph, from_nodes[i]),
374374
GraphNode._create(self._h_graph, to_nodes[i]))
375375
for i in range(num_edges)

cuda_core/cuda/core/_graph/_graph_def/_graph_node.pyx

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,8 @@ cdef class GraphNode:
123123
"""
124124
return as_py(self._h_node)
125125

126-
def remove(self):
127-
"""Remove this node and all its edges from the parent graph."""
126+
def discard(self):
127+
"""Discard this node and remove all its edges from the parent graph."""
128128
cdef cydriver.CUgraphNode node = as_cu(self._h_node)
129129
with nogil:
130130
HANDLE_RETURN(cydriver.cuGraphDestroyNode(node))
@@ -134,11 +134,23 @@ cdef class GraphNode:
134134
"""A mutable set-like view of this node's predecessors."""
135135
return AdjacencySet(self, False)
136136

137+
@pred.setter
138+
def pred(self, value):
139+
p = AdjacencySet(self, False)
140+
p.clear()
141+
p.update(value)
142+
137143
@property
138144
def succ(self):
139145
"""A mutable set-like view of this node's successors."""
140146
return AdjacencySet(self, True)
141147

148+
@succ.setter
149+
def succ(self, value):
150+
s = AdjacencySet(self, True)
151+
s.clear()
152+
s.update(value)
153+
142154
def launch(self, config: LaunchConfig, kernel: Kernel, *args) -> KernelNode:
143155
"""Add a kernel launch node depending on this node.
144156

cuda_core/tests/graph/test_graphdef_lifetime.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def test_branches_survive_parent_deletion(init_cuda, builder, expected_count):
6868
gc.collect()
6969

7070
for branch in branches:
71-
assert branch.nodes() == ()
71+
assert branch.nodes() == set()
7272

7373

7474
@pytest.mark.parametrize("builder, expected_count", _COND_BUILDERS)
@@ -108,7 +108,7 @@ def test_reconstructed_body_survives_parent_deletion(init_cuda):
108108
del g, condition, all_nodes, cond_nodes, branches
109109
gc.collect()
110110

111-
assert body.nodes() == ()
111+
assert body.nodes() == set()
112112

113113

114114
# =============================================================================

0 commit comments

Comments
 (0)