Skip to content

Commit b07de8a

Browse files
authored
Merge pull request #178 from lucasimi/bugfix/nested-params
Fixed clone
2 parents 1876d57 + b2fbacd commit b07de8a

File tree

4 files changed

+129
-38
lines changed

4 files changed

+129
-38
lines changed

src/tdamapper/_common.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -126,17 +126,18 @@ def set_params(self, **params):
126126
return self
127127

128128
def __repr__(self):
129-
obj = type(self)()
130-
rep = f'{self.__class__.__name__}('
129+
obj_noargs = type(self)()
130+
args_repr = []
131131
for k, v in self.__dict__.items():
132-
obj_v = getattr(obj, k)
133-
if self.__is_param_public(k) and not v == obj_v:
134-
rep += f'{k}={v}, '
135-
rep += ')'
136-
return rep
132+
v_default = getattr(obj_noargs, k)
133+
v_default_repr = repr(v_default)
134+
v_repr = repr(v)
135+
if self.__is_param_public(k) and not v_repr == v_default_repr:
136+
args_repr.append(f'{k}={v_repr}')
137+
return f"{self.__class__.__name__}({', '.join(args_repr)})"
137138

138139

139-
def clone(estimator):
140+
def clone(obj):
140141
"""
141142
Clone an estimator, returning a new one, unfitted, having the same public
142143
parameters.
@@ -146,5 +147,7 @@ def clone(estimator):
146147
:return: A new estimator with the same parameters.
147148
:rtype: A scikit-learn compatible estimator
148149
"""
149-
params = estimator.get_params(deep=True)
150-
return type(estimator)(**params)
150+
params = obj.get_params(deep=True)
151+
obj_noargs = type(obj)()
152+
obj_noargs.set_params(**params)
153+
return obj_noargs

