Skip to content

Commit 55667d7

Browse files
feat: add AggregatedGenerator
1 parent f072c2e commit 55667d7

9 files changed

Lines changed: 331 additions & 455 deletions

File tree

graphgen/bases/base_generator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,16 @@ class BaseGenerator(ABC):
1313

1414
llm_client: BaseLLMClient
1515

16+
@staticmethod
1617
@abstractmethod
1718
def build_prompt(
18-
self, batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
19+
batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
1920
) -> str:
2021
"""Build prompt for LLM based on the given batch"""
2122

23+
@staticmethod
2224
@abstractmethod
23-
def parse_response(self, response: str) -> Any:
25+
def parse_response(response: str) -> Any:
2426
"""Parse the LLM response and return the generated QAs"""
2527

2628
async def generate(

graphgen/configs/aggregated_config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ quiz_and_judge: # quiz and test whether the LLM masters the knowledge points
1313
partition: # graph partition configuration
1414
method: ece # ece is a custom partition method based on comprehension loss
1515
method_params:
16-
max_units_per_community: 10 # max nodes and edges per community
16+
max_units_per_community: 20 # max nodes and edges per community
1717
max_tokens_per_community: 10240 # max tokens per community
1818
unit_sampling: max_loss # edge sampling strategy, support: random, max_loss, min_loss
1919
generate:
Lines changed: 122 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,127 @@
1+
from dataclasses import dataclass
2+
from typing import Any
3+
14
from graphgen.bases import BaseGenerator
5+
from graphgen.templates import AGGREGATED_GENERATION_PROMPT
6+
from graphgen.utils import compute_content_hash, detect_main_language, logger
27

38

9+
@dataclass
410
class AggregatedGenerator(BaseGenerator):
5-
def build_prompt(self, batch) -> str:
6-
pass
11+
"""
12+
Aggregated Generator follows a TWO-STEP process:
13+
1. rephrase: Rephrase the input nodes and edges into a coherent text that maintains the original meaning.
14+
The rephrased text is considered as answer to be used in the next step.
15+
2. question generation: Generate relevant questions based on the rephrased text.
16+
"""
17+
18+
@staticmethod
19+
def build_prompt(
20+
batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
21+
) -> str:
22+
"""
23+
Build prompts for REPHRASE.
24+
:param batch
25+
:return:
26+
"""
27+
nodes, edges = batch
28+
entities_str = "\n".join(
29+
[
30+
f"{index + 1}. {node[0]}: {node[1]['description']}"
31+
for index, node in enumerate(nodes)
32+
]
33+
)
34+
relations_str = "\n".join(
35+
[
36+
f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}"
37+
for index, edge in enumerate(edges)
38+
]
39+
)
40+
language = detect_main_language(entities_str + relations_str)
41+
42+
# TODO: configure add_context
43+
# if add_context:
44+
# original_ids = [
45+
# node["source_id"].split("<SEP>")[0] for node in _process_nodes
46+
# ] + [edge[2]["source_id"].split("<SEP>")[0] for edge in _process_edges]
47+
# original_ids = list(set(original_ids))
48+
# original_text = await text_chunks_storage.get_by_ids(original_ids)
49+
# original_text = "\n".join(
50+
# [
51+
# f"{index + 1}. {text['content']}"
52+
# for index, text in enumerate(original_text)
53+
# ]
54+
# )
55+
prompt = AGGREGATED_GENERATION_PROMPT[language]["ANSWER_REPHRASING"].format(
56+
language=language, entities=entities_str, relationships=relations_str
57+
)
58+
return prompt
59+
60+
@staticmethod
61+
def parse_rephrased_text(response: str) -> str:
62+
"""
63+
Parse the rephrased text from the response.
64+
:param response:
65+
:return: rephrased text
66+
"""
67+
if "Rephrased Text:" in response:
68+
rephrased_text = response.split("Rephrased Text:")[1].strip()
69+
elif "重述文本:" in response:
70+
rephrased_text = response.split("重述文本:")[1].strip()
71+
else:
72+
rephrased_text = response.strip()
73+
return rephrased_text.strip('"')
74+
75+
@staticmethod
76+
def _build_prompt_for_question_generation(answer: str) -> str:
77+
"""
78+
Build prompts for QUESTION GENERATION.
79+
:param answer:
80+
:return:
81+
"""
82+
language = detect_main_language(answer)
83+
prompt = AGGREGATED_GENERATION_PROMPT[language]["QUESTION_GENERATION"].format(
84+
answer=answer
85+
)
86+
return prompt
87+
88+
@staticmethod
89+
def parse_response(response: str) -> dict:
90+
if response.startswith("Question:"):
91+
question = response[len("Question:") :].strip()
92+
elif response.startswith("问题:"):
93+
question = response[len("问题:") :].strip()
94+
else:
95+
question = response.strip()
96+
return {
97+
"question": question,
98+
}
799

8-
def parse_response(self, response: str):
9-
pass
100+
async def generate(
101+
self,
102+
batch: tuple[
103+
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
104+
],
105+
) -> dict[str, Any]:
106+
"""
107+
Generate QAs based on a given batch.
108+
:param batch
109+
:return: QA pairs
110+
"""
111+
result = {}
112+
rephrasing_prompt = self.build_prompt(batch)
113+
response = await self.llm_client.generate_answer(rephrasing_prompt)
114+
context = self.parse_rephrased_text(response)
115+
question_generation_prompt = self._build_prompt_for_question_generation(context)
116+
response = await self.llm_client.generate_answer(question_generation_prompt)
117+
question = self.parse_response(response)["question"]
118+
logger.info("Question: %s", question)
119+
logger.info("Answer: %s", context)
120+
qa_pairs = {
121+
compute_content_hash(question): {
122+
"question": question,
123+
"answer": context,
124+
}
125+
}
126+
result.update(qa_pairs)
127+
return result

graphgen/models/generator/atomic_generator.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
1+
from dataclasses import dataclass
12
from typing import Any
23

3-
from graphgen.utils import compute_content_hash
44
from graphgen.bases import BaseGenerator
55
from graphgen.templates import ATOMIC_GENERATION_PROMPT
6-
from graphgen.utils import detect_main_language, logger
6+
from graphgen.utils import compute_content_hash, detect_main_language, logger
77

88

9+
@dataclass
910
class AtomicGenerator(BaseGenerator):
11+
@staticmethod
1012
def build_prompt(
11-
self, batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
13+
batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
1214
) -> str:
1315
nodes, edges = batch
1416
context = ""
@@ -21,7 +23,8 @@ def build_prompt(
2123
prompt = ATOMIC_GENERATION_PROMPT[language].format(context=context)
2224
return prompt
2325

24-
def parse_response(self, response: str) -> dict:
26+
@staticmethod
27+
def parse_response(response: str) -> dict:
2528
"""
2629
AtomicGenerator normally generates one QA pair per response.
2730
So we just need to parse one QA pair from the response.

0 commit comments

Comments
 (0)