Skip to content

Commit 84e9ce9

Browse files
committed
modAL.models.BayesianOptimizer.teach() tested
1 parent 2b513c9 commit 84e9ce9

File tree

1 file changed

+26
-1
lines changed

1 file changed

+26
-1
lines changed

tests/core_tests.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import modAL.utils.combination
1313

1414
from copy import deepcopy
15-
from itertools import chain
15+
from itertools import chain, product
1616
from collections import namedtuple
1717
from sklearn.ensemble import RandomForestClassifier
1818
from sklearn.metrics import confusion_matrix
@@ -494,6 +494,31 @@ def test_set_new_max(self):
494494
learner._set_max(y_new)
495495
np.testing.assert_almost_equal(np.max(y_new), learner.max_val)
496496

497+
def test_teach(self):
498+
# case 1. optimizer is uninitialized
499+
for bootstrap, only_new in product([True, False], [True, False]):
500+
for n_samples in range(1, 100):
501+
for n_features in range(1, 100):
502+
regressor = mock.MockClassifier()
503+
learner = modAL.models.BayesianOptimizer(estimator=regressor)
504+
505+
X = np.random.rand(n_samples, 2)
506+
y = np.random.rand(n_samples)
507+
learner.teach(X, y, bootstrap=bootstrap, only_new=only_new)
508+
509+
# case 2. optimizer is initialized
510+
for n_samples in range(1, 100):
511+
for n_features in range(1, 100):
512+
X = np.random.rand(n_samples, 2)
513+
y = np.random.rand(n_samples)
514+
515+
regressor = mock.MockClassifier()
516+
learner = modAL.models.BayesianOptimizer(
517+
estimator=regressor,
518+
X_training=X, y_training=y
519+
)
520+
learner.teach(X, y, bootstrap=bootstrap, only_new=only_new)
521+
497522

498523
class TestCommittee(unittest.TestCase):
499524

0 commit comments

Comments
 (0)