Skip to content

Commit 8919dc9

Browse files
author
beanbun
committed
feat: support synthesizing masked fill_in_blank QA pairs
1 parent d2a4df7 commit 8919dc9

File tree

11 files changed

+253
-1
lines changed

11 files changed

+253
-1
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Generate Masked Fill-in-blank QAs
2+
# TODO
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
python3 -m graphgen.run \
2+
--config_file examples/generate/generate_masked_fill_in_blank_qa/masked_fill_in_blank_config.yaml
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
global_params:
2+
working_dir: cache
3+
graph_backend: networkx # graph database backend, support: kuzu, networkx
4+
kv_backend: json_kv # key-value store backend, support: rocksdb, json_kv
5+
6+
nodes:
7+
- id: read_files # id is unique in the pipeline, and can be referenced by other steps
8+
op_name: read
9+
type: source
10+
dependencies: []
11+
params:
12+
input_path:
13+
- examples/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt, pdf. See examples/input_examples for examples
14+
15+
- id: chunk_documents
16+
op_name: chunk
17+
type: map_batch
18+
dependencies:
19+
- read_files
20+
execution_params:
21+
replicas: 4
22+
params:
23+
chunk_size: 1024 # chunk size for text splitting
24+
chunk_overlap: 100 # chunk overlap for text splitting
25+
26+
- id: build_kg
27+
op_name: build_kg
28+
type: map_batch
29+
dependencies:
30+
- chunk_documents
31+
execution_params:
32+
replicas: 1
33+
batch_size: 128
34+
35+
- id: partition
36+
op_name: partition
37+
type: aggregate
38+
dependencies:
39+
- build_kg
40+
params:
41+
method: triple
42+
43+
- id: generate
44+
op_name: generate
45+
type: map_batch
46+
dependencies:
47+
- partition
48+
execution_params:
49+
replicas: 1
50+
batch_size: 128
51+
save_output: true # save output
52+
params:
53+
method: masked_fill_in_blank # atomic, aggregated, multi_hop, cot, vqa
54+
data_format: QA_pairs # Alpaca, Sharegpt, ChatML, QA_pairs

graphgen/bases/base_generator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,4 +74,10 @@ def format_generation_results(
7474
{"role": "assistant", "content": answer},
7575
]
7676
}
77+
78+
if output_data_format == "QA_pairs":
79+
return {
80+
"question": question,
81+
"answer": answer,
82+
}
7783
raise ValueError(f"Unknown output data format: {output_data_format}")

graphgen/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
QuizGenerator,
2222
TrueFalseGenerator,
2323
VQAGenerator,
24+
MaskedFillInBlankGenerator,
2425
)
2526
from .kg_builder import LightRAGKGBuilder, MMKGBuilder
2627
from .llm import HTTPClient, OllamaClient, OpenAIClient
@@ -30,6 +31,7 @@
3031
DFSPartitioner,
3132
ECEPartitioner,
3233
LeidenPartitioner,
34+
TriplePartitioner,
3335
)
3436
from .reader import (
3537
CSVReader,
@@ -73,6 +75,7 @@
7375
"QuizGenerator": ".generator",
7476
"TrueFalseGenerator": ".generator",
7577
"VQAGenerator": ".generator",
78+
"MaskedFillInBlankGenerator": ".generator",
7679
# KG Builder
7780
"LightRAGKGBuilder": ".kg_builder",
7881
"MMKGBuilder": ".kg_builder",
@@ -86,6 +89,7 @@
8689
"DFSPartitioner": ".partitioner",
8790
"ECEPartitioner": ".partitioner",
8891
"LeidenPartitioner": ".partitioner",
92+
"TriplePartitioner": ".partitioner",
8993
# Reader
9094
"CSVReader": ".reader",
9195
"JSONReader": ".reader",

graphgen/models/generator/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@
88
from .quiz_generator import QuizGenerator
99
from .true_false_generator import TrueFalseGenerator
1010
from .vqa_generator import VQAGenerator
11+
from .masked_fill_in_blank_generator import MaskedFillInBlankGenerator
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import re
2+
import random
3+
from typing import Any, Optional
4+
5+
from graphgen.bases import BaseGenerator
6+
from graphgen.templates import AGGREGATED_GENERATION_PROMPT
7+
from graphgen.utils import detect_main_language, logger
8+
9+
random.seed(42)
10+
11+
class MaskedFillInBlankGenerator(BaseGenerator):
12+
"""
13+
Masked Fill-in-blank Generator follows a TWO-STEP process:
14+
1. rephrase: Rephrase the input nodes and edges into a coherent text that maintains the original meaning.
15+
2. mask: Randomly select a node from the input nodes, and then mask the name of the node in the rephrased text.
16+
"""
17+
18+
@staticmethod
19+
def build_prompt(
20+
batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]]
21+
) -> str:
22+
"""
23+
Build prompts for REPHRASE.
24+
:param batch
25+
:return:
26+
"""
27+
nodes, edges = batch
28+
entities_str = "\n".join(
29+
[
30+
f"{index + 1}. {node[0]}: {node[1]['description']}"
31+
for index, node in enumerate(nodes)
32+
]
33+
)
34+
relations_str = "\n".join(
35+
[
36+
f"{index + 1}. {edge[0]} -- {edge[1]}: {edge[2]['description']}"
37+
for index, edge in enumerate(edges)
38+
]
39+
)
40+
language = detect_main_language(entities_str + relations_str)
41+
42+
# TODO: configure add_context
43+
# if add_context:
44+
# original_ids = [
45+
# node["source_id"].split("<SEP>")[0] for node in _process_nodes
46+
# ] + [edge[2]["source_id"].split("<SEP>")[0] for edge in _process_edges]
47+
# original_ids = list(set(original_ids))
48+
# original_text = await text_chunks_storage.get_by_ids(original_ids)
49+
# original_text = "\n".join(
50+
# [
51+
# f"{index + 1}. {text['content']}"
52+
# for index, text in enumerate(original_text)
53+
# ]
54+
# )
55+
prompt = AGGREGATED_GENERATION_PROMPT[language]["ANSWER_REPHRASING"].format(
56+
entities=entities_str, relationships=relations_str
57+
)
58+
return prompt
59+
60+
@staticmethod
61+
def parse_rephrased_text(response: str) -> Optional[str]:
62+
"""
63+
Parse the rephrased text from the response.
64+
:param response:
65+
:return: rephrased text
66+
"""
67+
rephrased_match = re.search(
68+
r"<rephrased_text>(.*?)</rephrased_text>", response, re.DOTALL
69+
)
70+
if rephrased_match:
71+
rephrased_text = rephrased_match.group(1).strip()
72+
else:
73+
logger.warning("Failed to parse rephrased text from response: %s", response)
74+
return None
75+
return rephrased_text.strip('"').strip("'")
76+
77+
@staticmethod
78+
def parse_response(response: str) -> dict:
79+
pass
80+
81+
async def generate(
82+
self,
83+
batch: tuple[
84+
list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]
85+
],
86+
) -> list[dict]:
87+
"""
88+
Generate QAs based on a given batch.
89+
:param batch
90+
:return: QA pairs
91+
"""
92+
rephrasing_prompt = self.build_prompt(batch)
93+
response = await self.llm_client.generate_answer(rephrasing_prompt)
94+
context = self.parse_rephrased_text(response)
95+
if not context:
96+
return []
97+
98+
nodes, edge = batch
99+
assert len(nodes) == 2, "MaskedFillInBlankGenerator currently only supports triples, which should has 2 nodes."
100+
assert len(edge) == 1, "MaskedFillInBlankGenerator currently only supports triples, which should has 1 edge."
101+
102+
node1, node2 = nodes
103+
mask_node = random.choice([node1, node2])
104+
mask_node_name = mask_node[1]["entity_name"].strip('\'" \n\r\t')
105+
106+
mask_pattern = re.compile(re.escape(mask_node_name), re.IGNORECASE)
107+
masked_context = mask_pattern.sub("___", context)
108+
# For accuracy, extract the actual replaced text from the context as the ground truth (keeping the original case)
109+
gth = re.search(mask_pattern, context).group(0)
110+
111+
logger.debug("masked_context: %s", masked_context)
112+
qa_pairs = {
113+
"question": masked_context,
114+
"answer": gth,
115+
}
116+
return [qa_pairs]
117+

