Skip to content

Commit fd269ab

Browse files
committed
Bayesian optimization example added
1 parent 40ebcb0 commit fd269ab

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

examples/bayesian_optimization.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import numpy as np
2+
import matplotlib.pyplot as plt
3+
from sklearn.gaussian_process import GaussianProcessRegressor
4+
from sklearn.gaussian_process.kernels import Matern
5+
from modAL.models import BayesianOptimizer
6+
from modAL.acquisition import PI, EI, UCB, max_PI, max_EI, max_UCB
7+
8+
9+
# generating the data
10+
X = np.linspace(0, 20, 1000).reshape(-1, 1)
11+
y = np.sin(X)/2 - ((10 - X)**2)/50 + 2
12+
13+
# assembling initial training set
14+
X_initial, y_initial = X[150].reshape(1, -1), y[150].reshape(1, -1)
15+
16+
# defining the kernel for the Gaussian process
17+
kernel = Matern(length_scale=1.0)
18+
19+
# initializing the optimizer
20+
optimizer = BayesianOptimizer(
21+
estimator=GaussianProcessRegressor(kernel=kernel),
22+
X_training=X_initial, y_training=y_initial,
23+
query_strategy=max_EI
24+
)
25+
26+
# plotting the initial estimation
27+
with plt.style.context('seaborn-white'):
28+
plt.figure(figsize=(30, 6))
29+
for n_query in range(5):
30+
# plot current prediction
31+
plt.subplot(2, 5, n_query + 1)
32+
plt.title('Query no. %d' %(n_query + 1))
33+
if n_query == 0:
34+
plt.ylabel('Predictions')
35+
plt.ylim([-1.5, 3])
36+
pred, std = optimizer.predict(X.reshape(-1, 1), return_std=True)
37+
utility = EI(optimizer, X)
38+
plt.plot(X, pred)
39+
plt.fill_between(X.reshape(-1, ), pred.reshape(-1, ) - std, pred.reshape(-1, ) + std, alpha=0.2)
40+
plt.plot(X, y, c='k', linewidth=3)
41+
# plotting acquired values
42+
plt.scatter(optimizer.X_training[-1], optimizer.y_training[-1], c='w', s=40, zorder=20)
43+
plt.scatter(optimizer.X_training, optimizer.y_training, c='k', s=80, zorder=1)
44+
45+
plt.subplot(2, 5, 5 + n_query + 1)
46+
if n_query == 0:
47+
plt.ylabel('Expected improvement')
48+
plt.plot(X, 5*utility, c='r')
49+
plt.ylim([-0.1, 1])
50+
51+
# query
52+
query_idx, query_inst = optimizer.query(X)
53+
optimizer.teach(X[query_idx].reshape(1, -1), y[query_idx].reshape(1, -1))
54+
plt.show()

0 commit comments

Comments
 (0)