Skip to content

Commit 799a9fd

Browse files
committed
fix: track actual selection count in vec0_mmr_rerank
The copy-back loop iterated k_target times, but the greedy selection loop can terminate early via `if (best_idx < 0) break`, leaving the tail of out_rowids/out_distances uninitialized. Add an n_selected counter and out_n_selected output parameter so only actually-selected entries are copied back. The caller now sets k_used = n_selected instead of k_used = k_original. Credit: mceachen (vlasky#6)
1 parent be9c9dd commit 799a9fd

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

sqlite-vec.c

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7001,7 +7001,9 @@ static f32 vec0_compute_distance(struct VectorColumnDefinition *vector_column,
70017001
* cosine similarity between d and any already-selected result.
70027002
*
70037003
* Reorders topk_rowids and topk_distances in place.
7004-
* After return, the first k_target entries are the MMR-selected results.
7004+
* After return, the first *out_n_selected entries are the MMR-selected results.
7005+
* out_n_selected may be less than k_target if the greedy selection exhausts
7006+
* candidates early.
70057007
*/
70067008
static int vec0_mmr_rerank(
70077009
vec0_vtab *p,
@@ -7011,7 +7013,8 @@ static int vec0_mmr_rerank(
70117013
f32 *topk_distances,
70127014
i64 k_used,
70137015
i64 k_target,
7014-
f32 mmr_lambda
7016+
f32 mmr_lambda,
7017+
i64 *out_n_selected
70157018
) {
70167019
int rc = SQLITE_OK;
70177020

@@ -7056,6 +7059,7 @@ static int vec0_mmr_rerank(
70567059
}
70577060
memset(selected, 0, k_used);
70587061

7062+
i64 n_selected = 0;
70597063
for (i64 step = 0; step < k_target && step < k_used; step++) {
70607064
f32 best_mmr = -FLT_MAX;
70617065
i64 best_idx = -1;
@@ -7085,13 +7089,15 @@ static int vec0_mmr_rerank(
70857089
out_rowids[step] = topk_rowids[best_idx];
70867090
out_distances[step] = topk_distances[best_idx];
70877091
out_vectors[step] = vectors[best_idx];
7092+
n_selected++;
70887093
}
70897094

7090-
// 5. Copy results back to input arrays
7091-
for (i64 i = 0; i < k_target; i++) {
7095+
// 5. Copy only the actually-selected entries back to input arrays
7096+
for (i64 i = 0; i < n_selected; i++) {
70927097
topk_rowids[i] = out_rowids[i];
70937098
topk_distances[i] = out_distances[i];
70947099
}
7100+
*out_n_selected = n_selected;
70957101

70967102
cleanup:
70977103
if (vectors) {
@@ -7387,11 +7393,12 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
73877393

73887394
// MMR reranking: select diverse subset from over-fetched candidates
73897395
if (mmr_lambda >= 0.0f && mmr_lambda < 1.0f && k_used > k_original) {
7396+
i64 n_selected = 0;
73907397
rc = vec0_mmr_rerank(p, vectorColumnIdx, vector_column,
73917398
topk_rowids, topk_distances, k_used, k_original,
7392-
mmr_lambda);
7399+
mmr_lambda, &n_selected);
73937400
if (rc != SQLITE_OK) goto cleanup;
7394-
k_used = k_original;
7401+
k_used = n_selected;
73957402
k = k_original;
73967403
}
73977404

0 commit comments

Comments
 (0)