Skip to content

Commit 2e22d09

Browse files
committed
modAL.models.BayesianOptimizer .teach(), ._set_max() tests rewritten according to the changes
1 parent 9b56bdf commit 2e22d09

File tree

1 file changed

+7
-11
lines changed

1 file changed

+7
-11
lines changed

tests/core_tests.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ def test_set_max(self):
438438
# case 1: the estimator is not fitted yet
439439
regressor = mock.MockClassifier()
440440
learner = modAL.models.BayesianOptimizer(estimator=regressor)
441-
self.assertEqual(None, learner.max_val)
441+
self.assertEqual(-np.inf, learner.max_val)
442442

443443
# case 2: the estimator is fitted already
444444
for n_samples in range(1, 100):
@@ -453,15 +453,15 @@ def test_set_max(self):
453453
)
454454
np.testing.assert_almost_equal(max_val, learner.max_val)
455455

456-
def test_check_max(self):
456+
def test_set_new_max(self):
457457
for n_reps in range(100):
458458
# case 1: the learner is not fitted yet
459459
for n_samples in range(1, 10):
460460
y = np.random.rand(n_samples)
461461
regressor = mock.MockClassifier()
462462
learner = modAL.models.BayesianOptimizer(estimator=regressor)
463-
learner._check_max(y)
464-
self.assertEqual(learner.max_val, None)
463+
learner._set_max(y)
464+
self.assertEqual(learner.max_val, np.max(y))
465465

466466
# case 2: new value is not a maximum
467467
for n_samples in range(1, 10):
@@ -476,7 +476,7 @@ def test_check_max(self):
476476

477477
y_new = y - np.random.rand()
478478
old_max = learner.max_val
479-
learner._check_max(y_new)
479+
learner._set_max(y_new)
480480
np.testing.assert_almost_equal(old_max, learner.max_val)
481481

482482
# case 3: new value is a maximum
@@ -491,14 +491,10 @@ def test_check_max(self):
491491
)
492492

493493
y_new = y + np.random.rand()
494-
learner._check_max(y_new)
494+
learner._set_max(y_new)
495495
np.testing.assert_almost_equal(np.max(y_new), learner.max_val)
496496

497497

498-
499-
500-
501-
502498
class TestCommittee(unittest.TestCase):
503499

504500
def test_set_classes(self):
@@ -649,4 +645,4 @@ def test_examples(self):
649645

650646

651647
if __name__ == '__main__':
652-
unittest.main()
648+
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)