@@ -132,6 +132,9 @@ typedef struct {
132132 // whisper
133133 struct whisper_context * whisper ;
134134
135+ // embedding
136+ llama_seq_id sequence_id ; // some models requires to be unique across multiple calls to llm_embed_generate
137+
135138 // chat
136139 struct {
137140 char uuid [UUID_STR_MAXLEN ];
@@ -504,21 +507,22 @@ static void llm_embed_normalize (const float *src, float *dest, int dim) {
504507
505508static void llm_embed_generate_run (sqlite3_context * context , const char * text , int32_t text_len ) {
506509 ai_context * ai = (ai_context * )sqlite3_user_data (context );
510+ struct llama_model * model = ai -> model ;
507511
508512 // sanity check model
509- if (llama_model_has_encoder (ai -> model ) && llama_model_has_decoder (ai -> model )) {
513+ if (llama_model_has_encoder (model ) && llama_model_has_decoder (model )) {
510514 sqlite_context_result_error (context , SQLITE_ERROR , "Computing embeddings in encoder-decoder models is not supported" );
511515 return ;
512516 }
513517
514518 // sanity check model type (decode is used to create embeddings)
515- if (llama_model_has_decoder (ai -> model ) == false) {
519+ if (llama_model_has_decoder (model ) == false) {
516520 sqlite_context_result_error (context , SQLITE_ERROR , "Model does not support decoding (required for embedding)" );
517521 return ;
518522 }
519523
520524 // sanity check vocab
521- const struct llama_vocab * vocab = llama_model_get_vocab (ai -> model );
525+ const struct llama_vocab * vocab = llama_model_get_vocab (model );
522526 if (!vocab ) {
523527 sqlite_context_result_error (context , SQLITE_ERROR , "Failed to extract vocabulary from the model" );
524528 return ;
@@ -535,7 +539,7 @@ static void llm_embed_generate_run (sqlite3_context *context, const char *text,
535539 llama_set_embeddings (ctx , true);
536540
537541 // sanity check tokens
538- const int n_ctx_train = llama_model_n_ctx_train (ai -> model );
542+ const int n_ctx_train = llama_model_n_ctx_train (model );
539543 const int n_ctx = llama_n_ctx (ctx );
540544 if (n_ctx > n_ctx_train ) {
541545 char buffer [512 ];
@@ -595,24 +599,24 @@ static void llm_embed_generate_run (sqlite3_context *context, const char *text,
595599
596600 // set up batch for processing
597601 llama_batch batch = llama_batch_init (n_tokens , 0 , 1 );
598- llama_seq_id seq_id = 0 ;
599602 for (int i = 0 ; i < n_tokens ; ++ i ) {
600603 batch .token [batch .n_tokens ] = tokens [i ];
601604 batch .pos [batch .n_tokens ] = i ;
602605 batch .n_seq_id [batch .n_tokens ]= 1 ;
603- batch .seq_id [batch .n_tokens ][0 ] = seq_id ;
606+ batch .seq_id [batch .n_tokens ][0 ] = ai -> sequence_id ++ ;
604607 batch .logits [batch .n_tokens ] = true;
605608 batch .n_tokens ++ ;
606609 }
607610
608611 // do real processing
609612 llama_memory_t memory = llama_get_memory (ctx );
610613 int32_t rc = (memory ) ? llama_decode (ctx , batch ) : llama_encode (ctx , batch );
614+
611615 if (rc < 0 ) {
612616 sqlite3_free (tokens );
613617 sqlite3_free (embedding );
614618 llama_batch_free (batch );
615- sqlite_context_result_error (context , SQLITE_ERROR , "Model decode failed during embedding generation" );
619+ sqlite_context_result_error (context , SQLITE_ERROR , "Model decode failed during embedding generation (%d)" , rc );
616620 return ;
617621 }
618622
@@ -635,11 +639,11 @@ static void llm_embed_generate_run (sqlite3_context *context, const char *text,
635639
636640 // check if JSON output is set
637641 if (ai -> options .json_output ) {
638- sqlite3_str * s = sqlite3_str_new (NULL );
642+ sqlite3_str * s = sqlite3_str_new (sqlite3_context_db_handle ( context ) );
639643 sqlite3_str_appendchar (s , 1 , '[' );
640644 for (int i = 0 ; i < dimension ; i ++ ) {
641645 if (i != 0 ) sqlite3_str_appendchar (s , 1 , ',' );
642- sqlite3_str_appendf (s , "%f " , embedding [i ]);
646+ sqlite3_str_appendf (s , "%.6g " , embedding [i ]);
643647 }
644648 sqlite3_str_appendchar (s , 1 , ']' );
645649
0 commit comments