Skip to content

Commit 15d0036

Browse files
committed
Add cheap containment test and early type check for AdjacencySetProxy
Add _AdjacencySetCore.contains() that checks membership by comparing raw CUgraphNode handles at the C level, avoiding Python object construction. Uses a 16-element stack buffer for a single driver call in the common case. Move the type check in update() inline next to the extend loop so invalid input is rejected immediately. Made-with: Cursor
1 parent 9766e54 commit 15d0036

File tree

1 file changed

+31
-6
lines changed

1 file changed

+31
-6
lines changed

cuda_core/cuda/core/_graph/_graph_def/_adjacency_set_proxy.pyx

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class AdjacencySetProxy(MutableSet):
3939
def __contains__(self, x):
4040
if not isinstance(x, GraphNode):
4141
return False
42-
return x in (<_AdjacencySetCore>self._core).query()
42+
return (<_AdjacencySetCore>self._core).contains(<GraphNode>x)
4343

4444
def __iter__(self):
4545
return iter((<_AdjacencySetCore>self._core).query())
@@ -87,13 +87,13 @@ class AdjacencySetProxy(MutableSet):
8787
if isinstance(other, GraphNode):
8888
nodes.append(other)
8989
else:
90-
nodes.extend(other)
90+
for n in other:
91+
if not isinstance(n, GraphNode):
92+
raise TypeError(
93+
f"expected GraphNode, got {type(n).__name__}")
94+
nodes.append(n)
9195
if not nodes:
9296
return
93-
for n in nodes:
94-
if not isinstance(n, GraphNode):
95-
raise TypeError(
96-
f"expected GraphNode, got {type(n).__name__}")
9797
new = [n for n in nodes if n not in self]
9898
if new:
9999
(<_AdjacencySetCore>self._core).add_edges(new)
@@ -156,6 +156,31 @@ cdef class _AdjacencySetCore:
156156
return [GraphNode._create(self._h_graph, nodes_vec[i])
157157
for i in range(count)]
158158

159+
cdef bint contains(self, GraphNode other):
160+
cdef cydriver.CUgraphNode c_node = as_cu(self._h_node)
161+
cdef cydriver.CUgraphNode target = as_cu(other._h_node)
162+
if c_node == NULL or target == NULL:
163+
return False
164+
cdef cydriver.CUgraphNode buf[16]
165+
cdef size_t count = 16
166+
cdef size_t i
167+
with nogil:
168+
HANDLE_RETURN(self._query_fn(c_node, buf, &count))
169+
if count <= 16:
170+
for i in range(count):
171+
if buf[i] == target:
172+
return True
173+
else:
174+
cdef vector[cydriver.CUgraphNode] nodes_vec
175+
nodes_vec.resize(count)
176+
with nogil:
177+
HANDLE_RETURN(self._query_fn(c_node, nodes_vec.data(), &count))
178+
assert count == nodes_vec.size()
179+
for i in range(count):
180+
if nodes_vec[i] == target:
181+
return True
182+
return False
183+
159184
cdef Py_ssize_t count(self):
160185
cdef cydriver.CUgraphNode c_node = as_cu(self._h_node)
161186
if c_node == NULL:

0 commit comments

Comments
 (0)