Skip to content

Commit 797df4c

Browse files
authored
Merge pull request #159 from lucasimi/feature/scikit-learn-compatibility
Feature/scikit learn compatibility
2 parents 354ffb8 + ee48407 commit 797df4c

File tree

4 files changed

+146
-114
lines changed

4 files changed

+146
-114
lines changed

src/tdamapper/_common.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
"""
44
import warnings
55

6+
import numpy as np
7+
68

79
def warn_deprecated(deprecated, substitute):
810
msg = f'{deprecated} is deprecated and will be removed in a future version. Use {substitute} instead.'
@@ -17,6 +19,52 @@ def warn_user(msg):
1719
warnings.warn(msg, UserWarning, stacklevel=2)
1820

1921

22+
class EstimatorMixin:
23+
24+
def _is_sparse(self, X):
25+
# simple alternative use scipy.sparse.issparse
26+
return hasattr(X, 'toarray')
27+
28+
def _validate_X_y(self, X, y):
29+
if self._is_sparse(X):
30+
raise ValueError('Sparse data not supported.')
31+
32+
X = np.asarray(X)
33+
y = np.asarray(y)
34+
35+
if X.size == 0:
36+
msg = f'0 feature(s) (shape={X.shape}) while a minimum of 1 is required.'
37+
raise ValueError(msg)
38+
39+
if y.size == 0:
40+
msg = f'0 feature(s) (shape={y.shape}) while a minimum of 1 is required.'
41+
raise ValueError(msg)
42+
43+
if X.ndim == 1:
44+
raise ValueError('1d-arrays not supported.')
45+
46+
if np.iscomplexobj(X) or np.iscomplexobj(y):
47+
raise ValueError('Complex data not supported.')
48+
49+
if X.dtype == np.object_:
50+
X = np.array(X, dtype=float)
51+
52+
if y.dtype == np.object_:
53+
y = np.array(y, dtype=float)
54+
55+
if np.isnan(X).any() or np.isinf(X).any() or \
56+
np.isnan(y).any() or np.isinf(y).any():
57+
raise ValueError('NaNs or infinite values not supported.')
58+
59+
return X, y
60+
61+
def fit(self, X, y=None):
62+
X, y = self._validate_X_y(X, y)
63+
res = super().fit(X, y)
64+
self.n_features_in_ = X.shape[1]
65+
return res
66+
67+
2068
class ParamsMixin:
2169
"""
2270
Mixin to add setters and getters for public parameters, compatible with

src/tdamapper/clustering.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@
44

55
from tdamapper.core import mapper_connected_components, TrivialCover
66
import tdamapper.core
7-
from tdamapper._common import ParamsMixin, clone, warn_deprecated
7+
from tdamapper._common import (
8+
ParamsMixin,
9+
EstimatorMixin,
10+
clone,
11+
warn_deprecated,
12+
)
813

914

1015
class TrivialClustering(tdamapper.core.TrivialClustering):
@@ -33,7 +38,34 @@ def __init__(self, clustering=None, verbose=True):
3338
super().__init__(clustering, verbose)
3439

3540

36-
class MapperClustering(ParamsMixin):
41+
class _MapperClustering:
42+
43+
def __init__(self, cover=None, clustering=None, n_jobs=1):
44+
self.cover = cover
45+
self.clustering = clustering
46+
self.n_jobs = n_jobs
47+
48+
def fit(self, X, y=None):
49+
cover = TrivialCover() if self.cover is None \
50+
else self.cover
51+
cover = clone(cover)
52+
clustering = TrivialClustering() if self.clustering is None \
53+
else self.clustering
54+
clustering = clone(clustering)
55+
n_jobs = self.n_jobs
56+
y = X if y is None else y
57+
itm_lbls = mapper_connected_components(
58+
X,
59+
y,
60+
cover,
61+
clustering,
62+
n_jobs=n_jobs,
63+
)
64+
self.labels_ = [itm_lbls[i] for i, _ in enumerate(X)]
65+
return self
66+
67+
68+
class MapperClustering(EstimatorMixin, _MapperClustering, ParamsMixin):
3769
"""
3870
A clustering algorithm based on the Mapper graph.
3971
@@ -60,25 +92,4 @@ class MapperClustering(ParamsMixin):
6092
"""
6193

