@@ -68,7 +68,8 @@ def __init__(
6868 self .working_dir , namespace = "graph"
6969 )
7070 self .search_storage : JsonKVStorage = JsonKVStorage (
71- self .working_dir , namespace = "search"
71+ os .path .join (self .working_dir , "data" , "graphgen" , f"{ self .unique_id } " ),
72+ namespace = "search" ,
7273 )
7374 self .rephrase_storage : JsonKVStorage = JsonKVStorage (
7475 self .working_dir , namespace = "rephrase"
@@ -206,15 +207,23 @@ async def build_kg(self, inputs: List):
206207 # Step 3: store the new entities and relations
207208 await self .graph_storage .index_done_callback ()
208209
209- @op ("search" , deps = ["read" ], op_type = OpType .STREAMING )
210+ @op ("search" , deps = ["read" ], op_type = OpType .BATCH , batch_size = 64 )
210211 @async_to_sync_method
211- async def search (self , search_config : Dict , input_stream : Iterator ):
212+ async def search (self , search_config : Dict , inputs : List ):
213+ """
214+ search new documents from full_docs_storage
215+ input_stream: document IDs from full_docs_storage
216+ return: None
217+ """
212218 logger .info ("[Search] %s ..." , ", " .join (search_config ["data_sources" ]))
213219
214- seeds = await self .meta_storage .get_new_data (self .full_docs_storage )
215- if len (seeds ) == 0 :
216- logger .warning ("All documents are already been searched" )
217- return
220+ # Step 1: get documents
221+ seeds = {}
222+ for doc_id in inputs :
223+ doc = await self .full_docs_storage .get_by_id (doc_id )
224+ if doc :
225+ seeds [doc_id ] = doc
226+
218227 search_results = await search_all (
219228 seed_data = seeds ,
220229 search_config = search_config ,
@@ -223,16 +232,15 @@ async def search(self, search_config: Dict, input_stream: Iterator):
223232 _add_search_keys = await self .search_storage .filter_keys (
224233 list (search_results .keys ())
225234 )
235+
226236 search_results = {
227237 k : v for k , v in search_results .items () if k in _add_search_keys
228238 }
229239 if len (search_results ) == 0 :
230- logger .warning ("All search results are already in the storage" )
231- return
240+ logger .warning ("[Search] No new search results to add to storage" )
241+
232242 await self .search_storage .upsert (search_results )
233243 await self .search_storage .index_done_callback ()
234- await self .meta_storage .mark_done (self .full_docs_storage )
235- await self .meta_storage .index_done_callback ()
236244
237245 @op ("quiz_and_judge" , deps = ["build_kg" ], op_type = OpType .BARRIER )
238246 @async_to_sync_method
@@ -276,6 +284,7 @@ async def quiz_and_judge(self, quiz_and_judge_config: Dict):
276284 @op ("partition" , deps = ["build_kg" ], op_type = OpType .BARRIER )
277285 @async_to_sync_method
278286 async def partition (self , partition_config : Dict ):
287+ # TODO: partition 可以yield batches
279288 batches = await partition_kg (
280289 self .graph_storage ,
281290 self .chunks_storage ,
@@ -308,10 +317,10 @@ async def extract(self, extract_config: Dict, input_stream: Iterator):
308317 await self .meta_storage .mark_done (self .chunks_storage )
309318 await self .meta_storage .index_done_callback ()
310319
311- @op ("generate" , deps = ["partition" ], op_type = OpType .BARRIER )
320+ @op ("generate" , deps = ["partition" ], op_type = OpType .STREAMING )
312321 @async_to_sync_method
313- async def generate (self , generate_config : Dict , inputs : None ):
314-
322+ async def generate (self , generate_config : Dict , input_stream : Iterator ):
323+ # TODO:
315324 batches = self .partition_storage .data
316325 if not batches :
317326 logger .warning ("No partitions found for QA generation" )
0 commit comments