Skip to content

Commit fafbb25

Browse files
committed
Added types and docs
1 parent b9f8eee commit fafbb25

8 files changed

Lines changed: 269 additions & 61 deletions

File tree

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from typing import Any, Protocol
2+
3+
import networkx as nx
4+
import numpy as np
5+
from numpy.typing import NDArray
6+
7+
8+
class MapperPlotType(Protocol):
9+
10+
dim: int
11+
graph: nx.Graph
12+
positions: dict[Any, NDArray[np.float64]]

src/tdamapper/plot_backends/plot_matplotlib.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from numpy.typing import NDArray
1414

1515
from tdamapper.core import ATTR_SIZE, aggregate_graph
16+
from tdamapper.plot_backends.plot_common import MapperPlotType
1617

1718
_NODE_OUTER_WIDTH = 0.75
1819

@@ -24,7 +25,7 @@
2425

2526

2627
def plot_matplotlib(
27-
mapper_plot,
28+
mapper_plot: MapperPlotType,
2829
width: int,
2930
height: int,
3031
title: Optional[str],
@@ -56,7 +57,7 @@ def plot_matplotlib(
5657

5758

5859
def _plot_nodes(
59-
mapper_plot,
60+
mapper_plot: MapperPlotType,
6061
title: Optional[str],
6162
colors: NDArray[np.float64],
6263
node_size: float,
@@ -107,7 +108,7 @@ def _plot_nodes(
107108
colorbar.ax.locator_params(nbins=10)
108109

109110

110-
def _plot_edges(mapper_plot, ax: plt.Axes) -> None:
111+
def _plot_edges(mapper_plot: MapperPlotType, ax: plt.Axes) -> None:
111112
segments = [
112113
(mapper_plot.positions[e[0]], mapper_plot.positions[e[1]])
113114
for e in mapper_plot.graph.edges()
@@ -124,6 +125,6 @@ def _plot_edges(mapper_plot, ax: plt.Axes) -> None:
124125

125126

126127
def _node_pos_array(
127-
graph: nx.Graph, dim: int, positions: dict[Any, NDArray]
128-
) -> tuple[list[NDArray], ...]:
128+
graph: nx.Graph, dim: int, positions: dict[Any, NDArray[np.float64]]
129+
) -> tuple[list[NDArray[np.float64]], ...]:
129130
return tuple([positions[n][i] for n in graph.nodes()] for i in range(dim))

src/tdamapper/plot_backends/plot_plotly.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from numpy.typing import NDArray
1515

1616
from tdamapper.core import ATTR_SIZE, aggregate_graph
17+
from tdamapper.plot_backends.plot_common import MapperPlotType
1718

1819
_NODE_OUTER_WIDTH = 0.75
1920

@@ -145,7 +146,7 @@ def _get_cmap_rgb(cmap: str):
145146

146147

147148
def plot_plotly(
148-
mapper_plot,
149+
mapper_plot: MapperPlotType,
149150
colors: Union[NDArray[np.float64], list[float]],
150151
node_size: Optional[Union[int, float, list[Union[int, float]]]] = None,
151152
title: Optional[Union[str, list[str]]] = None,
@@ -187,7 +188,7 @@ def plot_plotly(
187188

188189

189190
def plot_plotly_update(
190-
mapper_plot,
191+
mapper_plot: MapperPlotType,
191192
fig: go.Figure,
192193
width: Optional[int] = None,
193194
height: Optional[int] = None,
@@ -257,7 +258,7 @@ class PlotlyPlot:
257258
:param fig: Optional existing Plotly figure to update.
258259
"""
259260

260-
def __init__(self, mapper_plot, fig: Optional[go.Figure] = None):
261+
def __init__(self, mapper_plot: MapperPlotType, fig: Optional[go.Figure] = None):
261262
self.mapper_plot = mapper_plot
262263
self.fig = fig
263264
self.graph = mapper_plot.graph
@@ -370,10 +371,12 @@ def _edge_colors_from_node_colors(
370371
edge_col.append(c1)
371372
return edge_col
372373

373-
def _set_colors(self, colors, agg):
374+
def _set_colors(self, colors: NDArray[np.float64], agg: Callable) -> None:
374375
node_col_agg = aggregate_graph(colors, self.graph, agg)
375376
node_col_arr = list(node_col_agg.values())
376377
scatter_text = self._text(node_col_agg)
378+
if self.fig is None:
379+
return
377380
self.fig.update_traces(
378381
patch=dict(
379382
text=scatter_text,
@@ -533,7 +536,9 @@ def update_figure(
533536
if cmaps is not None:
534537
self.set_cmap(cmaps[0])
535538

536-
def _nodes_trace(self, node_pos_arr) -> Union[go.Scatter, go.Scatter3d]:
539+
def _nodes_trace(
540+
self, node_pos_arr: tuple[list[float], ...]
541+
) -> Union[go.Scatter, go.Scatter3d]:
537542
scatter = dict(
538543
name=_NODES_TRACE,
539544
x=node_pos_arr[0],
@@ -559,7 +564,9 @@ def _nodes_trace(self, node_pos_arr) -> Union[go.Scatter, go.Scatter3d]:
559564
else:
560565
return go.Scatter(scatter)
561566

562-
def _edges_trace(self, edge_pos_arr) -> Union[go.Scatter, go.Scatter3d]:
567+
def _edges_trace(
568+
self, edge_pos_arr: tuple[list[Optional[float]], ...]
569+
) -> Union[go.Scatter, go.Scatter3d]:
563570
scatter = dict(
564571
name=_EDGES_TRACE,
565572
x=edge_pos_arr[0],
@@ -722,7 +729,7 @@ def set_ui(
722729
sliders=sliders,
723730
)
724731

725-
def ui_menu_dark_mode(self) -> dict:
732+
def ui_menu_dark_mode(self) -> dict[str, Any]:
726733
"""
727734
Create a dropdown menu for toggling dark mode in the Plotly figure.
728735
@@ -804,7 +811,7 @@ def _update_cmap(cmap: str) -> dict:
804811

805812
def _ui_menu_color(
806813
self, colors: NDArray[np.float64], titles: list[str], agg: Callable
807-
) -> dict:
814+
) -> dict[str, Any]:
808815
colors_arr = np.array(colors)
809816
colors_num = colors_arr.shape[1] if colors_arr.ndim == 2 else 1
810817

src/tdamapper/plot_backends/plot_pyvis.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from pyvis.network import Network
1515

1616
from tdamapper.core import aggregate_graph
17+
from tdamapper.plot_backends.plot_common import MapperPlotType
1718

1819
_EDGE_WIDTH = 0.75
1920

@@ -138,7 +139,7 @@ def _combine(network: Network, colorbar: go.Figure) -> str:
138139

139140

140141
def plot_pyvis(
141-
mapper_plot,
142+
mapper_plot: MapperPlotType,
142143
output_file: str,
143144
colors: NDArray[np.float64],
144145
node_size: float,
@@ -177,7 +178,7 @@ def plot_pyvis(
177178

178179

179180
def _compute_net(
180-
mapper_plot,
181+
mapper_plot: MapperPlotType,
181182
colors: NDArray[np.float64],
182183
node_size: float,
183184
agg: Callable,

src/tdamapper/utils/heap.py

Lines changed: 120 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,73 +1,168 @@
1-
def _left(i):
1+
"""
2+
This module implements a max-heap data structure. It provides methods to add
3+
elements, pop the maximum element, and retrieve the top element without
4+
removing it. It is designed to be used in scenarios where you need to maintain
5+
a collection of elements and frequently access the maximum element
6+
efficiently.
7+
"""
8+
9+
from __future__ import annotations
10+
11+
from typing import Any, Generic, Iterator, Optional, Protocol, TypeVar
12+
13+
14+
class Comparable(Protocol):
15+
"""
16+
A protocol that defines the methods required for an object to be
17+
orderable. This is used to ensure that the keys in the heap can be
18+
compared with each other.
19+
"""
20+
21+
def __lt__(self: K, other: K) -> bool: ...
22+
23+
def __le__(self: K, other: K) -> bool: ...
24+
25+
def __gt__(self: K, other: K) -> bool: ...
26+
27+
def __ge__(self: K, other: K) -> bool: ...
28+
29+
30+
K = TypeVar("K", bound=Comparable)
31+
32+
V = TypeVar("V")
33+
34+
35+
def _left(i: int) -> int:
36+
"""
37+
Returns the index of the left child of the node at index i in a binary heap.
38+
39+
:param i: The index of the parent node.
40+
:return: The index of the left child.
41+
"""
242
return 2 * i + 1
343

444

5-
def _right(i):
45+
def _right(i: int) -> int:
46+
"""
47+
Returns the index of the right child of the node at index i in a binary heap.
48+
49+
:param i: The index of the parent node.
50+
:return: The index of the right child.
51+
"""
652
return 2 * i + 2
753

854

9-
def _parent(i):
55+
def _parent(i: int) -> int:
56+
"""
57+
Returns the index of the parent node of the node at index i in a binary heap.
58+
59+
:param i: The index of the child node.
60+
:return: The index of the parent node.
61+
"""
1062
return max(0, (i - 1) // 2)
1163

1264

13-
class _HeapNode:
65+
class _HeapNode(Generic[K, V]):
66+
"""
67+
A private class representing a node in the max-heap. Each node contains a
68+
key and a value. The key is used to determine the order of the nodes in
69+
the heap, with larger keys being prioritized over smaller keys. The value
70+
can be any associated data that you want to store with the key.
71+
72+
:param key: The key of the node, used for ordering in the heap.
73+
:param value: The value associated with the key.
74+
"""
1475

15-
def __init__(self, key, value):
76+
def __init__(self, key: K, value: V) -> None:
1677
self._key = key
1778
self._value = value
1879

19-
def get(self):
80+
def get(self) -> tuple[K, V]:
81+
"""
82+
Returns the key and value of the node as a tuple.
83+
84+
:return: A tuple containing the key and value of the node.
85+
"""
2086
return self._key, self._value
2187

22-
def __lt__(self, other):
88+
def __lt__(self, other: _HeapNode[K, Any]) -> bool:
2389
return self._key < other._key
2490

25-
def __le__(self, other):
91+
def __le__(self, other: _HeapNode[K, Any]) -> bool:
2692
return self._key <= other._key
2793

28-
def __gt__(self, other):
94+
def __gt__(self, other: _HeapNode[K, Any]) -> bool:
2995
return self._key > other._key
3096

31-
def __ge__(self, other):
97+
def __ge__(self, other: _HeapNode[K, Any]) -> bool:
3298
return self._key >= other._key
3399

34100

35-
class MaxHeap:
101+
class MaxHeap(Generic[K, V]):
102+
"""
103+
A max-heap implementation that allows for efficient retrieval and removal
104+
of the maximum element. The heap is implemented as a list of _HeapNode
105+
objects, where each node contains a key and a value. The key is used to
106+
determine the order of the elements in the heap, with larger keys being
107+
prioritized over smaller keys.
108+
"""
36109

37-
def __init__(self):
110+
_heap: list[_HeapNode[K, V]]
111+
_iter: Iterator[_HeapNode[K, V]]
112+
113+
def __init__(self) -> None:
38114
self._heap = []
39-
self._iter = None
40115

41-
def __iter__(self):
116+
def __iter__(self) -> MaxHeap[K, V]:
42117
self._iter = iter(self._heap)
43118
return self
44119

45-
def __next__(self):
120+
def __next__(self) -> tuple[K, V]:
46121
node = next(self._iter)
47122
return node.get()
48123

49-
def __len__(self):
124+
def __len__(self) -> int:
50125
return len(self._heap)
51126

52-
def top(self):
127+
def top(self) -> tuple[Optional[K], Optional[V]]:
128+
"""
129+
Returns the maximum element of the heap without removing it.
130+
131+
:return: A tuple containing the key and value of the maximum element,
132+
or (None, None) if the heap is empty.
133+
"""
53134
if not self._heap:
54135
return (None, None)
55136
return self._heap[0].get()
56137

57-
def pop(self):
138+
def pop(self) -> tuple[Optional[K], Optional[V]]:
139+
"""
140+
Removes and returns the maximum element from the heap.
141+
142+
:return: A tuple containing the key and value of the maximum element,
143+
or (None, None) if the heap is empty.
144+
"""
58145
if not self._heap:
59-
return
146+
return (None, None)
60147
max_val = self._heap[0]
61148
self._heap[0] = self._heap[-1]
62149
self._heap.pop()
63150
self._bubble_down()
64151
return max_val.get()
65152

66-
def add(self, key, val):
153+
def add(self, key: K, val: V) -> None:
154+
"""
155+
Adds a new element to the heap.
156+
157+
:param key: The key of the element to be added, which determines its
158+
position in the heap.
159+
:param val: The value associated with the key.
160+
:return: None
161+
"""
67162
self._heap.append(_HeapNode(key, val))
68163
self._bubble_up()
69164

70-
def _get_local_max(self, i):
165+
def _get_local_max(self, i: int) -> int:
71166
heap_len = len(self._heap)
72167
left = _left(i)
73168
right = _right(i)
@@ -84,7 +179,7 @@ def _get_local_max(self, i):
84179
return max_child
85180
return i
86181

87-
def _fix_down(self, i):
182+
def _fix_down(self, i: int) -> int:
88183
local_max = self._get_local_max(i)
89184
if i < local_max:
90185
self._heap[i], self._heap[local_max] = (
@@ -94,22 +189,22 @@ def _fix_down(self, i):
94189
return local_max
95190
return i
96191

97-
def _fix_up(self, i):
192+
def _fix_up(self, i: int) -> int:
98193
parent = _parent(i)
99194
if self._heap[parent] < self._heap[i]:
100195
self._heap[i], self._heap[parent] = self._heap[parent], self._heap[i]
101196
return parent
102197
return i
103198

104-
def _bubble_down(self):
199+
def _bubble_down(self) -> None:
105200
current = 0
106201
done = False
107202
while not done:
108203
local_max = self._fix_down(current)
109204
done = current == local_max
110205
current = local_max
111206

112-
def _bubble_up(self):
207+
def _bubble_up(self) -> None:
113208
current = len(self._heap) - 1
114209
done = False
115210
while not done:

0 commit comments

Comments
 (0)