Skip to content

Commit 8d71fdc

Browse files
feat: add CoTGenerator
1 parent a132627 commit 8d71fdc

12 files changed

Lines changed: 309 additions & 308 deletions

File tree

graphgen/bases/base_partitioner.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,6 @@ async def partition(
2121
:return: List of communities
2222
"""
2323

24-
@abstractmethod
25-
def split_communities(self, communities: List[Community]) -> List[Community]:
26-
"""
27-
Split large communities into smaller ones based on max_size.
28-
:param communities
29-
:return:
30-
"""
31-
3224
@staticmethod
3325
async def community2batch(
3426
communities: List[Community], g: BaseGraphStorage

graphgen/configs/cot_config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ partition: # graph partition configuration
1212
method: leiden # leiden is a partitioner detection algorithm
1313
method_params:
1414
max_size: 20 # Maximum size of communities
15-
use_lcc: false
16-
random_seed: 42
15+
use_lcc: false # whether to use the largest connected component
16+
random_seed: 42 # random seed for partitioning
1717
generate:
1818
mode: cot # atomic, aggregated, multi_hop, cot
1919
data_format: Sharegpt # Alpaca, Sharegpt, ChatML
Lines changed: 117 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,122 @@
1+
from dataclasses import dataclass
2+
from typing import Any
3+
14
from graphgen.bases import BaseGenerator
5+
from graphgen.templates import COT_GENERATION_PROMPT
6+
from graphgen.utils import compute_content_hash, detect_main_language, logger
27

38

9+
@dataclass
410
class CoTGenerator(BaseGenerator):
5-
def build_prompt(self, batch) -> str:
6-
pass
11+
@staticmethod
12+
def build_prompt(
13+
batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
14+
) -> str:
15+
"""
16+
Build prompts for COT Template Design.
17+
:param batch:
18+
:return:
19+
"""
20+
nodes, edges = batch
21+
entities_str = "\n".join(
22+
[
23+
f"{index + 1}. {node[0]}: {node[1]['description']}"
24+
for index, node in enumerate(nodes)
25+
]
26+
)
27+
relationships_str = "\n".join(
28+
[
29+
f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}"
30+
for index, edge in enumerate(edges)
31+
]
32+
)
33+
language = detect_main_language(entities_str + relationships_str)
34+
prompt = COT_GENERATION_PROMPT[language]["COT_TEMPLATE_DESIGN"].format(
35+
entities=entities_str, relationships=relationships_str
36+
)
37+
return prompt
38+
39+
@staticmethod
40+
def build_prompt_for_cot_generation(
41+
batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]],
42+
question: str,
43+
reasoning_path: str,
44+
) -> str:
45+
"""
46+
Build prompts for COT Generation.
47+
"""
48+
nodes, edges = batch
49+
entities_str = "\n".join(
50+
[
51+
f"{index + 1}. {node[0]}: {node[1]['description']}"
52+
for index, node in enumerate(nodes)
53+
]
54+
)
55+
relationships_str = "\n".join(
56+
[
57+
f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}"
58+
for index, edge in enumerate(edges)
59+
]
60+
)
61+
language = detect_main_language(entities_str + relationships_str)
62+
prompt = COT_GENERATION_PROMPT[language]["COT_GENERATION"].format(
63+
entities=entities_str,
64+
relationships=relationships_str,
65+
question=question,
66+
reasoning_template=reasoning_path,
67+
)
68+
return prompt
69+
70+
@staticmethod
71+
def parse_response(response: str) -> dict:
72+
if "Question:" in response and "Reasoning-Path Design:" in response:
73+
question = (
74+
response.split("Question:")[1]
75+
.split("Reasoning-Path Design:")[0]
76+
.strip()
77+
)
78+
reasoning_path = response.split("Reasoning-Path Design:")[1].strip()
79+
elif "问题:" in response and "推理路径设计:" in response:
80+
question = response.split("问题:")[1].split("推理路径设计:")[0].strip()
81+
reasoning_path = response.split("推理路径设计:")[1].strip()
82+
else:
83+
logger.warning("Failed to parse CoT template: %s", response)
84+
return {}
85+
86+
question = question.strip('"')
87+
reasoning_path = reasoning_path.strip('"')
88+
logger.info("CoT Question: %s", question)
89+
logger.info("CoT Reasoning Path: %s", reasoning_path)
90+
return {
91+
"question": question,
92+
"reasoning_path": reasoning_path,
93+
}
794

8-
def parse_response(self, response: str):
9-
pass
95+
async def generate(
96+
self,
97+
batch: tuple[
98+
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
99+
],
100+
) -> dict[str, Any]:
101+
"""
102+
Generate QAs based on a given batch.
103+
:param batch
104+
:return: QA pairs
105+
"""
106+
result = {}
107+
prompt = self.build_prompt(batch)
108+
response = await self.llm_client.generate_answer(prompt)
109+
response = self.parse_response(response)
110+
question, reasoning_path = response["question"], response["reasoning_path"]
111+
prompt = self.build_prompt_for_cot_generation(batch, question, reasoning_path)
112+
cot_answer = await self.llm_client.generate_answer(prompt)
113+
logger.info("CoT Answer: %s", cot_answer)
114+
qa_pairs = {
115+
compute_content_hash(question): {
116+
"question": question,
117+
"answer": cot_answer,
118+
"reasoning_path": reasoning_path,
119+
}
120+
}
121+
result.update(qa_pairs)
122+
return result

graphgen/models/generator/multi_hop_generator.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,6 @@ def build_prompt(
3434

3535
@staticmethod
3636
def parse_response(response: str) -> dict:
37-
"""
38-
AtomicGenerator normally generates one QA pair per response.
39-
So we just need to parse one QA pair from the response.
40-
:param response:
41-
:return:
42-
"""
4337
if "Question:" in response and "Answer:" in response:
4438
question = response.split("Question:")[1].split("Answer:")[0].strip()
4539
answer = response.split("Answer:")[1].strip()

graphgen/models/partitioner/bfs_partitioner.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,3 @@ async def partition(
7676
)
7777

7878
return communities
79-
80-
def split_communities(self, communities: List[Community]) -> List[Community]:
81-
raise NotImplementedError("BFSPartitioner does not need to split communities.")

graphgen/models/partitioner/dfs_partitioner.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,3 @@ async def partition(
7373
)
7474

7575
return communities
76-
77-
def split_communities(self, communities: List[Community]) -> List[Community]:
78-
raise NotImplementedError("DFSPartitioner does not need to split communities.")
Lines changed: 77 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,95 +1,120 @@
11
from collections import defaultdict
22
from dataclasses import dataclass
3-
from typing import Any, Dict, List
3+
from typing import Any, Dict, List, Set, Tuple
44

5-
from graphgen.models.storage.networkx_storage import NetworkXStorage
5+
import igraph as ig
6+
from leidenalg import ModularityVertexPartition, find_partition
67

8+
from graphgen.bases import BaseGraphStorage, BasePartitioner
9+
from graphgen.bases.datatypes import Community
710

8-
@dataclass
9-
class LeidenPartitioner:
10-
"""Class for partitioner detection algorithms."""
1111

12-
graph_storage: NetworkXStorage = None
13-
method: str = "leiden"
14-
method_params: Dict[str, Any] = None
12+
@dataclass
13+
class LeidenPartitioner(BasePartitioner):
14+
"""
15+
Leiden partitioner that partitions the graph into communities using the Leiden algorithm.
16+
"""
17+
18+
async def partition(
19+
self,
20+
g: BaseGraphStorage,
21+
max_size: int = 20,
22+
use_lcc: bool = False,
23+
random_seed: int = 42,
24+
**kwargs: Any,
25+
) -> List[Community]:
26+
"""
27+
Leiden Partition follows these steps:
28+
1. export the graph from graph storage
29+
2. use the leiden algorithm to detect communities, get {node: community_id}
30+
3. split large communities if max_size is given
31+
4. convert {node: community_id} to List[Community]
32+
:param g
33+
:param max_size: maximum size of each community, if None or <=0, no limit
34+
:param use_lcc: whether to use the largest connected component only
35+
:param random_seed
36+
:param kwargs: other parameters for the leiden algorithm
37+
:return:
38+
"""
39+
nodes = await g.get_all_nodes() # List[Tuple[str, dict]]
40+
edges = await g.get_all_edges() # List[Tuple[str, str, dict]]
1541

16-
async def detect_communities(self) -> Dict[str, int]:
17-
if self.method == "leiden":
18-
return await self._leiden_communities(**self.method_params or {})
19-
raise ValueError(f"Unknown partitioner detection method: {self.method}")
42+
node2cid: Dict[str, int] = await self._run_leiden(
43+
nodes, edges, use_lcc, random_seed
44+
)
2045

21-
async def get_graph(self):
22-
return await self.graph_storage.get_graph()
46+
if max_size is not None and max_size > 0:
47+
node2cid = await self._split_communities(node2cid, max_size)
2348

24-
async def _leiden_communities(
25-
self, max_size: int = None, **kwargs
26-
) -> Dict[str, int]:
27-
"""
28-
Detect communities using the Leiden algorithm.
29-
If max_size is given, any partitioner larger than max_size will be split
30-
into smaller sub-communities each having at most max_size nodes.
31-
"""
32-
import igraph as ig
33-
import networkx as nx
34-
from leidenalg import ModularityVertexPartition, find_partition
49+
cid2nodes: Dict[int, List[str]] = defaultdict(list)
50+
for n, cid in node2cid.items():
51+
cid2nodes[cid].append(n)
3552

36-
graph = await self.get_graph()
37-
graph.remove_nodes_from(list(nx.isolates(graph)))
53+
communities: List[Community] = []
54+
for cid, nodes in cid2nodes.items():
55+
node_set: Set[str] = set(nodes)
56+
comm_edges: List[Tuple[str, str]] = [
57+
(u, v) for u, v, _ in edges if u in node_set and v in node_set
58+
]
59+
communities.append(Community(id=cid, nodes=nodes, edges=comm_edges))
60+
return communities
3861

39-
ig_graph = ig.Graph.TupleList(graph.edges(), directed=False)
62+
@staticmethod
63+
async def _run_leiden(
64+
nodes: List[Tuple[str, dict]],
65+
edges: List[Tuple[str, str, dict]],
66+
use_lcc: bool = False,
67+
random_seed: int = 42,
68+
) -> Dict[str, int]:
69+
# build igraph
70+
ig_graph = ig.Graph.TupleList(((u, v) for u, v, _ in edges), directed=False)
4071

41-
random_seed = kwargs.get("random_seed", 42)
42-
use_lcc = kwargs.get("use_lcc", False)
72+
# remove isolated nodes
73+
ig_graph.delete_vertices(ig_graph.vs.select(_degree_eq=0))
4374

44-
communities: Dict[str, int] = {}
75+
node2cid: Dict[str, int] = {}
4576
if use_lcc:
4677
lcc = ig_graph.components().giant()
4778
partition = find_partition(lcc, ModularityVertexPartition, seed=random_seed)
48-
for part, cluster in enumerate(partition):
79+
for part_id, cluster in enumerate(partition):
4980
for v in cluster:
50-
communities[lcc.vs[v]["name"]] = part
81+
node2cid[lcc.vs[v]["name"]] = part_id
5182
else:
5283
offset = 0
5384
for component in ig_graph.components():
5485
subgraph = ig_graph.induced_subgraph(component)
5586
partition = find_partition(
5687
subgraph, ModularityVertexPartition, seed=random_seed
5788
)
58-
for part, cluster in enumerate(partition):
89+
for part_id, cluster in enumerate(partition):
5990
for v in cluster:
6091
original_node = subgraph.vs[v]["name"]
61-
communities[original_node] = part + offset
92+
node2cid[original_node] = part_id + offset
6293
offset += len(partition)
63-
64-
# split large communities if max_size is specified
65-
if max_size is None or max_size <= 0:
66-
return communities
67-
68-
return await self._split_communities(communities, max_size)
94+
return node2cid
6995

7096
@staticmethod
7197
async def _split_communities(
72-
communities: Dict[str, int], max_size: int
98+
node2cid: Dict[str, int], max_size: int
7399
) -> Dict[str, int]:
74100
"""
75101
Split communities larger than max_size into smaller sub-communities.
76102
"""
77103
cid2nodes: Dict[int, List[str]] = defaultdict(list)
78-
for node, cid in communities.items():
79-
cid2nodes[cid].append(node)
104+
for n, cid in node2cid.items():
105+
cid2nodes[cid].append(n)
80106

81-
new_communities: Dict[str, int] = {}
107+
new_mapping: Dict[str, int] = {}
82108
new_cid = 0
83-
for cid, nodes in cid2nodes.items():
109+
for nodes in cid2nodes.values():
84110
if len(nodes) <= max_size:
85111
for n in nodes:
86-
new_communities[n] = new_cid
112+
new_mapping[n] = new_cid
87113
new_cid += 1
88114
else:
89115
for start in range(0, len(nodes), max_size):
90-
sub = nodes[start : start + max_size]
91-
for n in sub:
92-
new_communities[n] = new_cid
116+
chunk = nodes[start : start + max_size]
117+
for n in chunk:
118+
new_mapping[n] = new_cid
93119
new_cid += 1
94-
95-
return new_communities
120+
return new_mapping

0 commit comments

Comments
 (0)