Skip to content

Commit 43bd226

Browse files
committed
feat: add mmr_lambda hidden column for MMR reranking in KNN queries
Add Maximal Marginal Relevance (MMR) support to vec0 virtual table. When mmr_lambda is provided in a KNN query, candidates are over-fetched and then greedily re-selected to balance relevance against diversity. API: WHERE embedding MATCH ? AND k = 10 AND mmr_lambda = 0.7 - mmr_lambda range [0.0, 1.0]: 1.0 = pure relevance, 0.0 = pure diversity - Over-fetch factor: 5x (capped at k_max=4096) - Supports float32, int8, and bit vector types - All distance metrics (L2, cosine, L1, hamming) - Zero impact when mmr_lambda is not provided
1 parent 563a3e6 commit 43bd226

1 file changed

Lines changed: 217 additions & 2 deletions

File tree

sqlite-vec.c

Lines changed: 217 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3376,6 +3376,7 @@ static sqlite3_module vec_npy_eachModule = {
33763376
#define VEC0_COLUMN_USERN_START 1
33773377
#define VEC0_COLUMN_OFFSET_DISTANCE 1
33783378
#define VEC0_COLUMN_OFFSET_K 2
3379+
#define VEC0_COLUMN_OFFSET_MMR_LAMBDA 3
33793380

33803381
#define VEC0_SHADOW_INFO_NAME "\"%w\".\"%w_info\""
33813382

@@ -3645,6 +3646,14 @@ int vec0_column_k_idx(vec0_vtab *p) {
36453646
VEC0_COLUMN_OFFSET_K;
36463647
}
36473648

3649+
/**
3650+
* Returns the column index for the hidden "mmr_lambda" column.
3651+
*/
3652+
int vec0_column_mmr_lambda_idx(vec0_vtab *p) {
3653+
return VEC0_COLUMN_USERN_START + (vec0_num_defined_user_columns(p) - 1) +
3654+
VEC0_COLUMN_OFFSET_MMR_LAMBDA;
3655+
}
3656+
36483657
/**
36493658
* Returns 1 if the given column-based index is a valid vector column,
36503659
* 0 otherwise.
@@ -4903,7 +4912,7 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv,
49034912
}
49044913

49054914
}
4906-
sqlite3_str_appendall(createStr, " distance hidden, k hidden) ");
4915+
sqlite3_str_appendall(createStr, " distance hidden, k hidden, mmr_lambda hidden) ");
49074916
if (pkColumnName) {
49084917
sqlite3_str_appendall(createStr, "without rowid ");
49094918
}
@@ -5321,6 +5330,7 @@ typedef enum {
53215330

53225331
// ~~~ ??? ~~~ //
53235332
VEC0_IDXSTR_KIND_METADATA_CONSTRAINT = '&',
5333+
VEC0_IDXSTR_KIND_KNN_MMR_LAMBDA = '#',
53245334
} vec0_idxstr_kind;
53255335

53265336
// The different SQLITE_INDEX_CONSTRAINT values that vec0 partition key columns
@@ -5384,6 +5394,7 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) {
53845394
int iLimitTerm = -1;
53855395
int iRowidTerm = -1;
53865396
int iKTerm = -1;
5397+
int iMmrLambdaTerm = -1;
53875398
int iRowidInTerm = -1;
53885399
int hasAuxConstraint = 0;
53895400

@@ -5440,6 +5451,9 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) {
54405451
if (op == SQLITE_INDEX_CONSTRAINT_EQ && iColumn == vec0_column_k_idx(p)) {
54415452
iKTerm = i;
54425453
}
5454+
if (op == SQLITE_INDEX_CONSTRAINT_EQ && iColumn == vec0_column_mmr_lambda_idx(p)) {
5455+
iMmrLambdaTerm = i;
5456+
}
54435457
if(
54445458
(op != SQLITE_INDEX_CONSTRAINT_LIMIT && op != SQLITE_INDEX_CONSTRAINT_OFFSET)
54455459
&& vec0_column_idx_is_auxiliary(p, iColumn)) {
@@ -5728,7 +5742,12 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) {
57285742
sqlite3_str_appendchar(idxStr, 1, '_');
57295743
}
57305744

5731-
5745+
if (iMmrLambdaTerm >= 0) {
5746+
pIdxInfo->aConstraintUsage[iMmrLambdaTerm].argvIndex = argvIndex++;
5747+
pIdxInfo->aConstraintUsage[iMmrLambdaTerm].omit = 1;
5748+
sqlite3_str_appendchar(idxStr, 1, VEC0_IDXSTR_KIND_KNN_MMR_LAMBDA);
5749+
sqlite3_str_appendchar(idxStr, 3, '_');
5750+
}
57325751

57335752
pIdxInfo->idxNum = iMatchVectorTerm;
57345753
pIdxInfo->estimatedCost = 30.0;
@@ -6936,6 +6955,159 @@ int vec0Filter_knn_chunks_iter(vec0_vtab *p, sqlite3_stmt *stmtChunks,
69366955
return rc;
69376956
}
69386957

6958+
/**
6959+
* Compute pairwise distance between two vectors stored in the vec0 table's
6960+
* native format. Handles float32, int8, and bit element types with the
6961+
* appropriate metric (L2, cosine, L1, hamming).
6962+
*/
6963+
static f32 vec0_compute_distance(struct VectorColumnDefinition *vector_column,
6964+
const void *a, const void *b) {
6965+
size_t dims = vector_column->dimensions;
6966+
switch (vector_column->element_type) {
6967+
case SQLITE_VEC_ELEMENT_TYPE_FLOAT32:
6968+
switch (vector_column->distance_metric) {
6969+
case VEC0_DISTANCE_METRIC_L2:
6970+
return distance_l2_sqr_float(a, b, &dims);
6971+
case VEC0_DISTANCE_METRIC_L1:
6972+
return (f32)distance_l1_f32(a, b, &dims);
6973+
case VEC0_DISTANCE_METRIC_COSINE:
6974+
return distance_cosine_float(a, b, &dims);
6975+
}
6976+
break;
6977+
case SQLITE_VEC_ELEMENT_TYPE_INT8:
6978+
switch (vector_column->distance_metric) {
6979+
case VEC0_DISTANCE_METRIC_L2:
6980+
return distance_l2_sqr_int8(a, b, &dims);
6981+
case VEC0_DISTANCE_METRIC_L1:
6982+
return (f32)distance_l1_int8(a, b, &dims);
6983+
case VEC0_DISTANCE_METRIC_COSINE:
6984+
return distance_cosine_int8(a, b, &dims);
6985+
}
6986+
break;
6987+
case SQLITE_VEC_ELEMENT_TYPE_BIT:
6988+
return distance_hamming(a, b, &dims);
6989+
}
6990+
return 0.0f;
6991+
}
6992+
6993+
/**
6994+
* MMR greedy reranking of KNN results.
6995+
*
6996+
* Loads vectors for top-k candidates, then iteratively selects the
6997+
* candidate with the best MMR score:
6998+
* MMR(d) = lambda * relevance(d) - (1-lambda) * max_sim(d, S)
6999+
*
7000+
* where relevance = 1 - normalized_distance, and max_sim is the maximum
7001+
* cosine similarity between d and any already-selected result.
7002+
*
7003+
* Reorders topk_rowids and topk_distances in place.
7004+
* After return, the first k_target entries are the MMR-selected results.
7005+
*/
7006+
static int vec0_mmr_rerank(
7007+
vec0_vtab *p,
7008+
int vectorColumnIdx,
7009+
struct VectorColumnDefinition *vector_column,
7010+
i64 *topk_rowids,
7011+
f32 *topk_distances,
7012+
i64 k_used,
7013+
i64 k_target,
7014+
f32 mmr_lambda
7015+
) {
7016+
int rc = SQLITE_OK;
7017+
7018+
// 1. Allocate vector storage for all candidates
7019+
void **vectors = sqlite3_malloc64(k_used * sizeof(void *));
7020+
if (!vectors) return SQLITE_NOMEM;
7021+
memset(vectors, 0, k_used * sizeof(void *));
7022+
7023+
f32 *relevance = NULL;
7024+
i64 *out_rowids = NULL;
7025+
f32 *out_distances = NULL;
7026+
void **out_vectors = NULL;
7027+
u8 *selected = NULL;
7028+
7029+
// 2. Load vectors from shadow tables
7030+
for (i64 i = 0; i < k_used; i++) {
7031+
rc = vec0_get_vector_data(p, topk_rowids[i], vectorColumnIdx,
7032+
&vectors[i], NULL);
7033+
if (rc != SQLITE_OK) goto cleanup;
7034+
}
7035+
7036+
// 3. Normalize distances to [0, 1] for relevance scoring
7037+
f32 max_dist = 0.0f;
7038+
for (i64 i = 0; i < k_used; i++) {
7039+
if (topk_distances[i] > max_dist) max_dist = topk_distances[i];
7040+
}
7041+
if (max_dist < 1e-9f) max_dist = 1.0f;
7042+
7043+
relevance = sqlite3_malloc64(k_used * sizeof(f32));
7044+
if (!relevance) { rc = SQLITE_NOMEM; goto cleanup; }
7045+
for (i64 i = 0; i < k_used; i++) {
7046+
relevance[i] = 1.0f - (topk_distances[i] / max_dist);
7047+
}
7048+
7049+
// 4. Greedy MMR selection
7050+
out_rowids = sqlite3_malloc64(k_target * sizeof(i64));
7051+
out_distances = sqlite3_malloc64(k_target * sizeof(f32));
7052+
out_vectors = sqlite3_malloc64(k_target * sizeof(void *));
7053+
selected = sqlite3_malloc64(k_used);
7054+
if (!out_rowids || !out_distances || !out_vectors || !selected) {
7055+
rc = SQLITE_NOMEM; goto cleanup;
7056+
}
7057+
memset(selected, 0, k_used);
7058+
7059+
for (i64 step = 0; step < k_target && step < k_used; step++) {
7060+
f32 best_mmr = -FLT_MAX;
7061+
i64 best_idx = -1;
7062+
7063+
for (i64 i = 0; i < k_used; i++) {
7064+
if (selected[i]) continue;
7065+
7066+
// max similarity to already-selected results
7067+
f32 max_sim = 0.0f;
7068+
for (i64 j = 0; j < step; j++) {
7069+
f32 d = vec0_compute_distance(vector_column,
7070+
vectors[i], out_vectors[j]);
7071+
f32 sim = 1.0f - d;
7072+
if (sim > max_sim) max_sim = sim;
7073+
}
7074+
7075+
f32 mmr_score = mmr_lambda * relevance[i]
7076+
- (1.0f - mmr_lambda) * max_sim;
7077+
if (mmr_score > best_mmr) {
7078+
best_mmr = mmr_score;
7079+
best_idx = i;
7080+
}
7081+
}
7082+
7083+
if (best_idx < 0) break;
7084+
selected[best_idx] = 1;
7085+
out_rowids[step] = topk_rowids[best_idx];
7086+
out_distances[step] = topk_distances[best_idx];
7087+
out_vectors[step] = vectors[best_idx];
7088+
}
7089+
7090+
// 5. Copy results back to input arrays
7091+
for (i64 i = 0; i < k_target; i++) {
7092+
topk_rowids[i] = out_rowids[i];
7093+
topk_distances[i] = out_distances[i];
7094+
}
7095+
7096+
cleanup:
7097+
if (vectors) {
7098+
for (i64 i = 0; i < k_used; i++) {
7099+
sqlite3_free(vectors[i]);
7100+
}
7101+
sqlite3_free(vectors);
7102+
}
7103+
sqlite3_free(relevance);
7104+
sqlite3_free(out_rowids);
7105+
sqlite3_free(out_distances);
7106+
sqlite3_free(out_vectors);
7107+
sqlite3_free(selected);
7108+
return rc;
7109+
}
7110+
69397111
int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
69407112
const char *idxStr, int argc, sqlite3_value **argv) {
69417113
assert(argc == (strlen(idxStr)-1) / 4);
@@ -6964,6 +7136,7 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
69647136
int query_idx =-1;
69657137
int k_idx = -1;
69667138
int rowid_in_idx = -1;
7139+
int mmr_lambda_idx = -1;
69677140
for(int i = 0; i < argc; i++) {
69687141
if(idxStr[1 + (i*4)] == VEC0_IDXSTR_KIND_KNN_MATCH) {
69697142
query_idx = i;
@@ -6974,6 +7147,9 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
69747147
if(idxStr[1 + (i*4)] == VEC0_IDXSTR_KIND_KNN_ROWID_IN) {
69757148
rowid_in_idx = i;
69767149
}
7150+
if(idxStr[1 + (i*4)] == VEC0_IDXSTR_KIND_KNN_MMR_LAMBDA) {
7151+
mmr_lambda_idx = i;
7152+
}
69777153
}
69787154
assert(query_idx >= 0);
69797155
assert(k_idx >= 0);
@@ -7036,6 +7212,29 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
70367212
goto cleanup;
70377213
}
70387214

7215+
// MMR: validate lambda and over-fetch candidates
7216+
#define SQLITE_VEC_MMR_OVERFETCH_FACTOR 5
7217+
f32 mmr_lambda = -1.0f;
7218+
i64 k_original = k;
7219+
if (mmr_lambda_idx >= 0) {
7220+
mmr_lambda = (f32)sqlite3_value_double(argv[mmr_lambda_idx]);
7221+
if (mmr_lambda < 0.0f || mmr_lambda > 1.0f) {
7222+
vtab_set_error(
7223+
&p->base,
7224+
"mmr_lambda value in knn query must be between 0.0 and 1.0, "
7225+
"provided %f",
7226+
(double)mmr_lambda);
7227+
rc = SQLITE_ERROR;
7228+
goto cleanup;
7229+
}
7230+
if (mmr_lambda < 1.0f) {
7231+
i64 k_internal = k * SQLITE_VEC_MMR_OVERFETCH_FACTOR;
7232+
if (k_internal > SQLITE_VEC_VEC0_K_MAX) k_internal = SQLITE_VEC_VEC0_K_MAX;
7233+
if (k_internal < k) k_internal = k; // overflow guard
7234+
k = k_internal;
7235+
}
7236+
}
7237+
70397238
// handle when a `rowid in (...)` operation was provided
70407239
// Array of all the rowids that appear in any `rowid in (...)` constraint.
70417240
// NULL if none were provided, which means a "full" scan.
@@ -7186,6 +7385,16 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
71867385
goto cleanup;
71877386
}
71887387

