diff --git a/python/cuvs_bench/cuvs_bench/generate_groundtruth/__main__.py b/python/cuvs_bench/cuvs_bench/generate_groundtruth/__main__.py index 1579e4ef1b..43c03f4322 100644 --- a/python/cuvs_bench/cuvs_bench/generate_groundtruth/__main__.py +++ b/python/cuvs_bench/cuvs_bench/generate_groundtruth/__main__.py @@ -201,7 +201,8 @@ def calc_truth(dataset, queries, k, metric="sqeuclidean"): else: distances = xp.concatenate([distances, D], axis=1) indices = xp.concatenate([indices, Ind], axis=1) - idx = xp.argsort(distances, axis=1)[:, :k] + sort_keys = -distances if metric == "inner_product" else distances + idx = xp.argsort(sort_keys, axis=1)[:, :k] distances = xp.take_along_axis(distances, idx, axis=1) indices = xp.take_along_axis(indices, idx, axis=1)