Skip to content

Commit b4431a2

Browse files
feat: add AtomicGenerator
1 parent 69e0d6a commit b4431a2

19 files changed

Lines changed: 282 additions & 131 deletions

graphgen/bases/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from .base_generator import BaseGenerator
12
from .base_kg_builder import BaseKGBuilder
23
from .base_llm_client import BaseLLMClient
34
from .base_partitioner import BasePartitioner

graphgen/bases/base_generator.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from abc import ABC, abstractmethod
2+
from dataclasses import dataclass
3+
from typing import Any
4+
5+
from graphgen.bases.base_llm_client import BaseLLMClient
6+
7+
8+
@dataclass
9+
class BaseGenerator(ABC):
10+
"""
11+
Generate QAs based on given prompts.
12+
"""
13+
14+
llm_client: BaseLLMClient
15+
16+
@abstractmethod
17+
def build_prompt(
18+
self, batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
19+
) -> str:
20+
"""Build prompt for LLM based on the given batch"""
21+
22+
@abstractmethod
23+
def parse_response(self, response: str) -> Any:
24+
"""Parse the LLM response and return the generated QAs"""
25+
26+
async def generate(
27+
self,
28+
batch: tuple[
29+
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
30+
],
31+
) -> dict[str, Any]:
32+
"""
33+
Generate QAs based on a given batch.
34+
:param batch
35+
:return: QA pairs
36+
"""
37+
result = {}
38+
prompt = self.build_prompt(batch)
39+
response = await self.llm_client.generate_answer(prompt)
40+
qa_pairs = self.parse_response(response) # generate one or more QA pairs
41+
result.update(qa_pairs)
42+
return result
43+
44+
@staticmethod
45+
def format_generation_results(
46+
results: list[dict], output_data_format: str
47+
) -> list[dict[str, Any]]:
48+
if output_data_format == "Alpaca":
49+
results = [
50+
{
51+
"instruction": v["question"],
52+
"input": "",
53+
"output": v["answer"],
54+
}
55+
for item in results
56+
for k, v in item.items()
57+
]
58+
elif output_data_format == "Sharegpt":
59+
results = [
60+
{
61+
"conversations": [
62+
{"from": "human", "value": v["question"]},
63+
{"from": "gpt", "value": v["answer"]},
64+
]
65+
}
66+
for item in results
67+
for k, v in item.items()
68+
]
69+
elif output_data_format == "ChatML":
70+
results = [
71+
{
72+
"messages": [
73+
{"role": "user", "content": v["question"]},
74+
{"role": "assistant", "content": v["answer"]},
75+
]
76+
}
77+
for item in results
78+
for k, v in item.items()
79+
]
80+
else:
81+
raise ValueError(f"Unknown output data format: {output_data_format}")
82+
return results

graphgen/bases/base_partitioner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from abc import ABC, abstractmethod
22
from dataclasses import dataclass
3-
from typing import Any, List, Tuple
3+
from typing import Any, List
44

55
from graphgen.bases.base_storage import BaseGraphStorage
66
from graphgen.bases.datatypes import Community

graphgen/graphgen.py

Lines changed: 13 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import Dict, cast
66

77
import gradio as gr
8-
from jieba.lac_small.predict import results
98

109
from graphgen.bases.base_storage import StorageNameSpace
1110
from graphgen.bases.datatypes import Chunk
@@ -19,18 +18,14 @@
1918
from graphgen.operators import (
2019
build_kg,
2120
chunk_documents,
21+
generate_qas,
2222
judge_statement,
2323
partition_kg,
2424
quiz,
2525
read_files,
2626
search_all,
2727
)
28-
from graphgen.utils import (
29-
async_to_sync_method,
30-
compute_content_hash,
31-
format_generation_results,
32-
logger,
33-
)
28+
from graphgen.utils import async_to_sync_method, compute_content_hash, logger
3429

3530
sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
3631

@@ -239,52 +234,17 @@ async def generate(self, partition_config: Dict, generate_config: Dict):
239234
batches = await partition_kg(self.graph_storage, partition_config)
240235

