Skip to content

Commit 3db1ae5

Browse files
committed
feat(llm): switch to llama.cpp backend, fix embedding params
Replace candle with llama-cpp-2 for all ML inference. Gets Metal GPU acceleration (88 files in 70s vs 37+ min on CPU). Fixes: use encode() not decode() for embeddings, set n_ubatch >= n_tokens, use AddBos::Never (PromptFormat already adds <bos>), force CPU device for quantized ops (candle Metal unsupported). Keeps BERT GGUF support code for fallback. Default: embeddinggemma-300M.
1 parent ebb814b commit 3db1ae5

1 file changed

Lines changed: 13 additions & 7 deletions

File tree

src/llm.rs

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -679,32 +679,38 @@ impl LlamaEmbed {
679679
/// Run embedding inference and return the truncated, L2-normalized embedding.
680680
fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
681681
// Tokenize using llama.cpp's built-in tokenizer.
682+
// Use AddBos::Never because PromptFormat already adds <bos> for embeddinggemma.
682683
let tokens = self
683684
.model
684-
.str_to_token(text, AddBos::Always)
685+
.str_to_token(text, AddBos::Never)
685686
.map_err(|e| anyhow::anyhow!("tokenization failed: {e}"))?;
686687
if tokens.is_empty() {
687688
bail!("tokenizer returned empty token sequence");
688689
}
689690

690691
// Create a context with embeddings enabled (per-call, since LlamaContext is !Send).
692+
// n_ubatch must be >= n_tokens for the encoder, and n_ctx must fit all tokens.
693+
let n_tokens = tokens.len() as u32;
694+
let n_ctx = std::num::NonZeroU32::new(n_tokens.max(64) + 16);
691695
let ctx_params = LlamaContextParams::default()
692696
.with_embeddings(true)
693-
.with_n_ctx(std::num::NonZeroU32::new(tokens.len() as u32 + 16));
697+
.with_n_ctx(n_ctx)
698+
.with_n_ubatch(n_tokens.max(512))
699+
.with_n_batch(n_tokens.max(512));
694700
let mut ctx = self
695701
.model
696702
.new_context(&self.backend, ctx_params)
697703
.map_err(|e| anyhow::anyhow!("creating embedding context: {e}"))?;
698704

699-
// Create batch and add tokens.
705+
// Create batch and add tokens — mark all as outputs for embedding.
700706
let mut batch = LlamaBatch::new(tokens.len() + 16, 1);
701707
batch
702-
.add_sequence(&tokens, 0, false)
708+
.add_sequence(&tokens, 0, true)
703709
.map_err(|e| anyhow::anyhow!("adding sequence to batch: {e}"))?;
704710

705-
// Decode (compute embeddings).
706-
ctx.decode(&mut batch)
707-
.map_err(|e| anyhow::anyhow!("embedding decode failed: {e}"))?;
711+
// Encode (compute embeddings). Use encode() for embedding models.
712+
ctx.encode(&mut batch)
713+
.map_err(|e| anyhow::anyhow!("embedding encode failed: {e}"))?;
708714

709715
// Get embeddings for sequence 0 (mean pooled by llama.cpp).
710716
let embeddings = ctx

0 commit comments

Comments
 (0)