@@ -299,117 +299,6 @@ async def _process_single_batch(
299299 return results
300300
301301
302- # pylint: disable=too-many-branches, too-many-statements
303- async def traverse_graph_for_atomic (
304- llm_client : OpenAIClient ,
305- tokenizer : Tokenizer ,
306- graph_storage : NetworkXStorage ,
307- traverse_strategy : Dict ,
308- text_chunks_storage : JsonKVStorage ,
309- progress_bar : gr .Progress = None ,
310- max_concurrent : int = 1000 ,
311- ) -> dict :
312- """
313- Traverse the graph atomicly
314-
315- :param llm_client
316- :param tokenizer
317- :param graph_storage
318- :param traverse_strategy
319- :param text_chunks_storage
320- :param progress_bar
321- :param max_concurrent
322- :return: question and answer
323- """
324-
325- semaphore = asyncio .Semaphore (max_concurrent )
326-
327- def _parse_qa (qa : str ) -> tuple :
328- if "Question:" in qa and "Answer:" in qa :
329- question = qa .split ("Question:" )[1 ].split ("Answer:" )[0 ].strip ()
330- answer = qa .split ("Answer:" )[1 ].strip ()
331- elif "问题:" in qa and "答案:" in qa :
332- question = qa .split ("问题:" )[1 ].split ("答案:" )[0 ].strip ()
333- answer = qa .split ("答案:" )[1 ].strip ()
334- else :
335- return None , None
336- return question .strip ('"' ), answer .strip ('"' )
337-
338- async def _generate_question (node_or_edge : tuple ):
339- if len (node_or_edge ) == 2 :
340- des = node_or_edge [0 ] + ": " + node_or_edge [1 ]["description" ]
341- loss = node_or_edge [1 ]["loss" ] if "loss" in node_or_edge [1 ] else - 1.0
342- else :
343- des = node_or_edge [2 ]["description" ]
344- loss = node_or_edge [2 ]["loss" ] if "loss" in node_or_edge [2 ] else - 1.0
345-
346- async with semaphore :
347- try :
348- language = "Chinese" if detect_main_language (des ) == "zh" else "English"
349-
350- qa = await llm_client .generate_answer (
351- QUESTION_GENERATION_PROMPT [language ]["SINGLE_QA_TEMPLATE" ].format (
352- doc = des
353- )
354- )
355-
356- question , answer = _parse_qa (qa )
357- if question is None or answer is None :
358- return {}
359-
360- question = question .strip ('"' )
361- answer = answer .strip ('"' )
362-
363- logger .info ("Question: %s" , question )
364- logger .info ("Answer: %s" , answer )
365- return {
366- compute_content_hash (question ): {
367- "question" : question ,
368- "answer" : answer ,
369- "loss" : loss ,
370- }
371- }
372- except Exception as e : # pylint: disable=broad-except
373- logger .error ("Error occurred while generating question: %s" , e )
374- return {}
375-
376- results = {}
377- edges = list (await graph_storage .get_all_edges ())
378- nodes = list (await graph_storage .get_all_nodes ())
379-
380- edges , nodes = await _pre_tokenize (graph_storage , tokenizer , edges , nodes )
381-
382- tasks = []
383- for node in nodes :
384- if "<SEP>" in node [1 ]["description" ]:
385- description_list = node [1 ]["description" ].split ("<SEP>" )
386- for item in description_list :
387- tasks .append ((node [0 ], {"description" : item }))
388- if "loss" in node [1 ]:
389- tasks [- 1 ][1 ]["loss" ] = node [1 ]["loss" ]
390- else :
391- tasks .append ((node [0 ], node [1 ]))
392- for edge in edges :
393- if "<SEP>" in edge [2 ]["description" ]:
394- description_list = edge [2 ]["description" ].split ("<SEP>" )
395- for item in description_list :
396- tasks .append ((edge [0 ], edge [1 ], {"description" : item }))
397- if "loss" in edge [2 ]:
398- tasks [- 1 ][2 ]["loss" ] = edge [2 ]["loss" ]
399- else :
400- tasks .append ((edge [0 ], edge [1 ], edge [2 ]))
401-
402- results_list = await run_concurrent (
403- _generate_question ,
404- tasks ,
405- progress_bar = progress_bar ,
406- desc = "[4/4]Generating QAs" ,
407- )
408- for res in results_list :
409- results .update (res )
410- return results
411-
412-
413302async def traverse_graph_for_multi_hop (
414303 llm_client : OpenAIClient ,
415304 tokenizer : Tokenizer ,
0 commit comments