Skip to content

Commit bc60431

Browse files
committed
feat: add mmr_lambda hidden column and MMR reranking functionality for KNN queries -- see vlasky#6
1 parent c9cd23d commit bc60431

3 files changed

Lines changed: 714 additions & 2 deletions

File tree

sqlite-vec.c

Lines changed: 222 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3520,6 +3520,7 @@ static sqlite3_module vec_npy_eachModule = {
35203520
#define VEC0_COLUMN_OFFSET_DISTANCE 1
35213521
#define VEC0_COLUMN_OFFSET_K 2
35223522
#define VEC0_COLUMN_OFFSET_TABLE_NAME 3
3523+
#define VEC0_COLUMN_OFFSET_MMR_LAMBDA 4
35233524

35243525
#define VEC0_SHADOW_INFO_NAME "\"%w\".\"%w_info\""
35253526

@@ -3817,6 +3818,14 @@ int vec0_column_table_name_idx(vec0_vtab *p) {
38173818
VEC0_COLUMN_OFFSET_TABLE_NAME;
38183819
}
38193820

3821+
/**
3822+
* Returns the column index for the hidden "mmr_lambda" column.
3823+
*/
3824+
int vec0_column_mmr_lambda_idx(vec0_vtab *p) {
3825+
return VEC0_COLUMN_USERN_START + (vec0_num_defined_user_columns(p) - 1) +
3826+
VEC0_COLUMN_OFFSET_MMR_LAMBDA;
3827+
}
3828+
38203829
/**
38213830
* Returns 1 if the given column-based index is a valid vector column,
38223831
* 0 otherwise.
@@ -5083,7 +5092,7 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv,
50835092

50845093
}
50855094
sqlite3_str_appendall(createStr, " distance hidden, k hidden, ");
5086-
sqlite3_str_appendf(createStr, "%s hidden) ", tableName);
5095+
sqlite3_str_appendf(createStr, "%s hidden, mmr_lambda hidden) ", tableName);
50875096
if (pkColumnName) {
50885097
sqlite3_str_appendall(createStr, "without rowid ");
50895098
}
@@ -5495,6 +5504,7 @@ typedef enum {
54955504

54965505
// ~~~ ??? ~~~ //
54975506
VEC0_IDXSTR_KIND_METADATA_CONSTRAINT = '&',
5507+
VEC0_IDXSTR_KIND_KNN_MMR_LAMBDA = '#',
54985508
} vec0_idxstr_kind;
54995509

55005510
// The different SQLITE_INDEX_CONSTRAINT values that vec0 partition key columns
@@ -5564,6 +5574,7 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) {
55645574
int iLimitTerm = -1;
55655575
int iRowidTerm = -1;
55665576
int iKTerm = -1;
5577+
int iMmrLambdaTerm = -1;
55675578
int iRowidInTerm = -1;
55685579
int hasAuxConstraint = 0;
55695580

@@ -5620,6 +5631,9 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) {
56205631
if (op == SQLITE_INDEX_CONSTRAINT_EQ && iColumn == vec0_column_k_idx(p)) {
56215632
iKTerm = i;
56225633
}
5634+
if (op == SQLITE_INDEX_CONSTRAINT_EQ && iColumn == vec0_column_mmr_lambda_idx(p)) {
5635+
iMmrLambdaTerm = i;
5636+
}
56235637
if(
56245638
(op != SQLITE_INDEX_CONSTRAINT_LIMIT && op != SQLITE_INDEX_CONSTRAINT_OFFSET)
56255639
&& vec0_column_idx_is_auxiliary(p, iColumn)) {
@@ -5950,7 +5964,12 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) {
59505964
sqlite3_str_appendchar(idxStr, 1, '_');
59515965
}
59525966

5953-
5967+
if (iMmrLambdaTerm >= 0) {
5968+
pIdxInfo->aConstraintUsage[iMmrLambdaTerm].argvIndex = argvIndex++;
5969+
pIdxInfo->aConstraintUsage[iMmrLambdaTerm].omit = 1;
5970+
sqlite3_str_appendchar(idxStr, 1, VEC0_IDXSTR_KIND_KNN_MMR_LAMBDA);
5971+
sqlite3_str_appendchar(idxStr, 3, '_');
5972+
}
59545973

59555974
pIdxInfo->idxNum = iMatchVectorTerm;
59565975
pIdxInfo->estimatedCost = 30.0;
@@ -7490,6 +7509,163 @@ int vec0Filter_knn_chunks_iter(vec0_vtab *p, sqlite3_stmt *stmtChunks,
74907509
return rc;
74917510
}
74927511

7512+
/**
7513+
* Compute pairwise distance between two vectors stored in the vec0 table's
7514+
* native format. Handles float32, int8, and bit element types with the
7515+
* appropriate metric (L2, cosine, L1, hamming).
7516+
*/
7517+
static f32 vec0_compute_distance(struct VectorColumnDefinition *vector_column,
7518+
const void *a, const void *b) {
7519+
size_t dims = vector_column->dimensions;
7520+
switch (vector_column->element_type) {
7521+
case SQLITE_VEC_ELEMENT_TYPE_FLOAT32:
7522+
switch (vector_column->distance_metric) {
7523+
case VEC0_DISTANCE_METRIC_L2:
7524+
return distance_l2_sqr_float(a, b, &dims);
7525+
case VEC0_DISTANCE_METRIC_L1:
7526+
return (f32)distance_l1_f32(a, b, &dims);
7527+
case VEC0_DISTANCE_METRIC_COSINE:
7528+
return distance_cosine_float(a, b, &dims);
7529+
}
7530+
break;
7531+
case SQLITE_VEC_ELEMENT_TYPE_INT8:
7532+
switch (vector_column->distance_metric) {
7533+
case VEC0_DISTANCE_METRIC_L2:
7534+
return distance_l2_sqr_int8(a, b, &dims);
7535+
case VEC0_DISTANCE_METRIC_L1:
7536+
return (f32)distance_l1_int8(a, b, &dims);
7537+
case VEC0_DISTANCE_METRIC_COSINE:
7538+
return distance_cosine_int8(a, b, &dims);
7539+
}
7540+
break;
7541+
case SQLITE_VEC_ELEMENT_TYPE_BIT:
7542+
return distance_hamming(a, b, &dims);
7543+
}
7544+
return 0.0f;
7545+
}
7546+
7547+
/**
7548+
* MMR greedy reranking of KNN results.
7549+
*
7550+
* Loads vectors for top-k candidates, then iteratively selects the
7551+
* candidate with the best MMR score:
7552+
* MMR(d) = lambda * relevance(d) - (1-lambda) * max_sim(d, S)
7553+
*
7554+
* where relevance = 1 - normalized_distance, and max_sim is the maximum
7555+
* cosine similarity between d and any already-selected result.
7556+
*
7557+
* Reorders topk_rowids and topk_distances in place.
7558+
* After return, the first *out_n_selected entries are the MMR-selected results.
7559+
*/
7560+
static int vec0_mmr_rerank(
7561+
vec0_vtab *p,
7562+
int vectorColumnIdx,
7563+
struct VectorColumnDefinition *vector_column,
7564+
i64 *topk_rowids,
7565+
f32 *topk_distances,
7566+
i64 k_used,
7567+
i64 k_target,
7568+
f32 mmr_lambda,
7569+
i64 *out_n_selected
7570+
) {
7571+
int rc = SQLITE_OK;
7572+
7573+
// 1. Allocate vector storage for all candidates
7574+
void **vectors = sqlite3_malloc64(k_used * sizeof(void *));
7575+
if (!vectors) return SQLITE_NOMEM;
7576+
memset(vectors, 0, k_used * sizeof(void *));
7577+
7578+
f32 *relevance = NULL;
7579+
i64 *out_rowids = NULL;
7580+
f32 *out_distances = NULL;
7581+
void **out_vectors = NULL;
7582+
u8 *selected = NULL;
7583+
7584+
// 2. Load vectors from shadow tables
7585+
for (i64 i = 0; i < k_used; i++) {
7586+
rc = vec0_get_vector_data(p, topk_rowids[i], vectorColumnIdx,
7587+
&vectors[i], NULL);
7588+
if (rc != SQLITE_OK) goto cleanup;
7589+
}
7590+
7591+
// 3. Normalize distances to [0, 1] for relevance scoring
7592+
f32 max_dist = 0.0f;
7593+
for (i64 i = 0; i < k_used; i++) {
7594+
if (topk_distances[i] > max_dist) max_dist = topk_distances[i];
7595+
}
7596+
if (max_dist < 1e-9f) max_dist = 1.0f;
7597+
7598+
relevance = sqlite3_malloc64(k_used * sizeof(f32));
7599+
if (!relevance) { rc = SQLITE_NOMEM; goto cleanup; }
7600+
for (i64 i = 0; i < k_used; i++) {
7601+
relevance[i] = 1.0f - (topk_distances[i] / max_dist);
7602+
}
7603+
7604+
// 4. Greedy MMR selection
7605+
out_rowids = sqlite3_malloc64(k_target * sizeof(i64));
7606+
out_distances = sqlite3_malloc64(k_target * sizeof(f32));
7607+
out_vectors = sqlite3_malloc64(k_target * sizeof(void *));
7608+
selected = sqlite3_malloc64(k_used);
7609+
if (!out_rowids || !out_distances || !out_vectors || !selected) {
7610+
rc = SQLITE_NOMEM; goto cleanup;
7611+
}
7612+
memset(selected, 0, k_used);
7613+
7614+
i64 n_selected = 0;
7615+
for (i64 step = 0; step < k_target && step < k_used; step++) {
7616+
f32 best_mmr = -FLT_MAX;
7617+
i64 best_idx = -1;
7618+
7619+
for (i64 i = 0; i < k_used; i++) {
7620+
if (selected[i]) continue;
7621+
7622+
// max similarity to already-selected results
7623+
f32 max_sim = 0.0f;
7624+
for (i64 j = 0; j < step; j++) {
7625+
f32 d = vec0_compute_distance(vector_column,
7626+
vectors[i], out_vectors[j]);
7627+
f32 sim = 1.0f - d;
7628+
if (sim > max_sim) max_sim = sim;
7629+
}
7630+
7631+
f32 mmr_score = mmr_lambda * relevance[i]
7632+
- (1.0f - mmr_lambda) * max_sim;
7633+
if (mmr_score > best_mmr) {
7634+
best_mmr = mmr_score;
7635+
best_idx = i;
7636+
}
7637+
}
7638+
7639+
if (best_idx < 0) break;
7640+
selected[best_idx] = 1;
7641+
out_rowids[step] = topk_rowids[best_idx];
7642+
out_distances[step] = topk_distances[best_idx];
7643+
out_vectors[step] = vectors[best_idx];
7644+
n_selected++;
7645+
}
7646+
7647+
// 5. Copy results back to input arrays
7648+
for (i64 i = 0; i < n_selected; i++) {
7649+
topk_rowids[i] = out_rowids[i];
7650+
topk_distances[i] = out_distances[i];
7651+
}
7652+
*out_n_selected = n_selected;
7653+
7654+
cleanup:
7655+
if (vectors) {
7656+
for (i64 i = 0; i < k_used; i++) {
7657+
sqlite3_free(vectors[i]);
7658+
}
7659+
sqlite3_free(vectors);
7660+
}
7661+
sqlite3_free(relevance);
7662+
sqlite3_free(out_rowids);
7663+
sqlite3_free(out_distances);
7664+
sqlite3_free(out_vectors);
7665+
sqlite3_free(selected);
7666+
return rc;
7667+
}
7668+
74937669
int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
74947670
const char *idxStr, int argc, sqlite3_value **argv) {
74957671
assert(argc == (int)((strlen(idxStr)-1) / 4));
@@ -7518,6 +7694,7 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
75187694
int query_idx =-1;
75197695
int k_idx = -1;
75207696
int rowid_in_idx = -1;
7697+
int mmr_lambda_idx = -1;
75217698
for(int i = 0; i < argc; i++) {
75227699
if(idxStr[1 + (i*4)] == VEC0_IDXSTR_KIND_KNN_MATCH) {
75237700
query_idx = i;
@@ -7528,6 +7705,9 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
75287705
if(idxStr[1 + (i*4)] == VEC0_IDXSTR_KIND_KNN_ROWID_IN) {
75297706
rowid_in_idx = i;
75307707
}
7708+
if(idxStr[1 + (i*4)] == VEC0_IDXSTR_KIND_KNN_MMR_LAMBDA) {
7709+
mmr_lambda_idx = i;
7710+
}
75317711
}
75327712
assert(query_idx >= 0);
75337713
assert(k_idx >= 0);
@@ -7590,6 +7770,29 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
75907770
goto cleanup;
75917771
}
75927772

7773+
// MMR: validate lambda and over-fetch candidates
7774+
#define SQLITE_VEC_MMR_OVERFETCH_FACTOR 5
7775+
f32 mmr_lambda = -1.0f;
7776+
i64 k_original = k;
7777+
if (mmr_lambda_idx >= 0) {
7778+
mmr_lambda = (f32)sqlite3_value_double(argv[mmr_lambda_idx]);
7779+
if (mmr_lambda < 0.0f || mmr_lambda > 1.0f) {
7780+
vtab_set_error(
7781+
&p->base,
7782+
"mmr_lambda value in knn query must be between 0.0 and 1.0, "
7783+
"provided %f",
7784+
(double)mmr_lambda);
7785+
rc = SQLITE_ERROR;
7786+
goto cleanup;
7787+
}
7788+
if (mmr_lambda < 1.0f) {
7789+
i64 k_internal = k * SQLITE_VEC_MMR_OVERFETCH_FACTOR;
7790+
if (k_internal > SQLITE_VEC_VEC0_K_MAX) k_internal = SQLITE_VEC_VEC0_K_MAX;
7791+
if (k_internal < k) k_internal = k; // overflow guard
7792+
k = k_internal;
7793+
}
7794+
}
7795+
75937796
// handle when a `rowid in (...)` operation was provided
75947797
// Array of all the rowids that appear in any `rowid in (...)` constraint.
75957798
// NULL if none were provided, which means a "full" scan.
@@ -7757,6 +7960,17 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
77577960
goto cleanup;
77587961
}
77597962

