33"""Package where the MLBlock class is defined."""
44
55import importlib
6+ import logging
67
78from mlblocks .primitives import load_primitive
89
10+ LOGGER = logging .getLogger (__name__ )
11+
912
1013def import_object (object_name ):
1114 """Import an object from its Fully Qualified Name."""
@@ -83,7 +86,7 @@ def _extract_params(self, kwargs, hyperparameters):
8386 value = param ['default' ]
8487
8588 else :
86- raise TypeError ("Required argument '{}' not found" .format (name ))
89+ raise TypeError ("{} required argument '{}' not found" .format (self . name , name ))
8790
8891 init_params [name ] = value
8992
@@ -107,6 +110,33 @@ def _extract_params(self, kwargs, hyperparameters):
107110
108111 return init_params , fit_params , produce_params
109112
113+ @staticmethod
114+ def _filter_conditional (conditional , init_params ):
115+ condition = conditional ['condition' ]
116+ default = conditional .get ('default' )
117+
118+ if condition not in init_params :
119+ return default
120+
121+ condition_value = init_params [condition ]
122+ values = conditional ['values' ]
123+ return values .get (condition_value , default )
124+
125+ @classmethod
126+ def _get_tunable (cls , hyperparameters , init_params ):
127+ tunable = dict ()
128+ for name , param in hyperparameters .get ('tunable' , dict ()).items ():
129+ if name not in init_params :
130+ if param ['type' ] == 'conditional' :
131+ param = cls ._filter_conditional (param , init_params )
132+ if param is not None :
133+ tunable [name ] = param
134+
135+ else :
136+ tunable [name ] = param
137+
138+ return tunable
139+
110140 def __init__ (self , name , ** kwargs ):
111141
112142 self .name = name
@@ -133,13 +163,7 @@ def __init__(self, name, **kwargs):
133163 self ._fit_params = fit_params
134164 self ._produce_params = produce_params
135165
136- tunable = hyperparameters .get ('tunable' , dict ())
137- self ._tunable = {
138- name : param
139- for name , param in tunable .items ()
140- if name not in init_params
141- # TODO: filter conditionals
142- }
166+ self ._tunable = self ._get_tunable (hyperparameters , init_params )
143167
144168 default = {
145169 name : param ['default' ]
@@ -193,6 +217,7 @@ def set_hyperparameters(self, hyperparameters):
193217 self ._hyperparameters .update (hyperparameters )
194218
195219 if self ._class :
220+ LOGGER .debug ('Creating a new primitive instance for %s' , self .name )
196221 self .instance = self .primitive (** self ._hyperparameters )
197222
198223 def fit (self , ** kwargs ):
0 commit comments