Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 27 additions & 9 deletions sqlite-lembed.c
Original file line number Diff line number Diff line change
Expand Up @@ -57,23 +57,30 @@ int embed_single(struct llama_model *model, struct llama_context *context,
/** Output float embedding */
float **out_embedding,
/** Output embedding length (n dimensions) */
int *out_dimensions) {
int n_batch = 512;
int n_ctx_train = llama_n_ctx_train(model);
int *out_dimensions,
/** Output error message (caller must sqlite3_free if not NULL) */
char **out_error) {
int n_ctx = llama_n_ctx(context);

llama_token *tokens;
int token_count;
int rc = tokenize(model, input, input_length, &token_count, &tokens);
if(rc != SQLITE_OK) {
// TODO error message
*out_error = sqlite3_mprintf("Failed to tokenize input");
return rc;
}

struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
if (token_count > n_ctx) {
sqlite3_free(tokens);
*out_error = sqlite3_mprintf(
"Input too large: %d tokens exceeds model context size of %d",
token_count, n_ctx);
return SQLITE_ERROR;
}

struct llama_batch batch = llama_batch_init(token_count, 0, 1);

int seq_id = 0;
// llama_batch_add(batch, tokens, 0, )
for (int i = 0; i < token_count; i++) {
batch.token[batch.n_tokens] = tokens[i];
batch.pos[batch.n_tokens] = i;
Expand All @@ -85,18 +92,22 @@ int embed_single(struct llama_model *model, struct llama_context *context,
batch.n_tokens++;
}

sqlite3_free(tokens);
tokens = NULL;

int dimensions = llama_n_embd(model);
float *output_embedding = sqlite3_malloc(sizeof(float) * dimensions);
if(!output_embedding) {
llama_batch_free(batch);
return SQLITE_NOMEM;
}

llama_kv_cache_clear(context); // KV not needed for embeddings?
llama_kv_cache_clear(context);
rc = llama_decode(context, batch);
if(rc != 0) {
sqlite3_free(output_embedding);
llama_batch_free(batch);
*out_error = sqlite3_mprintf("Failed to decode input");
return SQLITE_ERROR;
}

Expand All @@ -110,6 +121,7 @@ int embed_single(struct llama_model *model, struct llama_context *context,
if(!source_embedding) {
sqlite3_free(output_embedding);
llama_batch_free(batch);
*out_error = sqlite3_mprintf("Failed to extract embeddings");
return SQLITE_ERROR;
}

Expand Down Expand Up @@ -302,9 +314,15 @@ static void lembed(sqlite3_context *context, int argc, sqlite3_value **argv) {

int dimensions;
float *embedding;
rc = embed_single(model, ctx, input, input_len, &embedding, &dimensions);
char *error_msg = NULL;
rc = embed_single(model, ctx, input, input_len, &embedding, &dimensions, &error_msg);
if(rc != SQLITE_OK) {
sqlite3_result_error(context, "Error generating embedding", -1);
if (error_msg) {
sqlite3_result_error(context, error_msg, -1);
sqlite3_free(error_msg);
} else {
sqlite3_result_error(context, "Error generating embedding", -1);
}
return;
}
sqlite3_result_blob(context, embedding, sizeof(float) * dimensions, sqlite3_free);
Expand Down