Skip to content

Commit c0c9226

Browse files
committed
Improved type-hints
1 parent 7fc2eeb commit c0c9226

10 files changed

Lines changed: 87 additions & 58 deletions

File tree

Makefile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ all: install
88
install:
99
$(PIP) install -e .[dev]
1010

11+
.PHONY: typecheck
12+
typecheck:
13+
mypy src --ignore-missing-imports --install-types
14+
1115
.PHONY: test
1216
test:
1317
coverage run --source=src -m pytest tests/test_unit_*.py

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ dev = [
5555
"nbformat>=4.2.0",
5656
"umap-learn<0.6.0",
5757
"nicegui>=2.18.0,<3.0.0",
58+
"mypy",
5859
]
5960
app = [
6061
"pandas<3.0.0",
@@ -113,3 +114,6 @@ filterwarnings = [
113114
markers = [
114115
"module_under_test",
115116
]
117+
118+
[tool.mypy]
119+
ignore_missing_imports = true

src/tdamapper/app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def _umap(X: NDArray[np.float64]) -> NDArray[np.float64]:
176176

177177
def run_mapper(
178178
df: pd.DataFrame, **kwargs: dict[str, Any]
179-
) -> Optional[tuple[MapperAlgorithm, pd.DataFrame]]:
179+
) -> Optional[tuple[nx.Graph, pd.DataFrame]]:
180180
"""
181181
Run the Mapper algorithm on the provided DataFrame.
182182

src/tdamapper/core.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -213,14 +213,13 @@ def mapper_graph(
213213
:rtype: :class:`networkx.Graph`
214214
"""
215215
itm_lbls = mapper_labels(X, y, cover, clustering, n_jobs=n_jobs)
216-
graph = nx.Graph()
216+
graph: nx.Graph = nx.Graph()
217217
for n, lbls in enumerate(itm_lbls):
218218
for lbl in lbls:
219219
if not graph.has_node(lbl):
220220
graph.add_node(lbl, **{ATTR_SIZE: 0, ATTR_IDS: []})
221-
nodes = graph.nodes()
222-
nodes[lbl][ATTR_SIZE] += 1
223-
nodes[lbl][ATTR_IDS].append(n)
221+
graph.nodes[lbl][ATTR_SIZE] += 1
222+
graph.nodes[lbl][ATTR_IDS].append(n)
224223
for lbls in itm_lbls:
225224
lbls_len = len(lbls)
226225
for i in range(lbls_len):
@@ -256,9 +255,8 @@ def aggregate_graph(X: ArrayLike, graph: nx.Graph, agg: Callable) -> dict[Any, A
256255
:rtype: dict
257256
"""
258257
agg_values = {}
259-
nodes = graph.nodes()
260-
for node_id in nodes:
261-
node_values = [X[i] for i in nodes[node_id][ATTR_IDS]]
258+
for node_id in graph.nodes:
259+
node_values = [X[i] for i in graph.nodes[node_id][ATTR_IDS]]
262260
agg_value = agg(node_values)
263261
agg_values[node_id] = agg_value
264262
return agg_values

src/tdamapper/cover.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -261,9 +261,9 @@ class BaseCubicalCover:
261261

262262
_n_intervals: int
263263
_overlap_frac: float
264-
_min: NDArray
265-
_max: NDArray
266-
_delta: NDArray
264+
_min: NDArray[np.float64]
265+
_max: NDArray[np.float64]
266+
_delta: NDArray[np.float64]
267267
_cover: BallCover
268268

269269
def __init__(
@@ -282,7 +282,7 @@ def __init__(
282282
self.leaf_radius = leaf_radius
283283
self.pivoting = pivoting
284284

285-
def _get_center(self, x: NDArray) -> tuple[tuple, NDArray]:
285+
def _get_center(self, x: NDArray[np.float64]) -> tuple[tuple, NDArray]:
286286
offset = self._offset(x)
287287
center = self._phi(x)
288288
return tuple(offset), center
@@ -291,20 +291,22 @@ def _get_overlap_frac(self, dim: int, overlap_vol_frac: float) -> float:
291291
beta = math.pow(1.0 - overlap_vol_frac, 1.0 / dim)
292292
return 1.0 - 1.0 / (2.0 - beta)
293293

294-
def _offset(self, x: NDArray) -> NDArray:
294+
def _offset(self, x: NDArray[np.float64]) -> NDArray[np.float64]:
295295
return np.minimum(self._n_intervals - 1, np.floor(self._gamma_n(x)))
296296

297-
def _phi(self, x: NDArray) -> NDArray:
297+
def _phi(self, x: NDArray[np.float64]) -> NDArray[np.float64]:
298298
offset = self._offset(x)
299299
return self._gamma_n_inv(0.5 + offset)
300300

301-
def _gamma_n(self, x: NDArray) -> NDArray:
301+
def _gamma_n(self, x: NDArray[np.float64]) -> NDArray[np.float64]:
302302
return self._n_intervals * (x - self._min) / self._delta
303303

304-
def _gamma_n_inv(self, x: NDArray) -> NDArray:
304+
def _gamma_n_inv(self, x: NDArray[np.float64]) -> NDArray[np.float64]:
305305
return self._min + self._delta * x / self._n_intervals
306306

307-
def _get_bounds(self, X: ArrayLike) -> tuple[NDArray, NDArray, NDArray]:
307+
def _get_bounds(
308+
self, X: ArrayLike
309+
) -> tuple[NDArray[np.float64], NDArray[np.float64], NDArray[np.float64]]:
308310
if (X is None) or len(X) == 0:
309311
raise ValueError("The dataset is empty or None.")
310312
_min, _max = X[0], X[0]
@@ -408,12 +410,12 @@ class ProximityCubicalCover(BaseCubicalCover, ParamsMixin, Proximity):
408410

409411
def __init__(
410412
self,
411-
n_intervals=1,
412-
overlap_frac=None,
413-
kind="flat",
414-
leaf_capacity=1,
415-
leaf_radius=None,
416-
pivoting=None,
413+
n_intervals: int = 1,
414+
overlap_frac: Optional[float] = None,
415+
kind: str = "flat",
416+
leaf_capacity: int = 1,
417+
leaf_radius: Optional[float] = None,
418+
pivoting: Optional[str] = None,
417419
):
418420
super().__init__(
419421
n_intervals=n_intervals,
@@ -479,15 +481,15 @@ def __init__(
479481
pivoting=pivoting,
480482
)
481483

482-
def _landmarks(self, X: ArrayLike) -> dict[tuple, NDArray]:
484+
def _landmarks(self, X: NDArray[np.float64]) -> dict[tuple, NDArray[np.float64]]:
483485
lmrks = {}
484486
for x in X:
485487
lmrk, _ = self._get_center(x)
486488
if lmrk not in lmrks:
487489
lmrks[lmrk] = x
488490
return lmrks
489491

490-
def apply(self, X: ArrayLike) -> Generator[list[int]]:
492+
def apply(self, X: NDArray[np.float64]) -> Generator[list[int]]:
491493
"""
492494
Covers the dataset using landmarks.
493495
@@ -575,7 +577,7 @@ def __init__(
575577
self.pivoting = pivoting
576578

577579
def _get_cubical_cover(self) -> Union[ProximityCubicalCover, StandardCubicalCover]:
578-
params = dict(
580+
params: dict[str, Any] = dict(
579581
n_intervals=self.n_intervals,
580582
overlap_frac=self.overlap_frac,
581583
kind=self.kind,
@@ -593,7 +595,7 @@ def _get_cubical_cover(self) -> Union[ProximityCubicalCover, StandardCubicalCove
593595
"'proximity'."
594596
)
595597

596-
def fit(self, X: ArrayLike) -> CubicalCover:
598+
def fit(self, X: NDArray[np.float64]) -> CubicalCover:
597599
"""
598600
Train internal parameters.
599601
@@ -623,7 +625,7 @@ def search(self, x: Any) -> list[int]:
623625
"""
624626
return self._cubical_cover.search(x)
625627

626-
def apply(self, X: ArrayLike) -> Generator[list[int]]:
628+
def apply(self, X: NDArray[np.float64]) -> Generator[list[int]]:
627629
"""
628630
Covers the dataset using hypercubes.
629631

src/tdamapper/learn.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,16 @@
99
scikit-learn's conventions for estimators.
1010
"""
1111

12-
from tdamapper._common import EstimatorMixin, ParamsMixin, clone
12+
from __future__ import annotations
13+
14+
from typing import Optional
15+
16+
import networkx as nx
17+
18+
from tdamapper._common import ArrayLike, EstimatorMixin, ParamsMixin, clone
1319
from tdamapper.core import (
20+
Clustering,
21+
Cover,
1422
FailSafeClustering,
1523
TrivialClustering,
1624
TrivialCover,
@@ -45,17 +53,19 @@ class MapperClustering(EstimatorMixin, ParamsMixin):
4553
:type n_jobs: int
4654
"""
4755

56+
labels_: list[int]
57+
4858
def __init__(
4959
self,
50-
cover=None,
51-
clustering=None,
52-
n_jobs=1,
60+
cover: Optional[Cover] = None,
61+
clustering: Optional[Clustering] = None,
62+
n_jobs: int = 1,
5363
):
5464
self.cover = cover
5565
self.clustering = clustering
5666
self.n_jobs = n_jobs
5767

58-
def fit(self, X, y=None):
68+
def fit(self, X: ArrayLike, y: Optional[ArrayLike] = None) -> MapperClustering:
5969
"""
6070
Fit the clustering algorithm to the data.
6171
@@ -123,21 +133,28 @@ class MapperAlgorithm(EstimatorMixin, ParamsMixin):
123133
:type n_jobs: int
124134
"""
125135

136+
_cover: Cover
137+
_clustering: Clustering
138+
_verbose: bool
139+
_failsafe: bool
140+
_n_jobs: int
141+
graph_: nx.Graph
142+
126143
def __init__(
127144
self,
128-
cover=None,
129-
clustering=None,
130-
failsafe=True,
131-
verbose=True,
132-
n_jobs=1,
145+
cover: Optional[Cover] = None,
146+
clustering: Optional[Clustering] = None,
147+
failsafe: bool = True,
148+
verbose: bool = True,
149+
n_jobs: int = 1,
133150
):
134151
self.cover = cover
135152
self.clustering = clustering
136153
self.failsafe = failsafe
137154
self.verbose = verbose
138155
self.n_jobs = n_jobs
139156

140-
def fit(self, X, y=None):
157+
def fit(self, X: ArrayLike, y: Optional[ArrayLike] = None) -> MapperAlgorithm:
141158
"""
142159
Create the Mapper graph and store it for later use.
143160
@@ -150,6 +167,8 @@ def fit(self, X, y=None):
150167
:type y: array-like of shape (n, k) or list-like of length n
151168
:return: The object itself.
152169
"""
170+
if y is None:
171+
y = X
153172
X, y = self._validate_X_y(X, y)
154173
self._cover = TrivialCover() if self.cover is None else self.cover
155174
self._clustering = (
@@ -176,7 +195,7 @@ def fit(self, X, y=None):
176195
self._set_n_features_in(X)
177196
return self
178197

179-
def fit_transform(self, X, y):
198+
def fit_transform(self, X: ArrayLike, y: ArrayLike) -> nx.Graph:
180199
"""
181200
Create the Mapper graph.
182201

src/tdamapper/plot.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
import igraph as ig
88
import networkx as nx
99
import numpy as np
10+
from numpy.typing import NDArray
1011
from plotly import graph_objects as go
1112

12-
from tdamapper._common import ArrayLike
1313
from tdamapper.plot_backends.plot_matplotlib import plot_matplotlib
1414
from tdamapper.plot_backends.plot_plotly import plot_plotly, plot_plotly_update
1515
from tdamapper.plot_backends.plot_pyvis import plot_pyvis
@@ -155,7 +155,7 @@ def plot_matplotlib(
155155

156156
def plot_plotly(
157157
self,
158-
colors: Union[np.ndarray, list[float]],
158+
colors: Union[NDArray[np.float64], list[float]],
159159
node_size: Union[int, float, list[Union[int, float]]] = 1,
160160
agg: Callable = np.nanmean,
161161
title: Optional[Union[str, list[str]]] = None,
@@ -210,7 +210,7 @@ def plot_plotly(
210210
def plot_plotly_update(
211211
self,
212212
fig: go.Figure,
213-
colors: Optional[Union[np.ndarray, list[float]]] = None,
213+
colors: Optional[Union[NDArray[np.float64], list[float]]] = None,
214214
node_size: Optional[Union[int, float, list[Union[int, float]]]] = None,
215215
agg: Optional[Callable] = None,
216216
title: Optional[Union[str, list[str]]] = None,
@@ -271,7 +271,7 @@ def plot_plotly_update(
271271
def plot_pyvis(
272272
self,
273273
output_file: str,
274-
colors: ArrayLike,
274+
colors: NDArray[np.float64],
275275
node_size: int = 1,
276276
agg: Callable = np.nanmean,
277277
title: Optional[str] = None,

src/tdamapper/plot_backends/plot_matplotlib.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
"""
55

66
import math
7-
from typing import Any, Callable
7+
from typing import Any, Callable, Optional
88

99
import matplotlib.pyplot as plt
1010
import networkx as nx
11+
import numpy as np
1112
from matplotlib.collections import LineCollection
1213
from numpy.typing import NDArray
1314

@@ -26,8 +27,8 @@ def plot_matplotlib(
2627
mapper_plot,
2728
width: int,
2829
height: int,
29-
title: str,
30-
colors: NDArray,
30+
title: Optional[str],
31+
colors: NDArray[np.float64],
3132
node_size: float,
3233
agg: Callable,
3334
cmap: str,
@@ -56,8 +57,8 @@ def plot_matplotlib(
5657

5758
def _plot_nodes(
5859
mapper_plot,
59-
title: str,
60-
colors: NDArray,
60+
title: Optional[str],
61+
colors: NDArray[np.float64],
6162
node_size: float,
6263
agg: Callable,
6364
cmap: str,
@@ -95,9 +96,10 @@ def _plot_nodes(
9596
ax=ax,
9697
format="%.2g",
9798
)
98-
colorbar.set_label(title, color=_NODE_OUTER_COLOR)
99+
if title is not None:
100+
colorbar.set_label(title, color=_NODE_OUTER_COLOR)
99101
colorbar.set_alpha(1.0)
100-
colorbar.outline.set_color(_NODE_OUTER_COLOR)
102+
# colorbar.outline.set_color(_NODE_OUTER_COLOR)
101103
colorbar.ax.yaxis.set_tick_params(
102104
color=_NODE_OUTER_COLOR, labelcolor=_NODE_OUTER_COLOR
103105
)

src/tdamapper/plot_backends/plot_plotly.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def _to_cmaps(cmap: Optional[Union[str, list[str]]]) -> list[str]:
9494
raise ValueError(f"Invalid cmap type: {type(cmap)}. Expected str or list[str].")
9595

9696

97-
def _to_colors(colors: Union[np.ndarray, list[float]]) -> np.ndarray:
97+
def _to_colors(colors: Union[NDArray[np.float64], list[float]]) -> np.ndarray:
9898
"""Convert colors to a numpy array."""
9999
colors_arr = np.array(colors)
100100
if colors_arr.ndim == 1:
@@ -146,7 +146,7 @@ def _get_cmap_rgb(cmap: str):
146146

147147
def plot_plotly(
148148
mapper_plot,
149-
colors: Union[np.ndarray, list[float]],
149+
colors: Union[NDArray[np.float64], list[float]],
150150
node_size: Optional[Union[int, float, list[Union[int, float]]]] = None,
151151
title: Optional[Union[str, list[str]]] = None,
152152
agg: Callable = np.nanmean,
@@ -192,7 +192,7 @@ def plot_plotly_update(
192192
width: Optional[int] = None,
193193
height: Optional[int] = None,
194194
node_size: Optional[Union[int, float, list[Union[int, float]]]] = None,
195-
colors: Optional[Union[np.ndarray, list[float]]] = None,
195+
colors: Optional[Union[NDArray[np.float64], list[float]]] = None,
196196
title: Optional[Union[str, list[str]]] = None,
197197
agg: Optional[Callable] = None,
198198
cmap: Optional[Union[str, list[str]]] = None,
@@ -266,7 +266,7 @@ def __init__(self, mapper_plot, fig: Optional[go.Figure] = None):
266266

267267
def plot(
268268
self,
269-
colors: np.ndarray,
269+
colors: NDArray[np.float64],
270270
node_sizes: list[float],
271271
titles: list[str],
272272
agg: Callable,

0 commit comments

Comments
 (0)