@@ -3497,6 +3497,7 @@ static sqlite3_module vec_npy_eachModule = {
34973497#define VEC0_COLUMN_OFFSET_DISTANCE 1
34983498#define VEC0_COLUMN_OFFSET_K 2
34993499#define VEC0_COLUMN_OFFSET_TABLE_NAME 3
3500+ #define VEC0_COLUMN_OFFSET_MMR_LAMBDA 4
35003501
35013502#define VEC0_SHADOW_INFO_NAME "\"%w\".\"%w_info\""
35023503
@@ -3777,6 +3778,14 @@ int vec0_column_table_name_idx(vec0_vtab *p) {
37773778 VEC0_COLUMN_OFFSET_TABLE_NAME ;
37783779}
37793780
3781+ /**
3782+ * Returns the column index for the hidden "mmr_lambda" column.
3783+ */
3784+ int vec0_column_mmr_lambda_idx (vec0_vtab * p ) {
3785+ return VEC0_COLUMN_USERN_START + (vec0_num_defined_user_columns (p ) - 1 ) +
3786+ VEC0_COLUMN_OFFSET_MMR_LAMBDA ;
3787+ }
3788+
37803789/**
37813790 * Returns 1 if the given column-based index is a valid vector column,
37823791 * 0 otherwise.
@@ -5039,7 +5048,7 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv,
50395048
50405049 }
50415050 sqlite3_str_appendall (createStr , " distance hidden, k hidden, " );
5042- sqlite3_str_appendf (createStr , "%s hidden) " , tableName );
5051+ sqlite3_str_appendf (createStr , "%s hidden, mmr_lambda hidden ) " , tableName );
50435052 if (pkColumnName ) {
50445053 sqlite3_str_appendall (createStr , "without rowid " );
50455054 }
@@ -5454,6 +5463,7 @@ typedef enum {
54545463
54555464 // ~~~ ??? ~~~ //
54565465 VEC0_IDXSTR_KIND_METADATA_CONSTRAINT = '&' ,
5466+ VEC0_IDXSTR_KIND_KNN_MMR_LAMBDA = '#' ,
54575467} vec0_idxstr_kind ;
54585468
54595469// The different SQLITE_INDEX_CONSTRAINT values that vec0 partition key columns
@@ -5523,6 +5533,7 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) {
55235533 int iLimitTerm = -1 ;
55245534 int iRowidTerm = -1 ;
55255535 int iKTerm = -1 ;
5536+ int iMmrLambdaTerm = -1 ;
55265537 int iRowidInTerm = -1 ;
55275538 int hasAuxConstraint = 0 ;
55285539
@@ -5579,6 +5590,9 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) {
55795590 if (op == SQLITE_INDEX_CONSTRAINT_EQ && iColumn == vec0_column_k_idx (p )) {
55805591 iKTerm = i ;
55815592 }
5593+ if (op == SQLITE_INDEX_CONSTRAINT_EQ && iColumn == vec0_column_mmr_lambda_idx (p )) {
5594+ iMmrLambdaTerm = i ;
5595+ }
55825596 if (
55835597 (op != SQLITE_INDEX_CONSTRAINT_LIMIT && op != SQLITE_INDEX_CONSTRAINT_OFFSET )
55845598 && vec0_column_idx_is_auxiliary (p , iColumn )) {
@@ -5909,7 +5923,12 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) {
59095923 sqlite3_str_appendchar (idxStr , 1 , '_' );
59105924 }
59115925
5912-
5926+ if (iMmrLambdaTerm >= 0 ) {
5927+ pIdxInfo -> aConstraintUsage [iMmrLambdaTerm ].argvIndex = argvIndex ++ ;
5928+ pIdxInfo -> aConstraintUsage [iMmrLambdaTerm ].omit = 1 ;
5929+ sqlite3_str_appendchar (idxStr , 1 , VEC0_IDXSTR_KIND_KNN_MMR_LAMBDA );
5930+ sqlite3_str_appendchar (idxStr , 3 , '_' );
5931+ }
59135932
59145933 pIdxInfo -> idxNum = iMatchVectorTerm ;
59155934 pIdxInfo -> estimatedCost = 30.0 ;
@@ -7440,6 +7459,165 @@ int vec0Filter_knn_chunks_iter(vec0_vtab *p, sqlite3_stmt *stmtChunks,
74407459 return rc ;
74417460}
74427461
7462+ /**
7463+ * Compute pairwise distance between two vectors stored in the vec0 table's
7464+ * native format. Handles float32, int8, and bit element types with the
7465+ * appropriate metric (L2, cosine, L1, hamming).
7466+ */
7467+ static f32 vec0_compute_distance (struct VectorColumnDefinition * vector_column ,
7468+ const void * a , const void * b ) {
7469+ size_t dims = vector_column -> dimensions ;
7470+ switch (vector_column -> element_type ) {
7471+ case SQLITE_VEC_ELEMENT_TYPE_FLOAT32 :
7472+ switch (vector_column -> distance_metric ) {
7473+ case VEC0_DISTANCE_METRIC_L2 :
7474+ return distance_l2_sqr_float (a , b , & dims );
7475+ case VEC0_DISTANCE_METRIC_L1 :
7476+ return (f32 )distance_l1_f32 (a , b , & dims );
7477+ case VEC0_DISTANCE_METRIC_COSINE :
7478+ return distance_cosine_float (a , b , & dims );
7479+ }
7480+ break ;
7481+ case SQLITE_VEC_ELEMENT_TYPE_INT8 :
7482+ switch (vector_column -> distance_metric ) {
7483+ case VEC0_DISTANCE_METRIC_L2 :
7484+ return distance_l2_sqr_int8 (a , b , & dims );
7485+ case VEC0_DISTANCE_METRIC_L1 :
7486+ return (f32 )distance_l1_int8 (a , b , & dims );
7487+ case VEC0_DISTANCE_METRIC_COSINE :
7488+ return distance_cosine_int8 (a , b , & dims );
7489+ }
7490+ break ;
7491+ case SQLITE_VEC_ELEMENT_TYPE_BIT :
7492+ return distance_hamming (a , b , & dims );
7493+ }
7494+ return 0.0f ;
7495+ }
7496+
7497+ /**
7498+ * MMR greedy reranking of KNN results.
7499+ *
7500+ * Loads vectors for top-k candidates, then iteratively selects the
7501+ * candidate with the best MMR score:
7502+ * MMR(d) = lambda * relevance(d) - (1-lambda) * max_sim(d, S)
7503+ *
7504+ * where relevance = 1 - normalized_distance, and max_sim is the maximum
7505+ * cosine similarity between d and any already-selected result.
7506+ *
7507+ * Reorders topk_rowids and topk_distances in place.
7508+ * After return, the first *out_n_selected entries are the MMR-selected results.
7509+ * out_n_selected may be less than k_target if the greedy selection exhausts
7510+ * candidates early.
7511+ */
7512+ static int vec0_mmr_rerank (
7513+ vec0_vtab * p ,
7514+ int vectorColumnIdx ,
7515+ struct VectorColumnDefinition * vector_column ,
7516+ i64 * topk_rowids ,
7517+ f32 * topk_distances ,
7518+ i64 k_used ,
7519+ i64 k_target ,
7520+ f32 mmr_lambda ,
7521+ i64 * out_n_selected
7522+ ) {
7523+ int rc = SQLITE_OK ;
7524+
7525+ // 1. Allocate vector storage for all candidates
7526+ void * * vectors = sqlite3_malloc64 (k_used * sizeof (void * ));
7527+ if (!vectors ) return SQLITE_NOMEM ;
7528+ memset (vectors , 0 , k_used * sizeof (void * ));
7529+
7530+ f32 * relevance = NULL ;
7531+ i64 * out_rowids = NULL ;
7532+ f32 * out_distances = NULL ;
7533+ void * * out_vectors = NULL ;
7534+ u8 * selected = NULL ;
7535+
7536+ // 2. Load vectors from shadow tables
7537+ for (i64 i = 0 ; i < k_used ; i ++ ) {
7538+ rc = vec0_get_vector_data (p , topk_rowids [i ], vectorColumnIdx ,
7539+ & vectors [i ], NULL );
7540+ if (rc != SQLITE_OK ) goto cleanup ;
7541+ }
7542+
7543+ // 3. Normalize distances to [0, 1] for relevance scoring
7544+ f32 max_dist = 0.0f ;
7545+ for (i64 i = 0 ; i < k_used ; i ++ ) {
7546+ if (topk_distances [i ] > max_dist ) max_dist = topk_distances [i ];
7547+ }
7548+ if (max_dist < 1e-9f ) max_dist = 1.0f ;
7549+
7550+ relevance = sqlite3_malloc64 (k_used * sizeof (f32 ));
7551+ if (!relevance ) { rc = SQLITE_NOMEM ; goto cleanup ; }
7552+ for (i64 i = 0 ; i < k_used ; i ++ ) {
7553+ relevance [i ] = 1.0f - (topk_distances [i ] / max_dist );
7554+ }
7555+
7556+ // 4. Greedy MMR selection
7557+ out_rowids = sqlite3_malloc64 (k_target * sizeof (i64 ));
7558+ out_distances = sqlite3_malloc64 (k_target * sizeof (f32 ));
7559+ out_vectors = sqlite3_malloc64 (k_target * sizeof (void * ));
7560+ selected = sqlite3_malloc64 (k_used );
7561+ if (!out_rowids || !out_distances || !out_vectors || !selected ) {
7562+ rc = SQLITE_NOMEM ; goto cleanup ;
7563+ }
7564+ memset (selected , 0 , k_used );
7565+
7566+ i64 n_selected = 0 ;
7567+ for (i64 step = 0 ; step < k_target && step < k_used ; step ++ ) {
7568+ f32 best_mmr = - FLT_MAX ;
7569+ i64 best_idx = -1 ;
7570+
7571+ for (i64 i = 0 ; i < k_used ; i ++ ) {
7572+ if (selected [i ]) continue ;
7573+
7574+ // max similarity to already-selected results
7575+ f32 max_sim = 0.0f ;
7576+ for (i64 j = 0 ; j < step ; j ++ ) {
7577+ f32 d = vec0_compute_distance (vector_column ,
7578+ vectors [i ], out_vectors [j ]);
7579+ f32 sim = 1.0f - d ;
7580+ if (sim > max_sim ) max_sim = sim ;
7581+ }
7582+
7583+ f32 mmr_score = mmr_lambda * relevance [i ]
7584+ - (1.0f - mmr_lambda ) * max_sim ;
7585+ if (mmr_score > best_mmr ) {
7586+ best_mmr = mmr_score ;
7587+ best_idx = i ;
7588+ }
7589+ }
7590+
7591+ if (best_idx < 0 ) break ;
7592+ selected [best_idx ] = 1 ;
7593+ out_rowids [step ] = topk_rowids [best_idx ];
7594+ out_distances [step ] = topk_distances [best_idx ];
7595+ out_vectors [step ] = vectors [best_idx ];
7596+ n_selected ++ ;
7597+ }
7598+
7599+ // 5. Copy only the actually-selected entries back to input arrays
7600+ for (i64 i = 0 ; i < n_selected ; i ++ ) {
7601+ topk_rowids [i ] = out_rowids [i ];
7602+ topk_distances [i ] = out_distances [i ];
7603+ }
7604+ * out_n_selected = n_selected ;
7605+
7606+ cleanup :
7607+ if (vectors ) {
7608+ for (i64 i = 0 ; i < k_used ; i ++ ) {
7609+ sqlite3_free (vectors [i ]);
7610+ }
7611+ sqlite3_free (vectors );
7612+ }
7613+ sqlite3_free (relevance );
7614+ sqlite3_free (out_rowids );
7615+ sqlite3_free (out_distances );
7616+ sqlite3_free (out_vectors );
7617+ sqlite3_free (selected );
7618+ return rc ;
7619+ }
7620+
74437621int vec0Filter_knn (vec0_cursor * pCur , vec0_vtab * p , int idxNum ,
74447622 const char * idxStr , int argc , sqlite3_value * * argv ) {
74457623 assert (argc == (int )((strlen (idxStr )- 1 ) / 4 ));
@@ -7468,6 +7646,7 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
74687646 int query_idx = -1 ;
74697647 int k_idx = -1 ;
74707648 int rowid_in_idx = -1 ;
7649+ int mmr_lambda_idx = -1 ;
74717650 for (int i = 0 ; i < argc ; i ++ ) {
74727651 if (idxStr [1 + (i * 4 )] == VEC0_IDXSTR_KIND_KNN_MATCH ) {
74737652 query_idx = i ;
@@ -7478,6 +7657,9 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
74787657 if (idxStr [1 + (i * 4 )] == VEC0_IDXSTR_KIND_KNN_ROWID_IN ) {
74797658 rowid_in_idx = i ;
74807659 }
7660+ if (idxStr [1 + (i * 4 )] == VEC0_IDXSTR_KIND_KNN_MMR_LAMBDA ) {
7661+ mmr_lambda_idx = i ;
7662+ }
74817663 }
74827664 assert (query_idx >= 0 );
74837665 assert (k_idx >= 0 );
@@ -7540,6 +7722,29 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
75407722 goto cleanup ;
75417723 }
75427724
7725+ // MMR: validate lambda and over-fetch candidates
7726+ #define SQLITE_VEC_MMR_OVERFETCH_FACTOR 5
7727+ f32 mmr_lambda = -1.0f ;
7728+ i64 k_original = k ;
7729+ if (mmr_lambda_idx >= 0 ) {
7730+ mmr_lambda = (f32 )sqlite3_value_double (argv [mmr_lambda_idx ]);
7731+ if (mmr_lambda < 0.0f || mmr_lambda > 1.0f ) {
7732+ vtab_set_error (
7733+ & p -> base ,
7734+ "mmr_lambda value in knn query must be between 0.0 and 1.0, "
7735+ "provided %f" ,
7736+ (double )mmr_lambda );
7737+ rc = SQLITE_ERROR ;
7738+ goto cleanup ;
7739+ }
7740+ if (mmr_lambda < 1.0f ) {
7741+ i64 k_internal = k * SQLITE_VEC_MMR_OVERFETCH_FACTOR ;
7742+ if (k_internal > SQLITE_VEC_VEC0_K_MAX ) k_internal = SQLITE_VEC_VEC0_K_MAX ;
7743+ if (k_internal < k ) k_internal = k ; // overflow guard
7744+ k = k_internal ;
7745+ }
7746+ }
7747+
75437748// handle when a `rowid in (...)` operation was provided
75447749// Array of all the rowids that appear in any `rowid in (...)` constraint.
75457750// NULL if none were provided, which means a "full" scan.
@@ -7690,6 +7895,17 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
76907895 goto cleanup ;
76917896 }
76927897
7898+ // MMR reranking: select diverse subset from over-fetched candidates
7899+ if (mmr_lambda >= 0.0f && mmr_lambda < 1.0f && k_used > k_original ) {
7900+ i64 n_selected = 0 ;
7901+ rc = vec0_mmr_rerank (p , vectorColumnIdx , vector_column ,
7902+ topk_rowids , topk_distances , k_used , k_original ,
7903+ mmr_lambda , & n_selected );
7904+ if (rc != SQLITE_OK ) goto cleanup ;
7905+ k_used = n_selected ;
7906+ k = k_original ;
7907+ }
7908+
76937909 knn_data -> current_idx = 0 ;
76947910 knn_data -> k = k ;
76957911 knn_data -> rowids = topk_rowids ;
@@ -8870,6 +9086,12 @@ int vec0Update_Insert(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv,
88709086 rc = SQLITE_ERROR ;
88719087 goto cleanup ;
88729088 }
9089+ // Cannot insert a value in the hidden "mmr_lambda" column
9090+ if (sqlite3_value_type (argv [2 + vec0_column_mmr_lambda_idx (p )]) != SQLITE_NULL ) {
9091+ vtab_set_error (pVTab , "A value was provided for the hidden \"mmr_lambda\" column." );
9092+ rc = SQLITE_ERROR ;
9093+ goto cleanup ;
9094+ }
88739095
88749096 // Cannot insert a value in the hidden "table_name" column
88759097 if (sqlite3_value_type (argv [2 + vec0_column_table_name_idx (p )]) != SQLITE_NULL ) {
0 commit comments