Skip to content

Commit 634c326

Browse files
committed
modAL.acquisition EI implemented
1 parent a47e4e9 commit 634c326

File tree

2 files changed

+40
-2
lines changed

2 files changed

+40
-2
lines changed

modAL/acquisition.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,33 @@
77
-----------------------------------------------
88
"""
99

10+
from scipy.stats import norm
1011
from scipy.special import ndtr
1112
from modAL.utils.selection import multi_argmax
1213

1314

1415
def PI(optimizer, X, tradeoff=0):
1516
mean, std = optimizer.predict(X, return_std=True)
17+
1618
return ndtr((mean - optimizer.max_val - tradeoff)/std)
1719

1820

21+
def EI(optimizer, X, tradeoff=0):
22+
mean, std = optimizer.predict(X, return_std=True)
23+
z = (mean - optimizer.max_val - tradeoff)/std
24+
25+
return (mean - optimizer.max_val - tradeoff)*ndtr(z) + std*norm.pdf(z)
26+
27+
1928
def max_PI(optimizer, X, tradeoff=0, n_instances=1):
2029
pi = PI(optimizer, X, tradeoff=tradeoff)
2130
query_idx = multi_argmax(pi, n_instances=n_instances)
2231

23-
return query_idx, X[query_idx]
32+
return query_idx, X[query_idx]
33+
34+
35+
def max_EI(optimizer, X, tradeoff=0, n_instances=1):
36+
ei = EI(optimizer, X, tradeoff=tradeoff)
37+
query_idx = multi_argmax(ei, n_instances=n_instances)
38+
39+
return query_idx, X[query_idx]

tests/core_tests.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from collections import namedtuple
1717
from sklearn.ensemble import RandomForestClassifier
1818
from sklearn.metrics import confusion_matrix
19-
from scipy.stats import entropy
19+
from scipy.stats import entropy, norm
2020
from scipy.special import ndtr
2121

2222

@@ -145,6 +145,28 @@ def test_PI(self):
145145
modAL.acquisition.PI(optimizer, np.random.rand(n_samples, 2), tradeoff)
146146
)
147147

148+
def test_EI(self):
149+
for n_samples in range(1, 100):
150+
mean = np.random.rand(n_samples, )
151+
std = np.random.rand(n_samples, )
152+
tradeoff = np.random.rand()
153+
max_val = np.random.rand()
154+
155+
mock_estimator = mock.MockEstimator(
156+
predict_return=(mean, std)
157+
)
158+
159+
optimizer = modAL.models.BayesianOptimizer(estimator=mock_estimator)
160+
optimizer._set_max([max_val])
161+
162+
true_EI = (mean - optimizer.max_val - tradeoff) * ndtr((mean - optimizer.max_val - tradeoff)/std)\
163+
+ std * norm.pdf((mean - optimizer.max_val - tradeoff)/std)
164+
165+
np.testing.assert_almost_equal(
166+
true_EI,
167+
modAL.acquisition.EI(optimizer, np.random.rand(n_samples, 2), tradeoff)
168+
)
169+
148170

149171
class TestUncertainties(unittest.TestCase):
150172

0 commit comments

Comments
 (0)