6294
def __init__(self, cover=None, clustering=None, n_jobs=1):
63-
self.cover = cover
64-
self.clustering = clustering
65-
self.n_jobs = n_jobs
66-
67-
def fit(self, X, y=None):
68-
cover = TrivialCover() if self.cover is None \
69-
else self.cover
70-
cover = clone(cover)
71-
clustering = TrivialClustering() if self.clustering is None \
72-
else self.clustering
73-
clustering = clone(clustering)
74-
n_jobs = self.n_jobs
75-
y = X if y is None else y
76-
itm_lbls = mapper_connected_components(
77-
X,
78-
y,
79-
cover,
80-
clustering,
81-
n_jobs=n_jobs,
82-
)
83-
self.labels_ = [itm_lbls[i] for i, _ in enumerate(X)]
84-
return self
95+
super().__init__(cover=cover, clustering=clustering, n_jobs=n_jobs)

src/tdamapper/core.py

Lines changed: 57 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from joblib import Parallel, delayed
3434

3535
from tdamapper.utils.unionfind import UnionFind
36-
from tdamapper._common import ParamsMixin, clone
36+
from tdamapper._common import ParamsMixin, EstimatorMixin, clone
3737

3838

3939
ATTR_IDS = 'ids'
@@ -364,7 +364,53 @@ def apply(self, X):
364364
yield list(range(0, len(X)))
365365

366366

