Skip to content

Commit b228301

Browse files
authored
Merge pull request #163 from lucasimi/feature/add-standard-cover
Feature/add standard cover
2 parents 84eb186 + 86deac0 commit b228301

File tree

5 files changed

+337
-152
lines changed

5 files changed

+337
-152
lines changed

src/tdamapper/_common.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@ def warn_user(msg):
2121

2222
class EstimatorMixin:
2323

24-
def _is_sparse(self, X):
24+
def __is_sparse(self, X):
2525
# simple alternative use scipy.sparse.issparse
2626
return hasattr(X, 'toarray')
2727

2828
def _validate_X_y(self, X, y):
29-
if self._is_sparse(X):
29+
if self.__is_sparse(X):
3030
raise ValueError('Sparse data not supported.')
3131

3232
X = np.asarray(X)
@@ -58,11 +58,8 @@ def _validate_X_y(self, X, y):
5858

5959
return X, y
6060

61-
def fit(self, X, y=None):
62-
X, y = self._validate_X_y(X, y)
63-
res = super().fit(X, y)
61+
def _set_n_features_in(self, X):
6462
self.n_features_in_ = X.shape[1]
65-
return res
6663

6764

6865
class ParamsMixin:
@@ -71,7 +68,7 @@ class ParamsMixin:
7168
scikit-learn `get_params` and `set_params`.
7269
"""
7370

74-
def _is_param_internal(self, k):
71+
def __is_param_internal(self, k):
7572
return k.startswith('_') or k.endswith('_')
7673

7774
def get_params(self, deep=True):
@@ -82,14 +79,14 @@ def get_params(self, deep=True):
8279
:type deep: bool, optional.
8380
"""
8481
params = self.__dict__.items()
85-
return {k: v for k, v in params if not self._is_param_internal(k)}
82+
return {k: v for k, v in params if not self.__is_param_internal(k)}
8683

8784
def set_params(self, **params):
8885
"""
8986
Set public parameters. Only updates attributes that already exist.
9087
"""
9188
for k, v in params.items():
92-
if hasattr(self, k) and not self._is_param_internal(k):
89+
if hasattr(self, k) and not self.__is_param_internal(k):
9390
setattr(self, k, v)
9491
return self
9592

src/tdamapper/clustering.py

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -38,34 +38,7 @@ def __init__(self, clustering=None, verbose=True):
3838
super().__init__(clustering, verbose)
3939

4040

41-
class _MapperClustering:
42-
43-
def __init__(self, cover=None, clustering=None, n_jobs=1):
44-
self.cover = cover
45-
self.clustering = clustering
46-
self.n_jobs = n_jobs
47-
48-
def fit(self, X, y=None):
49-
cover = TrivialCover() if self.cover is None \
50-
else self.cover
51-
cover = clone(cover)
52-
clustering = TrivialClustering() if self.clustering is None \
53-
else self.clustering
54-
clustering = clone(clustering)
55-
n_jobs = self.n_jobs
56-
y = X if y is None else y
57-
itm_lbls = mapper_connected_components(
58-
X,
59-
y,
60-
cover,
61-
clustering,
62-
n_jobs=n_jobs,
63-
)
64-
self.labels_ = [itm_lbls[i] for i, _ in enumerate(X)]
65-
return self
66-
67-
68-
class MapperClustering(EstimatorMixin, _MapperClustering, ParamsMixin):
41+
class MapperClustering(EstimatorMixin, ParamsMixin):
6942
"""
7043
A clustering algorithm based on the Mapper graph.
7144
@@ -92,4 +65,27 @@ class MapperClustering(EstimatorMixin, _MapperClustering, ParamsMixin):
9265
"""
9366

9467
def __init__(self, cover=None, clustering=None, n_jobs=1):
95-
super().__init__(cover=cover, clustering=clustering, n_jobs=n_jobs)
68+
self.cover = cover
69+
self.clustering = clustering
70+
self.n_jobs = n_jobs
71+
72+
def fit(self, X, y=None):
73+
X, y = self._validate_X_y(X, y)
74+
cover = TrivialCover() if self.cover is None \
75+
else self.cover
76+
cover = clone(cover)
77+
clustering = TrivialClustering() if self.clustering is None \
78+
else self.clustering
79+
clustering = clone(clustering)
80+
n_jobs = self.n_jobs
81+
y = X if y is None else y
82+
itm_lbls = mapper_connected_components(
83+
X,
84+
y,
85+
cover,
86+
clustering,
87+
n_jobs=n_jobs,
88+
)
89+
self.labels_ = [itm_lbls[i] for i, _ in enumerate(X)]
90+
self._set_n_features_in(X)
91+
return self

