Skip to content

Commit 58c71fb

Browse files
committed
modAL.acquisition UCB implemented
1 parent 634c326 commit 58c71fb

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

modAL/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .models import ActiveLearner, Committee, CommitteeRegressor
2-
from .acquisition import PI, max_PI
2+
from .acquisition import PI, EI, UCB, max_PI, max_EI, max_UCB
33
from .uncertainty import classifier_uncertainty, classifier_margin, classifier_entropy, \
44
uncertainty_sampling, margin_sampling, entropy_sampling
55
from .disagreement import vote_entropy, consensus_entropy, KL_max_disagreement, \
@@ -8,7 +8,7 @@
88

99
__all__ = [
1010
'ActiveLearner', 'Committee', 'CommitteeRegressor',
11-
'PI', 'max_PI',
11+
'PI', 'EI', 'UCB', 'max_PI', 'max_EI', 'max_UCB',
1212
'classifier_uncertainty', 'classifier_margin', 'classifier_entropy',
1313
'uncertainty_sampling', 'margin_sampling', 'entropy_sampling',
1414
'vote_entropy', 'consensus_entropy', 'KL_max_disagreement',

modAL/acquisition.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,15 @@ def EI(optimizer, X, tradeoff=0):
2525
return (mean - optimizer.max_val - tradeoff)*ndtr(z) + std*norm.pdf(z)
2626

2727

28+
def UCB(optimizer, X, beta=1):
29+
"""
30+
Ref: https://arxiv.org/abs/0912.3995
31+
"""
32+
mean, std = optimizer.predict(X, return_std=True)
33+
34+
return mean + beta*std
35+
36+
2837
def max_PI(optimizer, X, tradeoff=0, n_instances=1):
2938
pi = PI(optimizer, X, tradeoff=tradeoff)
3039
query_idx = multi_argmax(pi, n_instances=n_instances)
@@ -37,3 +46,10 @@ def max_EI(optimizer, X, tradeoff=0, n_instances=1):
3746
query_idx = multi_argmax(ei, n_instances=n_instances)
3847

3948
return query_idx, X[query_idx]
49+
50+
51+
def max_UCB(optimizer, X, beta=0, n_instances=1):
52+
ucb = UCB(optimizer, X, beta=beta)
53+
query_idx = multi_argmax(ucb, n_instances=n_instances)
54+
55+
return query_idx, X[query_idx]

0 commit comments

Comments
 (0)