@@ -81,8 +81,14 @@ class ParamsMixin:
8181 scikit-learn `get_params` and `set_params`.
8282 """
8383
84- def __is_param_internal (self , k ):
85- return k .startswith ('_' ) or k .endswith ('_' )
84+ def __is_param_public (self , k ):
85+ return (not k .startswith ('_' )) and (not k .endswith ('_' ))
86+
87+ def __split_param (self , k ):
88+ k_split = k .split ('__' )
89+ outer = k_split [0 ]
90+ inner = '__' .join (k_split [1 :])
91+ return outer , inner
8692
8793 def get_params (self , deep = True ):
8894 """
@@ -91,18 +97,44 @@ def get_params(self, deep=True):
9197 :param deep: A flag for returning also nested parameters.
9298 :type deep: bool, optional.
9399 """
94- params = self .__dict__ .items ()
95- return {k : v for k , v in params if not self .__is_param_internal (k )}
100+ params = {}
101+ for k , v in self .__dict__ .items ():
102+ if self .__is_param_public (k ):
103+ params [k ] = v
104+ if hasattr (v , 'get_params' ) and deep :
105+ for _k , _v in v .get_params ().items ():
106+ params [f'{ k } __{ _k } ' ] = _v
107+ return params
96108
97109 def set_params (self , ** params ):
98110 """
99111 Set public parameters. Only updates attributes that already exist.
100112 """
113+ nested_params = []
101114 for k , v in params .items ():
102- if hasattr (self , k ) and not self .__is_param_internal (k ):
103- setattr (self , k , v )
115+ if self .__is_param_public (k ):
116+ k_outer , k_inner = self .__split_param (k )
117+ if not k_inner :
118+ if hasattr (self , k_outer ):
119+ setattr (self , k_outer , v )
120+ else :
121+ nested_params .append ((k_outer , k_inner , v ))
122+ for k_outer , k_inner , v in nested_params :
123+ if hasattr (self , k_outer ):
124+ k_attr = getattr (self , k_outer )
125+ k_attr .set_params (** {k_inner : v })
104126 return self
105127
128+ def __repr__ (self ):
129+ obj = type (self )()
130+ rep = f'{ self .__class__ .__name__ } ('
131+ for k , v in self .__dict__ .items ():
132+ obj_v = getattr (obj , k )
133+ if self .__is_param_public (k ) and not v == obj_v :
134+ rep += f'{ k } ={ v } , '
135+ rep += ')'
136+ return rep
137+
106138
107139def clone (estimator ):
108140 """
0 commit comments