Skip to content

Commit 7fc2eeb

Browse files
committed
Merge branch 'main' of github.com:lucasimi/tda-mapper-python into add-type-hints
2 parents 994d875 + fa618f0 commit 7fc2eeb

4 files changed

Lines changed: 64 additions & 413 deletions

File tree

src/tdamapper/clustering.py

Lines changed: 0 additions & 80 deletions
This file was deleted.

src/tdamapper/core.py

Lines changed: 1 addition & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
import networkx as nx
3737
from joblib import Parallel, delayed
3838

39-
from tdamapper._common import ArrayLike, EstimatorMixin, ParamsMixin, clone, deprecated
39+
from tdamapper._common import ArrayLike, ParamsMixin, clone
4040
from tdamapper.utils.unionfind import UnionFind
4141

4242
ATTR_IDS = "ids"
@@ -394,68 +394,6 @@ def apply(self, X: ArrayLike) -> Generator[list[int]]:
394394
yield list(range(0, len(X)))
395395

396396

397-
class _MapperAlgorithm(EstimatorMixin, ParamsMixin):
398-
399-
def __init__(
400-
self,
401-
cover=None,
402-
clustering=None,
403-
failsafe=True,
404-
verbose=True,
405-
n_jobs=1,
406-
):
407-
self.cover = cover
408-
self.clustering = clustering
409-
self.failsafe = failsafe
410-
self.verbose = verbose
411-
self.n_jobs = n_jobs
412-
413-
def fit(self, X, y=None):
414-
X, y = self._validate_X_y(X, y)
415-
self._cover = TrivialCover() if self.cover is None else self.cover
416-
self._clustering = (
417-
TrivialClustering() if self.clustering is None else self.clustering
418-
)
419-
self._verbose = self.verbose
420-
self._failsafe = self.failsafe
421-
if self._failsafe:
422-
self._clustering = FailSafeClustering(
423-
clustering=self._clustering,
424-
verbose=self._verbose,
425-
)
426-
self._cover = clone(self._cover)
427-
self._clustering = clone(self._clustering)
428-
self._n_jobs = self.n_jobs
429-
y = X if y is None else y
430-
self.graph_ = mapper_graph(
431-
X,
432-
y,
433-
self._cover,
434-
self._clustering,
435-
n_jobs=self._n_jobs,
436-
)
437-
self._set_n_features_in(X)
438-
return self
439-
440-
def fit_transform(self, X, y):
441-
self.fit(X, y)
442-
return self.graph_
443-
444-
445-
class MapperAlgorithm(_MapperAlgorithm):
446-
"""
447-
**DEPRECATED**: This class is deprecated and will be removed in a future
448-
release. Use :class:`tdamapper.learn.MapperAlgorithm`.
449-
"""
450-
451-
@deprecated(
452-
"This class is deprecated and will be removed in a future release. "
453-
"Use tdamapper.learn.MapperAlgorithm."
454-
)
455-
def __init__(self, *args, **kwargs):
456-
super().__init__(*args, **kwargs)
457-
458-
459397
class FailSafeClustering(ParamsMixin):
460398
"""
461399
A delegating clustering algorithm that prevents failure.

src/tdamapper/learn.py

Lines changed: 62 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,17 @@
99
scikit-learn's conventions for estimators.
1010
"""
1111

12-
from tdamapper.clustering import _MapperClustering
13-
from tdamapper.core import _MapperAlgorithm
12+
from tdamapper._common import EstimatorMixin, ParamsMixin, clone
13+
from tdamapper.core import (
14+
FailSafeClustering,
15+
TrivialClustering,
16+
TrivialCover,
17+
mapper_connected_components,
18+
mapper_graph,
19+
)
1420

1521

16-
class MapperClustering(_MapperClustering):
22+
class MapperClustering(EstimatorMixin, ParamsMixin):
1723
"""
1824
A clustering algorithm based on the Mapper graph.
1925
@@ -45,11 +51,9 @@ def __init__(
4551
clustering=None,
4652
n_jobs=1,
4753
):
48-
super().__init__(
49-
cover=cover,
50-
clustering=clustering,
51-
n_jobs=n_jobs,
52-
)
54+
self.cover = cover
55+
self.clustering = clustering
56+
self.n_jobs = n_jobs
5357

5458
def fit(self, X, y=None):
5559
"""
@@ -60,10 +64,26 @@ def fit(self, X, y=None):
6064
:param y: Ignored.
6165
:return: self
6266
"""
63-
return super().fit(X, y)
67+
y = X if y is None else y
68+
X, y = self._validate_X_y(X, y)
69+
cover = TrivialCover() if self.cover is None else self.cover
70+
cover = clone(cover)
71+
clustering = TrivialClustering() if self.clustering is None else self.clustering
72+
clustering = clone(clustering)
73+
n_jobs = self.n_jobs
74+
itm_lbls = mapper_connected_components(
75+
X,
76+
y,
77+
cover,
78+
clustering,
79+
n_jobs=n_jobs,
80+
)
81+
self.labels_ = [itm_lbls[i] for i, _ in enumerate(X)]
82+
self._set_n_features_in(X)
83+
return self
6484

6585

66-
class MapperAlgorithm(_MapperAlgorithm):
86+
class MapperAlgorithm(EstimatorMixin, ParamsMixin):
6787
"""
6888
A class for creating and analyzing Mapper graphs.
6989
@@ -111,13 +131,11 @@ def __init__(
111131
verbose=True,
112132
n_jobs=1,
113133
):
114-
super().__init__(
115-
cover=cover,
116-
clustering=clustering,
117-
failsafe=failsafe,
118-
verbose=verbose,
119-
n_jobs=n_jobs,
120-
)
134+
self.cover = cover
135+
self.clustering = clustering
136+
self.failsafe = failsafe
137+
self.verbose = verbose
138+
self.n_jobs = n_jobs
121139

122140
def fit(self, X, y=None):
123141
"""
@@ -132,7 +150,31 @@ def fit(self, X, y=None):
132150
:type y: array-like of shape (n, k) or list-like of length n
133151
:return: The object itself.
134152
"""
135-
return super().fit(X, y)
153+
X, y = self._validate_X_y(X, y)
154+
self._cover = TrivialCover() if self.cover is None else self.cover
155+
self._clustering = (
156+
TrivialClustering() if self.clustering is None else self.clustering
157+
)
158+
self._verbose = self.verbose
159+
self._failsafe = self.failsafe
160+
if self._failsafe:
161+
self._clustering = FailSafeClustering(
162+
clustering=self._clustering,
163+
verbose=self._verbose,
164+
)
165+
self._cover = clone(self._cover)
166+
self._clustering = clone(self._clustering)
167+
self._n_jobs = self.n_jobs
168+
y = X if y is None else y
169+
self.graph_ = mapper_graph(
170+
X,
171+
y,
172+
self._cover,
173+
self._clustering,
174+
n_jobs=self._n_jobs,
175+
)
176+
self._set_n_features_in(X)
177+
return self
136178

137179
def fit_transform(self, X, y):
138180
"""
@@ -148,4 +190,5 @@ def fit_transform(self, X, y):
148190
:return: The Mapper graph.
149191
:rtype: :class:`networkx.Graph`
150192
"""
151-
return super().fit_transform(X, y)
193+
self.fit(X, y)
194+
return self.graph_

0 commit comments

Comments
 (0)