367-
class MapperAlgorithm(ParamsMixin):
367+
class _MapperAlgorithm:
368+
369+
def __init__(
370+
self,
371+
cover=None,
372+
clustering=None,
373+
failsafe=True,
374+
verbose=True,
375+
n_jobs=1,
376+
):
377+
self.cover = cover
378+
self.clustering = clustering
379+
self.failsafe = failsafe
380+
self.verbose = verbose
381+
self.n_jobs = n_jobs
382+
383+
def fit(self, X, y=None):
384+
self.__cover = TrivialCover() if self.cover is None \
385+
else self.cover
386+
self.__clustering = TrivialClustering() if self.clustering is None \
387+
else self.clustering
388+
self.__verbose = self.verbose
389+
self.__failsafe = self.failsafe
390+
if self.__failsafe:
391+
self.__clustering = FailSafeClustering(
392+
clustering=self.__clustering,
393+
verbose=self.__verbose,
394+
)
395+
self.__cover = clone(self.__cover)
396+
self.__clustering = clone(self.__clustering)
397+
self.__n_jobs = self.n_jobs
398+
y = X if y is None else y
399+
self.graph_ = mapper_graph(
400+
X,
401+
y,
402+
self.__cover,
403+
self.__clustering,
404+
n_jobs=self.__n_jobs,
405+
)
406+
return self
407+
408+
def fit_transform(self, X, y):
409+
self.fit(X, y)
410+
return self.graph_
411+
412+
413+
class MapperAlgorithm(EstimatorMixin, _MapperAlgorithm, ParamsMixin):
368414
"""
369415
A class for creating and analyzing Mapper graphs.
370416
@@ -412,11 +458,13 @@ def __init__(
412458
verbose=True,
413459
n_jobs=1,
414460
):
415-
self.cover = cover
416-
self.clustering = clustering
417-
self.failsafe = failsafe
418-
self.verbose = verbose
419-
self.n_jobs = n_jobs
461+
super().__init__(
462+
cover=cover,
463+
clustering=clustering,
464+
failsafe=failsafe,
465+
verbose=verbose,
466+
n_jobs=n_jobs,
467+
)
420468

421469
def fit(self, X, y=None):
422470
"""
@@ -431,29 +479,7 @@ def fit(self, X, y=None):
431479
:type y: array-like of shape (n, k) or list-like of length n
432480
:return: The object itself.
433481
"""
434-
self.__cover = TrivialCover() if self.cover is None \
435-
else self.cover
436-
self.__clustering = TrivialClustering() if self.clustering is None \
437-
else self.clustering
438-
self.__verbose = self.verbose
439-
self.__failsafe = self.failsafe
440-
if self.__failsafe:
441-
self.__clustering = FailSafeClustering(
442-
clustering=self.__clustering,
443-
verbose=self.__verbose,
444-
)
445-
self.__cover = clone(self.__cover)
446-
self.__clustering = clone(self.__clustering)
447-
self.__n_jobs = self.n_jobs
448-
y = X if y is None else y
449-
self.graph_ = mapper_graph(
450-
X,
451-
y,
452-
self.__cover,
453-
self.__clustering,
454-
n_jobs=self.__n_jobs,
455-
)
456-
return self
482+
return super().fit(X, y)
457483

458484
def fit_transform(self, X, y):
459485
"""
@@ -469,8 +495,7 @@ def fit_transform(self, X, y):
469495
:return: The Mapper graph.
470496
:rtype: :class:`networkx.Graph`
471497
"""
472-
self.fit(X, y)
473-
return self.graph_
498+
return super().fit_transform(X, y)
474499

475500

476501
class FailSafeClustering(ParamsMixin):

tests/test_unit_sklearn.py

Lines changed: 6 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import numpy as np
55

6-
76
from sklearn.utils.estimator_checks import check_estimator
87

98
from tdamapper.core import MapperAlgorithm
@@ -21,57 +20,6 @@ def euclidean(x, y):
2120
return np.linalg.norm(x - y)
2221

2322

24-
class ValidationMixin:
25-
26-
def _is_sparse(self, X):
27-
# in alternative use scipy.sparse.issparse
28-
return hasattr(X, 'toarray')
29-
30-
def _validate_X_y(self, X, y):
31-
if self._is_sparse(X):
32-
raise ValueError('Sparse data not supported.')
33-
34-
if X.size == 0:
35-
msg = f'0 feature(s) (shape={X.shape}) while a minimum of 1 is required.'
36-
raise ValueError(msg)
37-
38-
if y.size == 0:
39-
msg = f'0 feature(s) (shape={y.shape}) while a minimum of 1 is required.'
40-
raise ValueError(msg)
41-
42-
if X.ndim == 1:
43-
raise ValueError('1d-arrays not supported.')
44-
45-
if np.iscomplexobj(X) or np.iscomplexobj(y):
46-
raise ValueError('Complex data not supported.')
47-
48-
if X.dtype == np.object_:
49-
X = np.array(X, dtype=float)
50-
51-
if y.dtype == np.object_:
52-
y = np.array(y, dtype=float)
53-
54-
if np.isnan(X).any() or np.isinf(X).any() or \
55-
np.isnan(y).any() or np.isinf(y).any():
56-
raise ValueError('NaNs or infinite values not supported.')
57-
58-
return X, y
59-
60-
def fit(self, X, y=None):
61-
X, y = self._validate_X_y(X, y)
62-
res = super().fit(X, y)
63-
self.n_features_in_ = X.shape[1]
64-
return res
65-
66-
67-
class MapperEstimator(ValidationMixin, MapperAlgorithm):
68-
pass
69-
70-
71-
class MapperClusteringEstimator(ValidationMixin, MapperClustering):
72-
pass
73-
74-
7523
class TestSklearn(unittest.TestCase):
7624

7725
setup_logging()
@@ -83,25 +31,25 @@ def run_tests(self, estimator):
8331
check(est)
8432

8533
def test_trivial(self):
86-
est = MapperEstimator()
34+
est = MapperAlgorithm()
8735
self.run_tests(est)
8836

8937
def test_ball(self):
90-
est = MapperEstimator(cover=BallCover(metric=euclidean))
38+
est = MapperAlgorithm(cover=BallCover(metric=euclidean))
9139
self.run_tests(est)
9240

9341
def test_knn(self):
94-
est = MapperEstimator(cover=KNNCover(metric=euclidean))
42+
est = MapperAlgorithm(cover=KNNCover(metric=euclidean))
9543
self.run_tests(est)
9644

9745
def test_cubical(self):
98-
est = MapperEstimator(cover=CubicalCover())
46+
est = MapperAlgorithm(cover=CubicalCover())
9947
self.run_tests(est)
10048

10149
def test_clustering_trivial(self):
102-
est = MapperClusteringEstimator()
50+
est = MapperClustering()
10351
self.run_tests(est)
10452

10553
def test_clustering_ball(self):
106-
est = MapperClusteringEstimator(cover=BallCover(metric=euclidean))
54+
est = MapperClustering(cover=BallCover(metric=euclidean))
10755
self.run_tests(est)

0 commit comments

Comments
 (0)