Skip to content

Commit 820037b

Browse files
committed
modAL.acquisition std shape bug fixed
1 parent 6bc0cb7 commit 820037b

File tree

3 files changed

+9
-6
lines changed

3 files changed

+9
-6
lines changed

examples/bayesian_optimization.py

Whitespace-only changes.

modAL/acquisition.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@
1414

1515
def PI(optimizer, X, tradeoff=0):
1616
mean, std = optimizer.predict(X, return_std=True)
17+
std = std.reshape(-1, 1)
1718

1819
return ndtr((mean - optimizer.max_val - tradeoff)/std)
1920

2021

2122
def EI(optimizer, X, tradeoff=0):
2223
mean, std = optimizer.predict(X, return_std=True)
24+
std = std.reshape(-1, 1)
2325
z = (mean - optimizer.max_val - tradeoff)/std
2426

2527
return (mean - optimizer.max_val - tradeoff)*ndtr(z) + std*norm.pdf(z)
@@ -30,6 +32,7 @@ def UCB(optimizer, X, beta=1):
3032
Ref: https://arxiv.org/abs/0912.3995
3133
"""
3234
mean, std = optimizer.predict(X, return_std=True)
35+
std = std.reshape(-1, 1)
3336

3437
return mean + beta*std
3538

tests/core_tests.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,8 @@ def test_make_query_strategy(self):
128128
class TestAcquisitionFunctions(unittest.TestCase):
129129
def test_PI(self):
130130
for n_samples in range(1, 100):
131-
mean = np.random.rand(n_samples, )
132-
std = np.random.rand(n_samples, )
131+
mean = np.random.rand(n_samples, 1)
132+
std = np.random.rand(n_samples, 1)
133133
tradeoff = np.random.rand()
134134
max_val = np.random.rand()
135135

@@ -147,8 +147,8 @@ def test_PI(self):
147147

148148
def test_EI(self):
149149
for n_samples in range(1, 100):
150-
mean = np.random.rand(n_samples, )
151-
std = np.random.rand(n_samples, )
150+
mean = np.random.rand(n_samples, 1)
151+
std = np.random.rand(n_samples, 1)
152152
tradeoff = np.random.rand()
153153
max_val = np.random.rand()
154154

@@ -169,8 +169,8 @@ def test_EI(self):
169169

170170
def test_UCB(self):
171171
for n_samples in range(1, 100):
172-
mean = np.random.rand(n_samples, )
173-
std = np.random.rand(n_samples, )
172+
mean = np.random.rand(n_samples, 1)
173+
std = np.random.rand(n_samples, 1)
174174
beta = np.random.rand()
175175

176176
mock_estimator = mock.MockEstimator(

0 commit comments

Comments
 (0)