Skip to content

Commit 4877d91

Browse files
committed
Added first implementation of types
1 parent c8c1332 commit 4877d91

2 files changed

Lines changed: 111 additions & 122 deletions

File tree

src/tdamapper/core.py

Lines changed: 63 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,15 @@
2828
this module is a NetworkX graph object.
2929
"""
3030

31+
from __future__ import annotations
32+
3133
import logging
34+
from typing import Any, Callable, Dict, Generator, List, Optional, Protocol, Union
3235

3336
import networkx as nx
37+
import numpy as np
3438
from joblib import Parallel, delayed
39+
from numpy.typing import NDArray
3540

3641
from tdamapper._common import EstimatorMixin, ParamsMixin, clone, deprecated
3742
from tdamapper.utils.unionfind import UnionFind
@@ -50,8 +55,16 @@
5055
handlers=[logging.StreamHandler()],
5156
)
5257

58+
ArrayLike = Union[List[Any], NDArray[np.float64]]
59+
5360

54-
def mapper_labels(X, y, cover, clustering, n_jobs=1):
61+
def mapper_labels(
62+
X: ArrayLike,
63+
y: ArrayLike,
64+
cover: Cover,
65+
clustering: Clustering,
66+
n_jobs: int = 1,
67+
) -> List[List[int]]:
5568
"""
5669
Identify the nodes of the Mapper graph.
5770
@@ -94,9 +107,9 @@ def _run_clustering(local_ids, X_local, clust):
94107
delayed(_run_clustering)(
95108
local_ids, [X[j] for j in local_ids], clone(clustering)
96109
)
97-
for local_ids in cover.apply(y)
110+
for local_ids in cover.transform(y)
98111
)
99-
itm_lbls = [[] for _ in X]
112+
itm_lbls: List[List[int]] = [[] for _ in X]
100113
max_lbl = 0
101114
for local_ids, local_lbls in _lbls:
102115
max_local_lbl = 0
@@ -109,7 +122,13 @@ def _run_clustering(local_ids, X_local, clust):
109122
return itm_lbls
110123

111124

112-
def mapper_connected_components(X, y, cover, clustering, n_jobs=1):
125+
def mapper_connected_components(
126+
X: ArrayLike,
127+
y: ArrayLike,
128+
cover: Cover,
129+
clustering: Clustering,
130+
n_jobs: int = 1,
131+
) -> List[int]:
113132
"""
114133
Identify the connected components of the Mapper graph.
115134
@@ -159,7 +178,13 @@ def mapper_connected_components(X, y, cover, clustering, n_jobs=1):
159178
return labels
160179

161180

162-
def mapper_graph(X, y, cover, clustering, n_jobs=1):
181+
def mapper_graph(
182+
X: ArrayLike,
183+
y: ArrayLike,
184+
cover: Cover,
185+
clustering: Clustering,
186+
n_jobs: int = 1,
187+
) -> nx.Graph:
163188
"""
164189
Create the Mapper graph.
165190
@@ -211,7 +236,7 @@ def mapper_graph(X, y, cover, clustering, n_jobs=1):
211236
return graph
212237

213238

214-
def aggregate_graph(X, graph, agg):
239+
def aggregate_graph(X: ArrayLike, graph: nx.Graph, agg: Callable) -> Dict:
215240
"""
216241
Apply an aggregation function to the nodes of a graph.
217242
@@ -243,101 +268,29 @@ def aggregate_graph(X, graph, agg):
243268
return agg_values
244269

245270

246-
class Cover(ParamsMixin):
271+
class Cover(Protocol):
247272
"""
248273
Abstract interface for cover algorithms.
249274
250-
This is a naive implementation. Subclasses should override the methods of
275+
Subclasses should override the methods of
251276
this class to implement more meaningful cover algorithms.
252277
"""
253278

254-
def apply(self, X):
255-
"""
256-
Covers the dataset with a single open set.
257-
258-
This is a naive implementation that returns a generator producing a
259-
single list containing all the ids if the original dataset. This
260-
method should be overridden by subclasses to implement more meaningful
261-
cover algorithms.
262-
263-
:param X: A dataset of n points.
264-
:type X: array-like of shape (n, m) or list-like of length n
265-
:return: A generator of lists of ids.
266-
:rtype: generator of lists of ints
267-
"""
268-
yield list(range(0, len(X)))
269-
270-
271-
class Proximity(Cover):
272-
"""
273-
Abstract interface for proximity functions. A proximity function is a
274-
function that maps each point into a subset of the dataset that contains
275-
the point itself. Every proximity function defines also a covering
276-
algorithm based on proximity-netm that is implemented in this class.
277-
278-
Proximity functions, implemented as subclasses of this class, are a
279-
convenient way to implement open cover algorithms by using the
280-
proximity-net construction. Proximity-net is implemented by function
281-
:func:`tdamapper.core.Proximity.apply`.
282-
283-
Subclasses should override the methods :func:`tdamapper.core.Proximity.fit`
284-
and :func:`tdamapper.core.Proximity.search` of this class to implement
285-
more meaningful proximity functions.
286-
"""
287-
288-
def fit(self, X):
289-
"""
290-
Train internal parameters.
291-
292-
This is a naive implementation that should be overridden by subclasses
293-
to implement more meaningful proximity functions.
294-
295-
:param X: A dataset of n points.
296-
:type X: array-like of shape (n, m) or list-like of length n
297-
:return: The object itself.
298-
:rtype: self
299-
"""
300-
self._X = X
301-
return self
302-
303-
def search(self, x):
304-
"""
305-
Return a list of neighbors for the query point.
279+
def fit(self, X: ArrayLike) -> Cover: ...
306280

