1+ from collections import defaultdict
12from dataclasses import dataclass
2- from typing import Any , Dict
3+ from typing import Any , Dict , List
34
45from graphgen .models .storage .networkx_storage import NetworkXStorage
56
@@ -13,45 +14,40 @@ class CommunityDetector:
1314 method_params : Dict [str , Any ] = None
1415
1516 async def detect_communities (self ) -> Dict [str , int ]:
16- """
17- Detect communities based on the chosen method.
18- """
1917 if self .method == "leiden" :
2018 return await self ._leiden_communities (** self .method_params or {})
2119 raise ValueError (f"Unknown community detection method: { self .method } " )
2220
2321 async def get_graph (self ):
24- """
25- Asynchronously get the graph from the storage.
26- """
2722 return await self .graph_storage .get_graph ()
2823
29- async def _leiden_communities (self , ** kwargs ) -> Dict [str , int ]:
24+ async def _leiden_communities (
25+ self , max_size : int = None , ** kwargs
26+ ) -> Dict [str , int ]:
3027 """
3128 Detect communities using the Leiden algorithm.
29+ If max_size is given, any community larger than max_size will be split
30+ into smaller sub-communities each having at most max_size nodes.
3231 """
3332 import igraph as ig
3433 import networkx as nx
3534 from leidenalg import ModularityVertexPartition , find_partition
3635
3736 graph = await self .get_graph ()
38- # Filter out isolated nodes
3937 graph .remove_nodes_from (list (nx .isolates (graph )))
4038
41- # Convert NetworkX graph to igraph graph
4239 ig_graph = ig .Graph .TupleList (graph .edges (), directed = False )
4340
4441 random_seed = kwargs .get ("random_seed" , 42 )
4542 use_lcc = kwargs .get ("use_lcc" , False )
4643
47- communities = {}
44+ communities : Dict [ str , int ] = {}
4845 if use_lcc :
49- # Use the largest connected component
5046 lcc = ig_graph .components ().giant ()
5147 partition = find_partition (lcc , ModularityVertexPartition , seed = random_seed )
5248 for part , cluster in enumerate (partition ):
5349 for v in cluster :
54- communities [v ] = part
50+ communities [lcc . vs [ v ][ "name" ] ] = part
5551 else :
5652 offset = 0
5753 for component in ig_graph .components ():
@@ -64,4 +60,36 @@ async def _leiden_communities(self, **kwargs) -> Dict[str, int]:
6460 original_node = subgraph .vs [v ]["name" ]
6561 communities [original_node ] = part + offset
6662 offset += len (partition )
67- return communities
63+
64+ # split large communities if max_size is specified
65+ if max_size is None or max_size <= 0 :
66+ return communities
67+
68+ return await self ._split_communities (communities , max_size )
69+
70+ @staticmethod
71+ async def _split_communities (
72+ communities : Dict [str , int ], max_size : int
73+ ) -> Dict [str , int ]:
74+ """
75+ Split communities larger than max_size into smaller sub-communities.
76+ """
77+ cid2nodes : Dict [int , List [str ]] = defaultdict (list )
78+ for node , cid in communities .items ():
79+ cid2nodes [cid ].append (node )
80+
81+ new_communities : Dict [str , int ] = {}
82+ new_cid = 0
83+ for cid , nodes in cid2nodes .items ():
84+ if len (nodes ) <= max_size :
85+ for n in nodes :
86+ new_communities [n ] = new_cid
87+ new_cid += 1
88+ else :
89+ for start in range (0 , len (nodes ), max_size ):
90+ sub = nodes [start : start + max_size ]
91+ for n in sub :
92+ new_communities [n ] = new_cid
93+ new_cid += 1
94+
95+ return new_communities
0 commit comments