Skip to content

Commit 40ebcb0

Browse files
committed
modAL.utils.selection.multi_argmax bugfix, now it works well with objects having shape (n, 1)
1 parent 820037b commit 40ebcb0

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

modAL/utils/selection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def multi_argmax(values, n_instances=1):
2525
"""
2626
assert n_instances <= len(values), 'n_instances must be less or equal than the size of utility'
2727

28-
max_idx = np.argpartition(-values, n_instances-1)[:n_instances]
28+
max_idx = np.argpartition(-values, n_instances-1, axis=0)[:n_instances]
2929
return max_idx
3030

3131

tests/core_tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ def test_multi_argmax(self):
366366
for n_instances in range(1, n_pool):
367367
utility = np.zeros(n_pool)
368368
max_idx = np.random.choice(range(n_pool), size=n_instances, replace=False)
369-
utility[max_idx] = 1.0
369+
utility[max_idx] = 1e-10 + np.random.rand(n_instances, )
370370
np.testing.assert_equal(
371371
np.sort(modAL.utils.selection.multi_argmax(utility, n_instances)),
372372
np.sort(max_idx)

0 commit comments

Comments
 (0)