Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions graphgen/bases/base_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def get_node(self, node_id: str) -> Union[dict, None]:
raise NotImplementedError

@abstractmethod
def update_node(self, node_id: str, node_data: dict[str, str]):
def update_node(self, node_id: str, node_data: dict[str, any]):
raise NotImplementedError

@abstractmethod
Expand All @@ -96,7 +96,7 @@ def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None

@abstractmethod
def update_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
self, source_node_id: str, target_node_id: str, edge_data: dict[str, any]
):
raise NotImplementedError

Expand All @@ -113,12 +113,12 @@ def get_node_edges(self, source_node_id: str) -> Union[list[tuple[str, str]], No
raise NotImplementedError

@abstractmethod
def upsert_node(self, node_id: str, node_data: dict[str, str]):
def upsert_node(self, node_id: str, node_data: dict[str, any]):
raise NotImplementedError

@abstractmethod
def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
self, source_node_id: str, target_node_id: str, edge_data: dict[str, any]
):
raise NotImplementedError

Expand Down
11 changes: 8 additions & 3 deletions graphgen/models/kg_builder/light_rag_kg_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class LightRAGKGBuilder(BaseKGBuilder):
def __init__(self, llm_client: BaseLLMWrapper, max_loop: int = 3):
super().__init__(llm_client)
self.max_loop = max_loop
self.tokenizer = llm_client.tokenizer

async def extract(
self, chunk: Chunk
Expand Down Expand Up @@ -134,6 +135,7 @@ async def merge_nodes(
"entity_name": entity_name,
"description": description,
"source_id": source_id,
"length": self.tokenizer.count_tokens(description),
}
kg_instance.upsert_node(entity_name, node_data=node_data)
return node_data
Expand Down Expand Up @@ -167,9 +169,11 @@ async def merge_edges(
kg_instance.upsert_node(
insert_id,
node_data={
"source_id": source_id,
"description": description,
"entity_type": "UNKNOWN",
"entity_name": insert_id,
"description": description,
"source_id": source_id,
"length": self.tokenizer.count_tokens(description),
},
)

Expand All @@ -182,12 +186,13 @@ async def merge_edges(
"tgt_id": tgt_id,
"description": description,
"source_id": source_id, # for traceability
"length": self.tokenizer.count_tokens(description),
}

kg_instance.upsert_edge(
src_id,
tgt_id,
edge_data={"source_id": source_id, "description": description},
edge_data=edge_data,
)
return edge_data

Expand Down
2 changes: 1 addition & 1 deletion graphgen/models/partitioner/ece_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def _add_unit(u):
return False
community_edges[i] = d
used_e.add(i)
token_sum += d.get("length", 0)
token_sum += int(d.get("length", 0))
return True

_add_unit(seed_unit)
Expand Down
8 changes: 4 additions & 4 deletions graphgen/models/storage/graph/kuzu_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def get_node(self, node_id: str) -> Any:
data_str = result.get_next()[0]
return self._safe_json_loads(data_str)

def update_node(self, node_id: str, node_data: dict[str, str]):
def update_node(self, node_id: str, node_data: dict[str, any]):
current_data = self.get_node(node_id)
if current_data is None:
print(f"Node {node_id} not found for update.")
Expand Down Expand Up @@ -263,7 +263,7 @@ def get_edge(self, source_node_id: str, target_node_id: str):
return self._safe_json_loads(data_str)

def update_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
self, source_node_id: str, target_node_id: str, edge_data: dict[str, any]
):
current_data = self.get_edge(source_node_id, target_node_id)
if current_data is None:
Expand Down Expand Up @@ -318,7 +318,7 @@ def get_node_edges(self, source_node_id: str) -> Any:
edges.append((src, dst, data))
return edges

def upsert_node(self, node_id: str, node_data: dict[str, str]):
def upsert_node(self, node_id: str, node_data: dict[str, any]):
"""
Insert or Update node.
Kuzu supports MERGE clause (similar to Neo4j) to handle upserts.
Expand All @@ -336,7 +336,7 @@ def upsert_node(self, node_id: str, node_data: dict[str, str]):
self._conn.execute(query, {"id": node_id, "data": json_data})

def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
self, source_node_id: str, target_node_id: str, edge_data: dict[str, any]
):
"""
Insert or Update edge.
Expand Down
8 changes: 4 additions & 4 deletions graphgen/models/storage/graph/networkx_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,22 +144,22 @@ def get_node_edges(self, source_node_id: str) -> Union[list[tuple[str, str]], No
def get_graph(self) -> nx.Graph:
return self._graph

def upsert_node(self, node_id: str, node_data: dict[str, str]):
def upsert_node(self, node_id: str, node_data: dict[str, any]):
self._graph.add_node(node_id, **node_data)

def update_node(self, node_id: str, node_data: dict[str, str]):
def update_node(self, node_id: str, node_data: dict[str, any]):
if self._graph.has_node(node_id):
self._graph.nodes[node_id].update(node_data)
else:
print(f"Node {node_id} not found in the graph for update.")

def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
self, source_node_id: str, target_node_id: str, edge_data: dict[str, any]
):
self._graph.add_edge(source_node_id, target_node_id, **edge_data)

def update_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
self, source_node_id: str, target_node_id: str, edge_data: dict[str, any]
):
if self._graph.has_edge(source_node_id, target_node_id):
self._graph.edges[(source_node_id, target_node_id)].update(edge_data)
Expand Down
41 changes: 2 additions & 39 deletions graphgen/operators/partition/partition_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,8 @@ def partition(self) -> Iterable[pd.DataFrame]:
partitioner = DFSPartitioner()
elif method == "ece":
logger.info("Partitioning knowledge graph using ECE method.")
# TODO: before ECE partitioning, we need to:
# 1. 'quiz' and 'judge' to get the comprehension loss if unit_sampling is not random
# 2. pre-tokenize nodes and edges to get the token length
self._pre_tokenize()
# before ECE partitioning, we need to:
# 'quiz' and 'judge' to get the comprehension loss if unit_sampling is not random
partitioner = ECEPartitioner()
elif method == "leiden":
logger.info("Partitioning knowledge graph using Leiden method.")
Expand Down Expand Up @@ -97,41 +95,6 @@ def partition(self) -> Iterable[pd.DataFrame]:
)
logger.info("Total communities partitioned: %d", count)

def _pre_tokenize(self) -> None:
"""Pre-tokenize all nodes and edges to add token length information."""
logger.info("Starting pre-tokenization of nodes and edges...")

nodes = self.kg_instance.get_all_nodes()
edges = self.kg_instance.get_all_edges()

# Process nodes
for node_id, node_data in nodes:
if "length" not in node_data:
try:
description = node_data.get("description", "")
tokens = self.tokenizer_instance.encode(description)
node_data["length"] = len(tokens)
self.kg_instance.update_node(node_id, node_data)
except Exception as e:
logger.warning("Failed to tokenize node %s: %s", node_id, e)
node_data["length"] = 0

# Process edges
for u, v, edge_data in edges:
if "length" not in edge_data:
try:
description = edge_data.get("description", "")
tokens = self.tokenizer_instance.encode(description)
edge_data["length"] = len(tokens)
self.kg_instance.update_edge(u, v, edge_data)
except Exception as e:
logger.warning("Failed to tokenize edge %s-%s: %s", u, v, e)
edge_data["length"] = 0

# Persist changes
self.kg_instance.index_done_callback()
logger.info("Pre-tokenization completed.")

def _attach_additional_data_to_node(self, batch: tuple) -> tuple:
"""
Attach additional data from chunk_storage to nodes in the batch.
Expand Down