Skip to content

Commit 20be487

Browse files
committed
fix: use float32 RmsNorm for Metal GPU compatibility in Gemma embedding
Replace candle_transformers::quantized_nn::RmsNorm (which lacks a Metal kernel) with candle_nn::RmsNorm throughout the Gemma embedding code. QTensor weights are dequantized to f32 Tensor at load time so the standard RmsNorm forward pass runs on Metal without error. Also restores embeddinggemma as the default model (256-dim), replaces eprint indexing progress with an indicatif progress bar, and fixes store tests to match the new default dimension.
1 parent 4892309 commit 20be487

3 files changed

Lines changed: 38 additions & 26 deletions

File tree

src/indexer.rs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use std::time::{Duration, Instant};
55
use anyhow::{Context, Result, anyhow};
66
use ignore::WalkBuilder;
77
use sha2::{Digest, Sha256};
8+
use indicatif::{ProgressBar, ProgressStyle};
89
use tracing::info;
910

1011
use crate::chunker::{chunk_markdown, split_oversized_chunks};
@@ -561,15 +562,22 @@ fn run_index_inner(
561562
let mut total_chunks = 0usize;
562563
let mut indexed_rel_paths: Vec<String> = Vec::new();
563564

564-
let total_files = file_contents.len();
565+
let pb = ProgressBar::new(file_contents.len() as u64);
566+
pb.set_style(
567+
ProgressStyle::with_template(" [{bar:40.cyan/blue}] {pos}/{len} {msg} ({eta})")
568+
.unwrap()
569+
.progress_chars("=>-"),
570+
);
571+
565572
store.conn().execute_batch("BEGIN DEFERRED")?;
566-
for (i, (rel_str, content, hash)) in file_contents.iter().enumerate() {
567-
eprint!("\r [{}/{}] {}", i + 1, total_files, rel_str);
573+
for (rel_str, content, hash) in &file_contents {
574+
pb.set_message(rel_str.clone());
568575
let result = index_file(rel_str, content, hash, store, embedder, vault_path, config)?;
569576
total_chunks += result.total_chunks;
570577
indexed_rel_paths.push(rel_str.clone());
578+
pb.inc(1);
571579
}
572-
eprintln!("\r [{}/{}] done{}", total_files, total_files, " ".repeat(60));
580+
pb.finish_with_message("done");
573581
store.commit()?;
574582

575583
// Step 9: Build vault graph edges.

src/llm.rs

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -595,8 +595,8 @@ pub struct ModelDefaults {
595595
impl Default for ModelDefaults {
596596
fn default() -> Self {
597597
Self {
598-
embed_uri: "hf:leliuga/all-MiniLM-L6-v2-GGUF/all-MiniLM-L6-v2.Q8_0.gguf".into(),
599-
embed_dim: 384,
598+
embed_uri: "hf:ggml-org/embeddinggemma-300M-GGUF/embeddinggemma-300M-Q8_0.gguf".into(),
599+
embed_dim: 256,
600600
rerank_uri: "hf:ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF/qwen3-reranker-0.6b-q8_0.gguf"
601601
.into(),
602602
expand_uri: "hf:Qwen/Qwen3-0.6B-GGUF/Qwen3-0.6B-Q8_0.gguf".into(),
@@ -630,12 +630,12 @@ struct EmbedLayer {
630630
attention_wk: CandleQMatMul,
631631
attention_wv: CandleQMatMul,
632632
attention_wo: CandleQMatMul,
633-
attention_q_norm: candle_transformers::quantized_nn::RmsNorm,
634-
attention_k_norm: candle_transformers::quantized_nn::RmsNorm,
635-
attention_norm: candle_transformers::quantized_nn::RmsNorm,
636-
post_attention_norm: candle_transformers::quantized_nn::RmsNorm,
637-
ffn_norm: candle_transformers::quantized_nn::RmsNorm,
638-
post_ffn_norm: candle_transformers::quantized_nn::RmsNorm,
633+
attention_q_norm: candle_nn::RmsNorm,
634+
attention_k_norm: candle_nn::RmsNorm,
635+
attention_norm: candle_nn::RmsNorm,
636+
post_attention_norm: candle_nn::RmsNorm,
637+
ffn_norm: candle_nn::RmsNorm,
638+
post_ffn_norm: candle_nn::RmsNorm,
639639
ffn_gate: CandleQMatMul,
640640
ffn_up: CandleQMatMul,
641641
ffn_down: CandleQMatMul,
@@ -804,7 +804,7 @@ enum EmbedModelVariant {
804804
Gemma {
805805
layers: Vec<EmbedLayer>,
806806
tok_embeddings: Embedding,
807-
norm: candle_transformers::quantized_nn::RmsNorm,
807+
norm: candle_nn::RmsNorm,
808808
embedding_length: usize,
809809
},
810810
Bert {
@@ -962,7 +962,7 @@ impl CandleEmbed {
962962
) -> Result<(
963963
Vec<EmbedLayer>,
964964
Embedding,
965-
candle_transformers::quantized_nn::RmsNorm,
965+
candle_nn::RmsNorm,
966966
usize,
967967
)> {
968968
use candle_core::quantized::gguf_file;
@@ -1027,12 +1027,14 @@ impl CandleEmbed {
10271027
.map_err(|e| anyhow::anyhow!("dequantizing token_embd: {e}"))?;
10281028
let tok_embeddings = Embedding::new(tok_embd_deq, embedding_length);
10291029

1030-
// Final norm.
1030+
// Final norm (dequantize to f32 for Metal compatibility).
10311031
let norm_qt = ct
10321032
.tensor(&mut file, "output_norm.weight", device)
10331033
.map_err(|e| anyhow::anyhow!("loading output_norm.weight: {e}"))?;
1034-
let norm = candle_transformers::quantized_nn::RmsNorm::from_qtensor(norm_qt, rms_norm_eps)
1035-
.map_err(|e| anyhow::anyhow!("creating RmsNorm: {e}"))?;
1034+
let norm_weight = norm_qt
1035+
.dequantize(device)
1036+
.map_err(|e| anyhow::anyhow!("dequantizing output_norm.weight: {e}"))?;
1037+
let norm = candle_nn::RmsNorm::new(norm_weight, rms_norm_eps);
10361038

10371039
// Load transformer layers.
10381040
let mut layers = Vec::with_capacity(block_count);
@@ -1051,15 +1053,17 @@ impl CandleEmbed {
10511053
}};
10521054
}
10531055

1054-
// Helper: load a norm weight tensor as RmsNorm.
1056+
// Helper: load a norm weight tensor as RmsNorm (dequantize for Metal).
10551057
macro_rules! load_norm {
10561058
($name:expr) => {{
10571059
let full = format!("{}.{}", p, $name);
10581060
let qt = ct
10591061
.tensor(&mut file, &full, device)
10601062
.map_err(|e| anyhow::anyhow!("loading {full}: {e}"))?;
1061-
candle_transformers::quantized_nn::RmsNorm::from_qtensor(qt, rms_norm_eps)
1062-
.map_err(|e| anyhow::anyhow!("RmsNorm for {full}: {e}"))?
1063+
let weight = qt
1064+
.dequantize(device)
1065+
.map_err(|e| anyhow::anyhow!("dequantizing {full}: {e}"))?;
1066+
candle_nn::RmsNorm::new(weight, rms_norm_eps)
10631067
}};
10641068
}
10651069

@@ -1991,10 +1995,10 @@ mod tests {
19911995
fn test_model_defaults() {
19921996
let defaults = ModelDefaults::default();
19931997
assert!(defaults.embed_uri.starts_with("hf:"));
1994-
assert_eq!(defaults.embed_dim, 384);
1998+
assert_eq!(defaults.embed_dim, 256);
19951999
assert!(
1996-
defaults.embed_uri.contains("all-MiniLM-L6-v2"),
1997-
"default embed model should be all-MiniLM-L6-v2-GGUF"
2000+
defaults.embed_uri.contains("embeddinggemma"),
2001+
"default embed model should be embeddinggemma"
19982002
);
19992003
}
20002004

src/store.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ impl Store {
145145
let dim = self
146146
.get_meta("embedding_dim")?
147147
.and_then(|s| s.parse::<usize>().ok())
148-
.unwrap_or(384);
148+
.unwrap_or(256);
149149
crate::vecstore::init_vec_table(&self.conn, dim)?;
150150
self.migrate_vectors_to_vec0()?;
151151
Ok(())
@@ -2264,7 +2264,7 @@ mod tests {
22642264
#[test]
22652265
fn test_store_vec_roundtrip() {
22662266
let store = Store::open_memory().unwrap();
2267-
let vector: Vec<f32> = (0..384).map(|i| (i as f32) / 384.0).collect();
2267+
let vector: Vec<f32> = (0..256).map(|i| (i as f32) / 256.0).collect();
22682268
store.insert_vec(0, &vector).unwrap();
22692269

22702270
let results = store
@@ -2282,7 +2282,7 @@ mod tests {
22822282
let file_id = store
22832283
.insert_file("test.md", "hash123", 0, &[], "abc123", None)
22842284
.unwrap();
2285-
let vector: Vec<f32> = (0..384).map(|i| (i as f32) / 384.0).collect();
2285+
let vector: Vec<f32> = (0..256).map(|i| (i as f32) / 256.0).collect();
22862286
store
22872287
.insert_chunk_with_vector(file_id, "heading", "snippet", 0, 100, &vector)
22882288
.unwrap();

0 commit comments

Comments
 (0)