|
1 | 1 | from collections import defaultdict |
2 | 2 | from dataclasses import dataclass |
3 | | -from typing import Any, Dict, List |
| 3 | +from typing import Any, Dict, List, Set, Tuple |
4 | 4 |
|
5 | | -from graphgen.models.storage.networkx_storage import NetworkXStorage |
| 5 | +import igraph as ig |
| 6 | +from leidenalg import ModularityVertexPartition, find_partition |
6 | 7 |
|
| 8 | +from graphgen.bases import BaseGraphStorage, BasePartitioner |
| 9 | +from graphgen.bases.datatypes import Community |
7 | 10 |
|
8 | | -@dataclass |
9 | | -class LeidenPartitioner: |
10 | | - """Class for partitioner detection algorithms.""" |
11 | 11 |
|
12 | | - graph_storage: NetworkXStorage = None |
13 | | - method: str = "leiden" |
14 | | - method_params: Dict[str, Any] = None |
| 12 | +@dataclass |
| 13 | +class LeidenPartitioner(BasePartitioner): |
| 14 | + """ |
| 15 | + Leiden partitioner that partitions the graph into communities using the Leiden algorithm. |
| 16 | + """ |
| 17 | + |
| 18 | + async def partition( |
| 19 | + self, |
| 20 | + g: BaseGraphStorage, |
| 21 | + max_size: int = 20, |
| 22 | + use_lcc: bool = False, |
| 23 | + random_seed: int = 42, |
| 24 | + **kwargs: Any, |
| 25 | + ) -> List[Community]: |
| 26 | + """ |
| 27 | + Leiden Partition follows these steps: |
| 28 | + 1. export the graph from graph storage |
| 29 | + 2. use the leiden algorithm to detect communities, get {node: community_id} |
| 30 | + 3. split large communities if max_size is given |
| 31 | + 4. convert {node: community_id} to List[Community] |
| 32 | + :param g |
| 33 | + :param max_size: maximum size of each community, if None or <=0, no limit |
| 34 | + :param use_lcc: whether to use the largest connected component only |
| 35 | + :param random_seed |
| 36 | + :param kwargs: other parameters for the leiden algorithm |
| 37 | + :return: |
| 38 | + """ |
| 39 | + nodes = await g.get_all_nodes() # List[Tuple[str, dict]] |
| 40 | + edges = await g.get_all_edges() # List[Tuple[str, str, dict]] |
15 | 41 |
|
16 | | - async def detect_communities(self) -> Dict[str, int]: |
17 | | - if self.method == "leiden": |
18 | | - return await self._leiden_communities(**self.method_params or {}) |
19 | | - raise ValueError(f"Unknown partitioner detection method: {self.method}") |
| 42 | + node2cid: Dict[str, int] = await self._run_leiden( |
| 43 | + nodes, edges, use_lcc, random_seed |
| 44 | + ) |
20 | 45 |
|
21 | | - async def get_graph(self): |
22 | | - return await self.graph_storage.get_graph() |
| 46 | + if max_size is not None and max_size > 0: |
| 47 | + node2cid = await self._split_communities(node2cid, max_size) |
23 | 48 |
|
24 | | - async def _leiden_communities( |
25 | | - self, max_size: int = None, **kwargs |
26 | | - ) -> Dict[str, int]: |
27 | | - """ |
28 | | - Detect communities using the Leiden algorithm. |
29 | | - If max_size is given, any partitioner larger than max_size will be split |
30 | | - into smaller sub-communities each having at most max_size nodes. |
31 | | - """ |
32 | | - import igraph as ig |
33 | | - import networkx as nx |
34 | | - from leidenalg import ModularityVertexPartition, find_partition |
| 49 | + cid2nodes: Dict[int, List[str]] = defaultdict(list) |
| 50 | + for n, cid in node2cid.items(): |
| 51 | + cid2nodes[cid].append(n) |
35 | 52 |
|
36 | | - graph = await self.get_graph() |
37 | | - graph.remove_nodes_from(list(nx.isolates(graph))) |
| 53 | + communities: List[Community] = [] |
| 54 | + for cid, nodes in cid2nodes.items(): |
| 55 | + node_set: Set[str] = set(nodes) |
| 56 | + comm_edges: List[Tuple[str, str]] = [ |
| 57 | + (u, v) for u, v, _ in edges if u in node_set and v in node_set |
| 58 | + ] |
| 59 | + communities.append(Community(id=cid, nodes=nodes, edges=comm_edges)) |
| 60 | + return communities |
38 | 61 |
|
39 | | - ig_graph = ig.Graph.TupleList(graph.edges(), directed=False) |
| 62 | + @staticmethod |
| 63 | + async def _run_leiden( |
| 64 | + nodes: List[Tuple[str, dict]], |
| 65 | + edges: List[Tuple[str, str, dict]], |
| 66 | + use_lcc: bool = False, |
| 67 | + random_seed: int = 42, |
| 68 | + ) -> Dict[str, int]: |
| 69 | + # build igraph |
| 70 | + ig_graph = ig.Graph.TupleList(((u, v) for u, v, _ in edges), directed=False) |
40 | 71 |
|
41 | | - random_seed = kwargs.get("random_seed", 42) |
42 | | - use_lcc = kwargs.get("use_lcc", False) |
| 72 | + # remove isolated nodes |
| 73 | + ig_graph.delete_vertices(ig_graph.vs.select(_degree_eq=0)) |
43 | 74 |
|
44 | | - communities: Dict[str, int] = {} |
| 75 | + node2cid: Dict[str, int] = {} |
45 | 76 | if use_lcc: |
46 | 77 | lcc = ig_graph.components().giant() |
47 | 78 | partition = find_partition(lcc, ModularityVertexPartition, seed=random_seed) |
48 | | - for part, cluster in enumerate(partition): |
| 79 | + for part_id, cluster in enumerate(partition): |
49 | 80 | for v in cluster: |
50 | | - communities[lcc.vs[v]["name"]] = part |
| 81 | + node2cid[lcc.vs[v]["name"]] = part_id |
51 | 82 | else: |
52 | 83 | offset = 0 |
53 | 84 | for component in ig_graph.components(): |
54 | 85 | subgraph = ig_graph.induced_subgraph(component) |
55 | 86 | partition = find_partition( |
56 | 87 | subgraph, ModularityVertexPartition, seed=random_seed |
57 | 88 | ) |
58 | | - for part, cluster in enumerate(partition): |
| 89 | + for part_id, cluster in enumerate(partition): |
59 | 90 | for v in cluster: |
60 | 91 | original_node = subgraph.vs[v]["name"] |
61 | | - communities[original_node] = part + offset |
| 92 | + node2cid[original_node] = part_id + offset |
62 | 93 | offset += len(partition) |
63 | | - |
64 | | - # split large communities if max_size is specified |
65 | | - if max_size is None or max_size <= 0: |
66 | | - return communities |
67 | | - |
68 | | - return await self._split_communities(communities, max_size) |
| 94 | + return node2cid |
69 | 95 |
|
70 | 96 | @staticmethod |
71 | 97 | async def _split_communities( |
72 | | - communities: Dict[str, int], max_size: int |
| 98 | + node2cid: Dict[str, int], max_size: int |
73 | 99 | ) -> Dict[str, int]: |
74 | 100 | """ |
75 | 101 | Split communities larger than max_size into smaller sub-communities. |
76 | 102 | """ |
77 | 103 | cid2nodes: Dict[int, List[str]] = defaultdict(list) |
78 | | - for node, cid in communities.items(): |
79 | | - cid2nodes[cid].append(node) |
| 104 | + for n, cid in node2cid.items(): |
| 105 | + cid2nodes[cid].append(n) |
80 | 106 |
|
81 | | - new_communities: Dict[str, int] = {} |
| 107 | + new_mapping: Dict[str, int] = {} |
82 | 108 | new_cid = 0 |
83 | | - for cid, nodes in cid2nodes.items(): |
| 109 | + for nodes in cid2nodes.values(): |
84 | 110 | if len(nodes) <= max_size: |
85 | 111 | for n in nodes: |
86 | | - new_communities[n] = new_cid |
| 112 | + new_mapping[n] = new_cid |
87 | 113 | new_cid += 1 |
88 | 114 | else: |
89 | 115 | for start in range(0, len(nodes), max_size): |
90 | | - sub = nodes[start : start + max_size] |
91 | | - for n in sub: |
92 | | - new_communities[n] = new_cid |
| 116 | + chunk = nodes[start : start + max_size] |
| 117 | + for n in chunk: |
| 118 | + new_mapping[n] = new_cid |
93 | 119 | new_cid += 1 |
94 | | - |
95 | | - return new_communities |
| 120 | + return new_mapping |
0 commit comments