Skip to content

Commit 9ed56a6

Browse files
feat: add ECEPartitioner
1 parent dcae889 commit 9ed56a6

5 files changed

Lines changed: 192 additions & 23 deletions

File tree

graphgen/graphgen.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,9 @@ async def quiz_and_judge(self, quiz_and_judge_config: Dict):
231231
@async_to_sync_method
232232
async def generate(self, partition_config: Dict, generate_config: Dict):
233233
# Step 1: partition the graph
234-
batches = await partition_kg(self.graph_storage, partition_config)
234+
batches = await partition_kg(
235+
self.graph_storage, self.tokenizer_instance, partition_config
236+
)
235237

236238
# Step 2: generate QA pairs
237239
results = await generate_qas(
Lines changed: 131 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
1-
from typing import List
1+
import asyncio
2+
import random
3+
from dataclasses import dataclass
4+
from typing import Any, Dict, List, Set, Tuple
5+
6+
from tqdm.asyncio import tqdm as tqdm_async
27

38
from graphgen.bases import BaseGraphStorage
49
from graphgen.bases.datatypes import Community
5-
from graphgen.models import BFSPartitioner
10+
from graphgen.models.partitioner.bfs_partitioner import BFSPartitioner
611

712

13+
@dataclass
814
class ECEPartitioner(BFSPartitioner):
915
"""
1016
ECE partitioner that partitions the graph into communities based on Expected Calibration Error (ECE).
11-
We calculate ECE for edges in KG(represented as 'comprehension loss') and group edges with similar ECE values into the same community.
17+
We calculate ECE for edges in KG(represented as 'comprehension loss')
18+
and group edges with similar ECE values into the same community.
1219
1. Select a sampling strategy.
1320
2. Choose a unit based on the sampling strategy.
1421
2. Expand the community using BFS.
@@ -17,21 +24,127 @@ class ECEPartitioner(BFSPartitioner):
1724
(A unit is a node or an edge.)
1825
"""
1926

20-
# async def partition(
21-
# self,
22-
# g: BaseGraphStorage,
23-
# *,
24-
# ):
25-
# pass
27+
@staticmethod
28+
def _sort_units(units: list, edge_sampling: str) -> list:
29+
"""
30+
Sort units with edge sampling strategy
31+
32+
:param units: total units
33+
:param edge_sampling: edge sampling strategy (random, min_loss, max_loss)
34+
:return: sorted units
35+
"""
36+
if edge_sampling == "random":
37+
random.shuffle(units)
38+
elif edge_sampling == "min_loss":
39+
units = sorted(
40+
units,
41+
key=lambda x: x[-1]["loss"],
42+
)
43+
elif edge_sampling == "max_loss":
44+
units = sorted(
45+
units,
46+
key=lambda x: x[-1]["loss"],
47+
reverse=True,
48+
)
49+
else:
50+
raise ValueError(f"Invalid edge sampling: {edge_sampling}")
51+
return units
52+
53+
async def partition(
54+
self,
55+
g: BaseGraphStorage,
56+
max_units_per_community: int = 10,
57+
max_tokens_per_community: int = 10240,
58+
edge_sampling: str = "random",
59+
**kwargs: Any,
60+
) -> List[Community]:
61+
nodes: List[Tuple[str, dict]] = await g.get_all_nodes()
62+
edges: List[Tuple[str, str, dict]] = await g.get_all_edges()
63+
64+
adj, _ = self._build_adjacency_list(nodes, edges)
65+
node_dict = dict(nodes)
66+
edge_dict = {frozenset((u, v)): d for u, v, d in edges}
67+
68+
all_units: List[Tuple[str, Any, dict]] = [("n", nid, d) for nid, d in nodes] + [
69+
("e", frozenset((u, v)), d) for u, v, d in edges
70+
]
71+
72+
used_n: Set[str] = set()
73+
used_e: Set[frozenset[str]] = set()
74+
communities: List = []
75+
76+
all_units = self._sort_units(all_units, edge_sampling)
77+
78+
async def _grow_community(seed_unit: Tuple[str, Any, dict]) -> Community:
79+
nonlocal used_n, used_e
80+
81+
community_nodes: Dict[str, dict] = {}
82+
community_edges: Dict[frozenset[str], dict] = {}
83+
queue: asyncio.Queue = asyncio.Queue()
84+
token_sum = 0
85+
86+
async def _add_unit(u):
87+
nonlocal token_sum
88+
t, i, d = u
89+
if t == "n":
90+
if i in used_n or i in community_nodes:
91+
return False
92+
community_nodes[i] = d
93+
used_n.add(i)
94+
else: # edge
95+
if i in used_e or i in community_edges:
96+
return False
97+
community_edges[i] = d
98+
used_e.add(i)
99+
token_sum += d.get("length", 0)
100+
return True
101+
102+
await _add_unit(seed_unit)
103+
await queue.put(seed_unit)
104+
105+
# BFS
106+
while not queue.empty():
107+
if (
108+
len(community_nodes) + len(community_edges)
109+
>= max_units_per_community
110+
or token_sum >= max_tokens_per_community
111+
):
112+
break
113+
114+
cur_type, cur_id, _ = await queue.get()
115+
116+
neighbors: List[Tuple[str, Any, dict]] = []
117+
if cur_type == "n":
118+
for nb_id in adj.get(cur_id, []):
119+
e_key = frozenset((cur_id, nb_id))
120+
if e_key not in used_e and e_key not in community_edges:
121+
neighbors.append(("e", e_key, edge_dict[e_key]))
122+
else:
123+
for n_id in cur_id:
124+
if n_id not in used_n and n_id not in community_nodes:
125+
neighbors.append(("n", n_id, node_dict[n_id]))
126+
127+
neighbors = self._sort_units(neighbors, edge_sampling)
128+
for nb in neighbors:
129+
if (
130+
len(community_nodes) + len(community_edges)
131+
>= max_units_per_community
132+
or token_sum >= max_tokens_per_community
133+
):
134+
break
135+
if await _add_unit(nb):
136+
await queue.put(nb)
26137

138+
return Community(
139+
id=len(communities),
140+
nodes=list(community_nodes.keys()),
141+
edges=[(u, v) for (u, v), _ in community_edges.items()],
142+
)
27143

28-
# 修改
29-
# max_depth 取消
30-
# expand_method 改名为 xxx
31-
# edge_sampling
32-
# loss_strategy取消,因为node和edge可以看作同一种unit
33-
# bidirectional 取消
34-
# max_extra_edges 改名为 max_units_per_community
35-
# max_tokens 改名为 max_tokens_per_community
144+
async for unit in tqdm_async(all_units, desc="ECE partition"):
145+
utype, uid, _ = unit
146+
if (utype == "n" and uid in used_n) or (utype == "e" and uid in used_e):
147+
continue
148+
communities.append(await _grow_community(unit))
36149

37-
# 可以退化成BFS
150+
return communities

graphgen/operators/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from graphgen.operators.partition.traverse_graph import (
22
traverse_graph_for_aggregated,
3-
traverse_graph_for_atomic,
43
traverse_graph_for_multi_hop,
54
)
65

graphgen/operators/partition/partition_kg.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
from typing import Any, List, Tuple
1+
from typing import Any
22

3-
from graphgen.bases import BaseGraphStorage
4-
from graphgen.bases.datatypes import Community
3+
from graphgen.bases import BaseGraphStorage, BaseTokenizer
54
from graphgen.models import (
65
BFSPartitioner,
76
DFSPartitioner,
@@ -10,9 +9,12 @@
109
)
1110
from graphgen.utils import logger
1211

12+
from .pre_tokenize import pre_tokenize
13+
1314

1415
async def partition_kg(
1516
kg_instance: BaseGraphStorage,
17+
tokenizer: Any = BaseTokenizer,
1618
partition_config: dict = None,
1719
) -> list[
1820
tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict] | tuple[Any, Any, Any]]]
@@ -27,6 +29,12 @@ async def partition_kg(
2729
partitioner = DFSPartitioner()
2830
elif method == "ece":
2931
logger.info("Partitioning knowledge graph using ECE method.")
32+
# TODO: before ECE partitioning, we need to:
33+
# 1. 'quiz and judge' to get the comprehension loss
34+
# 2. pre-tokenize nodes and edges to get the token length
35+
edges = await kg_instance.get_all_edges()
36+
nodes = await kg_instance.get_all_nodes()
37+
await pre_tokenize(kg_instance, tokenizer, edges, nodes)
3038
partitioner = ECEPartitioner()
3139
elif method == "leiden":
3240
logger.info("Partitioning knowledge graph using Leiden method.")
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import asyncio
2+
from typing import List, Tuple
3+
4+
from graphgen.bases import BaseGraphStorage, BaseTokenizer
5+
from graphgen.utils import run_concurrent
6+
7+
8+
async def pre_tokenize(
9+
graph_storage: BaseGraphStorage,
10+
tokenizer: BaseTokenizer,
11+
edges: List[Tuple],
12+
nodes: List[Tuple],
13+
) -> Tuple[List, List]:
14+
"""为 edges/nodes 补 token-length 并回写存储,并发 1000,带进度条。"""
15+
sem = asyncio.Semaphore(1000)
16+
17+
async def _patch_and_write(obj: Tuple, *, is_node: bool) -> Tuple:
18+
async with sem:
19+
data = obj[1] if is_node else obj[2]
20+
if "length" not in data:
21+
loop = asyncio.get_event_loop()
22+
data["length"] = len(
23+
await loop.run_in_executor(
24+
None, tokenizer.encode, data["description"]
25+
)
26+
)
27+
if is_node:
28+
await graph_storage.update_node(obj[0], obj[1])
29+
else:
30+
await graph_storage.update_edge(obj[0], obj[1], obj[2])
31+
return obj
32+
33+
new_edges, new_nodes = await asyncio.gather(
34+
run_concurrent(
35+
lambda e: _patch_and_write(e, is_node=False),
36+
edges,
37+
desc="Pre-tokenizing edges",
38+
),
39+
run_concurrent(
40+
lambda n: _patch_and_write(n, is_node=True),
41+
nodes,
42+
desc="Pre-tokenizing nodes",
43+
),
44+
)
45+
46+
await graph_storage.index_done_callback()
47+
return new_edges, new_nodes

0 commit comments

Comments
 (0)