Skip to content

Commit f9dc25f

Browse files
feat: add DFSPartitioner & BFSPartitioner
1 parent c98b9b1 commit f9dc25f

34 files changed

Lines changed: 486 additions & 128 deletions

graphgen/bases/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .base_kg_builder import BaseKGBuilder
22
from .base_llm_client import BaseLLMClient
3+
from .base_partitioner import BasePartitioner
34
from .base_reader import BaseReader
45
from .base_splitter import BaseSplitter
56
from .base_storage import (

graphgen/bases/base_partitioner.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from abc import ABC, abstractmethod
2+
from dataclasses import dataclass
3+
from typing import Any, List
4+
5+
from graphgen.bases.base_storage import BaseGraphStorage
6+
from graphgen.bases.datatypes import Community
7+
8+
9+
@dataclass
10+
class BasePartitioner(ABC):
11+
@abstractmethod
12+
async def partition(
13+
self,
14+
g: BaseGraphStorage,
15+
**kwargs: Any,
16+
) -> List[Community]:
17+
"""
18+
Graph -> Communities
19+
:param g: Graph storage instance
20+
:param kwargs: Additional parameters for partitioning
21+
:return: List of communities
22+
"""
23+
24+
@abstractmethod
25+
def split_communities(self, communities: List[Community]) -> List[Community]:
26+
"""
27+
Split large communities into smaller ones based on max_size.
28+
:param communities
29+
:return:
30+
"""
31+
32+
@staticmethod
33+
def _build_adjacency_list(
34+
nodes: List[tuple[str, dict]], edges: List[tuple[str, str, dict]]
35+
) -> tuple[dict[str, List[str]], set[tuple[str, str]]]:
36+
"""
37+
Build adjacency list and edge set from nodes and edges.
38+
:param nodes
39+
:param edges
40+
:return: adjacency list, edge set
41+
"""
42+
adj: dict[str, List[str]] = {n[0]: [] for n in nodes}
43+
edge_set: set[tuple[str, str]] = set()
44+
for e in edges:
45+
adj[e[0]].append(e[1])
46+
adj[e[1]].append(e[0])
47+
edge_set.add((e[0], e[1]))
48+
edge_set.add((e[1], e[0]))
49+
return adj, edge_set

graphgen/bases/base_storage.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ async def get_node(self, node_id: str) -> Union[dict, None]:
7878
async def update_node(self, node_id: str, node_data: dict[str, str]):
7979
raise NotImplementedError
8080

81-
async def get_all_nodes(self) -> Union[list[dict], None]:
81+
async def get_all_nodes(self) -> Union[list[tuple[str, dict]], None]:
8282
raise NotImplementedError
8383

8484
async def get_edge(
@@ -91,7 +91,7 @@ async def update_edge(
9191
):
9292
raise NotImplementedError
9393

94-
async def get_all_edges(self) -> Union[list[dict], None]:
94+
async def get_all_edges(self) -> Union[list[tuple[str, str, dict]], None]:
9595
raise NotImplementedError
9696

9797
async def get_node_edges(

graphgen/bases/datatypes.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,11 @@ class Token:
3030
@property
3131
def logprob(self) -> float:
3232
return math.log(self.prob)
33+
34+
35+
@dataclass
36+
class Community:
37+
id: Union[int, str]
38+
nodes: List[str] = field(default_factory=list)
39+
edges: List[tuple] = field(default_factory=list)
40+
metadata: dict = field(default_factory=dict)

graphgen/configs/atomic_config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ partition: # graph partition configuration
1717
edge_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
1818
expand_method: max_width # expand method, support: max_width, max_depth
1919
isolated_node_strategy: ignore # strategy for isolated nodes, support: ignore, add
20-
max_depth: 3 # maximum depth for graph traversal
21-
max_extra_edges: 5 # max edges per direction (if expand_method="max_width")
20+
max_depth: 1 # maximum depth for graph traversal
21+
max_extra_edges: 0 # max edges per direction (if expand_method="max_width")
2222
max_tokens: 256 # restricts input length (if expand_method="max_tokens")
2323
loss_strategy: only_edge # defines loss computation focus, support: only_edge, both
2424
generate:

graphgen/configs/cot_config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ search: # web search configuration
99
quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
1010
enabled: false
1111
partition: # graph partition configuration
12-
method: leiden # leiden is a community detection algorithm
12+
method: leiden # leiden is a partitioner detection algorithm
1313
method_params:
1414
max_size: 20 # Maximum size of communities
1515
use_lcc: false

graphgen/evaluate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from .utils import logger, set_logger
1414

1515
sys_path = os.path.abspath(os.path.dirname(__file__))
16-
set_logger(os.path.join(sys_path, "cache", "logs", "evaluate.log"))
16+
set_logger(os.path.join(sys_path, "cache", "logs", "evaluator.log"))
1717

1818
load_dotenv()
1919

graphgen/graphgen.py

Lines changed: 45 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,11 @@
1818
from graphgen.operators import (
1919
build_kg,
2020
chunk_documents,
21-
generate_cot,
2221
judge_statement,
22+
partition_kg,
2323
quiz,
2424
read_files,
2525
search_all,
26-
traverse_graph_for_aggregated,
27-
traverse_graph_for_atomic,
28-
traverse_graph_for_multi_hop,
2926
)
3027
from graphgen.utils import (
3128
async_to_sync_method,
@@ -237,54 +234,55 @@ async def quiz_and_judge(self, quiz_and_judge_config: Dict):
237234

238235
@async_to_sync_method
239236
async def generate(self, partition_config: Dict, generate_config: Dict):
237+
pass
240238
# Step 1: partition the graph
241-
# TODO: implement graph partitioning, e.g. Partitioner().partition(self.graph_storage)
242-
mode = generate_config["mode"]
243-
if mode == "atomic":
244-
results = await traverse_graph_for_atomic(
245-
self.synthesizer_llm_client,
246-
self.tokenizer_instance,
247-
self.graph_storage,
248-
partition_config["method_params"],
249-
self.text_chunks_storage,
250-
self.progress_bar,
251-
)
252-
elif mode == "multi_hop":
253-
results = await traverse_graph_for_multi_hop(
254-
self.synthesizer_llm_client,
255-
self.tokenizer_instance,
256-
self.graph_storage,
257-
partition_config["method_params"],
258-
self.text_chunks_storage,
259-
self.progress_bar,
260-
)
261-
elif mode == "aggregated":
262-
results = await traverse_graph_for_aggregated(
263-
self.synthesizer_llm_client,
264-
self.tokenizer_instance,
265-
self.graph_storage,
266-
partition_config["method_params"],
267-
self.text_chunks_storage,
268-
self.progress_bar,
269-
)
270-
elif mode == "cot":
271-
results = await generate_cot(
272-
self.graph_storage,
273-
self.synthesizer_llm_client,
274-
method_params=partition_config["method_params"],
275-
)
276-
else:
277-
raise ValueError(f"Unknown generation mode: {mode}")
239+
# mode = generate_config["mode"]
240+
# batches = partition_kg(self.graph_storage, partition_config)
241+
# if mode == "atomic":
242+
# results = await traverse_graph_for_atomic(
243+
# self.synthesizer_llm_client,
244+
# self.tokenizer_instance,
245+
# self.graph_storage,
246+
# partition_config["method_params"],
247+
# self.text_chunks_storage,
248+
# self.progress_bar,
249+
# )
250+
# elif mode == "multi_hop":
251+
# results = await traverse_graph_for_multi_hop(
252+
# self.synthesizer_llm_client,
253+
# self.tokenizer_instance,
254+
# self.graph_storage,
255+
# partition_config["method_params"],
256+
# self.text_chunks_storage,
257+
# self.progress_bar,
258+
# )
259+
# elif mode == "aggregated":
260+
# results = await traverse_graph_for_aggregated(
261+
# self.synthesizer_llm_client,
262+
# self.tokenizer_instance,
263+
# self.graph_storage,
264+
# partition_config["method_params"],
265+
# self.text_chunks_storage,
266+
# self.progress_bar,
267+
# )
268+
# elif mode == "cot":
269+
# results = await generate_cot(
270+
# self.graph_storage,
271+
# self.synthesizer_llm_client,
272+
# method_params=partition_config["method_params"],
273+
# )
274+
# else:
275+
# raise ValueError(f"Unknown generation mode: {mode}")
278276
# Step 2: generate QA pairs
279277
# TODO
280278

281279
# Step 3: format
282-
results = format_generation_results(
283-
results, output_data_format=generate_config["data_format"]
284-
)
285-
286-
await self.qa_storage.upsert(results)
287-
await self.qa_storage.index_done_callback()
280+
# results = format_generation_results(
281+
# results, output_data_format=generate_config["data_format"]
282+
# )
283+
#
284+
# await self.qa_storage.upsert(results)
285+
# await self.qa_storage.index_done_callback()
288286

289287
@async_to_sync_method
290288
async def clear(self):

graphgen/models/__init__.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
from .community.community_detector import CommunityDetector
2-
from .evaluate.length_evaluator import LengthEvaluator
3-
from .evaluate.mtld_evaluator import MTLDEvaluator
4-
from .evaluate.reward_evaluator import RewardEvaluator
5-
from .evaluate.uni_evaluator import UniEvaluator
6-
from .kg_builder.light_rag_kg_builder import LightRAGKGBuilder
1+
from .evaluator import LengthEvaluator, MTLDEvaluator, RewardEvaluator, UniEvaluator
2+
from .kg_builder import LightRAGKGBuilder
73
from .llm.openai_client import OpenAIClient
84
from .llm.topk_token_model import TopkTokenModel
5+
from .partitioner import (
6+
BFSPartitioner,
7+
DFSPartitioner,
8+
ECEPartitioner,
9+
LeidenPartitioner,
10+
)
911
from .reader import CsvReader, JsonlReader, JsonReader, TxtReader
1012
from .search.db.uniprot_search import UniProtSearch
1113
from .search.kg.wiki_search import WikiSearch

graphgen/models/community/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)