7388+
// MMR reranking: select diverse subset from over-fetched candidates
7389+
if (mmr_lambda >= 0.0f && mmr_lambda < 1.0f && k_used > k_original) {
7390+
rc = vec0_mmr_rerank(p, vectorColumnIdx, vector_column,
7391+
topk_rowids, topk_distances, k_used, k_original,
7392+
mmr_lambda);
7393+
if (rc != SQLITE_OK) goto cleanup;
7394+
k_used = k_original;
7395+
k = k_original;
7396+
}
7397+
71897398
knn_data->current_idx = 0;
71907399
knn_data->k = k;
71917400
knn_data->rowids = topk_rowids;
@@ -8364,6 +8573,12 @@ int vec0Update_Insert(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv,
83648573
rc = SQLITE_ERROR;
83658574
goto cleanup;
83668575
}
8576+
// Cannot insert a value in the hidden "mmr_lambda" column
8577+
if (sqlite3_value_type(argv[2 + vec0_column_mmr_lambda_idx(p)]) != SQLITE_NULL) {
8578+
vtab_set_error(pVTab, "A value was provided for the hidden \"mmr_lambda\" column.");
8579+
rc = SQLITE_ERROR;
8580+
goto cleanup;
8581+
}
83678582

83688583
// Step #1: Insert/get a rowid for this row, from the _rowids table.
83698584
rc = vec0Update_InsertRowidStep(p, argv[2 + VEC0_COLUMN_ID], &rowid);

0 commit comments

Comments
 (0)