Skip to content

Commit 8d9f0ca

Browse files
committed
modAL.models.BayesianOptimizer._check_max() tested
1 parent 694337b commit 8d9f0ca

File tree

1 file changed

+45
-2
lines changed

1 file changed

+45
-2
lines changed

tests/core_tests.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,6 @@ def test_set_max(self):
438438
# case 1: the estimator is not fitted yet
439439
regressor = mock.MockClassifier()
440440
learner = modAL.models.BayesianOptimizer(estimator=regressor)
441-
#learner._set_max()
442441
self.assertEqual(None, learner.max_val)
443442

444443
# case 2: the estimator is fitted already
@@ -452,9 +451,53 @@ def test_set_max(self):
452451
estimator=regressor,
453452
X_training=X, y_training=y
454453
)
455-
#learner._set_max()
456454
np.testing.assert_almost_equal(max_val, learner.max_val)
457455

456+
def test_check_max(self):
457+
for n_reps in range(100):
458+
# case 1: the learner is not fitted yet
459+
for n_samples in range(1, 10):
460+
y = np.random.rand(n_samples)
461+
regressor = mock.MockClassifier()
462+
learner = modAL.models.BayesianOptimizer(estimator=regressor)
463+
learner._check_max(y)
464+
self.assertEqual(learner.max_val, None)
465+
466+
# case 2: new value is not a maximum
467+
for n_samples in range(1, 10):
468+
X = np.random.rand(n_samples, 2)
469+
y = np.random.rand(n_samples)
470+
471+
regressor = mock.MockClassifier()
472+
learner = modAL.models.BayesianOptimizer(
473+
estimator=regressor,
474+
X_training=X, y_training=y
475+
)
476+
477+
y_new = y - np.random.rand()
478+
old_max = learner.max_val
479+
learner._check_max(y_new)
480+
np.testing.assert_almost_equal(old_max, learner.max_val)
481+
482+
# case 3: new value is a maximum
483+
for n_samples in range(1, 10):
484+
X = np.random.rand(n_samples, 2)
485+
y = np.random.rand(n_samples)
486+
487+
regressor = mock.MockClassifier()
488+
learner = modAL.models.BayesianOptimizer(
489+
estimator=regressor,
490+
X_training=X, y_training=y
491+
)
492+
493+
y_new = y + np.random.rand()
494+
learner._check_max(y_new)
495+
np.testing.assert_almost_equal(np.max(y_new), learner.max_val)
496+
497+
498+
499+
500+
458501

459502
class TestCommittee(unittest.TestCase):
460503

0 commit comments

Comments
 (0)