22import os
33import time
44from dataclasses import dataclass , field
5- from typing import Dict , List , Union , cast
5+ from typing import Dict , cast
66
77import gradio as gr
88from tqdm .asyncio import tqdm as tqdm_async
1616 OpenAIModel ,
1717 Tokenizer ,
1818 TraverseStrategy ,
19- read_file ,
20- split_chunks ,
2119)
22-
23- from .operators import (
20+ from graphgen .operators import (
2421 extract_kg ,
2522 generate_cot ,
2623 judge_statement ,
2724 quiz ,
25+ read_files ,
2826 search_all ,
27+ split_chunks ,
2928 traverse_graph_for_aggregated ,
3029 traverse_graph_for_atomic ,
3130 traverse_graph_for_multi_hop ,
3231)
33- from .utils import (
32+ from graphgen .utils import (
33+ async_to_sync_method ,
3434 compute_content_hash ,
35- create_event_loop ,
3635 detect_main_language ,
3736 format_generation_results ,
3837 logger ,
@@ -106,15 +105,25 @@ def __post_init__(self):
106105 namespace = f"qa-{ self .unique_id } " ,
107106 )
108107
109- async def async_split_chunks (self , data : List [Union [List , Dict ]]) -> dict :
110- # TODO: configurable whether to use coreference resolution
108+ @async_to_sync_method
109+ async def insert (self ):
110+ """
111+ insert chunks into the graph
112+ """
113+
114+ input_file = self .config ["read" ]["input_file" ]
115+
116+ # Step 1: Read files
117+ data = read_files (input_file )
111118 if len (data ) == 0 :
112- return {}
119+ logger .warning ("No data to process" )
120+ return
121+
122+ # TODO: configurable whether to use coreference resolution
113123
124+ # Step 2: Split chunks and filter existing ones
114125 inserting_chunks = {}
115126 assert isinstance (data , list ) and isinstance (data [0 ], dict )
116-
117- # compute hash for each document
118127 new_docs = {
119128 compute_content_hash (doc ["content" ], prefix = "doc-" ): {
120129 "content" : doc ["content" ]
@@ -123,9 +132,10 @@ async def async_split_chunks(self, data: List[Union[List, Dict]]) -> dict:
123132 }
124133 _add_doc_keys = await self .full_docs_storage .filter_keys (list (new_docs .keys ()))
125134 new_docs = {k : v for k , v in new_docs .items () if k in _add_doc_keys }
135+
126136 if len (new_docs ) == 0 :
127137 logger .warning ("All docs are already in the storage" )
128- return {}
138+ return
129139 logger .info ("[New Docs] inserting %d docs" , len (new_docs ))
130140
131141 cur_index = 1
@@ -162,29 +172,16 @@ async def async_split_chunks(self, data: List[Union[List, Dict]]) -> dict:
162172 inserting_chunks = {
163173 k : v for k , v in inserting_chunks .items () if k in _add_chunk_keys
164174 }
165- await self .full_docs_storage .upsert (new_docs )
166- await self .text_chunks_storage .upsert (inserting_chunks )
167-
168- return inserting_chunks
169-
170- def insert (self ):
171- loop = create_event_loop ()
172- loop .run_until_complete (self .async_insert ())
173-
174- async def async_insert (self ):
175- """
176- insert chunks into the graph
177- """
178-
179- input_file = self .config ["read" ]["input_file" ]
180- data = read_file (input_file )
181- inserting_chunks = await self .async_split_chunks (data )
182175
183176 if len (inserting_chunks ) == 0 :
184177 logger .warning ("All chunks are already in the storage" )
185178 return
179+
186180 logger .info ("[New Chunks] inserting %d chunks" , len (inserting_chunks ))
181+ await self .full_docs_storage .upsert (new_docs )
182+ await self .text_chunks_storage .upsert (inserting_chunks )
187183
184+ # Step 3: Extract entities and relations from chunks
188185 logger .info ("[Entity and Relation Extraction]..." )
189186 _add_entities_and_relations = await extract_kg (
190187 llm_client = self .synthesizer_llm_client ,
@@ -214,11 +211,8 @@ async def _insert_done(self):
214211 tasks .append (cast (StorageNameSpace , storage_instance ).index_done_callback ())
215212 await asyncio .gather (* tasks )
216213
217- def search (self ):
218- loop = create_event_loop ()
219- loop .run_until_complete (self .async_search ())
220-
221- async def async_search (self ):
214+ @async_to_sync_method
215+ async def search (self ):
222216 logger .info (
223217 "Search is %s" , "enabled" if self .search_config ["enabled" ] else "disabled"
224218 )
@@ -254,11 +248,8 @@ async def async_search(self):
254248 # TODO: fix insert after search
255249 await self .async_insert ()
256250
257- def quiz (self ):
258- loop = create_event_loop ()
259- loop .run_until_complete (self .async_quiz ())
260-
261- async def async_quiz (self ):
251+ @async_to_sync_method
252+ async def quiz (self ):
262253 max_samples = self .config ["quiz_and_judge_strategy" ]["quiz_samples" ]
263254 await quiz (
264255 self .synthesizer_llm_client ,
@@ -268,11 +259,8 @@ async def async_quiz(self):
268259 )
269260 await self .rephrase_storage .index_done_callback ()
270261
271- def judge (self ):
272- loop = create_event_loop ()
273- loop .run_until_complete (self .async_judge ())
274-
275- async def async_judge (self ):
262+ @async_to_sync_method
263+ async def judge (self ):
276264 re_judge = self .config ["quiz_and_judge_strategy" ]["re_judge" ]
277265 _update_relations = await judge_statement (
278266 self .trainee_llm_client ,
@@ -282,11 +270,8 @@ async def async_judge(self):
282270 )
283271 await _update_relations .index_done_callback ()
284272
285- def traverse (self ):
286- loop = create_event_loop ()
287- loop .run_until_complete (self .async_traverse ())
288-
289- async def async_traverse (self ):
273+ @async_to_sync_method
274+ async def traverse (self ):
290275 output_data_type = self .config ["output_data_type" ]
291276
292277 if output_data_type == "atomic" :
@@ -326,11 +311,12 @@ async def async_traverse(self):
326311 await self .qa_storage .upsert (results )
327312 await self .qa_storage .index_done_callback ()
328313
329- def generate_reasoning (self , method_params ):
330- loop = create_event_loop ()
331- loop .run_until_complete (self .async_generate_reasoning (method_params ))
314+ # def generate_reasoning(self, method_params):
315+ # loop = create_event_loop()
316+ # loop.run_until_complete(self.async_generate_reasoning(method_params))
332317
333- async def async_generate_reasoning (self , method_params ):
318+ @async_to_sync_method
319+ async def generate_reasoning (self , method_params ):
334320 results = await generate_cot (
335321 self .graph_storage ,
336322 self .synthesizer_llm_client ,
@@ -344,11 +330,8 @@ async def async_generate_reasoning(self, method_params):
344330 await self .qa_storage .upsert (results )
345331 await self .qa_storage .index_done_callback ()
346332
347- def clear (self ):
348- loop = create_event_loop ()
349- loop .run_until_complete (self .async_clear ())
350-
351- async def async_clear (self ):
333+ @async_to_sync_method
334+ async def clear (self ):
352335 await self .full_docs_storage .drop ()
353336 await self .text_chunks_storage .drop ()
354337 await self .search_storage .drop ()
0 commit comments