1- from typing import List
1+ import asyncio
2+ import random
3+ from dataclasses import dataclass
4+ from typing import Any , Dict , List , Set , Tuple
5+
6+ from tqdm .asyncio import tqdm as tqdm_async
27
38from graphgen .bases import BaseGraphStorage
49from graphgen .bases .datatypes import Community
5- from graphgen .models import BFSPartitioner
10+ from graphgen .models . partitioner . bfs_partitioner import BFSPartitioner
611
712
13+ @dataclass
814class ECEPartitioner (BFSPartitioner ):
915 """
1016 ECE partitioner that partitions the graph into communities based on Expected Calibration Error (ECE).
11- We calculate ECE for edges in KG(represented as 'comprehension loss') and group edges with similar ECE values into the same community.
17+ We calculate ECE for edges in KG(represented as 'comprehension loss')
18+ and group edges with similar ECE values into the same community.
1219 1. Select a sampling strategy.
1320 2. Choose a unit based on the sampling strategy.
1421 2. Expand the community using BFS.
@@ -17,21 +24,127 @@ class ECEPartitioner(BFSPartitioner):
1724 (A unit is a node or an edge.)
1825 """
1926
20- # async def partition(
21- # self,
22- # g: BaseGraphStorage,
23- # *,
24- # ):
25- # pass
27+ @staticmethod
28+ def _sort_units (units : list , edge_sampling : str ) -> list :
29+ """
30+ Sort units with edge sampling strategy
31+
32+ :param units: total units
33+ :param edge_sampling: edge sampling strategy (random, min_loss, max_loss)
34+ :return: sorted units
35+ """
36+ if edge_sampling == "random" :
37+ random .shuffle (units )
38+ elif edge_sampling == "min_loss" :
39+ units = sorted (
40+ units ,
41+ key = lambda x : x [- 1 ]["loss" ],
42+ )
43+ elif edge_sampling == "max_loss" :
44+ units = sorted (
45+ units ,
46+ key = lambda x : x [- 1 ]["loss" ],
47+ reverse = True ,
48+ )
49+ else :
50+ raise ValueError (f"Invalid edge sampling: { edge_sampling } " )
51+ return units
52+
53+ async def partition (
54+ self ,
55+ g : BaseGraphStorage ,
56+ max_units_per_community : int = 10 ,
57+ max_tokens_per_community : int = 10240 ,
58+ edge_sampling : str = "random" ,
59+ ** kwargs : Any ,
60+ ) -> List [Community ]:
61+ nodes : List [Tuple [str , dict ]] = await g .get_all_nodes ()
62+ edges : List [Tuple [str , str , dict ]] = await g .get_all_edges ()
63+
64+ adj , _ = self ._build_adjacency_list (nodes , edges )
65+ node_dict = dict (nodes )
66+ edge_dict = {frozenset ((u , v )): d for u , v , d in edges }
67+
68+ all_units : List [Tuple [str , Any , dict ]] = [("n" , nid , d ) for nid , d in nodes ] + [
69+ ("e" , frozenset ((u , v )), d ) for u , v , d in edges
70+ ]
71+
72+ used_n : Set [str ] = set ()
73+ used_e : Set [frozenset [str ]] = set ()
74+ communities : List = []
75+
76+ all_units = self ._sort_units (all_units , edge_sampling )
77+
78+ async def _grow_community (seed_unit : Tuple [str , Any , dict ]) -> Community :
79+ nonlocal used_n , used_e
80+
81+ community_nodes : Dict [str , dict ] = {}
82+ community_edges : Dict [frozenset [str ], dict ] = {}
83+ queue : asyncio .Queue = asyncio .Queue ()
84+ token_sum = 0
85+
86+ async def _add_unit (u ):
87+ nonlocal token_sum
88+ t , i , d = u
89+ if t == "n" :
90+ if i in used_n or i in community_nodes :
91+ return False
92+ community_nodes [i ] = d
93+ used_n .add (i )
94+ else : # edge
95+ if i in used_e or i in community_edges :
96+ return False
97+ community_edges [i ] = d
98+ used_e .add (i )
99+ token_sum += d .get ("length" , 0 )
100+ return True
101+
102+ await _add_unit (seed_unit )
103+ await queue .put (seed_unit )
104+
105+ # BFS
106+ while not queue .empty ():
107+ if (
108+ len (community_nodes ) + len (community_edges )
109+ >= max_units_per_community
110+ or token_sum >= max_tokens_per_community
111+ ):
112+ break
113+
114+ cur_type , cur_id , _ = await queue .get ()
115+
116+ neighbors : List [Tuple [str , Any , dict ]] = []
117+ if cur_type == "n" :
118+ for nb_id in adj .get (cur_id , []):
119+ e_key = frozenset ((cur_id , nb_id ))
120+ if e_key not in used_e and e_key not in community_edges :
121+ neighbors .append (("e" , e_key , edge_dict [e_key ]))
122+ else :
123+ for n_id in cur_id :
124+ if n_id not in used_n and n_id not in community_nodes :
125+ neighbors .append (("n" , n_id , node_dict [n_id ]))
126+
127+ neighbors = self ._sort_units (neighbors , edge_sampling )
128+ for nb in neighbors :
129+ if (
130+ len (community_nodes ) + len (community_edges )
131+ >= max_units_per_community
132+ or token_sum >= max_tokens_per_community
133+ ):
134+ break
135+ if await _add_unit (nb ):
136+ await queue .put (nb )
26137
138+ return Community (
139+ id = len (communities ),
140+ nodes = list (community_nodes .keys ()),
141+ edges = [(u , v ) for (u , v ), _ in community_edges .items ()],
142+ )
27143
28- # 修改
29- # max_depth 取消
30- # expand_method 改名为 xxx
31- # edge_sampling
32- # loss_strategy取消,因为node和edge可以看作同一种unit
33- # bidirectional 取消
34- # max_extra_edges 改名为 max_units_per_community
35- # max_tokens 改名为 max_tokens_per_community
144+ async for unit in tqdm_async (all_units , desc = "ECE partition" ):
145+ utype , uid , _ = unit
146+ if (utype == "n" and uid in used_n ) or (utype == "e" and uid in used_e ):
147+ continue
148+ communities .append (await _grow_community (unit ))
36149
37- # 可以退化成BFS
150+ return communities
0 commit comments