Skip to content

Commit 4c1c32a

Browse files
author
beanbun
committed
style: fix formatting issues
1 parent 64547af commit 4c1c32a

File tree

2 files changed

+20
-15
lines changed

2 files changed

+20
-15
lines changed

graphgen/models/generator/masked_fill_in_blank_generator.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import re
21
import random
2+
import re
33
from typing import Any, Optional
44

55
from graphgen.bases import BaseGenerator
@@ -8,6 +8,7 @@
88

99
random.seed(42)
1010

11+
1112
class MaskedFillInBlankGenerator(BaseGenerator):
1213
"""
1314
Masked Fill-in-blank Generator follows a TWO-STEP process:
@@ -94,18 +95,22 @@ async def generate(
9495
context = self.parse_rephrased_text(response)
9596
if not context:
9697
return []
97-
98+
9899
nodes, edge = batch
99-
assert len(nodes) == 2, "MaskedFillInBlankGenerator currently only supports triples, which should has 2 nodes."
100-
assert len(edge) == 1, "MaskedFillInBlankGenerator currently only supports triples, which should has 1 edge."
100+
assert (
101+
len(nodes) == 2
102+
), "MaskedFillInBlankGenerator currently only supports triples, which should has 2 nodes."
103+
assert (
104+
len(edge) == 1
105+
), "MaskedFillInBlankGenerator currently only supports triples, which should has 1 edge."
101106

102107
node1, node2 = nodes
103108
mask_node = random.choice([node1, node2])
104-
mask_node_name = mask_node[1]["entity_name"].strip('\'" \n\r\t')
109+
mask_node_name = mask_node[1]["entity_name"].strip("'\" \n\r\t")
105110

106111
mask_pattern = re.compile(re.escape(mask_node_name), re.IGNORECASE)
107112
masked_context = mask_pattern.sub("___", context)
108-
# For accuracy, extract the actual replaced text from the context as the ground truth (keeping the original case)
113+
# For accuracy, extract the actual replaced text from the context as the ground truth
109114
gth = re.search(mask_pattern, context).group(0)
110115

111116
logger.debug("masked_context: %s", masked_context)
@@ -114,4 +119,3 @@ async def generate(
114119
"answer": gth,
115120
}
116121
return [qa_pairs]
117-

graphgen/models/partitioner/triple_partitioner.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
random.seed(42)
99

10+
1011
class TriplePartitioner(BasePartitioner):
1112
"""
1213
Triple Partitioner that partitions the graph into multiple distinct triples (node, edge, node).
@@ -28,30 +29,30 @@ def partition(
2829
for seed in nodes:
2930
if seed in visited_nodes:
3031
continue
31-
32+
3233
# start BFS in a connected component
3334
queue = deque([seed])
3435
visited_nodes.add(seed)
35-
36+
3637
while queue:
3738
u = queue.popleft()
38-
39+
3940
for v in g.get_neighbors(u):
4041
edge_key = frozenset((u, v))
41-
42+
4243
# if this edge has not been used, a new triple has been found
4344
if edge_key not in used_edges:
4445
used_edges.add(edge_key)
45-
46+
4647
# use the edge name to ensure the uniqueness of the ID
4748
u_sorted, v_sorted = sorted((u, v))
4849
yield Community(
4950
id=f"{u_sorted}-{v_sorted}",
5051
nodes=[u_sorted, v_sorted],
51-
edges=[(u_sorted, v_sorted)]
52+
edges=[(u_sorted, v_sorted)],
5253
)
53-
54+
5455
# continue to BFS
5556
if v not in visited_nodes:
5657
visited_nodes.add(v)
57-
queue.append(v)
58+
queue.append(v)

0 commit comments

Comments
 (0)