|
| 1 | +import numpy as np |
| 2 | +import matplotlib.pyplot as plt |
| 3 | +from sklearn.gaussian_process import GaussianProcessRegressor |
| 4 | +from sklearn.gaussian_process.kernels import Matern |
| 5 | +from modAL.models import BayesianOptimizer |
| 6 | +from modAL.acquisition import PI, EI, UCB, max_PI, max_EI, max_UCB |
| 7 | + |
| 8 | + |
| 9 | +# generating the data |
| 10 | +X = np.linspace(0, 20, 1000).reshape(-1, 1) |
| 11 | +y = np.sin(X)/2 - ((10 - X)**2)/50 + 2 |
| 12 | + |
| 13 | +# assembling initial training set |
| 14 | +X_initial, y_initial = X[150].reshape(1, -1), y[150].reshape(1, -1) |
| 15 | + |
| 16 | +# defining the kernel for the Gaussian process |
| 17 | +kernel = Matern(length_scale=1.0) |
| 18 | + |
| 19 | +# initializing the optimizer |
| 20 | +optimizer = BayesianOptimizer( |
| 21 | + estimator=GaussianProcessRegressor(kernel=kernel), |
| 22 | + X_training=X_initial, y_training=y_initial, |
| 23 | + query_strategy=max_EI |
| 24 | +) |
| 25 | + |
| 26 | +# plotting the initial estimation |
| 27 | +with plt.style.context('seaborn-white'): |
| 28 | + plt.figure(figsize=(30, 6)) |
| 29 | + for n_query in range(5): |
| 30 | + # plot current prediction |
| 31 | + plt.subplot(2, 5, n_query + 1) |
| 32 | + plt.title('Query no. %d' %(n_query + 1)) |
| 33 | + if n_query == 0: |
| 34 | + plt.ylabel('Predictions') |
| 35 | + plt.ylim([-1.5, 3]) |
| 36 | + pred, std = optimizer.predict(X.reshape(-1, 1), return_std=True) |
| 37 | + utility = EI(optimizer, X) |
| 38 | + plt.plot(X, pred) |
| 39 | + plt.fill_between(X.reshape(-1, ), pred.reshape(-1, ) - std, pred.reshape(-1, ) + std, alpha=0.2) |
| 40 | + plt.plot(X, y, c='k', linewidth=3) |
| 41 | + # plotting acquired values |
| 42 | + plt.scatter(optimizer.X_training[-1], optimizer.y_training[-1], c='w', s=40, zorder=20) |
| 43 | + plt.scatter(optimizer.X_training, optimizer.y_training, c='k', s=80, zorder=1) |
| 44 | + |
| 45 | + plt.subplot(2, 5, 5 + n_query + 1) |
| 46 | + if n_query == 0: |
| 47 | + plt.ylabel('Expected improvement') |
| 48 | + plt.plot(X, 5*utility, c='r') |
| 49 | + plt.ylim([-0.1, 1]) |
| 50 | + |
| 51 | + # query |
| 52 | + query_idx, query_inst = optimizer.query(X) |
| 53 | + optimizer.teach(X[query_idx].reshape(1, -1), y[query_idx].reshape(1, -1)) |
| 54 | + plt.show() |
0 commit comments