|
5 | 5 | from typing import Dict, cast |
6 | 6 |
|
7 | 7 | import gradio as gr |
8 | | -from jieba.lac_small.predict import results |
9 | 8 |
|
10 | 9 | from graphgen.bases.base_storage import StorageNameSpace |
11 | 10 | from graphgen.bases.datatypes import Chunk |
|
19 | 18 | from graphgen.operators import ( |
20 | 19 | build_kg, |
21 | 20 | chunk_documents, |
| 21 | + generate_qas, |
22 | 22 | judge_statement, |
23 | 23 | partition_kg, |
24 | 24 | quiz, |
25 | 25 | read_files, |
26 | 26 | search_all, |
27 | 27 | ) |
28 | | -from graphgen.utils import ( |
29 | | - async_to_sync_method, |
30 | | - compute_content_hash, |
31 | | - format_generation_results, |
32 | | - logger, |
33 | | -) |
| 28 | +from graphgen.utils import async_to_sync_method, compute_content_hash, logger |
34 | 29 |
|
35 | 30 | sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) |
36 | 31 |
|
@@ -239,52 +234,17 @@ async def generate(self, partition_config: Dict, generate_config: Dict): |
239 | 234 | batches = await partition_kg(self.graph_storage, partition_config) |
240 | 235 |
|
241 | 236 | # Step 2: generate QA pairs |
242 | | - mode = generate_config["mode"] |
243 | | - logger.info("[Generation] mode: %s, batches: %d", mode, len(batches)) |
244 | | - # results = generate_qa_pairs(generate_config) |
245 | | - # if mode == "atomic": |
246 | | - # results = await traverse_graph_for_atomic( |
247 | | - # self.synthesizer_llm_client, |
248 | | - # self.tokenizer_instance, |
249 | | - # self.graph_storage, |
250 | | - # partition_config["method_params"], |
251 | | - # self.text_chunks_storage, |
252 | | - # self.progress_bar, |
253 | | - # ) |
254 | | - # elif mode == "multi_hop": |
255 | | - # results = await traverse_graph_for_multi_hop( |
256 | | - # self.synthesizer_llm_client, |
257 | | - # self.tokenizer_instance, |
258 | | - # self.graph_storage, |
259 | | - # partition_config["method_params"], |
260 | | - # self.text_chunks_storage, |
261 | | - # self.progress_bar, |
262 | | - # ) |
263 | | - # elif mode == "aggregated": |
264 | | - # results = await traverse_graph_for_aggregated( |
265 | | - # self.synthesizer_llm_client, |
266 | | - # self.tokenizer_instance, |
267 | | - # self.graph_storage, |
268 | | - # partition_config["method_params"], |
269 | | - # self.text_chunks_storage, |
270 | | - # self.progress_bar, |
271 | | - # ) |
272 | | - # elif mode == "cot": |
273 | | - # results = await generate_cot( |
274 | | - # self.graph_storage, |
275 | | - # self.synthesizer_llm_client, |
276 | | - # method_params=partition_config["method_params"], |
277 | | - # ) |
278 | | - # else: |
279 | | - # raise ValueError(f"Unknown generation mode: {mode}") |
280 | | - |
281 | | - # Step 3: format |
282 | | - # results = format_generation_results( |
283 | | - # results, output_data_format=generate_config["data_format"] |
284 | | - # ) |
285 | | - # |
286 | | - # await self.qa_storage.upsert(results) |
287 | | - # await self.qa_storage.index_done_callback() |
| 237 | + results = await generate_qas( |
| 238 | + self.synthesizer_llm_client, batches, generate_config |
| 239 | + ) |
| 240 | + |
| 241 | + if not results: |
| 242 | + logger.warning("No QA pairs generated") |
| 243 | + return |
| 244 | + |
| 245 | + # Step 3: store the generated QA pairs |
| 246 | + await self.qa_storage.upsert(results) |
| 247 | + await self.qa_storage.index_done_callback() |
288 | 248 |
|
289 | 249 | @async_to_sync_method |
290 | 250 | async def clear(self): |
|
0 commit comments