@@ -3520,6 +3520,7 @@ static sqlite3_module vec_npy_eachModule = {
35203520#define VEC0_COLUMN_OFFSET_DISTANCE 1
35213521#define VEC0_COLUMN_OFFSET_K 2
35223522#define VEC0_COLUMN_OFFSET_TABLE_NAME 3
3523+ #define VEC0_COLUMN_OFFSET_MMR_LAMBDA 4
35233524
35243525#define VEC0_SHADOW_INFO_NAME "\"%w\".\"%w_info\""
35253526
@@ -3817,6 +3818,14 @@ int vec0_column_table_name_idx(vec0_vtab *p) {
38173818 VEC0_COLUMN_OFFSET_TABLE_NAME ;
38183819}
38193820
3821+ /**
3822+ * Returns the column index for the hidden "mmr_lambda" column.
3823+ */
3824+ int vec0_column_mmr_lambda_idx (vec0_vtab * p ) {
3825+ return VEC0_COLUMN_USERN_START + (vec0_num_defined_user_columns (p ) - 1 ) +
3826+ VEC0_COLUMN_OFFSET_MMR_LAMBDA ;
3827+ }
3828+
38203829/**
38213830 * Returns 1 if the given column-based index is a valid vector column,
38223831 * 0 otherwise.
@@ -5083,7 +5092,7 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv,
50835092
50845093 }
50855094 sqlite3_str_appendall (createStr , " distance hidden, k hidden, " );
5086- sqlite3_str_appendf (createStr , "%s hidden) " , tableName );
5095+ sqlite3_str_appendf (createStr , "%s hidden, mmr_lambda hidden ) " , tableName );
50875096 if (pkColumnName ) {
50885097 sqlite3_str_appendall (createStr , "without rowid " );
50895098 }
@@ -5495,6 +5504,7 @@ typedef enum {
54955504
54965505 // ~~~ ??? ~~~ //
54975506 VEC0_IDXSTR_KIND_METADATA_CONSTRAINT = '&' ,
5507+ VEC0_IDXSTR_KIND_KNN_MMR_LAMBDA = '#' ,
54985508} vec0_idxstr_kind ;
54995509
55005510// The different SQLITE_INDEX_CONSTRAINT values that vec0 partition key columns
@@ -5564,6 +5574,7 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) {
55645574 int iLimitTerm = -1 ;
55655575 int iRowidTerm = -1 ;
55665576 int iKTerm = -1 ;
5577+ int iMmrLambdaTerm = -1 ;
55675578 int iRowidInTerm = -1 ;
55685579 int hasAuxConstraint = 0 ;
55695580
@@ -5620,6 +5631,9 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) {
56205631 if (op == SQLITE_INDEX_CONSTRAINT_EQ && iColumn == vec0_column_k_idx (p )) {
56215632 iKTerm = i ;
56225633 }
5634+ if (op == SQLITE_INDEX_CONSTRAINT_EQ && iColumn == vec0_column_mmr_lambda_idx (p )) {
5635+ iMmrLambdaTerm = i ;
5636+ }
56235637 if (
56245638 (op != SQLITE_INDEX_CONSTRAINT_LIMIT && op != SQLITE_INDEX_CONSTRAINT_OFFSET )
56255639 && vec0_column_idx_is_auxiliary (p , iColumn )) {
@@ -5950,7 +5964,12 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) {
59505964 sqlite3_str_appendchar (idxStr , 1 , '_' );
59515965 }
59525966
5953-
5967+ if (iMmrLambdaTerm >= 0 ) {
5968+ pIdxInfo -> aConstraintUsage [iMmrLambdaTerm ].argvIndex = argvIndex ++ ;
5969+ pIdxInfo -> aConstraintUsage [iMmrLambdaTerm ].omit = 1 ;
5970+ sqlite3_str_appendchar (idxStr , 1 , VEC0_IDXSTR_KIND_KNN_MMR_LAMBDA );
5971+ sqlite3_str_appendchar (idxStr , 3 , '_' );
5972+ }
59545973
59555974 pIdxInfo -> idxNum = iMatchVectorTerm ;
59565975 pIdxInfo -> estimatedCost = 30.0 ;
@@ -7490,6 +7509,163 @@ int vec0Filter_knn_chunks_iter(vec0_vtab *p, sqlite3_stmt *stmtChunks,
74907509 return rc ;
74917510}
74927511
7512+ /**
7513+ * Compute pairwise distance between two vectors stored in the vec0 table's
7514+ * native format. Handles float32, int8, and bit element types with the
7515+ * appropriate metric (L2, cosine, L1, hamming).
7516+ */
7517+ static f32 vec0_compute_distance (struct VectorColumnDefinition * vector_column ,
7518+ const void * a , const void * b ) {
7519+ size_t dims = vector_column -> dimensions ;
7520+ switch (vector_column -> element_type ) {
7521+ case SQLITE_VEC_ELEMENT_TYPE_FLOAT32 :
7522+ switch (vector_column -> distance_metric ) {
7523+ case VEC0_DISTANCE_METRIC_L2 :
7524+ return distance_l2_sqr_float (a , b , & dims );
7525+ case VEC0_DISTANCE_METRIC_L1 :
7526+ return (f32 )distance_l1_f32 (a , b , & dims );
7527+ case VEC0_DISTANCE_METRIC_COSINE :
7528+ return distance_cosine_float (a , b , & dims );
7529+ }
7530+ break ;
7531+ case SQLITE_VEC_ELEMENT_TYPE_INT8 :
7532+ switch (vector_column -> distance_metric ) {
7533+ case VEC0_DISTANCE_METRIC_L2 :
7534+ return distance_l2_sqr_int8 (a , b , & dims );
7535+ case VEC0_DISTANCE_METRIC_L1 :
7536+ return (f32 )distance_l1_int8 (a , b , & dims );
7537+ case VEC0_DISTANCE_METRIC_COSINE :
7538+ return distance_cosine_int8 (a , b , & dims );
7539+ }
7540+ break ;
7541+ case SQLITE_VEC_ELEMENT_TYPE_BIT :
7542+ return distance_hamming (a , b , & dims );
7543+ }
7544+ return 0.0f ;
7545+ }
7546+
7547+ /**
7548+ * MMR greedy reranking of KNN results.
7549+ *
7550+ * Loads vectors for top-k candidates, then iteratively selects the
7551+ * candidate with the best MMR score:
7552+ * MMR(d) = lambda * relevance(d) - (1-lambda) * max_sim(d, S)
7553+ *
7554+ * where relevance = 1 - normalized_distance, and max_sim is the maximum
7555+ * cosine similarity between d and any already-selected result.
7556+ *
7557+ * Reorders topk_rowids and topk_distances in place.
7558+ * After return, the first *out_n_selected entries are the MMR-selected results.
7559+ */
7560+ static int vec0_mmr_rerank (
7561+ vec0_vtab * p ,
7562+ int vectorColumnIdx ,
7563+ struct VectorColumnDefinition * vector_column ,
7564+ i64 * topk_rowids ,
7565+ f32 * topk_distances ,
7566+ i64 k_used ,
7567+ i64 k_target ,
7568+ f32 mmr_lambda ,
7569+ i64 * out_n_selected
7570+ ) {
7571+ int rc = SQLITE_OK ;
7572+
7573+ // 1. Allocate vector storage for all candidates
7574+ void * * vectors = sqlite3_malloc64 (k_used * sizeof (void * ));
7575+ if (!vectors ) return SQLITE_NOMEM ;
7576+ memset (vectors , 0 , k_used * sizeof (void * ));
7577+
7578+ f32 * relevance = NULL ;
7579+ i64 * out_rowids = NULL ;
7580+ f32 * out_distances = NULL ;
7581+ void * * out_vectors = NULL ;
7582+ u8 * selected = NULL ;
7583+
7584+ // 2. Load vectors from shadow tables
7585+ for (i64 i = 0 ; i < k_used ; i ++ ) {
7586+ rc = vec0_get_vector_data (p , topk_rowids [i ], vectorColumnIdx ,
7587+ & vectors [i ], NULL );
7588+ if (rc != SQLITE_OK ) goto cleanup ;
7589+ }
7590+
7591+ // 3. Normalize distances to [0, 1] for relevance scoring
7592+ f32 max_dist = 0.0f ;
7593+ for (i64 i = 0 ; i < k_used ; i ++ ) {
7594+ if (topk_distances [i ] > max_dist ) max_dist = topk_distances [i ];
7595+ }
7596+ if (max_dist < 1e-9f ) max_dist = 1.0f ;
7597+
7598+ relevance = sqlite3_malloc64 (k_used * sizeof (f32 ));
7599+ if (!relevance ) { rc = SQLITE_NOMEM ; goto cleanup ; }
7600+ for (i64 i = 0 ; i < k_used ; i ++ ) {
7601+ relevance [i ] = 1.0f - (topk_distances [i ] / max_dist );
7602+ }
7603+
7604+ // 4. Greedy MMR selection
7605+ out_rowids = sqlite3_malloc64 (k_target * sizeof (i64 ));
7606+ out_distances = sqlite3_malloc64 (k_target * sizeof (f32 ));
7607+ out_vectors = sqlite3_malloc64 (k_target * sizeof (void * ));
7608+ selected = sqlite3_malloc64 (k_used );
7609+ if (!out_rowids || !out_distances || !out_vectors || !selected ) {
7610+ rc = SQLITE_NOMEM ; goto cleanup ;
7611+ }
7612+ memset (selected , 0 , k_used );
7613+
7614+ i64 n_selected = 0 ;
7615+ for (i64 step = 0 ; step < k_target && step < k_used ; step ++ ) {
7616+ f32 best_mmr = - FLT_MAX ;
7617+ i64 best_idx = -1 ;
7618+
7619+ for (i64 i = 0 ; i < k_used ; i ++ ) {
7620+ if (selected [i ]) continue ;
7621+
7622+ // max similarity to already-selected results
7623+ f32 max_sim = 0.0f ;
7624+ for (i64 j = 0 ; j < step ; j ++ ) {
7625+ f32 d = vec0_compute_distance (vector_column ,
7626+ vectors [i ], out_vectors [j ]);
7627+ f32 sim = 1.0f - d ;
7628+ if (sim > max_sim ) max_sim = sim ;
7629+ }
7630+
7631+ f32 mmr_score = mmr_lambda * relevance [i ]
7632+ - (1.0f - mmr_lambda ) * max_sim ;
7633+ if (mmr_score > best_mmr ) {
7634+ best_mmr = mmr_score ;
7635+ best_idx = i ;
7636+ }
7637+ }
7638+
7639+ if (best_idx < 0 ) break ;
7640+ selected [best_idx ] = 1 ;
7641+ out_rowids [step ] = topk_rowids [best_idx ];
7642+ out_distances [step ] = topk_distances [best_idx ];
7643+ out_vectors [step ] = vectors [best_idx ];
7644+ n_selected ++ ;
7645+ }
7646+
7647+ // 5. Copy results back to input arrays
7648+ for (i64 i = 0 ; i < n_selected ; i ++ ) {
7649+ topk_rowids [i ] = out_rowids [i ];
7650+ topk_distances [i ] = out_distances [i ];
7651+ }
7652+ * out_n_selected = n_selected ;
7653+
7654+ cleanup :
7655+ if (vectors ) {
7656+ for (i64 i = 0 ; i < k_used ; i ++ ) {
7657+ sqlite3_free (vectors [i ]);
7658+ }
7659+ sqlite3_free (vectors );
7660+ }
7661+ sqlite3_free (relevance );
7662+ sqlite3_free (out_rowids );
7663+ sqlite3_free (out_distances );
7664+ sqlite3_free (out_vectors );
7665+ sqlite3_free (selected );
7666+ return rc ;
7667+ }
7668+
74937669int vec0Filter_knn (vec0_cursor * pCur , vec0_vtab * p , int idxNum ,
74947670 const char * idxStr , int argc , sqlite3_value * * argv ) {
74957671 assert (argc == (int )((strlen (idxStr )- 1 ) / 4 ));
@@ -7518,6 +7694,7 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
75187694 int query_idx = -1 ;
75197695 int k_idx = -1 ;
75207696 int rowid_in_idx = -1 ;
7697+ int mmr_lambda_idx = -1 ;
75217698 for (int i = 0 ; i < argc ; i ++ ) {
75227699 if (idxStr [1 + (i * 4 )] == VEC0_IDXSTR_KIND_KNN_MATCH ) {
75237700 query_idx = i ;
@@ -7528,6 +7705,9 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
75287705 if (idxStr [1 + (i * 4 )] == VEC0_IDXSTR_KIND_KNN_ROWID_IN ) {
75297706 rowid_in_idx = i ;
75307707 }
7708+ if (idxStr [1 + (i * 4 )] == VEC0_IDXSTR_KIND_KNN_MMR_LAMBDA ) {
7709+ mmr_lambda_idx = i ;
7710+ }
75317711 }
75327712 assert (query_idx >= 0 );
75337713 assert (k_idx >= 0 );
@@ -7590,6 +7770,29 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
75907770 goto cleanup ;
75917771 }
75927772
7773+ // MMR: validate lambda and over-fetch candidates
7774+ #define SQLITE_VEC_MMR_OVERFETCH_FACTOR 5
7775+ f32 mmr_lambda = -1.0f ;
7776+ i64 k_original = k ;
7777+ if (mmr_lambda_idx >= 0 ) {
7778+ mmr_lambda = (f32 )sqlite3_value_double (argv [mmr_lambda_idx ]);
7779+ if (mmr_lambda < 0.0f || mmr_lambda > 1.0f ) {
7780+ vtab_set_error (
7781+ & p -> base ,
7782+ "mmr_lambda value in knn query must be between 0.0 and 1.0, "
7783+ "provided %f" ,
7784+ (double )mmr_lambda );
7785+ rc = SQLITE_ERROR ;
7786+ goto cleanup ;
7787+ }
7788+ if (mmr_lambda < 1.0f ) {
7789+ i64 k_internal = k * SQLITE_VEC_MMR_OVERFETCH_FACTOR ;
7790+ if (k_internal > SQLITE_VEC_VEC0_K_MAX ) k_internal = SQLITE_VEC_VEC0_K_MAX ;
7791+ if (k_internal < k ) k_internal = k ; // overflow guard
7792+ k = k_internal ;
7793+ }
7794+ }
7795+
75937796// handle when a `rowid in (...)` operation was provided
75947797// Array of all the rowids that appear in any `rowid in (...)` constraint.
75957798// NULL if none were provided, which means a "full" scan.
@@ -7757,6 +7960,17 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
77577960 goto cleanup ;
77587961 }
77597962
7963+ // MMR reranking: select diverse subset from over-fetched candidates
7964+ if (mmr_lambda >= 0.0f && mmr_lambda < 1.0f && k_used > k_original ) {
7965+ i64 n_selected = 0 ;
7966+ rc = vec0_mmr_rerank (p , vectorColumnIdx , vector_column ,
7967+ topk_rowids , topk_distances , k_used , k_original ,
7968+ mmr_lambda , & n_selected );
7969+ if (rc != SQLITE_OK ) goto cleanup ;
7970+ k_used = n_selected ;
7971+ k = k_original ;
7972+ }
7973+
77607974 knn_data -> current_idx = 0 ;
77617975 knn_data -> k = k ;
77627976 knn_data -> rowids = topk_rowids ;
@@ -8959,6 +9173,12 @@ int vec0Update_Insert(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv,
89599173 rc = SQLITE_ERROR ;
89609174 goto cleanup ;
89619175 }
9176+ // Cannot insert a value in the hidden "mmr_lambda" column
9177+ if (sqlite3_value_type (argv [2 + vec0_column_mmr_lambda_idx (p )]) != SQLITE_NULL ) {
9178+ vtab_set_error (pVTab , "A value was provided for the hidden \"mmr_lambda\" column." );
9179+ rc = SQLITE_ERROR ;
9180+ goto cleanup ;
9181+ }
89629182
89639183 // Cannot insert a value in the hidden "table_name" column
89649184 if (sqlite3_value_type (argv [2 + vec0_column_table_name_idx (p )]) != SQLITE_NULL ) {
0 commit comments