Skip to content

Commit 1876d57

Browse files
authored
Merge pull request #177 from lucasimi/bugfix/nested-params
Improved handling of nested params. Added repr method
2 parents f220c41 + ba94681 commit 1876d57

File tree

2 files changed

+88
-6
lines changed

2 files changed

+88
-6
lines changed

src/tdamapper/_common.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,14 @@ class ParamsMixin:
8181
scikit-learn `get_params` and `set_params`.
8282
"""
8383

84-
def __is_param_internal(self, k):
85-
return k.startswith('_') or k.endswith('_')
84+
def __is_param_public(self, k):
85+
return (not k.startswith('_')) and (not k.endswith('_'))
86+
87+
def __split_param(self, k):
88+
k_split = k.split('__')
89+
outer = k_split[0]
90+
inner = '__'.join(k_split[1:])
91+
return outer, inner
8692

8793
def get_params(self, deep=True):
8894
"""
@@ -91,18 +97,44 @@ def get_params(self, deep=True):
9197
:param deep: A flag for returning also nested parameters.
9298
:type deep: bool, optional.
9399
"""
94-
params = self.__dict__.items()
95-
return {k: v for k, v in params if not self.__is_param_internal(k)}
100+
params = {}
101+
for k, v in self.__dict__.items():
102+
if self.__is_param_public(k):
103+
params[k] = v
104+
if hasattr(v, 'get_params') and deep:
105+
for _k, _v in v.get_params().items():
106+
params[f'{k}__{_k}'] = _v
107+
return params
96108

97109
def set_params(self, **params):
98110
"""
99111
Set public parameters. Only updates attributes that already exist.
100112
"""
113+
nested_params = []
101114
for k, v in params.items():
102-
if hasattr(self, k) and not self.__is_param_internal(k):
103-
setattr(self, k, v)
115+
if self.__is_param_public(k):
116+
k_outer, k_inner = self.__split_param(k)
117+
if not k_inner:
118+
if hasattr(self, k_outer):
119+
setattr(self, k_outer, v)
120+
else:
121+
nested_params.append((k_outer, k_inner, v))
122+
for k_outer, k_inner, v in nested_params:
123+
if hasattr(self, k_outer):
124+
k_attr = getattr(self, k_outer)
125+
k_attr.set_params(**{k_inner: v})
104126
return self
105127

128+
def __repr__(self):
129+
obj = type(self)()
130+
rep = f'{self.__class__.__name__}('
131+
for k, v in self.__dict__.items():
132+
obj_v = getattr(obj, k)
133+
if self.__is_param_public(k) and not v == obj_v:
134+
rep += f'{k}={v}, '
135+
rep += ')'
136+
return rep
137+
106138

107139
def clone(estimator):
108140
"""

tests/test_unit_params.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import unittest
2+
3+
from tdamapper.core import MapperAlgorithm
4+
from tdamapper.cover import (
5+
BallCover,
6+
KNNCover,
7+
CubicalCover
8+
)
9+
from tdamapper.clustering import MapperClustering
10+
11+
12+
class TestParams(unittest.TestCase):
13+
14+
def test_params_mapper(self):
15+
est = MapperAlgorithm(
16+
cover=CubicalCover(
17+
n_intervals=3,
18+
overlap_frac=0.3,
19+
),
20+
)
21+
params = est.get_params(deep=False)
22+
self.assertEquals(5, len(params))
23+
params = est.get_params()
24+
self.assertEquals(12, len(params))
25+
self.assertEquals(3, params['cover__n_intervals'])
26+
self.assertEquals(0.3, params['cover__overlap_frac'])
27+
est.set_params(cover__n_intervals=2, cover__overlap_frac=0.2)
28+
params = est.get_params()
29+
self.assertEquals(12, len(params))
30+
self.assertEquals(2, params['cover__n_intervals'])
31+
self.assertEquals(0.2, params['cover__overlap_frac'])
32+
33+
def test_params_clust(self):
34+
est = MapperClustering(
35+
cover=CubicalCover(
36+
n_intervals=3,
37+
overlap_frac=0.3,
38+
),
39+
)
40+
params = est.get_params(deep=False)
41+
self.assertEquals(3, len(params))
42+
params = est.get_params()
43+
self.assertEquals(10, len(params))
44+
self.assertEquals(3, params['cover__n_intervals'])
45+
self.assertEquals(0.3, params['cover__overlap_frac'])
46+
est.set_params(cover__n_intervals=2, cover__overlap_frac=0.2)
47+
params = est.get_params()
48+
self.assertEquals(10, len(params))
49+
self.assertEquals(2, params['cover__n_intervals'])
50+
self.assertEquals(0.2, params['cover__overlap_frac'])

0 commit comments

Comments
 (0)