Skip to content

Commit d849e23

Browse files
feat: add multi choice qa generation
1 parent 02dcafe commit d849e23

15 files changed

Lines changed: 371 additions & 76 deletions
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Generate Multi-Choice QAs
2+
3+
Multi-choice question answering (QA) tasks involve providing a question along with several answer options, where the goal is to select the correct answer from the given choices.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
python3 -m graphgen.run \
2+
--config_file examples/generate/generate_multi_choice_qa/multi_choice_config.yaml
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
global_params:
2+
working_dir: cache
3+
graph_backend: kuzu # graph database backend, support: kuzu, networkx
4+
kv_backend: rocksdb # key-value store backend, support: rocksdb, json_kv
5+
6+
nodes:
7+
- id: read_files # id is unique in the pipeline, and can be referenced by other steps
8+
op_name: read
9+
type: source
10+
dependencies: []
11+
params:
12+
input_path:
13+
- examples/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See examples/input_examples for examples
14+
15+
- id: chunk_documents
16+
op_name: chunk
17+
type: map_batch
18+
dependencies:
19+
- read_files
20+
execution_params:
21+
replicas: 4
22+
params:
23+
chunk_size: 1024 # chunk size for text splitting
24+
chunk_overlap: 100 # chunk overlap for text splitting
25+
26+
- id: build_kg
27+
op_name: build_kg
28+
type: map_batch
29+
dependencies:
30+
- chunk_documents
31+
execution_params:
32+
replicas: 1
33+
batch_size: 128
34+
35+
- id: quiz
36+
op_name: quiz
37+
type: map_batch
38+
dependencies:
39+
- build_kg
40+
execution_params:
41+
replicas: 1
42+
batch_size: 128
43+
params:
44+
quiz_samples: 2 # number of quiz samples to generate
45+
46+
- id: judge
47+
op_name: judge
48+
type: map_batch
49+
dependencies:
50+
- quiz
51+
execution_params:
52+
replicas: 1
53+
batch_size: 128
54+
55+
- id: partition
56+
op_name: partition
57+
type: aggregate
58+
dependencies:
59+
- judge
60+
params:
61+
method: ece # ece is a custom partition method based on comprehension loss
62+
method_params:
63+
max_units_per_community: 20 # max nodes and edges per community
64+
min_units_per_community: 5 # min nodes and edges per community
65+
max_tokens_per_community: 10240 # max tokens per community
66+
unit_sampling: max_loss # unit sampling strategy, support: random, max_loss, min_loss
67+
68+
- id: generate
69+
op_name: generate
70+
type: map_batch
71+
dependencies:
72+
- partition
73+
execution_params:
74+
replicas: 1
75+
batch_size: 128
76+
save_output: true # save output
77+
params:
78+
method: multi_choice
79+
num_of_questions: 5
80+
data_format: Alpaca # Alpaca, Sharegpt, ChatML

graphgen/bases/base_generator.py

