Skip to content

Commit 40a04d6

Browse files
author
beanbun
committed
feat: support partitioning the graph into quintuples
1 parent 41d5327 commit 40a04d6

File tree

6 files changed

+109
-15
lines changed

6 files changed

+109
-15
lines changed

examples/generate/generate_masked_fill_in_blank_qa/masked_fill_in_blank_config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ nodes:
3838
dependencies:
3939
- build_kg
4040
params:
41-
method: triple
41+
method: quintuple
4242

4343
- id: generate
4444
op_name: generate

graphgen/models/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
AtomicGenerator,
1616
CoTGenerator,
1717
FillInBlankGenerator,
18+
MaskedFillInBlankGenerator,
1819
MultiAnswerGenerator,
1920
MultiChoiceGenerator,
2021
MultiHopGenerator,
2122
QuizGenerator,
2223
TrueFalseGenerator,
2324
VQAGenerator,
24-
MaskedFillInBlankGenerator,
2525
)
2626
from .kg_builder import LightRAGKGBuilder, MMKGBuilder
2727
from .llm import HTTPClient, OllamaClient, OpenAIClient
@@ -31,6 +31,7 @@
3131
DFSPartitioner,
3232
ECEPartitioner,
3333
LeidenPartitioner,
34+
QuintuplePartitioner,
3435
TriplePartitioner,
3536
)
3637
from .reader import (
@@ -90,6 +91,7 @@
9091
"ECEPartitioner": ".partitioner",
9192
"LeidenPartitioner": ".partitioner",
9293
"TriplePartitioner": ".partitioner",
94+
"QuintuplePartitioner": ".partitioner",
9395
# Reader
9496
"CSVReader": ".reader",
9597
"JSONReader": ".reader",

graphgen/models/generator/masked_fill_in_blank_generator.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -96,22 +96,35 @@ async def generate(
9696
if not context:
9797
return []
9898

99-
nodes, edge = batch
100-
assert (
101-
len(nodes) == 2
102-
), "MaskedFillInBlankGenerator currently only supports triples, which should has 2 nodes."
103-
assert (
104-
len(edge) == 1
105-
), "MaskedFillInBlankGenerator currently only supports triples, which should has 1 edge."
99+
nodes, edges = batch
106100

107-
node1, node2 = nodes
108-
mask_node = random.choice([node1, node2])
109-
mask_node_name = mask_node[1]["entity_name"].strip("'\" \n\r\t")
101+
assert len(nodes) == 3, (
102+
"MaskedFillInBlankGenerator currently only supports quintuples that has 3 nodes, "
103+
f"but got {len(nodes)} nodes."
104+
)
105+
assert len(edges) == 2, (
106+
"MaskedFillInBlankGenerator currently only supports quintuples that has 2 edges, "
107+
f"but got {len(edges)} edges."
108+
)
110109

110+
node1, node2, node3 = nodes
111+
mask_node = random.choice([node1, node2, node3])
112+
mask_node_name = mask_node[1]["entity_name"].strip("'\" \n\r\t")
111113
mask_pattern = re.compile(re.escape(mask_node_name), re.IGNORECASE)
112-
masked_context = mask_pattern.sub("___", context)
113-
# For accuracy, extract the actual replaced text from the context as the ground truth
114-
gth = re.search(mask_pattern, context).group(0)
114+
115+
match = re.search(mask_pattern, context)
116+
if match:
117+
gth = match.group(0)
118+
masked_context = mask_pattern.sub("___", context)
119+
else:
120+
logger.debug(
121+
"Regex Match Failed!\n"
122+
"Expected name of node: %s\n"
123+
"Actual context: %s\n",
124+
mask_node_name,
125+
context,
126+
)
127+
return []
115128

116129
logger.debug("masked_context: %s", masked_context)
117130
qa_pairs = {

graphgen/models/partitioner/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@
33
from .dfs_partitioner import DFSPartitioner
44
from .ece_partitioner import ECEPartitioner
55
from .leiden_partitioner import LeidenPartitioner
6+
from .quintuple_partitioner import QuintuplePartitioner
67
from .triple_partitioner import TriplePartitioner
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import random
2+
from collections import deque
3+
from typing import Any, Iterable, Set
4+
5+
from graphgen.bases import BaseGraphStorage, BasePartitioner
6+
from graphgen.bases.datatypes import Community
7+
8+
random.seed(42)
9+
10+
11+
class QuintuplePartitioner(BasePartitioner):
12+
"""
13+
quintuple Partitioner that partitions the graph into multiple distinct quintuple (node, edge, node, edge, node).
14+
1. Automatically ignore isolated points.
15+
2. In each connected component, yield quintuples in the order of BFS.
16+
"""
17+
18+
def partition(
19+
self,
20+
g: BaseGraphStorage,
21+
**kwargs: Any,
22+
) -> Iterable[Community]:
23+
nodes = [n[0] for n in g.get_all_nodes()]
24+
random.shuffle(nodes)
25+
26+
visited_nodes: Set[str] = set()
27+
used_edges: Set[frozenset[str]] = set()
28+
29+
for seed in nodes:
30+
if seed in visited_nodes:
31+
continue
32+
33+
# start BFS in a connected component
34+
queue = deque([seed])
35+
visited_nodes.add(seed)
36+
37+
while queue:
38+
u = queue.popleft()
39+
40+
# collect all neighbors connected to node u via unused edges
41+
available_neighbors = []
42+
for v in g.get_neighbors(u):
43+
edge_key = frozenset((u, v))
44+
if edge_key not in used_edges:
45+
available_neighbors.append(v)
46+
47+
# standard BFS queue maintenance
48+
if v not in visited_nodes:
49+
visited_nodes.add(v)
50+
queue.append(v)
51+
52+
random.shuffle(available_neighbors)
53+
54+
# every two neighbors paired with the center node u creates one quintuple
55+
# Note: If available_neighbors has an odd length, the remaining edge
56+
# stays unused for now. It may be matched into a quintuple later
57+
# when its other endpoint is processed as a center node.
58+
for i in range(0, len(available_neighbors) // 2 * 2, 2):
59+
v1 = available_neighbors[i]
60+
v2 = available_neighbors[i + 1]
61+
62+
edge1 = frozenset((u, v1))
63+
edge2 = frozenset((u, v2))
64+
65+
used_edges.add(edge1)
66+
used_edges.add(edge2)
67+
68+
v1_s, v2_s = sorted((v1, v2))
69+
70+
yield Community(
71+
id=f"{v1_s}-{u}-{v2_s}",
72+
nodes=[v1_s, u, v2_s],
73+
edges=[tuple(sorted((v1_s, u))), tuple(sorted((u, v2_s)))],
74+
)

graphgen/operators/partition/partition_service.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ def __init__(
6161
from graphgen.models import TriplePartitioner
6262

6363
self.partitioner = TriplePartitioner()
64+
elif method == "quintuple":
65+
from graphgen.models import QuintuplePartitioner
66+
67+
self.partitioner = QuintuplePartitioner()
6468
else:
6569
raise ValueError(f"Unsupported partition method: {method}")
6670

0 commit comments

Comments
 (0)