src/tdamapper/cover.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -512,24 +512,20 @@ def __init__(
512512
self.leaf_capacity = leaf_capacity
513513
self.leaf_radius = leaf_radius
514514
self.pivoting = pivoting
515-
if algorithm == 'proximity':
516-
self.__cubical_cover = ProximityCubicalCover(
517-
n_intervals=n_intervals,
518-
overlap_frac=overlap_frac,
519-
kind=kind,
520-
leaf_capacity=leaf_capacity,
521-
leaf_radius=leaf_radius,
522-
pivoting=pivoting,
523-
)
524-
elif algorithm == 'standard':
525-
self.__cubical_cover = StandardCubicalCover(
526-
n_intervals=n_intervals,
527-
overlap_frac=overlap_frac,
528-
kind=kind,
529-
leaf_capacity=leaf_capacity,
530-
leaf_radius=leaf_radius,
531-
pivoting=pivoting,
532-
)
515+
516+
def __get_cubical_cover(self):
517+
params = dict(
518+
n_intervals=self.n_intervals,
519+
overlap_frac=self.overlap_frac,
520+
kind=self.kind,
521+
leaf_capacity=self.leaf_capacity,
522+
leaf_radius=self.leaf_radius,
523+
pivoting=self.pivoting,
524+
)
525+
if self.algorithm == 'proximity':
526+
return ProximityCubicalCover(**params)
527+
elif self.algorithm == 'standard':
528+
return StandardCubicalCover(**params)
533529
else:
534530
raise ValueError(
535531
"The only possible values for algorithm are 'standard' and "
@@ -548,7 +544,9 @@ def fit(self, X):
548544
:return: The object itself.
549545
:rtype: self
550546
"""
551-
return self.__cubical_cover.fit(X)
547+
self.__cubical_cover = self.__get_cubical_cover()
548+
self.__cubical_cover.fit(X)
549+
return self
552550

553551
def search(self, x):
554552
"""
@@ -576,4 +574,5 @@ def apply(self, X):
576574
:return: A generator of lists of ids.
577575
:rtype: generator of lists of ints
578576
"""
577+
self.__cubical_cover = self.__get_cubical_cover()
579578
return self.__cubical_cover.apply(X)

tests/test_unit_core.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
mapper_labels,
1111
TrivialCover,
1212
)
13-
from tdamapper.cover import BallCover
13+
from tdamapper.cover import BallCover, CubicalCover, ProximityCubicalCover, StandardCubicalCover
1414
from tdamapper.clustering import TrivialClustering
1515

1616

@@ -80,7 +80,7 @@ def test_ball_large_radius(self):
8080
ccs2 = mapper_connected_components(data, data, cover, clustering)
8181
self.assertEqual(len(data), len(ccs2))
8282

83-
def test_two_disconnected_clusters(self):
83+
def test_ball_two_disconnected_clusters(self):
8484
data = [np.array([float(i), 0.0]) for i in range(100)]
8585
data.extend([np.array([float(i), 500.0]) for i in range(100)])
8686
data = np.array(data)
@@ -98,7 +98,7 @@ def test_two_disconnected_clusters(self):
9898
ccs2 = mapper_connected_components(data, data, cover, clustering)
9999
self.assertEqual(len(data), len(ccs2))
100100

101-
def test_two_connected_clusters(self):
101+
def test_ball_two_connected_clusters(self):
102102
data = [
103103
np.array([0.0, 1.0]), np.array([1.0, 0.0]),
104104
np.array([0.0, 0.0]), np.array([1.0, 1.0])]
@@ -116,7 +116,7 @@ def test_two_connected_clusters(self):
116116
ccs2 = mapper_connected_components(data, data, cover, clustering)
117117
self.assertEqual(len(data), len(ccs2))
118118

119-
def test_two_connected_clusters_parallel(self):
119+
def test_ball_two_connected_clusters_parallel(self):
120120
data = [
121121
np.array([0.0, 1.0]), np.array([1.0, 0.0]),
122122
np.array([0.0, 0.0]), np.array([1.0, 1.0])]
@@ -136,7 +136,23 @@ def test_two_connected_clusters_parallel(self):
136136
ccs2 = mapper_connected_components(data, data, cover, clustering)
137137
self.assertEqual(len(data), len(ccs2))
138138

139-
def test_connected_components(self):
139+
def test_proximity_cubical_line(self):
140+
data = np.array([[float(i)] for i in range(1000)])
141+
cover = ProximityCubicalCover(n_intervals=4, overlap_frac=0.5)
142+
clustering = TrivialClustering()
143+
mp = MapperAlgorithm(cover, clustering)
144+
g = mp.fit_transform(data, data)
145+
self.assertEqual(4, len(g.nodes))
146+
147+
def test_cubical_line(self):
148+
data = np.array([[float(i)] for i in range(1000)])
149+
cover = CubicalCover(n_intervals=4, overlap_frac=0.5)
150+
clustering = TrivialClustering()
151+
mp = MapperAlgorithm(cover, clustering)
152+
g = mp.fit_transform(data, data)
153+
self.assertEqual(4, len(g.nodes))
154+
155+
def test_mock_connected_components(self):
140156
data = [0, 1, 2, 3]
141157

142158
class MockCover:
@@ -156,7 +172,7 @@ def apply(self, X):
156172
self.assertEqual(cc0, ccs[2])
157173
self.assertEqual(cc0, ccs[3])
158174

159-
def test_labels(self):
175+
def test_mock_labels(self):
160176
data = [0, 1, 2, 3]
161177

162178
class MockCover:

tests/test_unit_params.py

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import unittest
22

3+
from sklearn.cluster import DBSCAN
4+
5+
from tdamapper._common import clone
36
from tdamapper.core import MapperAlgorithm
47
from tdamapper.cover import (
58
BallCover,
@@ -11,7 +14,23 @@
1114

1215
class TestParams(unittest.TestCase):
1316

14-
def test_params_mapper(self):
17+
def __test_clone(self, obj):
18+
obj_repr = repr(obj)
19+
obj_cln = clone(obj)
20+
cln_repr = repr(obj_cln)
21+
self.assertEquals(obj_repr, cln_repr)
22+
23+
def __test_repr(self, obj):
24+
obj_repr = repr(obj)
25+
_obj = eval(obj_repr)
26+
_obj_repr = repr(_obj)
27+
self.assertEquals(obj_repr, _obj_repr)
28+
29+
def __test_clone_and_repr(self, obj):
30+
self.__test_clone(obj)
31+
self.__test_repr(obj)
32+
33+
def test_params_mapper_algorithm(self):
1534
est = MapperAlgorithm(
1635
cover=CubicalCover(
1736
n_intervals=3,
@@ -30,7 +49,7 @@ def test_params_mapper(self):
3049
self.assertEquals(2, params['cover__n_intervals'])
3150
self.assertEquals(0.2, params['cover__overlap_frac'])
3251

33-
def test_params_clust(self):
52+
def test_params_mapper_clustering(self):
3453
est = MapperClustering(
3554
cover=CubicalCover(
3655
n_intervals=3,
@@ -47,4 +66,58 @@ def test_params_clust(self):
4766
params = est.get_params()
4867
self.assertEquals(10, len(params))
4968
self.assertEquals(2, params['cover__n_intervals'])
50-
self.assertEquals(0.2, params['cover__overlap_frac'])
69+
self.assertEquals(0.2, params['cover__overlap_frac'])
70+
71+
def test_clone_and_repr_ball_cover(self):
72+
self.__test_clone_and_repr(BallCover())
73+
self.__test_clone_and_repr(BallCover(
74+
radius=2.0,
75+
metric='test',
76+
metric_params={'f': 4},
77+
kind='kind_test',
78+
leaf_capacity=3.0,
79+
leaf_radius=-2.0,
80+
pivoting=7,
81+
))
82+
83+
def test_clone_and_repr_cubical_cover(self):
84+
self.__test_clone_and_repr(CubicalCover())
85+
self.__test_clone_and_repr(CubicalCover(
86+
n_intervals=4,
87+
overlap_frac=5,
88+
algorithm='algo_test',
89+
kind='simple',
90+
leaf_radius=5,
91+
leaf_capacity=6,
92+
pivoting='no'
93+
))
94+
95+
def test_clone_repr_mapper_algorithm(self):
96+
self.__test_clone_and_repr(MapperAlgorithm())
97+
self.__test_clone_and_repr(MapperAlgorithm(
98+
cover=CubicalCover(
99+
n_intervals=3,
100+
overlap_frac=0.3,
101+
),
102+
clustering=DBSCAN(
103+
eps='none',
104+
min_samples=5.4,
105+
),
106+
failsafe=4,
107+
n_jobs='foo',
108+
verbose=4,
109+
))
110+
111+
def test_clone_repr_mapper_clustering(self):
112+
self.__test_clone_and_repr(MapperClustering())
113+
self.__test_clone_and_repr(MapperClustering(
114+
cover=CubicalCover(
115+
n_intervals=3,
116+
overlap_frac=0.3,
117+
),
118+
clustering=DBSCAN(
119+
eps='none',
120+
min_samples=5.4,
121+
),
122+
n_jobs='foo',
123+
))

0 commit comments

Comments
 (0)