@@ -164,12 +164,12 @@ def _fit_to_known(self, bootstrap=False, **fit_kwargs):
164164 """
165165 if not bootstrap :
166166 self .estimator .fit (self .X_training , self .y_training , ** fit_kwargs )
167- return self
168167 else :
169168 n_instances = len (self .X_training )
170169 bootstrap_idx = np .random .choice (range (n_instances ), n_instances , replace = True )
171170 self .estimator .fit (self .X_training [bootstrap_idx ], self .y_training [bootstrap_idx ], ** fit_kwargs )
172- return self
171+
172+ return self
173173
174174 def _fit_on_new (self , X , y , bootstrap = False , ** fit_kwargs ):
175175 """
@@ -194,11 +194,11 @@ def _fit_on_new(self, X, y, bootstrap=False, **fit_kwargs):
194194
195195 if not bootstrap :
196196 self .estimator .fit (X , y , ** fit_kwargs )
197- return self
198197 else :
199198 bootstrap_idx = np .random .choice (range (len (X )), len (X ), replace = True )
200199 self .estimator .fit (X [bootstrap_idx ], y [bootstrap_idx ])
201- return self
200+
201+ return self
202202
203203 def fit (self , X , y , bootstrap = False , ** fit_kwargs ):
204204 """
@@ -407,8 +407,7 @@ def teach(self, X, y, bootstrap=False, only_new=False, **fit_kwargs):
407407 Parameters
408408 ----------
409409 X: numpy.ndarray of shape (n_samples, n_features)
410- The new samples for which the labels are supplied
411- by the expert.
410+ The new samples for which the labels are supplied by the expert.
412411
413412 y: numpy.ndarray of shape (n_samples, )
414413 Labels corresponding to the new instances in X.
@@ -423,8 +422,7 @@ def teach(self, X, y, bootstrap=False, only_new=False, **fit_kwargs):
423422 doesn't retrain the model from scratch. (For example, in tensorflow or keras.)
424423
425424 fit_kwargs: keyword arguments
426- Keyword arguments to be passed to the fit method
427- of the predictor.
425+ Keyword arguments to be passed to the fit method of the predictor.
428426 """
429427 self ._add_training_data (X , y )
430428 if not only_new :
@@ -436,22 +434,50 @@ def teach(self, X, y, bootstrap=False, only_new=False, **fit_kwargs):
436434class BayesianOptimizer (BaseLearner ):
437435 def __init__ (self , * args , ** kwargs ):
438436 super (BayesianOptimizer , self ).__init__ (* args , ** kwargs )
439- self ._set_max ()
440-
441- def _set_max (self ):
437+ # setting the maximum value
442438 if self .y_training is not None :
443439 self .max_val = np .max (self .y_training )
444440 else :
445- self .max_val = None
441+ self .max_val = - np . inf
446442
447- def _check_max (self , y ):
448- if self .max_val is not None :
449- y_max = np .max (y )
450- if y_max > self .max_val :
451- self .max_val = y_max
443+ def _set_max (self , y ):
444+ y_max = np .max (y )
445+ if y_max > self .max_val :
446+ self .max_val = y_max
452447
453448 def teach (self ):
454- pass
449+ """
450+ Adds X and y to the known training data and retrains the predictor with the
451+ augmented dataset. This method also keeps track of the maximum value encountered
452+ in the training data.
453+
454+ Parameters
455+ ----------
456+ X: numpy.ndarray of shape (n_samples, n_features)
457+ The new samples for which the values are supplied.
458+
459+ y: numpy.ndarray of shape (n_samples, )
460+ Values corresponding to the new instances in X.
461+
462+ bootstrap: boolean
463+ If True, training is done on a bootstrapped dataset. Useful for building
464+ Committee models with bagging.
465+
466+ only_new: boolean
467+ If True, the model is retrained using only X and y, ignoring the previously
468+ provided examples. Useful when working with models where the .fit() method
469+ doesn't retrain the model from scratch. (For example, in tensorflow or keras.)
470+
471+ fit_kwargs: keyword arguments
472+ Keyword arguments to be passed to the fit method of the predictor.
473+ """
474+ self ._add_training_data (X , y )
475+ if not only_new :
476+ self ._fit_to_known (bootstrap = bootstrap , ** fit_kwargs )
477+ self ._set_max (y )
478+ else :
479+ self ._fit_on_new (X , y , bootstrap = bootstrap , ** fit_kwargs )
480+ self ._set_max (y )
455481
456482
457483class BaseCommittee (ABC , BaseEstimator ):
0 commit comments