7963+
// MMR reranking: select diverse subset from over-fetched candidates
7964+
if (mmr_lambda >= 0.0f && mmr_lambda < 1.0f && k_used > k_original) {
7965+
i64 n_selected = 0;
7966+
rc = vec0_mmr_rerank(p, vectorColumnIdx, vector_column,
7967+
topk_rowids, topk_distances, k_used, k_original,
7968+
mmr_lambda, &n_selected);
7969+
if (rc != SQLITE_OK) goto cleanup;
7970+
k_used = n_selected;
7971+
k = k_original;
7972+
}
7973+
77607974
knn_data->current_idx = 0;
77617975
knn_data->k = k;
77627976
knn_data->rowids = topk_rowids;
@@ -8959,6 +9173,12 @@ int vec0Update_Insert(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv,
89599173
rc = SQLITE_ERROR;
89609174
goto cleanup;
89619175
}
9176+
// Cannot insert a value in the hidden "mmr_lambda" column
9177+
if (sqlite3_value_type(argv[2 + vec0_column_mmr_lambda_idx(p)]) != SQLITE_NULL) {
9178+
vtab_set_error(pVTab, "A value was provided for the hidden \"mmr_lambda\" column.");
9179+
rc = SQLITE_ERROR;
9180+
goto cleanup;
9181+
}
89629182

89639183
// Cannot insert a value in the hidden "table_name" column
89649184
if (sqlite3_value_type(argv[2 + vec0_column_table_name_idx(p)]) != SQLITE_NULL) {

0 commit comments

Comments
 (0)