Skip to content

Commit cac4c1e

Browse files
authored
Standardize _get_scores() to return (np.ndarray, np.ndarray) across all strategies (#197)
1 parent ae33df7 commit cac4c1e

17 files changed

Lines changed: 645 additions & 158 deletions

libact/base/interfaces.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,27 @@ def update(self, entry_id, label):
3939
pass
4040

4141
def _get_scores(self):
42-
"""Return the score used for making query, the larger the better. Read-only.
42+
"""Return acquisition scores for all unlabeled samples.
4343
44-
No modification to the internal states.
44+
Subclasses should override this method to enable batch mode queries
45+
and score-based strategy composition.
4546
4647
Returns
4748
-------
48-
(ask_id, scores): list of tuple (int, float)
49-
The index of the next unlabeled sample to be queried and the score assigned.
49+
entry_ids : np.ndarray, shape (n_unlabeled,)
50+
Global entry IDs of unlabeled samples.
51+
scores : np.ndarray, shape (n_unlabeled,)
52+
Acquisition scores. Higher = more informative.
53+
54+
Raises
55+
------
56+
NotImplementedError
57+
If the strategy does not support per-sample scoring.
5058
"""
51-
pass
59+
raise NotImplementedError(
60+
f"{self.__class__.__name__} does not implement _get_scores(). "
61+
"This is required for batch mode and score-based composition."
62+
)
5263

5364
@abstractmethod
5465
def make_query(self):

libact/query_strategies/bald.py

Lines changed: 16 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -186,52 +186,14 @@ def update(self, entry_id, label):
186186
# Retrain ensemble with the new labeled data
187187
self._train_ensemble()
188188

189-
@inherit_docstring_from(QueryStrategy)
190-
def make_query(self):
191-
dataset = self.dataset
192-
unlabeled_entry_ids, X_pool = dataset.get_unlabeled_entries()
193-
X_pool = np.asarray(X_pool)
194-
195-
if len(unlabeled_entry_ids) == 0:
196-
raise ValueError("No unlabeled samples available")
197-
198-
# Get predictions from all models
199-
all_proba = []
200-
for model in self.models:
201-
proba = model.predict_proba(X_pool)
202-
all_proba.append(np.asarray(proba))
203-
204-
all_proba = np.array(all_proba) # shape: (n_models, n_samples, n_classes)
205-
206-
# Calculate BALD score: H[mean(P)] - mean(H[P])
207-
# Mean probability across ensemble
208-
mean_proba = np.mean(all_proba, axis=0) # shape: (n_samples, n_classes)
209-
210-
# Entropy of mean predictions (total uncertainty)
211-
entropy_mean = self._entropy(mean_proba) # shape: (n_samples,)
212-
213-
# Mean entropy across models (expected data uncertainty)
214-
entropies = np.array([self._entropy(p) for p in all_proba]) # shape: (n_models, n_samples)
215-
mean_entropy = np.mean(entropies, axis=0) # shape: (n_samples,)
216-
217-
# BALD score = mutual information
218-
bald_scores = entropy_mean - mean_entropy # shape: (n_samples,)
219-
220-
# Select sample with highest BALD score (break ties randomly)
221-
max_score = np.max(bald_scores)
222-
candidates = np.where(np.isclose(bald_scores, max_score))[0]
223-
selected_idx = self.random_state_.choice(candidates)
224-
225-
return unlabeled_entry_ids[selected_idx]
226-
227189
def _get_scores(self):
228190
"""Return BALD scores for all unlabeled samples."""
229191
dataset = self.dataset
230192
unlabeled_entry_ids, X_pool = dataset.get_unlabeled_entries()
231193
X_pool = np.asarray(X_pool)
232194

233195
if len(unlabeled_entry_ids) == 0:
234-
return []
196+
return np.array([], dtype=int), np.array([], dtype=float)
235197

236198
# Get predictions from all models
237199
all_proba = np.array([
@@ -245,4 +207,18 @@ def _get_scores(self):
245207
mean_entropy = np.mean(entropies, axis=0)
246208
bald_scores = entropy_mean - mean_entropy
247209

248-
return list(zip(unlabeled_entry_ids, bald_scores))
210+
return np.asarray(unlabeled_entry_ids), bald_scores
211+
212+
@inherit_docstring_from(QueryStrategy)
213+
def make_query(self):
214+
unlabeled_entry_ids, bald_scores = self._get_scores()
215+
216+
if len(unlabeled_entry_ids) == 0:
217+
raise ValueError("No unlabeled samples available")
218+
219+
# Select sample with highest BALD score (break ties randomly)
220+
max_score = np.max(bald_scores)
221+
candidates = np.where(np.isclose(bald_scores, max_score))[0]
222+
selected_idx = self.random_state_.choice(candidates)
223+
224+
return unlabeled_entry_ids[selected_idx]

libact/query_strategies/coreset.py

Lines changed: 28 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -79,67 +79,30 @@ def __init__(self, dataset, **kwargs):
7979
random_state = kwargs.pop('random_state', None)
8080
self.random_state_ = seed_random_state(random_state)
8181

82-
@inherit_docstring_from(QueryStrategy)
83-
def make_query(self):
84-
dataset = self.dataset
85-
unlabeled_entry_ids, X_pool = dataset.get_unlabeled_entries()
86-
X_pool = np.asarray(X_pool)
87-
88-
if len(unlabeled_entry_ids) == 0:
89-
raise ValueError("No unlabeled samples available")
90-
91-
# Get labeled data
92-
labeled_entries = dataset.get_labeled_entries()
93-
X_labeled = np.asarray(labeled_entries[0])
94-
95-
# Fallback to random if no labeled data
96-
if len(X_labeled) == 0:
97-
idx = self.random_state_.randint(0, len(unlabeled_entry_ids))
98-
return unlabeled_entry_ids[idx]
99-
100-
# Transform features if transformer is provided
101-
if self.transformer is not None:
102-
X_pool_t = np.asarray(self.transformer.transform(X_pool))
103-
X_labeled_t = np.asarray(self.transformer.transform(X_labeled))
104-
else:
105-
X_pool_t = X_pool
106-
X_labeled_t = X_labeled
107-
108-
# Compute pairwise distances: (n_unlabeled, n_labeled)
109-
dist_matrix = cdist(X_pool_t, X_labeled_t, metric=self.metric)
110-
111-
# For each unlabeled point, find minimum distance to any labeled point
112-
min_distances = np.min(dist_matrix, axis=1)
113-
114-
# Select the unlabeled point with maximum min-distance (farthest)
115-
max_dist = np.max(min_distances)
116-
candidates = np.where(np.isclose(min_distances, max_dist))[0]
117-
selected_idx = self.random_state_.choice(candidates)
118-
119-
return unlabeled_entry_ids[selected_idx]
120-
12182
def _get_scores(self):
12283
"""Return min-distances to labeled set for all unlabeled samples.
12384
12485
Returns
12586
-------
126-
scores : list of (entry_id, score) tuples
127-
Each score is the minimum distance from that unlabeled point
128-
to any labeled point. Higher score means more informative.
87+
entry_ids : np.ndarray, shape (n_unlabeled,)
88+
Global entry IDs of unlabeled samples.
89+
scores : np.ndarray, shape (n_unlabeled,)
90+
Min-distance from each unlabeled point to any labeled point.
91+
Higher score means more informative.
12992
"""
13093
dataset = self.dataset
13194
unlabeled_entry_ids, X_pool = dataset.get_unlabeled_entries()
13295
X_pool = np.asarray(X_pool)
13396

13497
if len(unlabeled_entry_ids) == 0:
135-
return []
98+
return np.array([], dtype=int), np.array([], dtype=float)
13699

137100
labeled_entries = dataset.get_labeled_entries()
138101
X_labeled = np.asarray(labeled_entries[0])
139102

140103
if len(X_labeled) == 0:
141-
return list(zip(unlabeled_entry_ids,
142-
[float('inf')] * len(unlabeled_entry_ids)))
104+
return np.asarray(unlabeled_entry_ids), \
105+
np.full(len(unlabeled_entry_ids), float('inf'))
143106

144107
if self.transformer is not None:
145108
X_pool_t = np.asarray(self.transformer.transform(X_pool))
@@ -151,4 +114,23 @@ def _get_scores(self):
151114
dist_matrix = cdist(X_pool_t, X_labeled_t, metric=self.metric)
152115
min_distances = np.min(dist_matrix, axis=1)
153116

154-
return list(zip(unlabeled_entry_ids, min_distances))
117+
return np.asarray(unlabeled_entry_ids), min_distances
118+
119+
@inherit_docstring_from(QueryStrategy)
120+
def make_query(self):
121+
unlabeled_entry_ids, min_distances = self._get_scores()
122+
123+
if len(unlabeled_entry_ids) == 0:
124+
raise ValueError("No unlabeled samples available")
125+
126+
# Fallback to random if no labeled data (scores are all inf)
127+
if np.all(np.isinf(min_distances)):
128+
idx = self.random_state_.randint(0, len(unlabeled_entry_ids))
129+
return unlabeled_entry_ids[idx]
130+
131+
# Select the unlabeled point with maximum min-distance (farthest)
132+
max_dist = np.max(min_distances)
133+
candidates = np.where(np.isclose(min_distances, max_dist))[0]
134+
selected_idx = self.random_state_.choice(candidates)
135+
136+
return unlabeled_entry_ids[selected_idx]

libact/query_strategies/density_weighted_meta.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,12 @@ def update(self, entry_id, label):
9999
@inherit_docstring_from(QueryStrategy)
100100
def _get_scores(self):
101101
dataset = self.dataset
102-
X, _ = zip(*dataset.data)
103-
scores = self.base_query_strategy._get_scores()
104-
_, X_pool = dataset.get_unlabeled_entries()
105-
unlabeled_entry_ids, base_scores = zip(*scores)
102+
X, _ = dataset.get_entries()
103+
unlabeled_entry_ids, X_pool = dataset.get_unlabeled_entries()
104+
105+
if len(unlabeled_entry_ids) == 0:
106+
return np.array([], dtype=int), np.array([], dtype=float)
107+
_, base_scores = self.base_query_strategy._get_scores()
106108

107109
self.clustering_method.fit(X)
108110
pool_cluster = self.clustering_method.predict(X_pool)
@@ -118,13 +120,16 @@ def _get_scores(self):
118120
similarity = np.asarray(similarity)
119121

120122
scores = base_scores * similarity**self.beta
121-
return zip(unlabeled_entry_ids, scores)
123+
return np.asarray(unlabeled_entry_ids), np.asarray(scores)
122124

123125
@inherit_docstring_from(QueryStrategy)
124126
def make_query(self):
125-
dataset = self.dataset
127+
unlabeled_entry_ids, scores = self._get_scores()
128+
129+
if len(unlabeled_entry_ids) == 0:
130+
raise ValueError("No unlabeled samples available")
126131

127-
unlabeled_entry_ids, scores = zip(*self._get_scores())
128-
ask_id = self.random_state_.choice(np.where(scores == np.max(scores))[0])
132+
ask_id = self.random_state_.choice(
133+
np.where(np.isclose(scores, np.max(scores)))[0])
129134

130135
return unlabeled_entry_ids[ask_id]

libact/query_strategies/epsilon_uncertainty_sampling.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,18 +170,21 @@ def _get_scores(self):
170170
171171
Returns
172172
-------
173-
scores : list of (entry_id, score) tuples
173+
entry_ids : np.ndarray, shape (n_unlabeled,)
174+
Global entry IDs of unlabeled samples.
175+
scores : np.ndarray, shape (n_unlabeled,)
176+
Uncertainty scores. Higher = more uncertain.
174177
"""
175178
dataset = self.dataset
176179
self.model.train(dataset)
177180
unlabeled_entry_ids, X_pool = dataset.get_unlabeled_entries()
178181
X_pool = np.asarray(X_pool)
179182

180183
if len(unlabeled_entry_ids) == 0:
181-
return []
184+
return np.array([], dtype=int), np.array([], dtype=float)
182185

183186
scores = self._get_uncertainty_scores(X_pool)
184-
return list(zip(unlabeled_entry_ids, scores))
187+
return np.asarray(unlabeled_entry_ids), np.asarray(scores)
185188

186189
@inherit_docstring_from(QueryStrategy)
187190
def make_query(self, return_score=False):
@@ -207,7 +210,8 @@ def make_query(self, return_score=False):
207210
ask_id = unlabeled_entry_ids[selected_idx]
208211

209212
if return_score:
210-
return ask_id, self._get_scores()
213+
entry_ids, scores = self._get_scores()
214+
return ask_id, list(zip(entry_ids, scores))
211215
else:
212216
return ask_id
213217

libact/query_strategies/hintsvm.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,10 +129,22 @@ def __init__(self, *args, **kwargs):
129129

130130
self.svm_params['C'] = self.cl
131131

132-
@inherit_docstring_from(QueryStrategy)
133-
def make_query(self):
132+
def _get_scores(self):
133+
"""Return absolute decision values for all unlabeled samples.
134+
135+
Returns
136+
-------
137+
entry_ids : np.ndarray, shape (n_unlabeled,)
138+
Global entry IDs of unlabeled samples.
139+
scores : np.ndarray, shape (n_unlabeled,)
140+
Absolute decision values from HintSVM. Higher = more informative.
141+
"""
134142
dataset = self.dataset
135143
unlabeled_entry_ids, unlabeled_pool = dataset.get_unlabeled_entries()
144+
145+
if len(unlabeled_entry_ids) == 0:
146+
return np.array([], dtype=int), np.array([], dtype=float)
147+
136148
labeled_pool, y = dataset.get_labeled_entries()
137149
if len(np.unique(y)) > 2:
138150
raise ValueError("HintSVM query strategy support binary class "
@@ -155,6 +167,15 @@ def make_query(self):
155167
np.array(unlabeled_pool, dtype=np.float64),
156168
self.svm_params)
157169

158-
p_val = [abs(float(val[0])) for val in p_val]
159-
idx = int(np.argmax(p_val))
170+
scores = np.array([abs(float(val[0])) for val in p_val])
171+
return np.asarray(unlabeled_entry_ids), scores
172+
173+
@inherit_docstring_from(QueryStrategy)
174+
def make_query(self):
175+
unlabeled_entry_ids, scores = self._get_scores()
176+
177+
if len(unlabeled_entry_ids) == 0:
178+
raise ValueError("No unlabeled samples available")
179+
180+
idx = int(np.argmax(scores))
160181
return unlabeled_entry_ids[idx]

libact/query_strategies/information_density.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def _get_scores(self):
197197
X_pool = np.asarray(X_pool)
198198

199199
if len(unlabeled_entry_ids) == 0:
200-
return []
200+
return np.array([], dtype=int), np.array([], dtype=float)
201201

202202
uncertainty = self._uncertainty_scores(X_pool)
203203
# Ensure non-negative uncertainty (ContinuousModel predict_real can
@@ -209,26 +209,22 @@ def _get_scores(self):
209209

210210
scores = uncertainty * (density ** self.beta)
211211

212-
return list(zip(unlabeled_entry_ids, scores))
212+
return np.asarray(unlabeled_entry_ids), scores
213213

214214
@inherit_docstring_from(QueryStrategy)
215215
def make_query(self, return_score=False):
216-
dataset = self.dataset
217-
unlabeled_entry_ids, _ = dataset.get_unlabeled_entries()
216+
entry_ids, score_values = self._get_scores()
218217

219-
if len(unlabeled_entry_ids) == 0:
218+
if len(entry_ids) == 0:
220219
raise ValueError("No unlabeled samples available")
221220

222-
scores = self._get_scores()
223-
entry_ids, score_values = zip(*scores)
224-
score_values = np.asarray(list(score_values))
225-
226221
max_score = np.max(score_values)
227222
candidates = np.where(np.isclose(score_values, max_score))[0]
228223
selected_idx = self.random_state_.choice(candidates)
229224

230225
if return_score:
231-
return entry_ids[selected_idx], scores
226+
return entry_ids[selected_idx], \
227+
list(zip(entry_ids, score_values))
232228
else:
233229
return entry_ids[selected_idx]
234230

0 commit comments

Comments
 (0)