@@ -3376,6 +3376,7 @@ static sqlite3_module vec_npy_eachModule = {
33763376#define VEC0_COLUMN_USERN_START 1
33773377#define VEC0_COLUMN_OFFSET_DISTANCE 1
33783378#define VEC0_COLUMN_OFFSET_K 2
3379+ #define VEC0_COLUMN_OFFSET_MMR_LAMBDA 3
33793380
33803381#define VEC0_SHADOW_INFO_NAME "\"%w\".\"%w_info\""
33813382
@@ -3645,6 +3646,14 @@ int vec0_column_k_idx(vec0_vtab *p) {
36453646 VEC0_COLUMN_OFFSET_K ;
36463647}
36473648
3649+ /**
3650+ * Returns the column index for the hidden "mmr_lambda" column.
3651+ */
3652+ int vec0_column_mmr_lambda_idx (vec0_vtab * p ) {
3653+ return VEC0_COLUMN_USERN_START + (vec0_num_defined_user_columns (p ) - 1 ) +
3654+ VEC0_COLUMN_OFFSET_MMR_LAMBDA ;
3655+ }
3656+
36483657/**
36493658 * Returns 1 if the given column-based index is a valid vector column,
36503659 * 0 otherwise.
@@ -4903,7 +4912,7 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv,
49034912 }
49044913
49054914 }
4906- sqlite3_str_appendall (createStr , " distance hidden, k hidden) " );
4915+ sqlite3_str_appendall (createStr , " distance hidden, k hidden, mmr_lambda hidden ) " );
49074916 if (pkColumnName ) {
49084917 sqlite3_str_appendall (createStr , "without rowid " );
49094918 }
@@ -5321,6 +5330,7 @@ typedef enum {
53215330
53225331 // ~~~ ??? ~~~ //
53235332 VEC0_IDXSTR_KIND_METADATA_CONSTRAINT = '&' ,
5333+ VEC0_IDXSTR_KIND_KNN_MMR_LAMBDA = '#' ,
53245334} vec0_idxstr_kind ;
53255335
53265336// The different SQLITE_INDEX_CONSTRAINT values that vec0 partition key columns
@@ -5384,6 +5394,7 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) {
53845394 int iLimitTerm = -1 ;
53855395 int iRowidTerm = -1 ;
53865396 int iKTerm = -1 ;
5397+ int iMmrLambdaTerm = -1 ;
53875398 int iRowidInTerm = -1 ;
53885399 int hasAuxConstraint = 0 ;
53895400
@@ -5440,6 +5451,9 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) {
54405451 if (op == SQLITE_INDEX_CONSTRAINT_EQ && iColumn == vec0_column_k_idx (p )) {
54415452 iKTerm = i ;
54425453 }
5454+ if (op == SQLITE_INDEX_CONSTRAINT_EQ && iColumn == vec0_column_mmr_lambda_idx (p )) {
5455+ iMmrLambdaTerm = i ;
5456+ }
54435457 if (
54445458 (op != SQLITE_INDEX_CONSTRAINT_LIMIT && op != SQLITE_INDEX_CONSTRAINT_OFFSET )
54455459 && vec0_column_idx_is_auxiliary (p , iColumn )) {
@@ -5728,7 +5742,12 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) {
57285742 sqlite3_str_appendchar (idxStr , 1 , '_' );
57295743 }
57305744
5731-
5745+ if (iMmrLambdaTerm >= 0 ) {
5746+ pIdxInfo -> aConstraintUsage [iMmrLambdaTerm ].argvIndex = argvIndex ++ ;
5747+ pIdxInfo -> aConstraintUsage [iMmrLambdaTerm ].omit = 1 ;
5748+ sqlite3_str_appendchar (idxStr , 1 , VEC0_IDXSTR_KIND_KNN_MMR_LAMBDA );
5749+ sqlite3_str_appendchar (idxStr , 3 , '_' );
5750+ }
57325751
57335752 pIdxInfo -> idxNum = iMatchVectorTerm ;
57345753 pIdxInfo -> estimatedCost = 30.0 ;
@@ -6936,6 +6955,159 @@ int vec0Filter_knn_chunks_iter(vec0_vtab *p, sqlite3_stmt *stmtChunks,
69366955 return rc ;
69376956}
69386957
6958+ /**
6959+ * Compute pairwise distance between two vectors stored in the vec0 table's
6960+ * native format. Handles float32, int8, and bit element types with the
6961+ * appropriate metric (L2, cosine, L1, hamming).
6962+ */
6963+ static f32 vec0_compute_distance (struct VectorColumnDefinition * vector_column ,
6964+ const void * a , const void * b ) {
6965+ size_t dims = vector_column -> dimensions ;
6966+ switch (vector_column -> element_type ) {
6967+ case SQLITE_VEC_ELEMENT_TYPE_FLOAT32 :
6968+ switch (vector_column -> distance_metric ) {
6969+ case VEC0_DISTANCE_METRIC_L2 :
6970+ return distance_l2_sqr_float (a , b , & dims );
6971+ case VEC0_DISTANCE_METRIC_L1 :
6972+ return (f32 )distance_l1_f32 (a , b , & dims );
6973+ case VEC0_DISTANCE_METRIC_COSINE :
6974+ return distance_cosine_float (a , b , & dims );
6975+ }
6976+ break ;
6977+ case SQLITE_VEC_ELEMENT_TYPE_INT8 :
6978+ switch (vector_column -> distance_metric ) {
6979+ case VEC0_DISTANCE_METRIC_L2 :
6980+ return distance_l2_sqr_int8 (a , b , & dims );
6981+ case VEC0_DISTANCE_METRIC_L1 :
6982+ return (f32 )distance_l1_int8 (a , b , & dims );
6983+ case VEC0_DISTANCE_METRIC_COSINE :
6984+ return distance_cosine_int8 (a , b , & dims );
6985+ }
6986+ break ;
6987+ case SQLITE_VEC_ELEMENT_TYPE_BIT :
6988+ return distance_hamming (a , b , & dims );
6989+ }
6990+ return 0.0f ;
6991+ }
6992+
6993+ /**
6994+ * MMR greedy reranking of KNN results.
6995+ *
6996+ * Loads vectors for top-k candidates, then iteratively selects the
6997+ * candidate with the best MMR score:
6998+ * MMR(d) = lambda * relevance(d) - (1-lambda) * max_sim(d, S)
6999+ *
7000+ * where relevance = 1 - normalized_distance, and max_sim is the maximum
7001+ * cosine similarity between d and any already-selected result.
7002+ *
7003+ * Reorders topk_rowids and topk_distances in place.
7004+ * After return, the first k_target entries are the MMR-selected results.
7005+ */
7006+ static int vec0_mmr_rerank (
7007+ vec0_vtab * p ,
7008+ int vectorColumnIdx ,
7009+ struct VectorColumnDefinition * vector_column ,
7010+ i64 * topk_rowids ,
7011+ f32 * topk_distances ,
7012+ i64 k_used ,
7013+ i64 k_target ,
7014+ f32 mmr_lambda
7015+ ) {
7016+ int rc = SQLITE_OK ;
7017+
7018+ // 1. Allocate vector storage for all candidates
7019+ void * * vectors = sqlite3_malloc64 (k_used * sizeof (void * ));
7020+ if (!vectors ) return SQLITE_NOMEM ;
7021+ memset (vectors , 0 , k_used * sizeof (void * ));
7022+
7023+ f32 * relevance = NULL ;
7024+ i64 * out_rowids = NULL ;
7025+ f32 * out_distances = NULL ;
7026+ void * * out_vectors = NULL ;
7027+ u8 * selected = NULL ;
7028+
7029+ // 2. Load vectors from shadow tables
7030+ for (i64 i = 0 ; i < k_used ; i ++ ) {
7031+ rc = vec0_get_vector_data (p , topk_rowids [i ], vectorColumnIdx ,
7032+ & vectors [i ], NULL );
7033+ if (rc != SQLITE_OK ) goto cleanup ;
7034+ }
7035+
7036+ // 3. Normalize distances to [0, 1] for relevance scoring
7037+ f32 max_dist = 0.0f ;
7038+ for (i64 i = 0 ; i < k_used ; i ++ ) {
7039+ if (topk_distances [i ] > max_dist ) max_dist = topk_distances [i ];
7040+ }
7041+ if (max_dist < 1e-9f ) max_dist = 1.0f ;
7042+
7043+ relevance = sqlite3_malloc64 (k_used * sizeof (f32 ));
7044+ if (!relevance ) { rc = SQLITE_NOMEM ; goto cleanup ; }
7045+ for (i64 i = 0 ; i < k_used ; i ++ ) {
7046+ relevance [i ] = 1.0f - (topk_distances [i ] / max_dist );
7047+ }
7048+
7049+ // 4. Greedy MMR selection
7050+ out_rowids = sqlite3_malloc64 (k_target * sizeof (i64 ));
7051+ out_distances = sqlite3_malloc64 (k_target * sizeof (f32 ));
7052+ out_vectors = sqlite3_malloc64 (k_target * sizeof (void * ));
7053+ selected = sqlite3_malloc64 (k_used );
7054+ if (!out_rowids || !out_distances || !out_vectors || !selected ) {
7055+ rc = SQLITE_NOMEM ; goto cleanup ;
7056+ }
7057+ memset (selected , 0 , k_used );
7058+
7059+ for (i64 step = 0 ; step < k_target && step < k_used ; step ++ ) {
7060+ f32 best_mmr = - FLT_MAX ;
7061+ i64 best_idx = -1 ;
7062+
7063+ for (i64 i = 0 ; i < k_used ; i ++ ) {
7064+ if (selected [i ]) continue ;
7065+
7066+ // max similarity to already-selected results
7067+ f32 max_sim = 0.0f ;
7068+ for (i64 j = 0 ; j < step ; j ++ ) {
7069+ f32 d = vec0_compute_distance (vector_column ,
7070+ vectors [i ], out_vectors [j ]);
7071+ f32 sim = 1.0f - d ;
7072+ if (sim > max_sim ) max_sim = sim ;
7073+ }
7074+
7075+ f32 mmr_score = mmr_lambda * relevance [i ]
7076+ - (1.0f - mmr_lambda ) * max_sim ;
7077+ if (mmr_score > best_mmr ) {
7078+ best_mmr = mmr_score ;
7079+ best_idx = i ;
7080+ }
7081+ }
7082+
7083+ if (best_idx < 0 ) break ;
7084+ selected [best_idx ] = 1 ;
7085+ out_rowids [step ] = topk_rowids [best_idx ];
7086+ out_distances [step ] = topk_distances [best_idx ];
7087+ out_vectors [step ] = vectors [best_idx ];
7088+ }
7089+
7090+ // 5. Copy results back to input arrays
7091+ for (i64 i = 0 ; i < k_target ; i ++ ) {
7092+ topk_rowids [i ] = out_rowids [i ];
7093+ topk_distances [i ] = out_distances [i ];
7094+ }
7095+
7096+ cleanup :
7097+ if (vectors ) {
7098+ for (i64 i = 0 ; i < k_used ; i ++ ) {
7099+ sqlite3_free (vectors [i ]);
7100+ }
7101+ sqlite3_free (vectors );
7102+ }
7103+ sqlite3_free (relevance );
7104+ sqlite3_free (out_rowids );
7105+ sqlite3_free (out_distances );
7106+ sqlite3_free (out_vectors );
7107+ sqlite3_free (selected );
7108+ return rc ;
7109+ }
7110+
69397111int vec0Filter_knn (vec0_cursor * pCur , vec0_vtab * p , int idxNum ,
69407112 const char * idxStr , int argc , sqlite3_value * * argv ) {
69417113 assert (argc == (strlen (idxStr )- 1 ) / 4 );
@@ -6964,6 +7136,7 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
69647136 int query_idx = -1 ;
69657137 int k_idx = -1 ;
69667138 int rowid_in_idx = -1 ;
7139+ int mmr_lambda_idx = -1 ;
69677140 for (int i = 0 ; i < argc ; i ++ ) {
69687141 if (idxStr [1 + (i * 4 )] == VEC0_IDXSTR_KIND_KNN_MATCH ) {
69697142 query_idx = i ;
@@ -6974,6 +7147,9 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
69747147 if (idxStr [1 + (i * 4 )] == VEC0_IDXSTR_KIND_KNN_ROWID_IN ) {
69757148 rowid_in_idx = i ;
69767149 }
7150+ if (idxStr [1 + (i * 4 )] == VEC0_IDXSTR_KIND_KNN_MMR_LAMBDA ) {
7151+ mmr_lambda_idx = i ;
7152+ }
69777153 }
69787154 assert (query_idx >= 0 );
69797155 assert (k_idx >= 0 );
@@ -7036,6 +7212,29 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
70367212 goto cleanup ;
70377213 }
70387214
7215+ // MMR: validate lambda and over-fetch candidates
7216+ #define SQLITE_VEC_MMR_OVERFETCH_FACTOR 5
7217+ f32 mmr_lambda = -1.0f ;
7218+ i64 k_original = k ;
7219+ if (mmr_lambda_idx >= 0 ) {
7220+ mmr_lambda = (f32 )sqlite3_value_double (argv [mmr_lambda_idx ]);
7221+ if (mmr_lambda < 0.0f || mmr_lambda > 1.0f ) {
7222+ vtab_set_error (
7223+ & p -> base ,
7224+ "mmr_lambda value in knn query must be between 0.0 and 1.0, "
7225+ "provided %f" ,
7226+ (double )mmr_lambda );
7227+ rc = SQLITE_ERROR ;
7228+ goto cleanup ;
7229+ }
7230+ if (mmr_lambda < 1.0f ) {
7231+ i64 k_internal = k * SQLITE_VEC_MMR_OVERFETCH_FACTOR ;
7232+ if (k_internal > SQLITE_VEC_VEC0_K_MAX ) k_internal = SQLITE_VEC_VEC0_K_MAX ;
7233+ if (k_internal < k ) k_internal = k ; // overflow guard
7234+ k = k_internal ;
7235+ }
7236+ }
7237+
70397238// handle when a `rowid in (...)` operation was provided
70407239// Array of all the rowids that appear in any `rowid in (...)` constraint.
70417240// NULL if none were provided, which means a "full" scan.
@@ -7186,6 +7385,16 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
71867385 goto cleanup ;
71877386 }
71887387
7388+ // MMR reranking: select diverse subset from over-fetched candidates
7389+ if (mmr_lambda >= 0.0f && mmr_lambda < 1.0f && k_used > k_original ) {
7390+ rc = vec0_mmr_rerank (p , vectorColumnIdx , vector_column ,
7391+ topk_rowids , topk_distances , k_used , k_original ,
7392+ mmr_lambda );
7393+ if (rc != SQLITE_OK ) goto cleanup ;
7394+ k_used = k_original ;
7395+ k = k_original ;
7396+ }
7397+
71897398 knn_data -> current_idx = 0 ;
71907399 knn_data -> k = k ;
71917400 knn_data -> rowids = topk_rowids ;
@@ -8364,6 +8573,12 @@ int vec0Update_Insert(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv,
83648573 rc = SQLITE_ERROR ;
83658574 goto cleanup ;
83668575 }
8576+ // Cannot insert a value in the hidden "mmr_lambda" column
8577+ if (sqlite3_value_type (argv [2 + vec0_column_mmr_lambda_idx (p )]) != SQLITE_NULL ) {
8578+ vtab_set_error (pVTab , "A value was provided for the hidden \"mmr_lambda\" column." );
8579+ rc = SQLITE_ERROR ;
8580+ goto cleanup ;
8581+ }
83678582
83688583 // Step #1: Insert/get a rowid for this row, from the _rowids table.
83698584 rc = vec0Update_InsertRowidStep (p , argv [2 + VEC0_COLUMN_ID ], & rowid );
0 commit comments