Lines changed: 44 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -46,38 +46,47 @@ async def generate(
4646
def format_generation_results(
4747
results: list[dict], output_data_format: str
4848
) -> list[dict[str, Any]]:
49-
if output_data_format == "Alpaca":
50-
results = [
51-
{
52-
"instruction": v["question"],
53-
"input": "",
54-
"output": v["answer"],
55-
}
56-
for item in results
57-
for k, v in item.items()
58-
]
59-
elif output_data_format == "Sharegpt":
60-
results = [
61-
{
62-
"conversations": [
63-
{"from": "human", "value": v["question"]},
64-
{"from": "gpt", "value": v["answer"]},
65-
]
66-
}
67-
for item in results
68-
for k, v in item.items()
69-
]
70-
elif output_data_format == "ChatML":
71-
results = [
72-
{
73-
"messages": [
74-
{"role": "user", "content": v["question"]},
75-
{"role": "assistant", "content": v["answer"]},
76-
]
77-
}
78-
for item in results
79-
for k, v in item.items()
80-
]
81-
else:
82-
raise ValueError(f"Unknown output data format: {output_data_format}")
83-
return results
49+
50+
flat_results = []
51+
for item in results:
52+
for _, qa_data in item.items():
53+
question = qa_data.get("question", "")
54+
answer = qa_data.get("answer", "")
55+
if "options" in qa_data and qa_data["options"]:
56+
options = qa_data["options"]
57+
options_str = "\n".join(
58+
[f"{key}. {options[key]}" for key in sorted(options.keys())]
59+
)
60+
question += f"\nOptions:\n{options_str}"
61+
62+
if output_data_format == "Alpaca":
63+
flat_results.append(
64+
{
65+
"instruction": question,
66+
"input": "",
67+
"output": answer,
68+
}
69+
)
70+
elif output_data_format == "Sharegpt":
71+
flat_results.append(
72+
{
73+
"conversations": [
74+
{"from": "human", "value": question},
75+
{"from": "gpt", "value": answer},
76+
]
77+
}
78+
)
79+
elif output_data_format == "ChatML":
80+
results.append(
81+
{
82+
"messages": [
83+
{"role": "user", "content": question},
84+
{"role": "assistant", "content": answer},
85+
]
86+
}
87+
)
88+
else:
89+
raise ValueError(
90+
f"Unknown output data format: {output_data_format}"
91+
)
92+
return flat_results

graphgen/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
AggregatedGenerator,
1212
AtomicGenerator,
1313
CoTGenerator,
14+
MultiChoiceGenerator,
1415
MultiHopGenerator,
1516
QuizGenerator,
1617
VQAGenerator,
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from .aggregated_generator import AggregatedGenerator
22
from .atomic_generator import AtomicGenerator
33
from .cot_generator import CoTGenerator
4+
from .multi_choice_generator import MultiChoiceGenerator
45
from .multi_hop_generator import MultiHopGenerator
56
from .quiz_generator import QuizGenerator
67
from .vqa_generator import VQAGenerator
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import re
2+
from typing import Any
3+
4+
from graphgen.bases import BaseGenerator
5+
from graphgen.templates import MCQ_GENERATION_PROMPT
6+
from graphgen.utils import compute_content_hash, detect_main_language, logger
7+
8+
9+
class MultiChoiceGenerator(BaseGenerator):
10+
def __init__(self, llm_client, num_of_questions) -> None:
11+
super().__init__(llm_client)
12+
self.num_of_questions = num_of_questions
13+
14+
@staticmethod
15+
def parse_response(response: str) -> Any:
16+
"""
17+
Parse multiple choice QA pairs from the LLM response.
18+
Each QA pair contains question text, four options, and the correct answer.
19+
20+
:param response: The LLM response containing XML-formatted QA pairs
21+
:return: Dictionary mapping question hash to question data, where each
22+
value is a dict with "question", "options", "answer", and
23+
"correct_answer_text" keys
24+
"""
25+
qa_pairs = {}
26+
27+
# Extract all QA pair blocks
28+
qa_blocks = re.findall(r"<qa_pair>(.*?)</qa_pair>", response, re.DOTALL)
29+
30+
if not qa_blocks:
31+
logger.warning("No QA pairs found in response: %s", response)
32+
return {}
33+
34+
for block in qa_blocks:
35+
# Extract and clean question text
36+
q_match = re.search(r"<question>(.*?)</question>", block, re.DOTALL)
37+
if not q_match:
38+
logger.warning("Failed to parse question from block: %s", block)
39+
continue
40+
question = q_match.group(1).strip().strip('"').strip("'")
41+
42+
# Extract and parse options (A, B, C, D)
43+
opt_match = re.search(r"<options>(.*?)</options>", block, re.DOTALL)
44+
if not opt_match:
45+
logger.warning("Failed to parse options from block: %s", block)
46+
continue
47+
48+
options = {}
49+
options_text = opt_match.group(1).strip()
50+
for line in options_text.split("\n"):
51+
line = line.strip()
52+
if not line:
53+
continue
54+
# Match patterns like "A. text" or "B. text"
55+
if m := re.match(r"^([A-D])[.\s]\s*(.*)$", line):
56+
letter, text = m.groups()
57+
options[letter] = text.strip()
58+
59+
# Validate options count
60+
if len(options) != 4:
61+
logger.warning(
62+
"Expected 4 options, found %d: %s", len(options), options_text
63+
)
64+
continue
65+
66+
# Extract and validate answer
67+
ans_match = re.search(r"<answer>(.*?)</answer>", block, re.DOTALL)
68+
if not ans_match:
69+
logger.warning("Failed to parse answer from block: %s", block)
70+
continue
71+
answer = ans_match.group(1).strip().strip('"').strip("'")
72+
73+
# Ensure answer exists in options
74+
if answer not in options:
75+
logger.warning(
76+
"Answer '%s' not found in options: %s", answer, list(options.keys())
77+
)
78+
continue
79+
80+
# Build result entry with question hash as key
81+
question_hash = compute_content_hash(question)
82+
qa_pairs[question_hash] = {
83+
"question": question,
84+
"options": options, # Dict like {"A": "text", "B": "text", ...}
85+
"answer": answer, # Single letter: "A", "B", "C", or "D"
86+
"correct_answer_text": options[
87+
answer
88+
], # The actual text of correct answer
89+
}
90+
91+
logger.debug("Successfully parsed MCQ: %s", question[:50])
92+
93+
if not qa_pairs:
94+
logger.error("Failed to parse any valid MCQ pairs from response")
95+
96+
return qa_pairs
97+
98+
# pylint: disable=W0221
99+
def build_prompt(
100+
self, batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
101+
) -> str:
102+
nodes, edges = batch
103+
entities_str = "\n".join(
104+
[
105+
f"{index + 1}. {node[0]}: {node[1]['description']}"
106+
for index, node in enumerate(nodes)
107+
]
108+
)
109+
110+
relationships_str = "\n".join(
111+
[
112+
f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}"
113+
for index, edge in enumerate(edges)
114+
]
115+
)
116+
context = entities_str + "\n" + relationships_str
117+
language = detect_main_language(entities_str + relationships_str)
118+
prompt = MCQ_GENERATION_PROMPT[language].format(
119+
context=context,
120+
num_of_questions=self.num_of_questions,
121+
)
122+
return prompt

