Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 0 additions & 21 deletions graphgen/bases/base_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,24 +51,3 @@ def community2batch(
if edge_data:
edges_data.append((u, v, edge_data))
return nodes_data, edges_data

@staticmethod
def _build_adjacency_list(
nodes: List[tuple[str, dict]], edges: List[tuple[str, str, dict]]
) -> tuple[dict[str, List[str]], set[tuple[str, str]]]:
"""
Build adjacency list and edge set from nodes and edges.
:param nodes
:param edges
:return: adjacency list, edge set
"""
adj: dict[str, List[str]] = {n[0]: [] for n in nodes}
edge_set: set[tuple[str, str]] = set()
for u, v, _ in edges:
if u == v:
continue
adj[u].append(v)
adj[v].append(u)
edge_set.add((u, v))
edge_set.add((v, u))
return adj, edge_set
4 changes: 4 additions & 0 deletions graphgen/bases/base_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ def upsert_edge(
def delete_node(self, node_id: str):
raise NotImplementedError

@abstractmethod
def get_neighbors(self, node_id: str) -> List[str]:
raise NotImplementedError

@abstractmethod
def reload(self):
raise NotImplementedError
Expand Down
6 changes: 6 additions & 0 deletions graphgen/common/init_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ def upsert_edge(
def delete_node(self, node_id: str):
return self.graph.delete_node(node_id)

def get_neighbors(self, node_id: str) -> List[str]:
return self.graph.get_neighbors(node_id)

def reload(self):
return self.graph.reload()

Expand Down Expand Up @@ -245,6 +248,9 @@ def upsert_edge(
def delete_node(self, node_id: str):
return ray.get(self.actor.delete_node.remote(node_id))

def get_neighbors(self, node_id: str) -> List[str]:
return ray.get(self.actor.get_neighbors.remote(node_id))

def reload(self):
return ray.get(self.actor.reload.remote())

Expand Down
7 changes: 3 additions & 4 deletions graphgen/models/partitioner/bfs_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ def partition(
nodes = g.get_all_nodes()
edges = g.get_all_edges()

adj, _ = self._build_adjacency_list(nodes, edges)

used_n: set[str] = set()
used_e: set[frozenset[str]] = set()

Expand Down Expand Up @@ -55,15 +53,16 @@ def partition(
used_n.add(it)
comm_n.append(it)
cnt += 1
for nei in adj[it]:
for nei in g.get_neighbors(it):
e_key = frozenset((it, nei))
if e_key not in used_e:
queue.append((EDGE_UNIT, e_key))
else:
if it in used_e:
continue
used_e.add(it)
comm_e.append(tuple(sorted(it)))
u, v = sorted(it)
comm_e.append((u, v))
cnt += 1
# push nodes that are not visited
for n in it:
Expand Down
7 changes: 3 additions & 4 deletions graphgen/models/partitioner/dfs_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ def partition(
nodes = g.get_all_nodes()
edges = g.get_all_edges()

adj, _ = self._build_adjacency_list(nodes, edges)

used_n: set[str] = set()
used_e: set[frozenset[str]] = set()

Expand Down Expand Up @@ -55,7 +53,7 @@ def partition(
used_n.add(it)
comm_n.append(it)
cnt += 1
for nei in adj[it]:
for nei in g.get_neighbors(it):
e_key = frozenset((it, nei))
if e_key not in used_e:
stack.append((EDGE_UNIT, e_key))
Expand All @@ -64,7 +62,8 @@ def partition(
if it in used_e:
continue
used_e.add(it)
comm_e.append(tuple(sorted(it)))
u, v = sorted(it)
comm_e.append((u, v))
cnt += 1
# push neighboring nodes
for n in it:
Expand Down
3 changes: 1 addition & 2 deletions graphgen/models/partitioner/ece_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def partition(
nodes: List[Tuple[str, dict]] = g.get_all_nodes()
edges: List[Tuple[str, str, dict]] = g.get_all_edges()

adj, _ = self._build_adjacency_list(nodes, edges)
node_dict = dict(nodes)
edge_dict = {frozenset((u, v)): d for u, v, d in edges}

Expand Down Expand Up @@ -118,7 +117,7 @@ def _add_unit(u):

neighbors: List[Tuple[str, Any, dict]] = []
if cur_type == NODE_UNIT:
for nb_id in adj.get(cur_id, []):
for nb_id in g.get_neighbors(cur_id):
e_key = frozenset((cur_id, nb_id))
if e_key not in used_e and e_key not in community_edges:
neighbors.append((EDGE_UNIT, e_key, edge_dict[e_key]))
Expand Down
8 changes: 8 additions & 0 deletions graphgen/models/storage/graph/kuzu_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,14 @@ def delete_node(self, node_id: str):
self._conn.execute(query, {"id": node_id})
print(f"Node {node_id} deleted from KuzuDB.")

def get_neighbors(self, node_id: str) -> List[str]:
query = """
MATCH (a:Entity {id: $id})-[:Relation]-(b:Entity)
RETURN DISTINCT b.id
"""
result = self._conn.execute(query, {"id": node_id})
return [row[0] for row in result if row]

def clear(self):
"""Clear all data but keep schema (or drop tables)."""
self._conn.execute("MATCH (n) DETACH DELETE n")
Expand Down
12 changes: 12 additions & 0 deletions graphgen/models/storage/graph/networkx_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,18 @@ def delete_node(self, node_id: str):
else:
print(f"Node {node_id} not found in the graph for deletion.")

def get_neighbors(self, node_id: str) -> List[str]:
"""
Get the neighbors of a node based on the specified node_id.

:param node_id: The node_id to get neighbors for
:return: List of neighbor node IDs
"""
if self._graph.has_node(node_id):
return list(self._graph.neighbors(node_id))
print(f"Node {node_id} not found in the graph.")
return []
Comment on lines +200 to +201
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Printing to standard output in a library function is generally discouraged. It can clutter the output of the application using the library and may not be desirable in all contexts (e.g., when running in production or as part of a larger pipeline). It's better to either use a proper logging framework or simply return an empty list, letting the caller decide how to handle the 'node not found' case. The empty list sufficiently signals that there are no neighbors, which is true if the node doesn't exist.

Suggested change
print(f"Node {node_id} not found in the graph.")
return []
return []


def clear(self):
"""
Clear the graph by removing all nodes and edges.
Expand Down