307-
This is a naive implementation that returns all the points in the
308-
dataset as neighbors. This method should be overridden by subclasses
309-
to implement more meaningful proximity functions.
281+
def transform(self, X: ArrayLike) -> Generator[List[int], None, None]: ...
310282

311-
:param x: A query point for which we want to find neighbors.
312-
:type x: Any
313-
:return: A list containing all the indices of the points in the
314-
dataset.
315-
:rtype: list[int]
316-
"""
317-
return list(range(0, len(self._X)))
318283

319-
def apply(self, X):
320-
"""
321-
Covers the dataset using proximity-net.
284+
class ProximityNetCover:
322285

323-
This function applies an iterative algorithm to create the
324-
proximity-net. It picks an arbitrary point and forms an open cover
325-
calling the proximity function on the chosen point. The points
326-
contained in the open cover are then marked as covered, and discarded
327-
in the following steps. The procedure is repeated on the leftover
328-
points until every point is eventually covered.
286+
def fit(self, X: ArrayLike) -> Cover:
287+
raise NotImplementedError()
329288

330-
This function returns a generator that yields each element of the
331-
proximity-net as a list of ids. The ids are the indices of the points
332-
in the original dataset.
289+
def search(self, x: Any) -> List[int]:
290+
raise NotImplementedError()
333291

334-
:param X: A dataset of n points.
335-
:type X: array-like of shape (n, m) or list-like of length n
336-
:return: A generator of lists of ids.
337-
:rtype: generator of lists of ints
338-
"""
292+
def transform(self, X: ArrayLike) -> Generator[List[int], None, None]:
339293
covered_ids = set()
340-
self.fit(X)
341294
for i, xi in enumerate(X):
342295
if i not in covered_ids:
343296
neigh_ids = self.search(xi)
@@ -346,7 +299,14 @@ def apply(self, X):
346299
yield neigh_ids
347300

348301

349-
class TrivialCover(Cover):
302+
class Clustering(Protocol):
303+
304+
labels_: List[int]
305+
306+
def fit(self, X: ArrayLike, y: Any = None) -> Clustering: ...
307+
308+
309+
class TrivialCover:
350310
"""
351311
Cover algorithm that covers data with a single subset containing the whole
352312
dataset.
@@ -355,35 +315,30 @@ class TrivialCover(Cover):
355315
dataset.
356316
"""
357317

358-
def apply(self, X):
359-
"""
360-
Covers the dataset with a single open set.
318+
def fit(self, X: ArrayLike) -> TrivialCover:
319+
return self
361320

362-
:param X: A dataset of n points.
363-
:type X: array-like of shape (n, m) or list-like of length n
364-
:return: A generator of lists of ids.
365-
:rtype: generator of lists of ints
366-
"""
367-
yield list(range(0, len(X)))
321+
def transform(self, X: ArrayLike) -> Generator[List[int], None, None]:
322+
yield list(range(len(X)))
368323

369324

370325
class _MapperAlgorithm(EstimatorMixin, ParamsMixin):
371326

372327
def __init__(
373328
self,
374-
cover=None,
375-
clustering=None,
376-
failsafe=True,
377-
verbose=True,
378-
n_jobs=1,
329+
cover: Optional[Cover] = None,
330+
clustering: Optional[Clustering] = None,
331+
failsafe: bool = True,
332+
verbose: bool = True,
333+
n_jobs: int = 1,
379334
):
380335
self.cover = cover
381336
self.clustering = clustering
382337
self.failsafe = failsafe
383338
self.verbose = verbose
384339
self.n_jobs = n_jobs
385340

386-
def fit(self, X, y=None):
341+
def fit(self, X: ArrayLike, y: Optional[ArrayLike] = None):
387342
X, y = self._validate_X_y(X, y)
388343
self._cover = TrivialCover() if self.cover is None else self.cover
389344
self._clustering = (
@@ -410,7 +365,7 @@ def fit(self, X, y=None):
410365
self._set_n_features_in(X)
411366
return self
412367

413-
def fit_transform(self, X, y):
368+
def fit_transform(self, X: ArrayLike, y: ArrayLike) -> nx.Graph:
414369
self.fit(X, y)
415370
return self.graph_
416371

@@ -446,11 +401,11 @@ class FailSafeClustering(ParamsMixin):
446401
:type verbose: bool, optional.
447402
"""
448403

449-
def __init__(self, clustering=None, verbose=True):
404+
def __init__(self, clustering: Optional[Clustering] = None, verbose: bool = True):
450405
self.clustering = clustering
451406
self.verbose = verbose
452407

453-
def fit(self, X, y=None):
408+
def fit(self, X: ArrayLike, y: Optional[ArrayLike] = None):
454409
self._clustering = (
455410
TrivialClustering() if self.clustering is None else self.clustering
456411
)
@@ -479,7 +434,7 @@ class TrivialClustering(ParamsMixin):
479434
def __init__(self):
480435
pass
481436

482-
def fit(self, X, y=None):
437+
def fit(self, X: ArrayLike, y: Optional[ArrayLike] = None) -> TrivialClustering:
483438
"""
484439
Fit the clustering algorithm to the data.
485440

0 commit comments

Comments
 (0)