graphgen/operators/generate/generate_service.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,6 @@
22

33
from graphgen.bases import BaseLLMWrapper, BaseOperator
44
from graphgen.common import init_llm
5-
from graphgen.models import (
6-
AggregatedGenerator,
7-
AtomicGenerator,
8-
CoTGenerator,
9-
MultiHopGenerator,
10-
VQAGenerator,
11-
)
125
from graphgen.utils import logger, run_concurrent
136

147

@@ -22,6 +15,7 @@ def __init__(
2215
working_dir: str = "cache",
2316
method: str = "aggregated",
2417
data_format: str = "ChatML",
18+
**generate_kwargs,
2519
):
2620
super().__init__(working_dir=working_dir, op_name="generate_service")
2721
self.llm_client: BaseLLMWrapper = init_llm("synthesizer")
@@ -30,15 +24,32 @@ def __init__(
3024
self.data_format = data_format
3125

3226
if self.method == "atomic":
27+
from graphgen.models import AtomicGenerator
28+
3329
self.generator = AtomicGenerator(self.llm_client)
3430
elif self.method == "aggregated":
31+
from graphgen.models import AggregatedGenerator
32+
3533
self.generator = AggregatedGenerator(self.llm_client)
3634
elif self.method == "multi_hop":
35+
from graphgen.models import MultiHopGenerator
36+
3737
self.generator = MultiHopGenerator(self.llm_client)
3838
elif self.method == "cot":
39+
from graphgen.models import CoTGenerator
40+
3941
self.generator = CoTGenerator(self.llm_client)
40-
elif self.method in ["vqa"]:
42+
elif self.method == "vqa":
43+
from graphgen.models import VQAGenerator
44+
4145
self.generator = VQAGenerator(self.llm_client)
46+
elif self.method == "multi_choice":
47+
from graphgen.models import MultiChoiceGenerator
48+
49+
self.generator = MultiChoiceGenerator(
50+
self.llm_client,
51+
num_of_questions=generate_kwargs.get("num_of_questions", 5),
52+
)
4253
else:
4354
raise ValueError(f"Unsupported generation mode: {method}")
4455

graphgen/templates/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
AGGREGATED_GENERATION_PROMPT,
77
ATOMIC_GENERATION_PROMPT,
88
COT_GENERATION_PROMPT,
9+
MCQ_GENERATION_PROMPT,
910
MULTI_HOP_GENERATION_PROMPT,
1011
VQA_GENERATION_PROMPT,
1112
)
1213
from .kg import KG_EXTRACTION_PROMPT, KG_SUMMARIZATION_PROMPT, MMKG_EXTRACTION_PROMPT
13-
from .question_generation import QUESTION_GENERATION_PROMPT
1414
from .search_judgement import SEARCH_JUDGEMENT_PROMPT
1515
from .statement_judgement import STATEMENT_JUDGEMENT_PROMPT
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .aggregated_generation import AGGREGATED_GENERATION_PROMPT
22
from .atomic_generation import ATOMIC_GENERATION_PROMPT
33
from .cot_generation import COT_GENERATION_PROMPT
4+
from .multi_choice_generation import MCQ_GENERATION_PROMPT
45
from .multi_hop_generation import MULTI_HOP_GENERATION_PROMPT
56
from .vqa_generation import VQA_GENERATION_PROMPT

0 commit comments

Comments
 (0)