Skip to content

Commit 61794a4

Browse files
committed
Added parametric types for protocols
1 parent 5ba1c50 commit 61794a4

19 files changed

Lines changed: 269 additions & 241 deletions

File tree

src/tdamapper/_common.py

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,37 +8,14 @@
88
import io
99
import pstats
1010
import warnings
11-
from typing import Any, Callable, Iterator, Protocol, TypeVar
11+
from typing import Any, Callable
1212

1313
import numpy as np
1414
from numpy.typing import NDArray
1515

16-
warnings.filterwarnings("default", category=DeprecationWarning, module=r"^tdamapper\.")
17-
18-
T = TypeVar("T")
19-
20-
21-
class Array(Protocol[T]):
22-
23-
def __getitem__(self, index: int) -> T:
24-
"""
25-
Get an item from the array.
26-
"""
27-
28-
def __len__(self) -> int:
29-
"""
30-
Get the length of the array.
31-
"""
32-
33-
def __setitem__(self, index: int, value: T) -> None:
34-
"""
35-
Set an item in the array.
36-
"""
16+
from tdamapper.protocols import Array
3717

38-
def __iter__(self) -> Iterator[T]:
39-
"""
40-
Iterate over the array.
41-
"""
18+
warnings.filterwarnings("default", category=DeprecationWarning, module=r"^tdamapper\.")
4219

4320

4421
def deprecated(msg: str) -> Callable[..., Any]:

src/tdamapper/app.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515
from sklearn.preprocessing import StandardScaler
1616
from umap import UMAP
1717

18-
from tdamapper.core import Cover, TrivialClustering
18+
from tdamapper.core import TrivialClustering
1919
from tdamapper.cover import BallCover, CubicalCover, KNNCover
2020
from tdamapper.learn import MapperAlgorithm
2121
from tdamapper.plot import MapperPlot
22+
from tdamapper.protocols import Clustering, Cover
2223

