@@ -205,7 +205,11 @@ void llm_set_model_options (struct llama_model_params *model_params, llm_options
205205}
206206
207207void llm_set_context_options (struct llama_context_params * llama_context , llm_options * options ) {
208- if (options -> generate_embedding ) llama_context -> embeddings = true;
208+ if (options -> generate_embedding ) {
209+ llama_context -> embeddings = true;
210+ llama_context -> pooling_type = LLAMA_POOLING_TYPE_LAST ;
211+ }
212+
209213 if (options -> context_size ) {
210214 llama_context -> n_ctx = options -> context_size ;
211215 llama_context -> n_batch = options -> context_size ;
@@ -547,18 +551,6 @@ static void llm_embed_generate_run (sqlite3_context *context, const char *text,
547551 ai_logger (GGML_LOG_LEVEL_WARN , buffer , sqlite3_context_db_handle (context ));
548552 }
549553
550- /*
551- if (llama_vocab_get_add_sep(vocab)) {
552- const char *sep = llama_vocab_get_text(vocab, llama_vocab_sep(vocab));
553- printf("sep: %s\n", sep);
554- }
555-
556- if (llama_vocab_get_add_eos(vocab)) {
557- const char *eos = llama_vocab_get_text(vocab, llama_vocab_eos(vocab));
558- printf("eos: %s\n", eos);
559- }
560- */
561-
562554 // sanity check embedding memory
563555 int dimension = llama_model_n_embd (llama_get_model (ctx ));
564556 float * embedding = (float * )sqlite3_malloc64 (sizeof (float ) * dimension );
@@ -568,7 +560,7 @@ static void llm_embed_generate_run (sqlite3_context *context, const char *text,
568560 }
569561
570562 // get token count
571- int32_t n_tokens = - llama_tokenize (vocab , text , text_len , NULL , 0 , true, false );
563+ int32_t n_tokens = - llama_tokenize (vocab , text , text_len , NULL , 0 , true, true );
572564 if (n_tokens == 0 ) {
573565 sqlite3_free (embedding );
574566 sqlite_context_result_error (context , SQLITE_ERROR , "Tokenization failed: returned %d tokens" , n_tokens );
@@ -599,16 +591,18 @@ static void llm_embed_generate_run (sqlite3_context *context, const char *text,
599591
600592 // set up batch for processing
601593 llama_batch batch = llama_batch_init (n_tokens , 0 , 1 );
594+ llama_seq_id sequence_id = ai -> sequence_id ;
602595 for (int i = 0 ; i < n_tokens ; ++ i ) {
603596 batch .token [batch .n_tokens ] = tokens [i ];
604597 batch .pos [batch .n_tokens ] = i ;
605- batch .n_seq_id [batch .n_tokens ]= 1 ;
606- batch .seq_id [batch .n_tokens ][0 ] = ai -> sequence_id ++ ;
607- batch .logits [batch .n_tokens ] = true ;
598+ batch .n_seq_id [batch .n_tokens ] = 1 ;
599+ batch .seq_id [batch .n_tokens ][0 ] = sequence_id ;
600+ batch .logits [batch .n_tokens ] = i == ( n_tokens - 1 ) ;
608601 batch .n_tokens ++ ;
609602 }
603+ ai -> sequence_id ++ ;
610604
611- // do real processing
605+ // run model ( do real processing)
612606 llama_memory_t memory = llama_get_memory (ctx );
613607 int32_t rc = (memory ) ? llama_decode (ctx , batch ) : llama_encode (ctx , batch );
614608
@@ -620,11 +614,11 @@ static void llm_embed_generate_run (sqlite3_context *context, const char *text,
620614 return ;
621615 }
622616
623- // retrieve embeddings
617+ // retrieve embeddings (context set to LLAMA_POOLING_TYPE_LAST in llama_init_from_model)
624618 const float * result = NULL ;
625619 const enum llama_pooling_type pooling_type = llama_pooling_type (ctx );
626620 if (pooling_type == LLAMA_POOLING_TYPE_NONE ) result = llama_get_embeddings (ctx );
627- else result = llama_get_embeddings_seq (ctx , 0 );
621+ else result = llama_get_embeddings_seq (ctx , sequence_id );
628622
629623 if (result == NULL ) {
630624 sqlite3_free (tokens );
@@ -637,6 +631,11 @@ static void llm_embed_generate_run (sqlite3_context *context, const char *text,
637631 // check if normalization is needed (default true)
638632 (ai -> options .normalize_embedding ) ? llm_embed_normalize (result , embedding , dimension ) : memcpy (embedding , result , sizeof (float ) * dimension );
639633
634+ if (memory ) {
635+ llama_memory_clear (memory , true);
636+ llama_memory_seq_rm (memory , sequence_id , 0 , -1 );
637+ }
638+
640639 // check if JSON output is set
641640 if (ai -> options .json_output ) {
642641 sqlite3_str * s = sqlite3_str_new (sqlite3_context_db_handle (context ));
0 commit comments