99scikit-learn's conventions for estimators.
1010"""
1111
12- from tdamapper .clustering import _MapperClustering
13- from tdamapper .core import _MapperAlgorithm
12+ from tdamapper ._common import EstimatorMixin , ParamsMixin , clone
13+ from tdamapper .core import (
14+ FailSafeClustering ,
15+ TrivialClustering ,
16+ TrivialCover ,
17+ mapper_connected_components ,
18+ mapper_graph ,
19+ )
1420
1521
16- class MapperClustering (_MapperClustering ):
22+ class MapperClustering (EstimatorMixin , ParamsMixin ):
1723 """
1824 A clustering algorithm based on the Mapper graph.
1925
@@ -45,11 +51,9 @@ def __init__(
4551 clustering = None ,
4652 n_jobs = 1 ,
4753 ):
48- super ().__init__ (
49- cover = cover ,
50- clustering = clustering ,
51- n_jobs = n_jobs ,
52- )
54+ self .cover = cover
55+ self .clustering = clustering
56+ self .n_jobs = n_jobs
5357
5458 def fit (self , X , y = None ):
5559 """
@@ -60,10 +64,26 @@ def fit(self, X, y=None):
6064 :param y: Ignored.
6165 :return: self
6266 """
63- return super ().fit (X , y )
67+ y = X if y is None else y
68+ X , y = self ._validate_X_y (X , y )
69+ cover = TrivialCover () if self .cover is None else self .cover
70+ cover = clone (cover )
71+ clustering = TrivialClustering () if self .clustering is None else self .clustering
72+ clustering = clone (clustering )
73+ n_jobs = self .n_jobs
74+ itm_lbls = mapper_connected_components (
75+ X ,
76+ y ,
77+ cover ,
78+ clustering ,
79+ n_jobs = n_jobs ,
80+ )
81+ self .labels_ = [itm_lbls [i ] for i , _ in enumerate (X )]
82+ self ._set_n_features_in (X )
83+ return self
6484
6585
66- class MapperAlgorithm (_MapperAlgorithm ):
86+ class MapperAlgorithm (EstimatorMixin , ParamsMixin ):
6787 """
6888 A class for creating and analyzing Mapper graphs.
6989
@@ -111,13 +131,11 @@ def __init__(
111131 verbose = True ,
112132 n_jobs = 1 ,
113133 ):
114- super ().__init__ (
115- cover = cover ,
116- clustering = clustering ,
117- failsafe = failsafe ,
118- verbose = verbose ,
119- n_jobs = n_jobs ,
120- )
134+ self .cover = cover
135+ self .clustering = clustering
136+ self .failsafe = failsafe
137+ self .verbose = verbose
138+ self .n_jobs = n_jobs
121139
122140 def fit (self , X , y = None ):
123141 """
@@ -132,7 +150,31 @@ def fit(self, X, y=None):
132150 :type y: array-like of shape (n, k) or list-like of length n
133151 :return: The object itself.
134152 """
135- return super ().fit (X , y )
153+ X , y = self ._validate_X_y (X , y )
154+ self ._cover = TrivialCover () if self .cover is None else self .cover
155+ self ._clustering = (
156+ TrivialClustering () if self .clustering is None else self .clustering
157+ )
158+ self ._verbose = self .verbose
159+ self ._failsafe = self .failsafe
160+ if self ._failsafe :
161+ self ._clustering = FailSafeClustering (
162+ clustering = self ._clustering ,
163+ verbose = self ._verbose ,
164+ )
165+ self ._cover = clone (self ._cover )
166+ self ._clustering = clone (self ._clustering )
167+ self ._n_jobs = self .n_jobs
168+ y = X if y is None else y
169+ self .graph_ = mapper_graph (
170+ X ,
171+ y ,
172+ self ._cover ,
173+ self ._clustering ,
174+ n_jobs = self ._n_jobs ,
175+ )
176+ self ._set_n_features_in (X )
177+ return self
136178
137179 def fit_transform (self , X , y ):
138180 """
@@ -148,4 +190,5 @@ def fit_transform(self, X, y):
148190 :return: The Mapper graph.
149191 :rtype: :class:`networkx.Graph`
150192 """
151- return super ().fit_transform (X , y )
193+ self .fit (X , y )
194+ return self .graph_
0 commit comments