Skip to content

Commit b5b3404

Browse files
feat(graphgen): add prompts for cot
1 parent 1319629 commit b5b3404

2 files changed

Lines changed: 272 additions & 92 deletions

File tree

graphgen/operators/traverse_graph.py

Lines changed: 107 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import asyncio
2-
import gradio as gr
32

43
from tqdm.asyncio import tqdm as tqdm_async
54

@@ -53,6 +52,7 @@ async def handle_node(node: dict) -> dict:
5352

5453
async def _construct_rephrasing_prompt(_process_nodes: list,
5554
_process_edges: list,
55+
_difficulty: str,
5656
text_chunks_storage: JsonKVStorage,
5757
add_context: bool = False
5858
) -> str:
@@ -76,15 +76,15 @@ async def _construct_rephrasing_prompt(_process_nodes: list,
7676
original_text = await text_chunks_storage.get_by_ids(original_ids)
7777
original_text = "\n".join([f"{index + 1}. {text['content']}" for index, text in enumerate(original_text)])
7878

79-
prompt = ANSWER_REPHRASING_PROMPT[language]['CONTEXT_TEMPLATE'].format(
79+
prompt = ANSWER_REPHRASING_PROMPT[_difficulty][language]['CONTEXT_TEMPLATE'].format(
8080
language=language,
8181
original_text=original_text,
8282
entities=entities_str,
8383
relationships=relations_str
8484
)
8585
return prompt
8686

87-
prompt = ANSWER_REPHRASING_PROMPT[language]['TEMPLATE'].format(
87+
prompt = ANSWER_REPHRASING_PROMPT[_difficulty][language]['TEMPLATE'].format(
8888
language=language,
8989
entities=entities_str,
9090
relationships=relations_str
@@ -98,6 +98,34 @@ def get_loss_tercile(losses: list) -> (float, float):
9898

9999
return losses[q1_index], losses[q2_index]
100100

101+
def assign_difficulty(subgraphs: list, difficulty_order: list, loss_strategy: str) -> list:
102+
"""
103+
Assign difficulty to subgraphs based on the loss.
104+
105+
:param subgraphs
106+
:param difficulty_order
107+
:param loss_strategy
108+
:return
109+
"""
110+
losses = []
111+
for subgraph in subgraphs:
112+
loss = get_average_loss(subgraph, loss_strategy)
113+
losses.append(loss)
114+
q1, q2 = get_loss_tercile(losses)
115+
116+
for i, subgraph in enumerate(subgraphs):
117+
loss = get_average_loss(subgraph, loss_strategy)
118+
if loss < q1:
119+
# easy
120+
subgraphs[i] = (subgraph[0], subgraph[1], difficulty_order[0])
121+
elif loss < q2:
122+
# medium
123+
subgraphs[i] = (subgraph[0], subgraph[1], difficulty_order[1])
124+
else:
125+
# hard
126+
subgraphs[i] = (subgraph[0], subgraph[1], difficulty_order[2])
127+
return subgraphs
128+
101129
def get_average_loss(batch: tuple, loss_strategy: str) -> float:
102130
if loss_strategy == "only_edge":
103131
return sum(edge[2]['loss'] for edge in batch[1]) / len(batch[1])
@@ -139,7 +167,6 @@ async def traverse_graph_by_edge(
139167
graph_storage: NetworkXStorage,
140168
traverse_strategy: TraverseStrategy,
141169
text_chunks_storage: JsonKVStorage,
142-
progress_bar: gr.Progress = None,
143170
max_concurrent: int = 1000
144171
) -> dict:
145172
"""
@@ -150,7 +177,6 @@ async def traverse_graph_by_edge(
150177
:param graph_storage
151178
:param traverse_strategy
152179
:param text_chunks_storage
153-
:param progress_bar
154180
:param max_concurrent
155181
:return: question and answer
156182
"""
@@ -160,10 +186,12 @@ async def traverse_graph_by_edge(
160186
async def _process_nodes_and_edges(
161187
_process_nodes: list,
162188
_process_edges: list,
189+
_difficulty: str,
163190
) -> str:
164191
prompt = await _construct_rephrasing_prompt(
165192
_process_nodes,
166193
_process_edges,
194+
_difficulty,
167195
text_chunks_storage,
168196
add_context = False
169197
)
@@ -185,48 +213,68 @@ async def _process_single_batch(
185213
context = await _process_nodes_and_edges(
186214
_process_batch[0],
187215
_process_batch[1],
216+
_process_batch[2]
188217
)
218+
# 一般第一行就是Question
219+
# 后面的都是Answer
220+
question = context.split("\n")[0]
221+
for prefix in ["Question:", "问题:", "问题:"]:
222+
if question.startswith(prefix):
223+
question = question[len(prefix):].strip()
224+
break
225+
answer = "\n".join(context.split("\n")[1:]).strip()
226+
for prefix in ["Answer:", "答案:","答案:", "回答:", "回答:"]:
227+
if answer.startswith(prefix):
228+
answer = answer[len(prefix):].strip()
229+
break
230+
qas = [
231+
{
232+
"question": question,
233+
"answer": answer
234+
}
235+
]
189236

190237
language = "Chinese" if detect_main_language(context) == "zh" else "English"
191238
pre_length = sum(node['length'] for node in _process_batch[0]) \
192239
+ sum(edge[2]['length'] for edge in _process_batch[1])
193240

194-
if question_type == "single":
195-
question = await llm_client.generate_answer(
196-
QUESTION_GENERATION_PROMPT[language]['SINGLE_TEMPLATE'].format(
197-
answer=context
198-
)
199-
)
200-
if question.startswith("Question:"):
201-
question = question[len("Question:"):].strip()
202-
elif question.startswith("问题:"):
203-
question = question[len("问题:"):].strip()
204-
205-
logger.info("%d nodes and %d edges processed", len(_process_batch[0]), len(_process_batch[1]))
206-
logger.info("Pre-length: %s", pre_length)
207-
logger.info("Question: %s", question)
208-
logger.info("Answer: %s", context)
209-
210-
return {
211-
compute_content_hash(context): {
212-
"question": question,
213-
"answer": context,
214-
"loss": get_average_loss(_process_batch, traverse_strategy.loss_strategy)
215-
}
216-
}
217-
218-
content = await llm_client.generate_answer(
219-
QUESTION_GENERATION_PROMPT[language]['MULTI_TEMPLATE'].format(
220-
doc=context
221-
)
222-
)
223-
qas = _post_process_synthetic_data(content)
224-
225-
if len(qas) == 0:
226-
print(content)
227-
logger.error("Error occurred while processing batch, question or answer is None")
228-
return {}
229-
241+
# if question_type == "single":
242+
# question = await llm_client.generate_answer(
243+
# QUESTION_GENERATION_PROMPT[language]['SINGLE_TEMPLATE'].format(
244+
# answer=context
245+
# )
246+
# )
247+
# if question.startswith("Question:"):
248+
# question = question[len("Question:"):].strip()
249+
# elif question.startswith("问题:"):
250+
# question = question[len("问题:"):].strip()
251+
#
252+
# logger.info("%d nodes and %d edges processed", len(_process_batch[0]), len(_process_batch[1]))
253+
# logger.info("Pre-length: %s", pre_length)
254+
# logger.info("Question: %s", question)
255+
# logger.info("Answer: %s", context)
256+
#
257+
# return {
258+
# compute_content_hash(context): {
259+
# "question": question,
260+
# "answer": context,
261+
# "loss": get_average_loss(_process_batch, traverse_strategy.loss_strategy),
262+
# "difficulty": _process_batch[2],
263+
# }
264+
# }
265+
#
266+
# content = await llm_client.generate_answer(
267+
# QUESTION_GENERATION_PROMPT[language]['MULTI_TEMPLATE'].format(
268+
# doc=context
269+
# )
270+
# )
271+
# qas = _post_process_synthetic_data(content)
272+
#
273+
# if len(qas) == 0:
274+
# print(content)
275+
# logger.error("Error occurred while processing batch, question or answer is None")
276+
# return {}
277+
#
230278
final_results = {}
231279
logger.info("%d nodes and %d edges processed", len(_process_batch[0]), len(_process_batch[1]))
232280
logger.info("Pre-length: %s", pre_length)
@@ -236,7 +284,8 @@ async def _process_single_batch(
236284
final_results[compute_content_hash(qa['question'])] = {
237285
"question": qa['question'],
238286
"answer": qa['answer'],
239-
"loss": get_average_loss(_process_batch, traverse_strategy.loss_strategy)
287+
"loss": get_average_loss(_process_batch, traverse_strategy.loss_strategy),
288+
"difficulty": _process_batch[2],
240289
}
241290
return final_results
242291

@@ -253,17 +302,16 @@ async def _process_single_batch(
253302
traverse_strategy
254303
)
255304

305+
processing_batches = assign_difficulty(processing_batches, traverse_strategy.difficulty_order,
306+
traverse_strategy.loss_strategy)
307+
256308
for result in tqdm_async(asyncio.as_completed(
257309
[_process_single_batch(batch) for batch in processing_batches]
258-
), total=len(processing_batches), desc="[4/4]Generating QAs"):
310+
), total=len(processing_batches), desc="Processing batches"):
259311
try:
260-
if progress_bar is not None:
261-
progress_bar(len(results) / len(processing_batches), desc="[4/4]Generating QAs")
262312
results.update(await result)
263-
if progress_bar is not None and len(results) == len(processing_batches):
264-
progress_bar(1, desc="[4/4]Generating QAs")
265313
except Exception as e: # pylint: disable=broad-except
266-
logger.error("Error occurred while generating QA: %s", e)
314+
logger.error("Error occurred while processing batches: %s", e)
267315

268316
return results
269317

@@ -274,7 +322,6 @@ async def traverse_graph_atomically(
274322
graph_storage: NetworkXStorage,
275323
traverse_strategy: TraverseStrategy,
276324
text_chunks_storage: JsonKVStorage,
277-
progress_bar: gr.Progress = None,
278325
max_concurrent: int = 1000
279326
) -> dict:
280327
"""
@@ -285,7 +332,6 @@ async def traverse_graph_atomically(
285332
:param graph_storage
286333
:param traverse_strategy
287334
:param text_chunks_storage
288-
:param progress_bar
289335
:param max_concurrent
290336
:return: question and answer
291337
"""
@@ -330,7 +376,8 @@ async def _generate_question(
330376
compute_content_hash(question): {
331377
"question": question,
332378
"answer": answer,
333-
"loss": loss
379+
"loss": loss,
380+
"difficulty": "medium"
334381
}
335382
}
336383
except Exception as e: # pylint: disable=broad-except
@@ -362,16 +409,12 @@ async def _generate_question(
362409
for result in tqdm_async(
363410
asyncio.as_completed([_generate_question(task) for task in tasks]),
364411
total=len(tasks),
365-
desc="[4/4]Generating QAs"
412+
desc="Generating questions"
366413
):
367414
try:
368-
if progress_bar is not None:
369-
progress_bar(len(results) / len(tasks), desc="[4/4]Generating QAs")
370415
results.update(await result)
371-
if progress_bar is not None and len(results) == len(tasks):
372-
progress_bar(1, desc="[4/4]Generating QAs")
373416
except Exception as e: # pylint: disable=broad-except
374-
logger.error("Error occurred while generating QA: %s", e)
417+
logger.error("Error occurred while generating questions: %s", e)
375418
return results
376419

377420
async def traverse_graph_for_multi_hop(
@@ -380,7 +423,6 @@ async def traverse_graph_for_multi_hop(
380423
graph_storage: NetworkXStorage,
381424
traverse_strategy: TraverseStrategy,
382425
text_chunks_storage: JsonKVStorage,
383-
progress_bar: gr.Progress = None,
384426
max_concurrent: int = 1000
385427
) -> dict:
386428
"""
@@ -391,7 +433,6 @@ async def traverse_graph_for_multi_hop(
391433
:param graph_storage
392434
:param traverse_strategy
393435
:param text_chunks_storage
394-
:param progress_bar
395436
:param max_concurrent
396437
:return: question and answer
397438
"""
@@ -412,6 +453,9 @@ async def traverse_graph_for_multi_hop(
412453
traverse_strategy
413454
)
414455

456+
processing_batches = assign_difficulty(processing_batches, traverse_strategy.difficulty_order,
457+
traverse_strategy.loss_strategy)
458+
415459
async def _process_single_batch(
416460
_process_batch: tuple
417461
) -> dict:
@@ -462,24 +506,21 @@ async def _process_single_batch(
462506
"question": question,
463507
"answer": answer,
464508
"loss": get_average_loss(_process_batch, traverse_strategy.loss_strategy),
509+
"difficulty": _process_batch[2],
465510
}
466511
}
467512

468513
except Exception as e: # pylint: disable=broad-except
469514
logger.error("Error occurred while processing batch: %s", e)
470515
return {}
471516

472-
async for result in tqdm_async(
517+
for result in tqdm_async(
473518
asyncio.as_completed([_process_single_batch(batch) for batch in processing_batches]),
474519
total=len(processing_batches),
475-
desc="[4/4]Generating QAs"
520+
desc="Processing batches"
476521
):
477522
try:
478-
if progress_bar is not None:
479-
progress_bar(len(results) / len(processing_batches), desc="[4/4]Generating QAs")
480523
results.update(await result)
481-
if progress_bar is not None and len(results) == len(processing_batches):
482-
progress_bar(1, desc="[4/4]Generating QAs")
483524
except Exception as e: # pylint: disable=broad-except
484-
logger.error("Error occurred while generating QA: %s", e)
525+
logger.error("Error occurred while processing batches: %s", e)
485526
return results

0 commit comments

Comments
 (0)