3333from joblib import Parallel , delayed
3434
3535from tdamapper .utils .unionfind import UnionFind
36- from tdamapper ._common import ParamsMixin , clone
36+ from tdamapper ._common import ParamsMixin , EstimatorMixin , clone
3737
3838
3939ATTR_IDS = 'ids'
@@ -364,7 +364,53 @@ def apply(self, X):
364364 yield list (range (0 , len (X )))
365365
366366
367- class MapperAlgorithm (ParamsMixin ):
367+ class _MapperAlgorithm :
368+
369+ def __init__ (
370+ self ,
371+ cover = None ,
372+ clustering = None ,
373+ failsafe = True ,
374+ verbose = True ,
375+ n_jobs = 1 ,
376+ ):
377+ self .cover = cover
378+ self .clustering = clustering
379+ self .failsafe = failsafe
380+ self .verbose = verbose
381+ self .n_jobs = n_jobs
382+
383+ def fit (self , X , y = None ):
384+ self .__cover = TrivialCover () if self .cover is None \
385+ else self .cover
386+ self .__clustering = TrivialClustering () if self .clustering is None \
387+ else self .clustering
388+ self .__verbose = self .verbose
389+ self .__failsafe = self .failsafe
390+ if self .__failsafe :
391+ self .__clustering = FailSafeClustering (
392+ clustering = self .__clustering ,
393+ verbose = self .__verbose ,
394+ )
395+ self .__cover = clone (self .__cover )
396+ self .__clustering = clone (self .__clustering )
397+ self .__n_jobs = self .n_jobs
398+ y = X if y is None else y
399+ self .graph_ = mapper_graph (
400+ X ,
401+ y ,
402+ self .__cover ,
403+ self .__clustering ,
404+ n_jobs = self .__n_jobs ,
405+ )
406+ return self
407+
408+ def fit_transform (self , X , y ):
409+ self .fit (X , y )
410+ return self .graph_
411+
412+
413+ class MapperAlgorithm (EstimatorMixin , _MapperAlgorithm , ParamsMixin ):
368414 """
369415 A class for creating and analyzing Mapper graphs.
370416
@@ -412,11 +458,13 @@ def __init__(
412458 verbose = True ,
413459 n_jobs = 1 ,
414460 ):
415- self .cover = cover
416- self .clustering = clustering
417- self .failsafe = failsafe
418- self .verbose = verbose
419- self .n_jobs = n_jobs
461+ super ().__init__ (
462+ cover = cover ,
463+ clustering = clustering ,
464+ failsafe = failsafe ,
465+ verbose = verbose ,
466+ n_jobs = n_jobs ,
467+ )
420468
421469 def fit (self , X , y = None ):
422470 """
@@ -431,29 +479,7 @@ def fit(self, X, y=None):
431479 :type y: array-like of shape (n, k) or list-like of length n
432480 :return: The object itself.
433481 """
434- self .__cover = TrivialCover () if self .cover is None \
435- else self .cover
436- self .__clustering = TrivialClustering () if self .clustering is None \
437- else self .clustering
438- self .__verbose = self .verbose
439- self .__failsafe = self .failsafe
440- if self .__failsafe :
441- self .__clustering = FailSafeClustering (
442- clustering = self .__clustering ,
443- verbose = self .__verbose ,
444- )
445- self .__cover = clone (self .__cover )
446- self .__clustering = clone (self .__clustering )
447- self .__n_jobs = self .n_jobs
448- y = X if y is None else y
449- self .graph_ = mapper_graph (
450- X ,
451- y ,
452- self .__cover ,
453- self .__clustering ,
454- n_jobs = self .__n_jobs ,
455- )
456- return self
482+ return super ().fit (X , y )
457483
458484 def fit_transform (self , X , y ):
459485 """
@@ -469,8 +495,7 @@ def fit_transform(self, X, y):
469495 :return: The Mapper graph.
470496 :rtype: :class:`networkx.Graph`
471497 """
472- self .fit (X , y )
473- return self .graph_
498+ return super ().fit_transform (X , y )
474499
475500
476501class FailSafeClustering (ParamsMixin ):
0 commit comments