2324
logging.basicConfig(level=logging.INFO)
2425
logger = logging.getLogger(__name__)
@@ -164,7 +165,7 @@ def run_mapper(
164165
elif lens_type == LENS_UMAP:
165166
lens = lens_umap(n_components=lens_umap_n_components)
166167

167-
cover: Cover
168+
cover: Cover[NDArray[np.float_]]
168169
if cover_type == COVER_CUBICAL:
169170
cover = CubicalCover(
170171
n_intervals=cover_cubical_n_intervals,
@@ -178,6 +179,7 @@ def run_mapper(
178179
logger.error(f"Unknown cover type: {cover_type}")
179180
return None
180181

182+
clustering: Clustering[NDArray[np.float_]]
181183
if clustering_type == CLUSTERING_TRIVIAL:
182184
clustering = TrivialClustering()
183185
elif clustering_type == CLUSTERING_KMEANS:

src/tdamapper/core.py

Lines changed: 33 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,13 @@
3131
from __future__ import annotations
3232

3333
import logging
34-
from typing import Any, Callable, Iterator, Optional, Protocol
34+
from typing import Any, Callable, Generic, Iterator, Optional, TypeVar
3535

3636
import networkx as nx
3737
from joblib import Parallel, delayed
3838

39-
from tdamapper._common import Array, ParamsMixin, clone
39+
from tdamapper._common import ParamsMixin, clone
40+
from tdamapper.protocols import Array, Clustering, Cover, SpatialSearch
4041
from tdamapper.utils.unionfind import UnionFind
4142

4243
ATTR_IDS = "ids"
@@ -53,9 +54,17 @@
5354
handlers=[logging.StreamHandler()],
5455
)
5556

57+
S = TypeVar("S")
58+
59+
T = TypeVar("T")
60+
5661

5762
def mapper_labels(
58-
X: Array[Any], y: Array[Any], cover: Cover, clustering: Clustering, n_jobs: int = 1
63+
X: Array[S],
64+
y: Array[T],
65+
cover: Cover[T],
66+
clustering: Clustering[S],
67+
n_jobs: int = 1,
5968
) -> list[list[int]]:
6069
"""
6170
Identify the nodes of the Mapper graph.
@@ -85,7 +94,7 @@ def mapper_labels(
8594
"""
8695

8796
def _run_clustering(
88-
local_ids: list[int], X_local: Array[Any], clust: Clustering
97+
local_ids: list[int], X_local: Array[S], clust: Clustering[S]
8998
) -> tuple[list[int], list[int]]:
9099
local_lbls = clust.fit(X_local).labels_
91100
return local_ids, local_lbls
@@ -110,7 +119,11 @@ def _run_clustering(
110119

111120

112121
def mapper_connected_components(
113-
X: Array[Any], y: Array[Any], cover: Cover, clustering: Clustering, n_jobs: int = 1
122+
X: Array[S],
123+
y: Array[T],
124+
cover: Cover[T],
125+
clustering: Clustering[S],
126+
n_jobs: int = 1,
114127
) -> list[int]:
115128
"""
116129
Identify the connected components of the Mapper graph.
@@ -155,7 +168,11 @@ def mapper_connected_components(
155168

156169

157170
def mapper_graph(
158-
X: Array[Any], y: Array[Any], cover: Cover, clustering: Clustering, n_jobs: int = 1
171+
X: Array[S],
172+
y: Array[T],
173+
cover: Cover[T],
174+
clustering: Clustering[S],
175+
n_jobs: int = 1,
159176
) -> nx.Graph:
160177
"""
161178
Create the Mapper graph.
@@ -201,7 +218,7 @@ def mapper_graph(
201218

202219

203220
def aggregate_graph(
204-
X: Array[Any], graph: nx.Graph, agg: Callable[..., Any]
221+
X: Array[S], graph: nx.Graph, agg: Callable[..., Any]
205222
) -> dict[int, Any]:
206223
"""
207224
Apply an aggregation function to the nodes of a graph.
@@ -229,81 +246,7 @@ def aggregate_graph(
229246
return agg_values
230247

231248

232-
class Cover(Protocol):
233-
"""
234-
Abstract interface for cover algorithms.
235-
236-
This is a naive implementation. Subclasses should override the methods of
237-
this class to implement more meaningful cover algorithms.
238-
"""
239-
240-
def apply(self, X: Array[Any]) -> Iterator[list[int]]:
241-
"""
242-
Covers the dataset with a single open set.
243-
244-
This is a naive implementation that returns a generator producing a
245-
single list containing all the ids if the original dataset. This
246-
method should be overridden by subclasses to implement more meaningful
247-
cover algorithms.
248-
249-
:param X: A dataset of n points.
250-
:return: A generator of lists of ids.
251-
"""
252-
253-
254-
class Clustering(Protocol):
255-
"""
256-
Abstract interface for clustering algorithms.
257-
258-
A clustering algorithm is a method for grouping data points into clusters.
259-
Each cluster is represented by a unique integer label, and the labels are
260-
assigned to the points in the dataset. The labels are typically non-negative
261-
integers, starting from zero. The labels are assigned such that the points
262-
in the same cluster have the same label, and the points in different clusters
263-
have different labels. The labels are not necessarily contiguous, and there
264-
may be gaps in the sequence of labels.
265-
"""
266-
267-
labels_: list[int]
268-
269-
def fit(self, X: Array[Any], y: Optional[Array[Any]] = None) -> Clustering:
270-
"""
271-
Fit the clustering algorithm to the data.
272-
273-
:param X: A dataset of n points.
274-
:param y: A dataset of targets. Typically ignored and present for
275-
compatibility with scikit-learn's clustering interface.
276-
:return: The fitted clustering object.
277-
"""
278-
279-
280-
class SpatialSearch(Protocol):
281-
"""
282-
Abstract interface for search algorithms.
283-
284-
A spatial search algorithm is a method for finding neighbors of a
285-
query point in a dataset.
286-
"""
287-
288-
def fit(self, X: Array[Any]) -> SpatialSearch:
289-
"""
290-
Train internal parameters.
291-
292-
:param X: A dataset of n points.
293-
:return: The object itself.
294-
"""
295-
296-
def search(self, x: Any) -> list[int]:
297-
"""
298-
Return a list of neighbors for the query point.
299-
300-
:param x: A query point for which we want to find neighbors.
301-
:return: A list containing all the indices of the points in the
302-
dataset.
303-
"""
304-
305-
306-
def proximity_net(search: SpatialSearch, X: Array[Any]) -> Iterator[list[int]]:
249+
def proximity_net(search: SpatialSearch[S], X: Array[S]) -> Iterator[list[int]]:
307250
"""
308251
Covers the dataset using proximity-net.
309252
@@ -331,7 +274,7 @@ def proximity_net(search: SpatialSearch, X: Array[Any]) -> Iterator[list[int]]:
331274
yield neigh_ids
332275

333276

334-
class TrivialCover(ParamsMixin):
277+
class TrivialCover(ParamsMixin, Generic[T]):
335278
"""
336279
Cover algorithm that covers data with a single subset containing the whole
337280
dataset.
@@ -340,7 +283,7 @@ class TrivialCover(ParamsMixin):
340283
dataset.
341284
"""
342285

343-
def apply(self, X: Array[Any]) -> Iterator[list[int]]:
286+
def apply(self, X: Array[T]) -> Iterator[list[int]]:
344287
"""
345288
Covers the dataset with a single open set.
346289
@@ -350,7 +293,7 @@ def apply(self, X: Array[Any]) -> Iterator[list[int]]:
350293
yield list(range(0, len(X)))
351294

352295

353-
class FailSafeClustering(ParamsMixin):
296+
class FailSafeClustering(ParamsMixin, Generic[T]):
354297
"""
355298
A delegating clustering algorithm that prevents failure.
356299
@@ -364,17 +307,17 @@ class FailSafeClustering(ParamsMixin):
364307
enable logging, or False to suppress it. Defaults to True.
365308
"""
366309

367-
_clustering: Optional[Clustering]
310+
_clustering: Optional[Clustering[T]]
368311
_verbose: bool
369312
labels_: list[int]
370313

371314
def __init__(
372-
self, clustering: Optional[Clustering] = None, verbose: bool = True
315+
self, clustering: Optional[Clustering[T]] = None, verbose: bool = True
373316
) -> None:
374317
self.clustering = clustering
375318
self.verbose = verbose
376319

377-
def fit(self, X: Array[Any], y: Optional[Array[Any]] = None) -> FailSafeClustering:
320+
def fit(self, X: Array[T], y: Optional[Array[T]] = None) -> FailSafeClustering[T]:
378321
self._clustering = (
379322
TrivialClustering() if self.clustering is None else self.clustering
380323
)
@@ -389,7 +332,7 @@ def fit(self, X: Array[Any], y: Optional[Array[Any]] = None) -> FailSafeClusteri
389332
return self
390333

391334

392-
class TrivialClustering(ParamsMixin):
335+
class TrivialClustering(ParamsMixin, Generic[T]):
393336
"""
394337
A clustering algorithm that returns a single cluster.
395338
@@ -404,7 +347,7 @@ class TrivialClustering(ParamsMixin):
404347
def __init__(self) -> None:
405348
pass
406349

407-
def fit(self, X: Array[Any], _y: Optional[Array[Any]] = None) -> TrivialClustering:
350+
def fit(self, X: Array[T], _y: Optional[Array[T]] = None) -> TrivialClustering[T]:
408351
"""
409352
Fit the clustering algorithm to the data.
410353

0 commit comments

Comments
 (0)