Skip to content

Commit 628ea6e

Browse files
committed
Fixed an issue that occurs in some model with multiple llm_embed_generate invocations
Some models require the seq_id field of the batch to be unique across multiple invocations.
1 parent b4bb31d commit 628ea6e

2 files changed

Lines changed: 14 additions & 10 deletions

File tree

src/sqlite-ai.c

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@ typedef struct {
132132
// whisper
133133
struct whisper_context *whisper;
134134

135+
// embedding
136+
llama_seq_id sequence_id; // some models requires to be unique across multiple calls to llm_embed_generate
137+
135138
// chat
136139
struct {
137140
char uuid[UUID_STR_MAXLEN];
@@ -504,21 +507,22 @@ static void llm_embed_normalize (const float *src, float *dest, int dim) {
504507

505508
static void llm_embed_generate_run (sqlite3_context *context, const char *text, int32_t text_len) {
506509
ai_context *ai = (ai_context *)sqlite3_user_data(context);
510+
struct llama_model *model = ai->model;
507511

508512
// sanity check model
509-
if (llama_model_has_encoder(ai->model) && llama_model_has_decoder(ai->model)) {
513+
if (llama_model_has_encoder(model) && llama_model_has_decoder(model)) {
510514
sqlite_context_result_error(context, SQLITE_ERROR, "Computing embeddings in encoder-decoder models is not supported");
511515
return;
512516
}
513517

514518
// sanity check model type (decode is used to create embeddings)
515-
if (llama_model_has_decoder(ai->model) == false) {
519+
if (llama_model_has_decoder(model) == false) {
516520
sqlite_context_result_error(context, SQLITE_ERROR, "Model does not support decoding (required for embedding)");
517521
return;
518522
}
519523

520524
// sanity check vocab
521-
const struct llama_vocab *vocab = llama_model_get_vocab(ai->model);
525+
const struct llama_vocab *vocab = llama_model_get_vocab(model);
522526
if (!vocab) {
523527
sqlite_context_result_error(context, SQLITE_ERROR, "Failed to extract vocabulary from the model");
524528
return;
@@ -535,7 +539,7 @@ static void llm_embed_generate_run (sqlite3_context *context, const char *text,
535539
llama_set_embeddings(ctx, true);
536540

537541
// sanity check tokens
538-
const int n_ctx_train = llama_model_n_ctx_train(ai->model);
542+
const int n_ctx_train = llama_model_n_ctx_train(model);
539543
const int n_ctx = llama_n_ctx(ctx);
540544
if (n_ctx > n_ctx_train) {
541545
char buffer[512];
@@ -595,24 +599,24 @@ static void llm_embed_generate_run (sqlite3_context *context, const char *text,
595599

596600
// set up batch for processing
597601
llama_batch batch = llama_batch_init(n_tokens, 0, 1);
598-
llama_seq_id seq_id = 0;
599602
for (int i = 0; i < n_tokens; ++i) {
600603
batch.token[batch.n_tokens] = tokens[i];
601604
batch.pos[batch.n_tokens] = i;
602605
batch.n_seq_id[batch.n_tokens]= 1;
603-
batch.seq_id[batch.n_tokens][0] = seq_id;
606+
batch.seq_id[batch.n_tokens][0] = ai->sequence_id++;
604607
batch.logits[batch.n_tokens] = true;
605608
batch.n_tokens++;
606609
}
607610

608611
// do real processing
609612
llama_memory_t memory = llama_get_memory(ctx);
610613
int32_t rc = (memory) ? llama_decode(ctx, batch) : llama_encode(ctx, batch);
614+
611615
if (rc < 0) {
612616
sqlite3_free(tokens);
613617
sqlite3_free(embedding);
614618
llama_batch_free(batch);
615-
sqlite_context_result_error(context, SQLITE_ERROR, "Model decode failed during embedding generation");
619+
sqlite_context_result_error(context, SQLITE_ERROR, "Model decode failed during embedding generation (%d)", rc);
616620
return;
617621
}
618622

@@ -635,11 +639,11 @@ static void llm_embed_generate_run (sqlite3_context *context, const char *text,
635639

636640
// check if JSON output is set
637641
if (ai->options.json_output) {
638-
sqlite3_str *s = sqlite3_str_new(NULL);
642+
sqlite3_str *s = sqlite3_str_new(sqlite3_context_db_handle(context));
639643
sqlite3_str_appendchar(s, 1, '[');
640644
for (int i = 0; i < dimension; i++) {
641645
if (i != 0) sqlite3_str_appendchar(s, 1, ',');
642-
sqlite3_str_appendf(s, "%f", embedding[i]);
646+
sqlite3_str_appendf(s, "%.6g", embedding[i]);
643647
}
644648
sqlite3_str_appendchar(s, 1, ']');
645649

src/sqlite-ai.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
extern "C" {
2525
#endif
2626

27-
#define SQLITE_AI_VERSION "0.5.5"
27+
#define SQLITE_AI_VERSION "0.5.6"
2828

2929
SQLITE_AI_API int sqlite3_ai_init (sqlite3 *db, char **pzErrMsg, const sqlite3_api_routines *pApi);
3030

0 commit comments

Comments
 (0)