graphgen/models/partitioner/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
from .dfs_partitioner import DFSPartitioner
44
from .ece_partitioner import ECEPartitioner
55
from .leiden_partitioner import LeidenPartitioner
6+
from .triple_partitioner import TriplePartitioner
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import random
2+
from collections import deque
3+
from typing import Any, Iterable, Set
4+
5+
from graphgen.bases import BaseGraphStorage, BasePartitioner
6+
from graphgen.bases.datatypes import Community
7+
8+
random.seed(42)
9+
10+
class TriplePartitioner(BasePartitioner):
11+
"""
12+
Triple Partitioner that partitions the graph into multiple distinct triples (node, edge, node).
13+
1. Automatically ignore isolated points.
14+
2. In each connected component, yield triples in the order of BFS.
15+
"""
16+
17+
def partition(
18+
self,
19+
g: BaseGraphStorage,
20+
**kwargs: Any,
21+
) -> Iterable[Community]:
22+
nodes = [n[0] for n in g.get_all_nodes()]
23+
random.shuffle(nodes)
24+
25+
visited_nodes: Set[str] = set()
26+
used_edges: Set[frozenset[str]] = set()
27+
28+
for seed in nodes:
29+
if seed in visited_nodes:
30+
continue
31+
32+
# start BFS in a connected component
33+
queue = deque([seed])
34+
visited_nodes.add(seed)
35+
36+
while queue:
37+
u = queue.popleft()
38+
39+
for v in g.get_neighbors(u):
40+
edge_key = frozenset((u, v))
41+
42+
# if this edge has not been used, a new triple has been found
43+
if edge_key not in used_edges:
44+
used_edges.add(edge_key)
45+
46+
# use the edge name to ensure the uniqueness of the ID
47+
u_sorted, v_sorted = sorted((u, v))
48+
yield Community(
49+
id=f"{u_sorted}-{v_sorted}",
50+
nodes=[u_sorted, v_sorted],
51+
edges=[(u_sorted, v_sorted)]
52+
)
53+
54+
# continue to BFS
55+
if v not in visited_nodes:
56+
visited_nodes.add(v)
57+
queue.append(v)

graphgen/operators/generate/generate_service.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ def __init__(
7171
self.llm_client,
7272
num_of_questions=generate_kwargs.get("num_of_questions", 5),
7373
)
74+
elif self.method == "masked_fill_in_blank":
75+
from graphgen.models import MaskedFillInBlankGenerator
76+
77+
self.generator = MaskedFillInBlankGenerator(self.llm_client)
7478
elif self.method == "true_false":
7579
from graphgen.models import TrueFalseGenerator
7680

0 commit comments

Comments
 (0)