2828this module is a NetworkX graph object.
2929"""
3030
31+ from __future__ import annotations
32+
3133import logging
34+ from typing import Any , Callable , Dict , Generator , List , Optional , Protocol , Union
3235
3336import networkx as nx
37+ import numpy as np
3438from joblib import Parallel , delayed
39+ from numpy .typing import NDArray
3540
3641from tdamapper ._common import EstimatorMixin , ParamsMixin , clone , deprecated
3742from tdamapper .utils .unionfind import UnionFind
5055 handlers = [logging .StreamHandler ()],
5156)
5257
58+ ArrayLike = Union [List [Any ], NDArray [np .float64 ]]
59+
5360
54- def mapper_labels (X , y , cover , clustering , n_jobs = 1 ):
61+ def mapper_labels (
62+ X : ArrayLike ,
63+ y : ArrayLike ,
64+ cover : Cover ,
65+ clustering : Clustering ,
66+ n_jobs : int = 1 ,
67+ ) -> List [List [int ]]:
5568 """
5669 Identify the nodes of the Mapper graph.
5770
@@ -94,9 +107,9 @@ def _run_clustering(local_ids, X_local, clust):
94107 delayed (_run_clustering )(
95108 local_ids , [X [j ] for j in local_ids ], clone (clustering )
96109 )
97- for local_ids in cover .apply (y )
110+ for local_ids in cover .transform (y )
98111 )
99- itm_lbls = [[] for _ in X ]
112+ itm_lbls : List [ List [ int ]] = [[] for _ in X ]
100113 max_lbl = 0
101114 for local_ids , local_lbls in _lbls :
102115 max_local_lbl = 0
@@ -109,7 +122,13 @@ def _run_clustering(local_ids, X_local, clust):
109122 return itm_lbls
110123
111124
112- def mapper_connected_components (X , y , cover , clustering , n_jobs = 1 ):
125+ def mapper_connected_components (
126+ X : ArrayLike ,
127+ y : ArrayLike ,
128+ cover : Cover ,
129+ clustering : Clustering ,
130+ n_jobs : int = 1 ,
131+ ) -> List [int ]:
113132 """
114133 Identify the connected components of the Mapper graph.
115134
@@ -159,7 +178,13 @@ def mapper_connected_components(X, y, cover, clustering, n_jobs=1):
159178 return labels
160179
161180
162- def mapper_graph (X , y , cover , clustering , n_jobs = 1 ):
181+ def mapper_graph (
182+ X : ArrayLike ,
183+ y : ArrayLike ,
184+ cover : Cover ,
185+ clustering : Clustering ,
186+ n_jobs : int = 1 ,
187+ ) -> nx .Graph :
163188 """
164189 Create the Mapper graph.
165190
@@ -211,7 +236,7 @@ def mapper_graph(X, y, cover, clustering, n_jobs=1):
211236 return graph
212237
213238
214- def aggregate_graph (X , graph , agg ) :
239+ def aggregate_graph (X : ArrayLike , graph : nx . Graph , agg : Callable ) -> Dict :
215240 """
216241 Apply an aggregation function to the nodes of a graph.
217242
@@ -243,101 +268,29 @@ def aggregate_graph(X, graph, agg):
243268 return agg_values
244269
245270
246- class Cover (ParamsMixin ):
271+ class Cover (Protocol ):
247272 """
248273 Abstract interface for cover algorithms.
249274
250- This is a naive implementation. Subclasses should override the methods of
275+ Subclasses should override the methods of
251276 this class to implement more meaningful cover algorithms.
252277 """
253278
254- def apply (self , X ):
255- """
256- Covers the dataset with a single open set.
257-
258- This is a naive implementation that returns a generator producing a
259- single list containing all the ids if the original dataset. This
260- method should be overridden by subclasses to implement more meaningful
261- cover algorithms.
262-
263- :param X: A dataset of n points.
264- :type X: array-like of shape (n, m) or list-like of length n
265- :return: A generator of lists of ids.
266- :rtype: generator of lists of ints
267- """
268- yield list (range (0 , len (X )))
269-
270-
271- class Proximity (Cover ):
272- """
273- Abstract interface for proximity functions. A proximity function is a
274- function that maps each point into a subset of the dataset that contains
275- the point itself. Every proximity function defines also a covering
276- algorithm based on proximity-netm that is implemented in this class.
277-
278- Proximity functions, implemented as subclasses of this class, are a
279- convenient way to implement open cover algorithms by using the
280- proximity-net construction. Proximity-net is implemented by function
281- :func:`tdamapper.core.Proximity.apply`.
282-
283- Subclasses should override the methods :func:`tdamapper.core.Proximity.fit`
284- and :func:`tdamapper.core.Proximity.search` of this class to implement
285- more meaningful proximity functions.
286- """
287-
288- def fit (self , X ):
289- """
290- Train internal parameters.
291-
292- This is a naive implementation that should be overridden by subclasses
293- to implement more meaningful proximity functions.
294-
295- :param X: A dataset of n points.
296- :type X: array-like of shape (n, m) or list-like of length n
297- :return: The object itself.
298- :rtype: self
299- """
300- self ._X = X
301- return self
302-
303- def search (self , x ):
304- """
305- Return a list of neighbors for the query point.
279+ def fit (self , X : ArrayLike ) -> Cover : ...
306280
307- This is a naive implementation that returns all the points in the
308- dataset as neighbors. This method should be overridden by subclasses
309- to implement more meaningful proximity functions.
281+ def transform (self , X : ArrayLike ) -> Generator [List [int ], None , None ]: ...
310282
311- :param x: A query point for which we want to find neighbors.
312- :type x: Any
313- :return: A list containing all the indices of the points in the
314- dataset.
315- :rtype: list[int]
316- """
317- return list (range (0 , len (self ._X )))
318283
319- def apply (self , X ):
320- """
321- Covers the dataset using proximity-net.
284+ class ProximityNetCover :
322285
323- This function applies an iterative algorithm to create the
324- proximity-net. It picks an arbitrary point and forms an open cover
325- calling the proximity function on the chosen point. The points
326- contained in the open cover are then marked as covered, and discarded
327- in the following steps. The procedure is repeated on the leftover
328- points until every point is eventually covered.
286+ def fit (self , X : ArrayLike ) -> Cover :
287+ raise NotImplementedError ()
329288
330- This function returns a generator that yields each element of the
331- proximity-net as a list of ids. The ids are the indices of the points
332- in the original dataset.
289+ def search (self , x : Any ) -> List [int ]:
290+ raise NotImplementedError ()
333291
334- :param X: A dataset of n points.
335- :type X: array-like of shape (n, m) or list-like of length n
336- :return: A generator of lists of ids.
337- :rtype: generator of lists of ints
338- """
292+ def transform (self , X : ArrayLike ) -> Generator [List [int ], None , None ]:
339293 covered_ids = set ()
340- self .fit (X )
341294 for i , xi in enumerate (X ):
342295 if i not in covered_ids :
343296 neigh_ids = self .search (xi )
@@ -346,7 +299,14 @@ def apply(self, X):
346299 yield neigh_ids
347300
348301
349- class TrivialCover (Cover ):
302+ class Clustering (Protocol ):
303+
304+ labels_ : List [int ]
305+
306+ def fit (self , X : ArrayLike , y : Any = None ) -> Clustering : ...
307+
308+
309+ class TrivialCover :
350310 """
351311 Cover algorithm that covers data with a single subset containing the whole
352312 dataset.
@@ -355,35 +315,30 @@ class TrivialCover(Cover):
355315 dataset.
356316 """
357317
358- def apply (self , X ):
359- """
360- Covers the dataset with a single open set.
318+ def fit (self , X : ArrayLike ) -> TrivialCover :
319+ return self
361320
362- :param X: A dataset of n points.
363- :type X: array-like of shape (n, m) or list-like of length n
364- :return: A generator of lists of ids.
365- :rtype: generator of lists of ints
366- """
367- yield list (range (0 , len (X )))
321+ def transform (self , X : ArrayLike ) -> Generator [List [int ], None , None ]:
322+ yield list (range (len (X )))
368323
369324
370325class _MapperAlgorithm (EstimatorMixin , ParamsMixin ):
371326
372327 def __init__ (
373328 self ,
374- cover = None ,
375- clustering = None ,
376- failsafe = True ,
377- verbose = True ,
378- n_jobs = 1 ,
329+ cover : Optional [ Cover ] = None ,
330+ clustering : Optional [ Clustering ] = None ,
331+ failsafe : bool = True ,
332+ verbose : bool = True ,
333+ n_jobs : int = 1 ,
379334 ):
380335 self .cover = cover
381336 self .clustering = clustering
382337 self .failsafe = failsafe
383338 self .verbose = verbose
384339 self .n_jobs = n_jobs
385340
386- def fit (self , X , y = None ):
341+ def fit (self , X : ArrayLike , y : Optional [ ArrayLike ] = None ):
387342 X , y = self ._validate_X_y (X , y )
388343 self ._cover = TrivialCover () if self .cover is None else self .cover
389344 self ._clustering = (
@@ -410,7 +365,7 @@ def fit(self, X, y=None):
410365 self ._set_n_features_in (X )
411366 return self
412367
413- def fit_transform (self , X , y ) :
368+ def fit_transform (self , X : ArrayLike , y : ArrayLike ) -> nx . Graph :
414369 self .fit (X , y )
415370 return self .graph_
416371
@@ -446,11 +401,11 @@ class FailSafeClustering(ParamsMixin):
446401 :type verbose: bool, optional.
447402 """
448403
449- def __init__ (self , clustering = None , verbose = True ):
404+ def __init__ (self , clustering : Optional [ Clustering ] = None , verbose : bool = True ):
450405 self .clustering = clustering
451406 self .verbose = verbose
452407
453- def fit (self , X , y = None ):
408+ def fit (self , X : ArrayLike , y : Optional [ ArrayLike ] = None ):
454409 self ._clustering = (
455410 TrivialClustering () if self .clustering is None else self .clustering
456411 )
@@ -479,7 +434,7 @@ class TrivialClustering(ParamsMixin):
479434 def __init__ (self ):
480435 pass
481436
482- def fit (self , X , y = None ):
437+ def fit (self , X : ArrayLike , y : Optional [ ArrayLike ] = None ) -> TrivialClustering :
483438 """
484439 Fit the clustering algorithm to the data.
485440
0 commit comments