11import asyncio
2- import gradio as gr
32
43from tqdm .asyncio import tqdm as tqdm_async
54
@@ -53,6 +52,7 @@ async def handle_node(node: dict) -> dict:
5352
5453async def _construct_rephrasing_prompt (_process_nodes : list ,
5554 _process_edges : list ,
55+ _difficulty : str ,
5656 text_chunks_storage : JsonKVStorage ,
5757 add_context : bool = False
5858 ) -> str :
@@ -76,15 +76,15 @@ async def _construct_rephrasing_prompt(_process_nodes: list,
7676 original_text = await text_chunks_storage .get_by_ids (original_ids )
7777 original_text = "\n " .join ([f"{ index + 1 } . { text ['content' ]} " for index , text in enumerate (original_text )])
7878
79- prompt = ANSWER_REPHRASING_PROMPT [language ]['CONTEXT_TEMPLATE' ].format (
79+ prompt = ANSWER_REPHRASING_PROMPT [_difficulty ][ language ]['CONTEXT_TEMPLATE' ].format (
8080 language = language ,
8181 original_text = original_text ,
8282 entities = entities_str ,
8383 relationships = relations_str
8484 )
8585 return prompt
8686
87- prompt = ANSWER_REPHRASING_PROMPT [language ]['TEMPLATE' ].format (
87+ prompt = ANSWER_REPHRASING_PROMPT [_difficulty ][ language ]['TEMPLATE' ].format (
8888 language = language ,
8989 entities = entities_str ,
9090 relationships = relations_str
@@ -98,6 +98,34 @@ def get_loss_tercile(losses: list) -> (float, float):
9898
9999 return losses [q1_index ], losses [q2_index ]
100100
101+ def assign_difficulty (subgraphs : list , difficulty_order : list , loss_strategy : str ) -> list :
102+ """
103+ Assign difficulty to subgraphs based on the loss.
104+
105+ :param subgraphs
106+ :param difficulty_order
107+ :param loss_strategy
108+ :return
109+ """
110+ losses = []
111+ for subgraph in subgraphs :
112+ loss = get_average_loss (subgraph , loss_strategy )
113+ losses .append (loss )
114+ q1 , q2 = get_loss_tercile (losses )
115+
116+ for i , subgraph in enumerate (subgraphs ):
117+ loss = get_average_loss (subgraph , loss_strategy )
118+ if loss < q1 :
119+ # easy
120+ subgraphs [i ] = (subgraph [0 ], subgraph [1 ], difficulty_order [0 ])
121+ elif loss < q2 :
122+ # medium
123+ subgraphs [i ] = (subgraph [0 ], subgraph [1 ], difficulty_order [1 ])
124+ else :
125+ # hard
126+ subgraphs [i ] = (subgraph [0 ], subgraph [1 ], difficulty_order [2 ])
127+ return subgraphs
128+
101129def get_average_loss (batch : tuple , loss_strategy : str ) -> float :
102130 if loss_strategy == "only_edge" :
103131 return sum (edge [2 ]['loss' ] for edge in batch [1 ]) / len (batch [1 ])
@@ -139,7 +167,6 @@ async def traverse_graph_by_edge(
139167 graph_storage : NetworkXStorage ,
140168 traverse_strategy : TraverseStrategy ,
141169 text_chunks_storage : JsonKVStorage ,
142- progress_bar : gr .Progress = None ,
143170 max_concurrent : int = 1000
144171) -> dict :
145172 """
@@ -150,7 +177,6 @@ async def traverse_graph_by_edge(
150177 :param graph_storage
151178 :param traverse_strategy
152179 :param text_chunks_storage
153- :param progress_bar
154180 :param max_concurrent
155181 :return: question and answer
156182 """
@@ -160,10 +186,12 @@ async def traverse_graph_by_edge(
160186 async def _process_nodes_and_edges (
161187 _process_nodes : list ,
162188 _process_edges : list ,
189+ _difficulty : str ,
163190 ) -> str :
164191 prompt = await _construct_rephrasing_prompt (
165192 _process_nodes ,
166193 _process_edges ,
194+ _difficulty ,
167195 text_chunks_storage ,
168196 add_context = False
169197 )
@@ -185,48 +213,68 @@ async def _process_single_batch(
185213 context = await _process_nodes_and_edges (
186214 _process_batch [0 ],
187215 _process_batch [1 ],
216+ _process_batch [2 ]
188217 )
218+ # 一般第一行就是Question
219+ # 后面的都是Answer
220+ question = context .split ("\n " )[0 ]
221+ for prefix in ["Question:" , "问题:" , "问题:" ]:
222+ if question .startswith (prefix ):
223+ question = question [len (prefix ):].strip ()
224+ break
225+ answer = "\n " .join (context .split ("\n " )[1 :]).strip ()
226+ for prefix in ["Answer:" , "答案:" ,"答案:" , "回答:" , "回答:" ]:
227+ if answer .startswith (prefix ):
228+ answer = answer [len (prefix ):].strip ()
229+ break
230+ qas = [
231+ {
232+ "question" : question ,
233+ "answer" : answer
234+ }
235+ ]
189236
190237 language = "Chinese" if detect_main_language (context ) == "zh" else "English"
191238 pre_length = sum (node ['length' ] for node in _process_batch [0 ]) \
192239 + sum (edge [2 ]['length' ] for edge in _process_batch [1 ])
193240
194- if question_type == "single" :
195- question = await llm_client .generate_answer (
196- QUESTION_GENERATION_PROMPT [language ]['SINGLE_TEMPLATE' ].format (
197- answer = context
198- )
199- )
200- if question .startswith ("Question:" ):
201- question = question [len ("Question:" ):].strip ()
202- elif question .startswith ("问题:" ):
203- question = question [len ("问题:" ):].strip ()
204-
205- logger .info ("%d nodes and %d edges processed" , len (_process_batch [0 ]), len (_process_batch [1 ]))
206- logger .info ("Pre-length: %s" , pre_length )
207- logger .info ("Question: %s" , question )
208- logger .info ("Answer: %s" , context )
209-
210- return {
211- compute_content_hash (context ): {
212- "question" : question ,
213- "answer" : context ,
214- "loss" : get_average_loss (_process_batch , traverse_strategy .loss_strategy )
215- }
216- }
217-
218- content = await llm_client .generate_answer (
219- QUESTION_GENERATION_PROMPT [language ]['MULTI_TEMPLATE' ].format (
220- doc = context
221- )
222- )
223- qas = _post_process_synthetic_data (content )
224-
225- if len (qas ) == 0 :
226- print (content )
227- logger .error ("Error occurred while processing batch, question or answer is None" )
228- return {}
229-
241+ # if question_type == "single":
242+ # question = await llm_client.generate_answer(
243+ # QUESTION_GENERATION_PROMPT[language]['SINGLE_TEMPLATE'].format(
244+ # answer=context
245+ # )
246+ # )
247+ # if question.startswith("Question:"):
248+ # question = question[len("Question:"):].strip()
249+ # elif question.startswith("问题:"):
250+ # question = question[len("问题:"):].strip()
251+ #
252+ # logger.info("%d nodes and %d edges processed", len(_process_batch[0]), len(_process_batch[1]))
253+ # logger.info("Pre-length: %s", pre_length)
254+ # logger.info("Question: %s", question)
255+ # logger.info("Answer: %s", context)
256+ #
257+ # return {
258+ # compute_content_hash(context): {
259+ # "question": question,
260+ # "answer": context,
261+ # "loss": get_average_loss(_process_batch, traverse_strategy.loss_strategy),
262+ # "difficulty": _process_batch[2],
263+ # }
264+ # }
265+ #
266+ # content = await llm_client.generate_answer(
267+ # QUESTION_GENERATION_PROMPT[language]['MULTI_TEMPLATE'].format(
268+ # doc=context
269+ # )
270+ # )
271+ # qas = _post_process_synthetic_data(content)
272+ #
273+ # if len(qas) == 0:
274+ # print(content)
275+ # logger.error("Error occurred while processing batch, question or answer is None")
276+ # return {}
277+ #
230278 final_results = {}
231279 logger .info ("%d nodes and %d edges processed" , len (_process_batch [0 ]), len (_process_batch [1 ]))
232280 logger .info ("Pre-length: %s" , pre_length )
@@ -236,7 +284,8 @@ async def _process_single_batch(
236284 final_results [compute_content_hash (qa ['question' ])] = {
237285 "question" : qa ['question' ],
238286 "answer" : qa ['answer' ],
239- "loss" : get_average_loss (_process_batch , traverse_strategy .loss_strategy )
287+ "loss" : get_average_loss (_process_batch , traverse_strategy .loss_strategy ),
288+ "difficulty" : _process_batch [2 ],
240289 }
241290 return final_results
242291
@@ -253,17 +302,16 @@ async def _process_single_batch(
253302 traverse_strategy
254303 )
255304
305+ processing_batches = assign_difficulty (processing_batches , traverse_strategy .difficulty_order ,
306+ traverse_strategy .loss_strategy )
307+
256308 for result in tqdm_async (asyncio .as_completed (
257309 [_process_single_batch (batch ) for batch in processing_batches ]
258- ), total = len (processing_batches ), desc = "[4/4]Generating QAs " ):
310+ ), total = len (processing_batches ), desc = "Processing batches " ):
259311 try :
260- if progress_bar is not None :
261- progress_bar (len (results ) / len (processing_batches ), desc = "[4/4]Generating QAs" )
262312 results .update (await result )
263- if progress_bar is not None and len (results ) == len (processing_batches ):
264- progress_bar (1 , desc = "[4/4]Generating QAs" )
265313 except Exception as e : # pylint: disable=broad-except
266- logger .error ("Error occurred while generating QA : %s" , e )
314+ logger .error ("Error occurred while processing batches : %s" , e )
267315
268316 return results
269317
@@ -274,7 +322,6 @@ async def traverse_graph_atomically(
274322 graph_storage : NetworkXStorage ,
275323 traverse_strategy : TraverseStrategy ,
276324 text_chunks_storage : JsonKVStorage ,
277- progress_bar : gr .Progress = None ,
278325 max_concurrent : int = 1000
279326) -> dict :
280327 """
@@ -285,7 +332,6 @@ async def traverse_graph_atomically(
285332 :param graph_storage
286333 :param traverse_strategy
287334 :param text_chunks_storage
288- :param progress_bar
289335 :param max_concurrent
290336 :return: question and answer
291337 """
@@ -330,7 +376,8 @@ async def _generate_question(
330376 compute_content_hash (question ): {
331377 "question" : question ,
332378 "answer" : answer ,
333- "loss" : loss
379+ "loss" : loss ,
380+ "difficulty" : "medium"
334381 }
335382 }
336383 except Exception as e : # pylint: disable=broad-except
@@ -362,16 +409,12 @@ async def _generate_question(
362409 for result in tqdm_async (
363410 asyncio .as_completed ([_generate_question (task ) for task in tasks ]),
364411 total = len (tasks ),
365- desc = "[4/4] Generating QAs "
412+ desc = "Generating questions "
366413 ):
367414 try :
368- if progress_bar is not None :
369- progress_bar (len (results ) / len (tasks ), desc = "[4/4]Generating QAs" )
370415 results .update (await result )
371- if progress_bar is not None and len (results ) == len (tasks ):
372- progress_bar (1 , desc = "[4/4]Generating QAs" )
373416 except Exception as e : # pylint: disable=broad-except
374- logger .error ("Error occurred while generating QA : %s" , e )
417+ logger .error ("Error occurred while generating questions : %s" , e )
375418 return results
376419
377420async def traverse_graph_for_multi_hop (
@@ -380,7 +423,6 @@ async def traverse_graph_for_multi_hop(
380423 graph_storage : NetworkXStorage ,
381424 traverse_strategy : TraverseStrategy ,
382425 text_chunks_storage : JsonKVStorage ,
383- progress_bar : gr .Progress = None ,
384426 max_concurrent : int = 1000
385427) -> dict :
386428 """
@@ -391,7 +433,6 @@ async def traverse_graph_for_multi_hop(
391433 :param graph_storage
392434 :param traverse_strategy
393435 :param text_chunks_storage
394- :param progress_bar
395436 :param max_concurrent
396437 :return: question and answer
397438 """
@@ -412,6 +453,9 @@ async def traverse_graph_for_multi_hop(
412453 traverse_strategy
413454 )
414455
456+ processing_batches = assign_difficulty (processing_batches , traverse_strategy .difficulty_order ,
457+ traverse_strategy .loss_strategy )
458+
415459 async def _process_single_batch (
416460 _process_batch : tuple
417461 ) -> dict :
@@ -462,24 +506,21 @@ async def _process_single_batch(
462506 "question" : question ,
463507 "answer" : answer ,
464508 "loss" : get_average_loss (_process_batch , traverse_strategy .loss_strategy ),
509+ "difficulty" : _process_batch [2 ],
465510 }
466511 }
467512
468513 except Exception as e : # pylint: disable=broad-except
469514 logger .error ("Error occurred while processing batch: %s" , e )
470515 return {}
471516
472- async for result in tqdm_async (
517+ for result in tqdm_async (
473518 asyncio .as_completed ([_process_single_batch (batch ) for batch in processing_batches ]),
474519 total = len (processing_batches ),
475- desc = "[4/4]Generating QAs "
520+ desc = "Processing batches "
476521 ):
477522 try :
478- if progress_bar is not None :
479- progress_bar (len (results ) / len (processing_batches ), desc = "[4/4]Generating QAs" )
480523 results .update (await result )
481- if progress_bar is not None and len (results ) == len (processing_batches ):
482- progress_bar (1 , desc = "[4/4]Generating QAs" )
483524 except Exception as e : # pylint: disable=broad-except
484- logger .error ("Error occurred while generating QA : %s" , e )
525+ logger .error ("Error occurred while processing batches : %s" , e )
485526 return results
0 commit comments