241236
# Step 2: generate QA pairs
242-
mode = generate_config["mode"]
243-
logger.info("[Generation] mode: %s, batches: %d", mode, len(batches))
244-
# results = generate_qa_pairs(generate_config)
245-
# if mode == "atomic":
246-
# results = await traverse_graph_for_atomic(
247-
# self.synthesizer_llm_client,
248-
# self.tokenizer_instance,
249-
# self.graph_storage,
250-
# partition_config["method_params"],
251-
# self.text_chunks_storage,
252-
# self.progress_bar,
253-
# )
254-
# elif mode == "multi_hop":
255-
# results = await traverse_graph_for_multi_hop(
256-
# self.synthesizer_llm_client,
257-
# self.tokenizer_instance,
258-
# self.graph_storage,
259-
# partition_config["method_params"],
260-
# self.text_chunks_storage,
261-
# self.progress_bar,
262-
# )
263-
# elif mode == "aggregated":
264-
# results = await traverse_graph_for_aggregated(
265-
# self.synthesizer_llm_client,
266-
# self.tokenizer_instance,
267-
# self.graph_storage,
268-
# partition_config["method_params"],
269-
# self.text_chunks_storage,
270-
# self.progress_bar,
271-
# )
272-
# elif mode == "cot":
273-
# results = await generate_cot(
274-
# self.graph_storage,
275-
# self.synthesizer_llm_client,
276-
# method_params=partition_config["method_params"],
277-
# )
278-
# else:
279-
# raise ValueError(f"Unknown generation mode: {mode}")
280-
281-
# Step 3: format
282-
# results = format_generation_results(
283-
# results, output_data_format=generate_config["data_format"]
284-
# )
285-
#
286-
# await self.qa_storage.upsert(results)
287-
# await self.qa_storage.index_done_callback()
237+
results = await generate_qas(
238+
self.synthesizer_llm_client, batches, generate_config
239+
)
240+
241+
if not results:
242+
logger.warning("No QA pairs generated")
243+
return
244+
245+
# Step 3: store the generated QA pairs
246+
await self.qa_storage.upsert(results)
247+
await self.qa_storage.index_done_callback()
288248

289249
@async_to_sync_method
290250
async def clear(self):

graphgen/models/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
11
from .evaluator import LengthEvaluator, MTLDEvaluator, RewardEvaluator, UniEvaluator
2+
from .generator import (
3+
AggregatedGenerator,
4+
AtomicGenerator,
5+
CoTGenerator,
6+
MultiHopGenerator,
7+
)
28
from .kg_builder import LightRAGKGBuilder
39
from .llm.openai_client import OpenAIClient
410
from .llm.topk_token_model import TopkTokenModel
@@ -14,6 +20,5 @@
1420
from .search.web.bing_search import BingSearch
1521
from .search.web.google_search import GoogleSearch
1622
from .splitter import ChineseRecursiveTextSplitter, RecursiveCharacterSplitter
17-
from .storage.json_storage import JsonKVStorage, JsonListStorage
18-
from .storage.networkx_storage import NetworkXStorage
23+
from .storage import JsonKVStorage, JsonListStorage, NetworkXStorage
1924
from .tokenizer import Tokenizer
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .aggregated_generator import AggregatedGenerator
2+
from .atomic_generator import AtomicGenerator
3+
from .cot_generator import CoTGenerator
4+
from .multi_hop_generator import MultiHopGenerator
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from graphgen.bases import BaseGenerator
2+
3+
4+
class AggregatedGenerator(BaseGenerator):
5+
def build_prompt(self, batch) -> str:
6+
pass
7+
8+
def parse_response(self, response: str):
9+
pass
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from typing import Any
2+
3+
from baselines.EntiGraph.tasks.baseline_task import compute_content_hash
4+
from graphgen.bases import BaseGenerator
5+
from graphgen.templates import ATOMIC_GENERATION_PROMPT
6+
from graphgen.utils import detect_main_language, logger
7+
8+
9+
class AtomicGenerator(BaseGenerator):
10+
def build_prompt(
11+
self, batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
12+
) -> str:
13+
nodes, edges = batch
14+
context = ""
15+
for node in nodes:
16+
context += f"- {node[0]}: {node[1]['description']}\n"
17+
for edge in edges:
18+
context += f"- {edge[0]} - {edge[1]}: {edge[2]['description']}\n"
19+
language = detect_main_language(context)
20+
21+
prompt = ATOMIC_GENERATION_PROMPT[language].format(context=context)
22+
return prompt
23+
24+
def parse_response(self, response: str) -> dict:
25+
"""
26+
AtomicGenerator normally generates one QA pair per response.
27+
So we just need to parse one QA pair from the response.
28+
:param response:
29+
:return:
30+
"""
31+
if "Question:" in response and "Answer:" in response:
32+
question = response.split("Question:")[1].split("Answer:")[0].strip()
33+
answer = response.split("Answer:")[1].strip()
34+
elif "问题:" in response and "答案:" in response:
35+
question = response.split("问题:")[1].split("答案:")[0].strip()
36+
answer = response.split("答案:")[1].strip()
37+
else:
38+
logger.warning("Failed to parse response: %s", response)
39+
return None, None
40+
question = question.strip('"')
41+
answer = answer.strip('"')
42+
logger.info("Question: %s", question)
43+
logger.info("Answer: %s", answer)
44+
return {
45+
compute_content_hash(question): {
46+
"question": question,
47+
"answer": answer,
48+
}
49+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from graphgen.bases import BaseGenerator
2+
3+
4+
class CoTGenerator(BaseGenerator):
5+
def build_prompt(self, batch) -> str:
6+
pass
7+
8+
def parse_response(self, response: str):
9+
pass
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from graphgen.bases import BaseGenerator
2+
3+
4+
class MultiHopGenerator(BaseGenerator):
5+
def build_prompt(self, batch) -> str:
6+
pass
7+
8+
def parse_response(self, response: str):
9+
pass

0 commit comments

Comments
 (0)