Skip to content

Commit e7b332f

Browse files
feat: split community with max_size
1 parent f5793a7 commit e7b332f

2 files changed

Lines changed: 42 additions & 15 deletions

File tree

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
from collections import defaultdict
12
from dataclasses import dataclass
2-
from typing import Any, Dict
3+
from typing import Any, Dict, List
34

45
from graphgen.models.storage.networkx_storage import NetworkXStorage
56

@@ -13,45 +14,40 @@ class CommunityDetector:
1314
method_params: Dict[str, Any] = None
1415

1516
async def detect_communities(self) -> Dict[str, int]:
16-
"""
17-
Detect communities based on the chosen method.
18-
"""
1917
if self.method == "leiden":
2018
return await self._leiden_communities(**self.method_params or {})
2119
raise ValueError(f"Unknown community detection method: {self.method}")
2220

2321
async def get_graph(self):
24-
"""
25-
Asynchronously get the graph from the storage.
26-
"""
2722
return await self.graph_storage.get_graph()
2823

29-
async def _leiden_communities(self, **kwargs) -> Dict[str, int]:
24+
async def _leiden_communities(
25+
self, max_size: int = None, **kwargs
26+
) -> Dict[str, int]:
3027
"""
3128
Detect communities using the Leiden algorithm.
29+
If max_size is given, any community larger than max_size will be split
30+
into smaller sub-communities each having at most max_size nodes.
3231
"""
3332
import igraph as ig
3433
import networkx as nx
3534
from leidenalg import ModularityVertexPartition, find_partition
3635

3736
graph = await self.get_graph()
38-
# Filter out isolated nodes
3937
graph.remove_nodes_from(list(nx.isolates(graph)))
4038

41-
# Convert NetworkX graph to igraph graph
4239
ig_graph = ig.Graph.TupleList(graph.edges(), directed=False)
4340

4441
random_seed = kwargs.get("random_seed", 42)
4542
use_lcc = kwargs.get("use_lcc", False)
4643

47-
communities = {}
44+
communities: Dict[str, int] = {}
4845
if use_lcc:
49-
# Use the largest connected component
5046
lcc = ig_graph.components().giant()
5147
partition = find_partition(lcc, ModularityVertexPartition, seed=random_seed)
5248
for part, cluster in enumerate(partition):
5349
for v in cluster:
54-
communities[v] = part
50+
communities[lcc.vs[v]["name"]] = part
5551
else:
5652
offset = 0
5753
for component in ig_graph.components():
@@ -64,4 +60,36 @@ async def _leiden_communities(self, **kwargs) -> Dict[str, int]:
6460
original_node = subgraph.vs[v]["name"]
6561
communities[original_node] = part + offset
6662
offset += len(partition)
67-
return communities
63+
64+
# split large communities if max_size is specified
65+
if max_size is None or max_size <= 0:
66+
return communities
67+
68+
return await self._split_communities(communities, max_size)
69+
70+
@staticmethod
71+
async def _split_communities(
72+
communities: Dict[str, int], max_size: int
73+
) -> Dict[str, int]:
74+
"""
75+
Split communities larger than max_size into smaller sub-communities.
76+
"""
77+
cid2nodes: Dict[int, List[str]] = defaultdict(list)
78+
for node, cid in communities.items():
79+
cid2nodes[cid].append(node)
80+
81+
new_communities: Dict[str, int] = {}
82+
new_cid = 0
83+
for cid, nodes in cid2nodes.items():
84+
if len(nodes) <= max_size:
85+
for n in nodes:
86+
new_communities[n] = new_cid
87+
new_cid += 1
88+
else:
89+
for start in range(0, len(nodes), max_size):
90+
sub = nodes[start : start + max_size]
91+
for n in sub:
92+
new_communities[n] = new_cid
93+
new_cid += 1
94+
95+
return new_communities

graphgen/models/vis/community_visualizer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def visualize(self, save_path: str = None):
3131
plt.figure(figsize=(10, 10))
3232

3333
node_colors = [self.communities.get(node, 0) for node in self.graph.nodes()]
34-
print(node_colors)
3534

3635
nx.draw_networkx_nodes(
3736
self.graph,

0 commit comments

Comments
 (0)