Skip to content

Commit 2ec93bb

Browse files
committed
Fixed an issue related to embedding generation
1 parent 628ea6e commit 2ec93bb

2 files changed

Lines changed: 20 additions & 21 deletions

File tree

src/sqlite-ai.c

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,11 @@ void llm_set_model_options (struct llama_model_params *model_params, llm_options
205205
}
206206

207207
void llm_set_context_options (struct llama_context_params *llama_context, llm_options *options) {
208-
if (options->generate_embedding) llama_context->embeddings = true;
208+
if (options->generate_embedding) {
209+
llama_context->embeddings = true;
210+
llama_context->pooling_type = LLAMA_POOLING_TYPE_LAST;
211+
}
212+
209213
if (options->context_size) {
210214
llama_context->n_ctx = options->context_size;
211215
llama_context->n_batch = options->context_size;
@@ -547,18 +551,6 @@ static void llm_embed_generate_run (sqlite3_context *context, const char *text,
547551
ai_logger(GGML_LOG_LEVEL_WARN, buffer, sqlite3_context_db_handle(context));
548552
}
549553

550-
/*
551-
if (llama_vocab_get_add_sep(vocab)) {
552-
const char *sep = llama_vocab_get_text(vocab, llama_vocab_sep(vocab));
553-
printf("sep: %s\n", sep);
554-
}
555-
556-
if (llama_vocab_get_add_eos(vocab)) {
557-
const char *eos = llama_vocab_get_text(vocab, llama_vocab_eos(vocab));
558-
printf("eos: %s\n", eos);
559-
}
560-
*/
561-
562554
// sanity check embedding memory
563555
int dimension = llama_model_n_embd(llama_get_model(ctx));
564556
float *embedding = (float *)sqlite3_malloc64(sizeof(float) * dimension);
@@ -568,7 +560,7 @@ static void llm_embed_generate_run (sqlite3_context *context, const char *text,
568560
}
569561

570562
// get token count
571-
int32_t n_tokens = -llama_tokenize(vocab, text, text_len, NULL, 0, true, false);
563+
int32_t n_tokens = -llama_tokenize(vocab, text, text_len, NULL, 0, true, true);
572564
if (n_tokens == 0) {
573565
sqlite3_free(embedding);
574566
sqlite_context_result_error(context, SQLITE_ERROR, "Tokenization failed: returned %d tokens", n_tokens);
@@ -599,16 +591,18 @@ static void llm_embed_generate_run (sqlite3_context *context, const char *text,
599591

600592
// set up batch for processing
601593
llama_batch batch = llama_batch_init(n_tokens, 0, 1);
594+
llama_seq_id sequence_id = ai->sequence_id;
602595
for (int i = 0; i < n_tokens; ++i) {
603596
batch.token[batch.n_tokens] = tokens[i];
604597
batch.pos[batch.n_tokens] = i;
605-
batch.n_seq_id[batch.n_tokens]= 1;
606-
batch.seq_id[batch.n_tokens][0] = ai->sequence_id++;
607-
batch.logits[batch.n_tokens] = true;
598+
batch.n_seq_id[batch.n_tokens] = 1;
599+
batch.seq_id[batch.n_tokens][0] = sequence_id;
600+
batch.logits[batch.n_tokens] = i == (n_tokens - 1);
608601
batch.n_tokens++;
609602
}
603+
ai->sequence_id++;
610604

611-
// do real processing
605+
// run model (do real processing)
612606
llama_memory_t memory = llama_get_memory(ctx);
613607
int32_t rc = (memory) ? llama_decode(ctx, batch) : llama_encode(ctx, batch);
614608

@@ -620,11 +614,11 @@ static void llm_embed_generate_run (sqlite3_context *context, const char *text,
620614
return;
621615
}
622616

623-
// retrieve embeddings
617+
// retrieve embeddings (context set to LLAMA_POOLING_TYPE_LAST in llama_init_from_model)
624618
const float *result = NULL;
625619
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
626620
if (pooling_type == LLAMA_POOLING_TYPE_NONE) result = llama_get_embeddings(ctx);
627-
else result = llama_get_embeddings_seq(ctx, 0);
621+
else result = llama_get_embeddings_seq(ctx, sequence_id);
628622

629623
if (result == NULL) {
630624
sqlite3_free(tokens);
@@ -637,6 +631,11 @@ static void llm_embed_generate_run (sqlite3_context *context, const char *text,
637631
// check if normalization is needed (default true)
638632
(ai->options.normalize_embedding) ? llm_embed_normalize(result, embedding, dimension) : memcpy(embedding, result, sizeof(float) * dimension);
639633

634+
if (memory) {
635+
llama_memory_clear(memory, true);
636+
llama_memory_seq_rm(memory, sequence_id, 0, -1);
637+
}
638+
640639
// check if JSON output is set
641640
if (ai->options.json_output) {
642641
sqlite3_str *s = sqlite3_str_new(sqlite3_context_db_handle(context));

src/sqlite-ai.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
extern "C" {
2525
#endif
2626

27-
#define SQLITE_AI_VERSION "0.5.6"
27+
#define SQLITE_AI_VERSION "0.5.7"
2828

2929
SQLITE_AI_API int sqlite3_ai_init (sqlite3 *db, char **pzErrMsg, const sqlite3_api_routines *pApi);
3030

0 commit comments

Comments
 (0)