Skip to content

Commit 69e0d6a

Browse files
feat: add community2batch method
1 parent c0fa9f9 commit 69e0d6a

4 files changed

Lines changed: 54 additions & 21 deletions

File tree

graphgen/bases/base_partitioner.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from abc import ABC, abstractmethod
22
from dataclasses import dataclass
3-
from typing import Any, List
3+
from typing import Any, List, Tuple
44

55
from graphgen.bases.base_storage import BaseGraphStorage
66
from graphgen.bases.datatypes import Community
@@ -29,6 +29,41 @@ def split_communities(self, communities: List[Community]) -> List[Community]:
2929
:return:
3030
"""
3131

32+
@staticmethod
33+
async def community2batch(
34+
communities: List[Community], g: BaseGraphStorage
35+
) -> list[
36+
tuple[
37+
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
38+
]
39+
]:
40+
"""
41+
Convert communities to batches of nodes and edges.
42+
:param communities
43+
:param g: Graph storage instance
44+
:return: List of batches, each batch is a tuple of (nodes, edges)
45+
"""
46+
batches = []
47+
for comm in communities:
48+
nodes = comm.nodes
49+
edges = comm.edges
50+
nodes_data = []
51+
for node in nodes:
52+
node_data = await g.get_node(node)
53+
if node_data:
54+
nodes_data.append((node, node_data))
55+
edges_data = []
56+
for u, v in edges:
57+
edge_data = await g.get_edge(u, v)
58+
if edge_data:
59+
edges_data.append((u, v, edge_data))
60+
else:
61+
edge_data = await g.get_edge(v, u)
62+
if edge_data:
63+
edges_data.append((v, u, edge_data))
64+
batches.append((nodes_data, edges_data))
65+
return batches
66+
3267
@staticmethod
3368
def _build_adjacency_list(
3469
nodes: List[tuple[str, dict]], edges: List[tuple[str, str, dict]]

graphgen/configs/atomic_config.yaml

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,9 @@ quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
1111
quiz_samples: 2 # number of quiz samples to generate
1212
re_judge: false # whether to re-judge the existing quiz samples
1313
partition: # graph partition configuration
14-
method: ece # ece is a custom partition method based on comprehension loss
14+
method: dfs # partition method, support: dfs, bfs, ece, leiden
1515
method_params:
16-
bidirectional: true # whether to traverse the graph in both directions
17-
edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
18-
expand_method: max_width # expand method, support: max_width, max_depth
19-
isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
20-
max_depth: 1 # maximum depth for graph traversal
21-
max_extra_edges: 0 # max edges per direction (if expand_method="max_width")
22-
max_tokens: 256 # restricts input length (if expand_method="max_tokens")
23-
loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
16+
max_units_per_community: 1 # atomic partition, one node or edge per community
2417
generate:
2518
mode: atomic # atomic, aggregated, multi_hop, cot
2619
data_format: Alpaca # Alpaca, Sharegpt, ChatML

graphgen/graphgen.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Dict, cast
66

77
import gradio as gr
8+
from jieba.lac_small.predict import results
89

910
from graphgen.bases.base_storage import StorageNameSpace
1011
from graphgen.bases.datatypes import Chunk
@@ -234,10 +235,13 @@ async def quiz_and_judge(self, quiz_and_judge_config: Dict):
234235

235236
@async_to_sync_method
236237
async def generate(self, partition_config: Dict, generate_config: Dict):
237-
pass
238238
# Step 1: partition the graph
239-
# mode = generate_config["mode"]
240-
# batches = partition_kg(self.graph_storage, partition_config)
239+
batches = await partition_kg(self.graph_storage, partition_config)
240+
241+
# Step 2: generate QA pairs
242+
mode = generate_config["mode"]
243+
logger.info("[Generation] mode: %s, batches: %d", mode, len(batches))
244+
# results = generate_qa_pairs(generate_config)
241245
# if mode == "atomic":
242246
# results = await traverse_graph_for_atomic(
243247
# self.synthesizer_llm_client,
@@ -273,8 +277,6 @@ async def generate(self, partition_config: Dict, generate_config: Dict):
273277
# )
274278
# else:
275279
# raise ValueError(f"Unknown generation mode: {mode}")
276-
# Step 2: generate QA pairs
277-
# TODO
278280

279281
# Step 3: format
280282
# results = format_generation_results(

graphgen/operators/partition/partition_kg.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List
1+
from typing import Any, List, Tuple
22

33
from graphgen.bases import BaseGraphStorage
44
from graphgen.bases.datatypes import Community
@@ -11,10 +11,12 @@
1111
from graphgen.utils import logger
1212

1313

14-
def partition_kg(
14+
async def partition_kg(
1515
kg_instance: BaseGraphStorage,
1616
partition_config: dict = None,
17-
) -> List[Community]:
17+
) -> list[
18+
tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]]
19+
]:
1820
method = partition_config["method"]
1921
method_params = partition_config["method_params"]
2022
if method == "bfs":
@@ -32,6 +34,7 @@ def partition_kg(
3234
else:
3335
raise ValueError(f"Unsupported partition method: {method}")
3436

35-
communities = partitioner.partition(g=kg_instance, **method_params)
36-
logger.info(f"Partitioned the graph into {len(communities)} communities.")
37-
return communities
37+
communities = await partitioner.partition(g=kg_instance, **method_params)
38+
logger.info("Partitioned the graph into %d communities.", len(communities))
39+
batches = await partitioner.community2batch(communities, g=kg_instance)
40+
return batches

0 commit comments

Comments
 (0)