Skip to content

Commit 69dfda1

Browse files
authored
Merge pull request #6 from MayCXC/mmr-reranking-vlasky
MMR reranking via mmr_lambda hidden column
2 parents c069089 + d7930b5 commit 69dfda1

5 files changed

Lines changed: 725 additions & 11 deletions

File tree

sqlite-vec.c

Lines changed: 224 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3497,6 +3497,7 @@ static sqlite3_module vec_npy_eachModule = {
34973497
#define VEC0_COLUMN_OFFSET_DISTANCE 1
34983498
#define VEC0_COLUMN_OFFSET_K 2
34993499
#define VEC0_COLUMN_OFFSET_TABLE_NAME 3
3500+
#define VEC0_COLUMN_OFFSET_MMR_LAMBDA 4
35003501

35013502
#define VEC0_SHADOW_INFO_NAME "\"%w\".\"%w_info\""
35023503

@@ -3777,6 +3778,14 @@ int vec0_column_table_name_idx(vec0_vtab *p) {
37773778
VEC0_COLUMN_OFFSET_TABLE_NAME;
37783779
}
37793780

3781+
/**
3782+
* Returns the column index for the hidden "mmr_lambda" column.
3783+
*/
3784+
int vec0_column_mmr_lambda_idx(vec0_vtab *p) {
3785+
return VEC0_COLUMN_USERN_START + (vec0_num_defined_user_columns(p) - 1) +
3786+
VEC0_COLUMN_OFFSET_MMR_LAMBDA;
3787+
}
3788+
37803789
/**
37813790
* Returns 1 if the given column-based index is a valid vector column,
37823791
* 0 otherwise.
@@ -5039,7 +5048,7 @@ static int vec0_init(sqlite3 *db, void *pAux, int argc, const char *const *argv,
50395048

50405049
}
50415050
sqlite3_str_appendall(createStr, " distance hidden, k hidden, ");
5042-
sqlite3_str_appendf(createStr, "%s hidden) ", tableName);
5051+
sqlite3_str_appendf(createStr, "%s hidden, mmr_lambda hidden) ", tableName);
50435052
if (pkColumnName) {
50445053
sqlite3_str_appendall(createStr, "without rowid ");
50455054
}
@@ -5454,6 +5463,7 @@ typedef enum {
54545463

54555464
// ~~~ ??? ~~~ //
54565465
VEC0_IDXSTR_KIND_METADATA_CONSTRAINT = '&',
5466+
VEC0_IDXSTR_KIND_KNN_MMR_LAMBDA = '#',
54575467
} vec0_idxstr_kind;
54585468

54595469
// The different SQLITE_INDEX_CONSTRAINT values that vec0 partition key columns
@@ -5523,6 +5533,7 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) {
55235533
int iLimitTerm = -1;
55245534
int iRowidTerm = -1;
55255535
int iKTerm = -1;
5536+
int iMmrLambdaTerm = -1;
55265537
int iRowidInTerm = -1;
55275538
int hasAuxConstraint = 0;
55285539

@@ -5579,6 +5590,9 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) {
55795590
if (op == SQLITE_INDEX_CONSTRAINT_EQ && iColumn == vec0_column_k_idx(p)) {
55805591
iKTerm = i;
55815592
}
5593+
if (op == SQLITE_INDEX_CONSTRAINT_EQ && iColumn == vec0_column_mmr_lambda_idx(p)) {
5594+
iMmrLambdaTerm = i;
5595+
}
55825596
if(
55835597
(op != SQLITE_INDEX_CONSTRAINT_LIMIT && op != SQLITE_INDEX_CONSTRAINT_OFFSET)
55845598
&& vec0_column_idx_is_auxiliary(p, iColumn)) {
@@ -5909,7 +5923,12 @@ static int vec0BestIndex(sqlite3_vtab *pVTab, sqlite3_index_info *pIdxInfo) {
59095923
sqlite3_str_appendchar(idxStr, 1, '_');
59105924
}
59115925

5912-
5926+
if (iMmrLambdaTerm >= 0) {
5927+
pIdxInfo->aConstraintUsage[iMmrLambdaTerm].argvIndex = argvIndex++;
5928+
pIdxInfo->aConstraintUsage[iMmrLambdaTerm].omit = 1;
5929+
sqlite3_str_appendchar(idxStr, 1, VEC0_IDXSTR_KIND_KNN_MMR_LAMBDA);
5930+
sqlite3_str_appendchar(idxStr, 3, '_');
5931+
}
59135932

59145933
pIdxInfo->idxNum = iMatchVectorTerm;
59155934
pIdxInfo->estimatedCost = 30.0;
@@ -7440,6 +7459,165 @@ int vec0Filter_knn_chunks_iter(vec0_vtab *p, sqlite3_stmt *stmtChunks,
74407459
return rc;
74417460
}
74427461

7462+
/**
7463+
* Compute pairwise distance between two vectors stored in the vec0 table's
7464+
* native format. Handles float32, int8, and bit element types with the
7465+
* appropriate metric (L2, cosine, L1, hamming).
7466+
*/
7467+
static f32 vec0_compute_distance(struct VectorColumnDefinition *vector_column,
7468+
const void *a, const void *b) {
7469+
size_t dims = vector_column->dimensions;
7470+
switch (vector_column->element_type) {
7471+
case SQLITE_VEC_ELEMENT_TYPE_FLOAT32:
7472+
switch (vector_column->distance_metric) {
7473+
case VEC0_DISTANCE_METRIC_L2:
7474+
return distance_l2_sqr_float(a, b, &dims);
7475+
case VEC0_DISTANCE_METRIC_L1:
7476+
return (f32)distance_l1_f32(a, b, &dims);
7477+
case VEC0_DISTANCE_METRIC_COSINE:
7478+
return distance_cosine_float(a, b, &dims);
7479+
}
7480+
break;
7481+
case SQLITE_VEC_ELEMENT_TYPE_INT8:
7482+
switch (vector_column->distance_metric) {
7483+
case VEC0_DISTANCE_METRIC_L2:
7484+
return distance_l2_sqr_int8(a, b, &dims);
7485+
case VEC0_DISTANCE_METRIC_L1:
7486+
return (f32)distance_l1_int8(a, b, &dims);
7487+
case VEC0_DISTANCE_METRIC_COSINE:
7488+
return distance_cosine_int8(a, b, &dims);
7489+
}
7490+
break;
7491+
case SQLITE_VEC_ELEMENT_TYPE_BIT:
7492+
return distance_hamming(a, b, &dims);
7493+
}
7494+
return 0.0f;
7495+
}
7496+
7497+
/**
7498+
* MMR greedy reranking of KNN results.
7499+
*
7500+
* Loads vectors for top-k candidates, then iteratively selects the
7501+
* candidate with the best MMR score:
7502+
* MMR(d) = lambda * relevance(d) - (1-lambda) * max_sim(d, S)
7503+
*
7504+
* where relevance = 1 - normalized_distance, and max_sim is the maximum
7505+
* cosine similarity between d and any already-selected result.
7506+
*
7507+
* Reorders topk_rowids and topk_distances in place.
7508+
* After return, the first *out_n_selected entries are the MMR-selected results.
7509+
* out_n_selected may be less than k_target if the greedy selection exhausts
7510+
* candidates early.
7511+
*/
7512+
static int vec0_mmr_rerank(
7513+
vec0_vtab *p,
7514+
int vectorColumnIdx,
7515+
struct VectorColumnDefinition *vector_column,
7516+
i64 *topk_rowids,
7517+
f32 *topk_distances,
7518+
i64 k_used,
7519+
i64 k_target,
7520+
f32 mmr_lambda,
7521+
i64 *out_n_selected
7522+
) {
7523+
int rc = SQLITE_OK;
7524+
7525+
// 1. Allocate vector storage for all candidates
7526+
void **vectors = sqlite3_malloc64(k_used * sizeof(void *));
7527+
if (!vectors) return SQLITE_NOMEM;
7528+
memset(vectors, 0, k_used * sizeof(void *));
7529+
7530+
f32 *relevance = NULL;
7531+
i64 *out_rowids = NULL;
7532+
f32 *out_distances = NULL;
7533+
void **out_vectors = NULL;
7534+
u8 *selected = NULL;
7535+
7536+
// 2. Load vectors from shadow tables
7537+
for (i64 i = 0; i < k_used; i++) {
7538+
rc = vec0_get_vector_data(p, topk_rowids[i], vectorColumnIdx,
7539+
&vectors[i], NULL);
7540+
if (rc != SQLITE_OK) goto cleanup;
7541+
}
7542+
7543+
// 3. Normalize distances to [0, 1] for relevance scoring
7544+
f32 max_dist = 0.0f;
7545+
for (i64 i = 0; i < k_used; i++) {
7546+
if (topk_distances[i] > max_dist) max_dist = topk_distances[i];
7547+
}
7548+
if (max_dist < 1e-9f) max_dist = 1.0f;
7549+
7550+
relevance = sqlite3_malloc64(k_used * sizeof(f32));
7551+
if (!relevance) { rc = SQLITE_NOMEM; goto cleanup; }
7552+
for (i64 i = 0; i < k_used; i++) {
7553+
relevance[i] = 1.0f - (topk_distances[i] / max_dist);
7554+
}
7555+
7556+
// 4. Greedy MMR selection
7557+
out_rowids = sqlite3_malloc64(k_target * sizeof(i64));
7558+
out_distances = sqlite3_malloc64(k_target * sizeof(f32));
7559+
out_vectors = sqlite3_malloc64(k_target * sizeof(void *));
7560+
selected = sqlite3_malloc64(k_used);
7561+
if (!out_rowids || !out_distances || !out_vectors || !selected) {
7562+
rc = SQLITE_NOMEM; goto cleanup;
7563+
}
7564+
memset(selected, 0, k_used);
7565+
7566+
i64 n_selected = 0;
7567+
for (i64 step = 0; step < k_target && step < k_used; step++) {
7568+
f32 best_mmr = -FLT_MAX;
7569+
i64 best_idx = -1;
7570+
7571+
for (i64 i = 0; i < k_used; i++) {
7572+
if (selected[i]) continue;
7573+
7574+
// max similarity to already-selected results
7575+
f32 max_sim = 0.0f;
7576+
for (i64 j = 0; j < step; j++) {
7577+
f32 d = vec0_compute_distance(vector_column,
7578+
vectors[i], out_vectors[j]);
7579+
f32 sim = 1.0f - d;
7580+
if (sim > max_sim) max_sim = sim;
7581+
}
7582+
7583+
f32 mmr_score = mmr_lambda * relevance[i]
7584+
- (1.0f - mmr_lambda) * max_sim;
7585+
if (mmr_score > best_mmr) {
7586+
best_mmr = mmr_score;
7587+
best_idx = i;
7588+
}
7589+
}
7590+
7591+
if (best_idx < 0) break;
7592+
selected[best_idx] = 1;
7593+
out_rowids[step] = topk_rowids[best_idx];
7594+
out_distances[step] = topk_distances[best_idx];
7595+
out_vectors[step] = vectors[best_idx];
7596+
n_selected++;
7597+
}
7598+
7599+
// 5. Copy only the actually-selected entries back to input arrays
7600+
for (i64 i = 0; i < n_selected; i++) {
7601+
topk_rowids[i] = out_rowids[i];
7602+
topk_distances[i] = out_distances[i];
7603+
}
7604+
*out_n_selected = n_selected;
7605+
7606+
cleanup:
7607+
if (vectors) {
7608+
for (i64 i = 0; i < k_used; i++) {
7609+
sqlite3_free(vectors[i]);
7610+
}
7611+
sqlite3_free(vectors);
7612+
}
7613+
sqlite3_free(relevance);
7614+
sqlite3_free(out_rowids);
7615+
sqlite3_free(out_distances);
7616+
sqlite3_free(out_vectors);
7617+
sqlite3_free(selected);
7618+
return rc;
7619+
}
7620+
74437621
int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
74447622
const char *idxStr, int argc, sqlite3_value **argv) {
74457623
assert(argc == (int)((strlen(idxStr)-1) / 4));
@@ -7468,6 +7646,7 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
74687646
int query_idx =-1;
74697647
int k_idx = -1;
74707648
int rowid_in_idx = -1;
7649+
int mmr_lambda_idx = -1;
74717650
for(int i = 0; i < argc; i++) {
74727651
if(idxStr[1 + (i*4)] == VEC0_IDXSTR_KIND_KNN_MATCH) {
74737652
query_idx = i;
@@ -7478,6 +7657,9 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
74787657
if(idxStr[1 + (i*4)] == VEC0_IDXSTR_KIND_KNN_ROWID_IN) {
74797658
rowid_in_idx = i;
74807659
}
7660+
if(idxStr[1 + (i*4)] == VEC0_IDXSTR_KIND_KNN_MMR_LAMBDA) {
7661+
mmr_lambda_idx = i;
7662+
}
74817663
}
74827664
assert(query_idx >= 0);
74837665
assert(k_idx >= 0);
@@ -7540,6 +7722,29 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
75407722
goto cleanup;
75417723
}
75427724

7725+
// MMR: validate lambda and over-fetch candidates
7726+
#define SQLITE_VEC_MMR_OVERFETCH_FACTOR 5
7727+
f32 mmr_lambda = -1.0f;
7728+
i64 k_original = k;
7729+
if (mmr_lambda_idx >= 0) {
7730+
mmr_lambda = (f32)sqlite3_value_double(argv[mmr_lambda_idx]);
7731+
if (mmr_lambda < 0.0f || mmr_lambda > 1.0f) {
7732+
vtab_set_error(
7733+
&p->base,
7734+
"mmr_lambda value in knn query must be between 0.0 and 1.0, "
7735+
"provided %f",
7736+
(double)mmr_lambda);
7737+
rc = SQLITE_ERROR;
7738+
goto cleanup;
7739+
}
7740+
if (mmr_lambda < 1.0f) {
7741+
i64 k_internal = k * SQLITE_VEC_MMR_OVERFETCH_FACTOR;
7742+
if (k_internal > SQLITE_VEC_VEC0_K_MAX) k_internal = SQLITE_VEC_VEC0_K_MAX;
7743+
if (k_internal < k) k_internal = k; // overflow guard
7744+
k = k_internal;
7745+
}
7746+
}
7747+
75437748
// handle when a `rowid in (...)` operation was provided
75447749
// Array of all the rowids that appear in any `rowid in (...)` constraint.
75457750
// NULL if none were provided, which means a "full" scan.
@@ -7690,6 +7895,17 @@ int vec0Filter_knn(vec0_cursor *pCur, vec0_vtab *p, int idxNum,
76907895
goto cleanup;
76917896
}
76927897

7898+
// MMR reranking: select diverse subset from over-fetched candidates
7899+
if (mmr_lambda >= 0.0f && mmr_lambda < 1.0f && k_used > k_original) {
7900+
i64 n_selected = 0;
7901+
rc = vec0_mmr_rerank(p, vectorColumnIdx, vector_column,
7902+
topk_rowids, topk_distances, k_used, k_original,
7903+
mmr_lambda, &n_selected);
7904+
if (rc != SQLITE_OK) goto cleanup;
7905+
k_used = n_selected;
7906+
k = k_original;
7907+
}
7908+
76937909
knn_data->current_idx = 0;
76947910
knn_data->k = k;
76957911
knn_data->rowids = topk_rowids;
@@ -8870,6 +9086,12 @@ int vec0Update_Insert(sqlite3_vtab *pVTab, int argc, sqlite3_value **argv,
88709086
rc = SQLITE_ERROR;
88719087
goto cleanup;
88729088
}
9089+
// Cannot insert a value in the hidden "mmr_lambda" column
9090+
if (sqlite3_value_type(argv[2 + vec0_column_mmr_lambda_idx(p)]) != SQLITE_NULL) {
9091+
vtab_set_error(pVTab, "A value was provided for the hidden \"mmr_lambda\" column.");
9092+
rc = SQLITE_ERROR;
9093+
goto cleanup;
9094+
}
88739095

88749096
// Cannot insert a value in the hidden "table_name" column
88759097
if (sqlite3_value_type(argv[2 + vec0_column_table_name_idx(p)]) != SQLITE_NULL) {

tests/__snapshots__/test-general.ambr

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@
108108
# ---
109109
# name: test_shadow.1
110110
OrderedDict({
111-
'sql': "select * from pragma_table_list where type = 'shadow'",
111+
'sql': "select * from pragma_table_list where type = 'shadow' order by name",
112112
'rows': list([
113113
OrderedDict({
114114
'schema': 'main',
@@ -136,25 +136,25 @@
136136
}),
137137
OrderedDict({
138138
'schema': 'main',
139-
'name': 'v_rowids',
139+
'name': 'v_metadatachunks00',
140140
'type': 'shadow',
141-
'ncol': 4,
141+
'ncol': 2,
142142
'wr': 0,
143143
'strict': 0,
144144
}),
145145
OrderedDict({
146146
'schema': 'main',
147-
'name': 'v_metadatachunks00',
147+
'name': 'v_metadatatext00',
148148
'type': 'shadow',
149149
'ncol': 2,
150150
'wr': 0,
151151
'strict': 0,
152152
}),
153153
OrderedDict({
154154
'schema': 'main',
155-
'name': 'v_metadatatext00',
155+
'name': 'v_rowids',
156156
'type': 'shadow',
157-
'ncol': 2,
157+
'ncol': 4,
158158
'wr': 0,
159159
'strict': 0,
160160
}),
@@ -163,7 +163,7 @@
163163
# ---
164164
# name: test_shadow.2
165165
OrderedDict({
166-
'sql': "select * from pragma_table_list where type = 'shadow'",
166+
'sql': "select * from pragma_table_list where type = 'shadow' order by name",
167167
'rows': list([
168168
]),
169169
})

0 commit comments

Comments
 (0)