src/tdamapper/core.py

Lines changed: 33 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -364,53 +364,7 @@ def apply(self, X):
364364
yield list(range(0, len(X)))
365365

366366

367-
class _MapperAlgorithm:
368-
369-
def __init__(
370-
self,
371-
cover=None,
372-
clustering=None,
373-
failsafe=True,
374-
verbose=True,
375-
n_jobs=1,
376-
):
377-
self.cover = cover
378-
self.clustering = clustering
379-
self.failsafe = failsafe
380-
self.verbose = verbose
381-
self.n_jobs = n_jobs
382-
383-
def fit(self, X, y=None):
384-
self.__cover = TrivialCover() if self.cover is None \
385-
else self.cover
386-
self.__clustering = TrivialClustering() if self.clustering is None \
387-
else self.clustering
388-
self.__verbose = self.verbose
389-
self.__failsafe = self.failsafe
390-
if self.__failsafe:
391-
self.__clustering = FailSafeClustering(
392-
clustering=self.__clustering,
393-
verbose=self.__verbose,
394-
)
395-
self.__cover = clone(self.__cover)
396-
self.__clustering = clone(self.__clustering)
397-
self.__n_jobs = self.n_jobs
398-
y = X if y is None else y
399-
self.graph_ = mapper_graph(
400-
X,
401-
y,
402-
self.__cover,
403-
self.__clustering,
404-
n_jobs=self.__n_jobs,
405-
)
406-
return self
407-
408-
def fit_transform(self, X, y):
409-
self.fit(X, y)
410-
return self.graph_
411-
412-
413-
class MapperAlgorithm(EstimatorMixin, _MapperAlgorithm, ParamsMixin):
367+
class MapperAlgorithm(EstimatorMixin, ParamsMixin):
414368
"""
415369
A class for creating and analyzing Mapper graphs.
416370
@@ -458,13 +412,11 @@ def __init__(
458412
verbose=True,
459413
n_jobs=1,
460414
):
461-
super().__init__(
462-
cover=cover,
463-
clustering=clustering,
464-
failsafe=failsafe,
465-
verbose=verbose,
466-
n_jobs=n_jobs,
467-
)
415+
self.cover = cover
416+
self.clustering = clustering
417+
self.failsafe = failsafe
418+
self.verbose = verbose
419+
self.n_jobs = n_jobs
468420

469421
def fit(self, X, y=None):
470422
"""
@@ -479,7 +431,31 @@ def fit(self, X, y=None):
479431
:type y: array-like of shape (n, k) or list-like of length n
480432
:return: The object itself.
481433
"""
482-
return super().fit(X, y)
434+
X, y = self._validate_X_y(X, y)
435+
self.__cover = TrivialCover() if self.cover is None \
436+
else self.cover
437+
self.__clustering = TrivialClustering() if self.clustering is None \
438+
else self.clustering
439+
self.__verbose = self.verbose
440+
self.__failsafe = self.failsafe
441+
if self.__failsafe:
442+
self.__clustering = FailSafeClustering(
443+
clustering=self.__clustering,
444+
verbose=self.__verbose,
445+
)
446+
self.__cover = clone(self.__cover)
447+
self.__clustering = clone(self.__clustering)
448+
self.__n_jobs = self.n_jobs
449+
y = X if y is None else y
450+
self.graph_ = mapper_graph(
451+
X,
452+
y,
453+
self.__cover,
454+
self.__clustering,
455+
n_jobs=self.__n_jobs,
456+
)
457+
self._set_n_features_in(X)
458+
return self
483459

484460
def fit_transform(self, X, y):
485461
"""
@@ -495,7 +471,8 @@ def fit_transform(self, X, y):
495471
:return: The Mapper graph.
496472
:rtype: :class:`networkx.Graph`
497473
"""
498-
return super().fit_transform(X, y)
474+
self.fit(X, y)
475+
return self.graph_
499476

500477

501478
class FailSafeClustering(ParamsMixin):

0 commit comments

Comments
 (0)