diff --git a/sqlite-lembed.c b/sqlite-lembed.c index 479a554..9b9e574 100644 --- a/sqlite-lembed.c +++ b/sqlite-lembed.c @@ -57,23 +57,30 @@ int embed_single(struct llama_model *model, struct llama_context *context, /** Output float embedding */ float **out_embedding, /** Output embedding length (n dimensions) */ - int *out_dimensions) { - int n_batch = 512; - int n_ctx_train = llama_n_ctx_train(model); + int *out_dimensions, + /** Output error message (caller must sqlite3_free if not NULL) */ + char **out_error) { int n_ctx = llama_n_ctx(context); llama_token *tokens; int token_count; int rc = tokenize(model, input, input_length, &token_count, &tokens); if(rc != SQLITE_OK) { - // TODO error message + *out_error = sqlite3_mprintf("Failed to tokenize input"); return rc; } - struct llama_batch batch = llama_batch_init(n_batch, 0, 1); + if (token_count > n_ctx) { + sqlite3_free(tokens); + *out_error = sqlite3_mprintf( + "Input too large: %d tokens exceeds model context size of %d", + token_count, n_ctx); + return SQLITE_ERROR; + } + + struct llama_batch batch = llama_batch_init(token_count, 0, 1); int seq_id = 0; - // llama_batch_add(batch, tokens, 0, ) for (int i = 0; i < token_count; i++) { batch.token[batch.n_tokens] = tokens[i]; batch.pos[batch.n_tokens] = i; @@ -85,6 +92,9 @@ int embed_single(struct llama_model *model, struct llama_context *context, batch.n_tokens++; } + sqlite3_free(tokens); + tokens = NULL; + int dimensions = llama_n_embd(model); float *output_embedding = sqlite3_malloc(sizeof(float) * dimensions); if(!output_embedding) { @@ -92,11 +102,12 @@ int embed_single(struct llama_model *model, struct llama_context *context, return SQLITE_NOMEM; } - llama_kv_cache_clear(context); // KV not needed for embeddings? + llama_kv_cache_clear(context); rc = llama_decode(context, batch); if(rc != 0) { sqlite3_free(output_embedding); llama_batch_free(batch); + *out_error = sqlite3_mprintf("Failed to decode input"); return SQLITE_ERROR; } @@ -110,6 +121,7 @@ int embed_single(struct llama_model *model, struct llama_context *context, if(!source_embedding) { sqlite3_free(output_embedding); llama_batch_free(batch); + *out_error = sqlite3_mprintf("Failed to extract embeddings"); return SQLITE_ERROR; } @@ -302,9 +314,15 @@ static void lembed(sqlite3_context *context, int argc, sqlite3_value **argv) { int dimensions; float *embedding; - rc = embed_single(model, ctx, input, input_len, &embedding, &dimensions); + char *error_msg = NULL; + rc = embed_single(model, ctx, input, input_len, &embedding, &dimensions, &error_msg); if(rc != SQLITE_OK) { - sqlite3_result_error(context, "Error generating embedding", -1); + if (error_msg) { + sqlite3_result_error(context, error_msg, -1); + sqlite3_free(error_msg); + } else { + sqlite3_result_error(context, "Error generating embedding", -1); + } return; } sqlite3_result_blob(context, embedding, sizeof(float) * dimensions, sqlite3_free);