Skip to content

Commit 9b56bdf

Browse files
committed
modAL.models.BayesianOptimizer .teach() method added, _set_max(), _check_max() methods refactored: max_val set upon initialization, _check_max() removed, _set_max() takes the job of both
1 parent 8d9f0ca commit 9b56bdf

File tree

1 file changed

+44
-18
lines changed

1 file changed

+44
-18
lines changed

modAL/models.py

Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -164,12 +164,12 @@ def _fit_to_known(self, bootstrap=False, **fit_kwargs):
164164
"""
165165
if not bootstrap:
166166
self.estimator.fit(self.X_training, self.y_training, **fit_kwargs)
167-
return self
168167
else:
169168
n_instances = len(self.X_training)
170169
bootstrap_idx = np.random.choice(range(n_instances), n_instances, replace=True)
171170
self.estimator.fit(self.X_training[bootstrap_idx], self.y_training[bootstrap_idx], **fit_kwargs)
172-
return self
171+
172+
return self
173173

174174
def _fit_on_new(self, X, y, bootstrap=False, **fit_kwargs):
175175
"""
@@ -194,11 +194,11 @@ def _fit_on_new(self, X, y, bootstrap=False, **fit_kwargs):
194194

195195
if not bootstrap:
196196
self.estimator.fit(X, y, **fit_kwargs)
197-
return self
198197
else:
199198
bootstrap_idx = np.random.choice(range(len(X)), len(X), replace=True)
200199
self.estimator.fit(X[bootstrap_idx], y[bootstrap_idx])
201-
return self
200+
201+
return self
202202

203203
def fit(self, X, y, bootstrap=False, **fit_kwargs):
204204
"""
@@ -407,8 +407,7 @@ def teach(self, X, y, bootstrap=False, only_new=False, **fit_kwargs):
407407
Parameters
408408
----------
409409
X: numpy.ndarray of shape (n_samples, n_features)
410-
The new samples for which the labels are supplied
411-
by the expert.
410+
The new samples for which the labels are supplied by the expert.
412411
413412
y: numpy.ndarray of shape (n_samples, )
414413
Labels corresponding to the new instances in X.
@@ -423,8 +422,7 @@ def teach(self, X, y, bootstrap=False, only_new=False, **fit_kwargs):
423422
doesn't retrain the model from scratch. (For example, in tensorflow or keras.)
424423
425424
fit_kwargs: keyword arguments
426-
Keyword arguments to be passed to the fit method
427-
of the predictor.
425+
Keyword arguments to be passed to the fit method of the predictor.
428426
"""
429427
self._add_training_data(X, y)
430428
if not only_new:
@@ -436,22 +434,50 @@ def teach(self, X, y, bootstrap=False, only_new=False, **fit_kwargs):
436434
class BayesianOptimizer(BaseLearner):
437435
def __init__(self, *args, **kwargs):
438436
super(BayesianOptimizer, self).__init__(*args, **kwargs)
439-
self._set_max()
440-
441-
def _set_max(self):
437+
# setting the maximum value
442438
if self.y_training is not None:
443439
self.max_val = np.max(self.y_training)
444440
else:
445-
self.max_val = None
441+
self.max_val = -np.inf
446442

447-
def _check_max(self, y):
448-
if self.max_val is not None:
449-
y_max = np.max(y)
450-
if y_max > self.max_val:
451-
self.max_val = y_max
443+
def _set_max(self, y):
444+
y_max = np.max(y)
445+
if y_max > self.max_val:
446+
self.max_val = y_max
452447

453448
def teach(self):
454-
pass
449+
"""
450+
Adds X and y to the known training data and retrains the predictor with the
451+
augmented dataset. This method also keeps track of the maximum value encountered
452+
in the training data.
453+
454+
Parameters
455+
----------
456+
X: numpy.ndarray of shape (n_samples, n_features)
457+
The new samples for which the values are supplied.
458+
459+
y: numpy.ndarray of shape (n_samples, )
460+
Values corresponding to the new instances in X.
461+
462+
bootstrap: boolean
463+
If True, training is done on a bootstrapped dataset. Useful for building
464+
Committee models with bagging.
465+
466+
only_new: boolean
467+
If True, the model is retrained using only X and y, ignoring the previously
468+
provided examples. Useful when working with models where the .fit() method
469+
doesn't retrain the model from scratch. (For example, in tensorflow or keras.)
470+
471+
fit_kwargs: keyword arguments
472+
Keyword arguments to be passed to the fit method of the predictor.
473+
"""
474+
self._add_training_data(X, y)
475+
if not only_new:
476+
self._fit_to_known(bootstrap=bootstrap, **fit_kwargs)
477+
self._set_max(y)
478+
else:
479+
self._fit_on_new(X, y, bootstrap=bootstrap, **fit_kwargs)
480+
self._set_max(y)
455481

456482

457483
class BaseCommittee(ABC, BaseEstimator):

0 commit comments

Comments
 (0)