@@ -595,8 +595,8 @@ pub struct ModelDefaults {
595595impl Default for ModelDefaults {
596596 fn default ( ) -> Self {
597597 Self {
598- embed_uri : "hf:leliuga/all-MiniLM-L6-v2- GGUF/all-MiniLM-L6-v2. Q8_0.gguf" . into ( ) ,
599- embed_dim : 384 ,
598+ embed_uri : "hf:ggml-org/embeddinggemma-300M- GGUF/embeddinggemma-300M- Q8_0.gguf" . into ( ) ,
599+ embed_dim : 256 ,
600600 rerank_uri : "hf:ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF/qwen3-reranker-0.6b-q8_0.gguf"
601601 . into ( ) ,
602602 expand_uri : "hf:Qwen/Qwen3-0.6B-GGUF/Qwen3-0.6B-Q8_0.gguf" . into ( ) ,
@@ -630,12 +630,12 @@ struct EmbedLayer {
630630 attention_wk : CandleQMatMul ,
631631 attention_wv : CandleQMatMul ,
632632 attention_wo : CandleQMatMul ,
633- attention_q_norm : candle_transformers :: quantized_nn :: RmsNorm ,
634- attention_k_norm : candle_transformers :: quantized_nn :: RmsNorm ,
635- attention_norm : candle_transformers :: quantized_nn :: RmsNorm ,
636- post_attention_norm : candle_transformers :: quantized_nn :: RmsNorm ,
637- ffn_norm : candle_transformers :: quantized_nn :: RmsNorm ,
638- post_ffn_norm : candle_transformers :: quantized_nn :: RmsNorm ,
633+ attention_q_norm : candle_nn :: RmsNorm ,
634+ attention_k_norm : candle_nn :: RmsNorm ,
635+ attention_norm : candle_nn :: RmsNorm ,
636+ post_attention_norm : candle_nn :: RmsNorm ,
637+ ffn_norm : candle_nn :: RmsNorm ,
638+ post_ffn_norm : candle_nn :: RmsNorm ,
639639 ffn_gate : CandleQMatMul ,
640640 ffn_up : CandleQMatMul ,
641641 ffn_down : CandleQMatMul ,
@@ -804,7 +804,7 @@ enum EmbedModelVariant {
804804 Gemma {
805805 layers : Vec < EmbedLayer > ,
806806 tok_embeddings : Embedding ,
807- norm : candle_transformers :: quantized_nn :: RmsNorm ,
807+ norm : candle_nn :: RmsNorm ,
808808 embedding_length : usize ,
809809 } ,
810810 Bert {
@@ -962,7 +962,7 @@ impl CandleEmbed {
962962 ) -> Result < (
963963 Vec < EmbedLayer > ,
964964 Embedding ,
965- candle_transformers :: quantized_nn :: RmsNorm ,
965+ candle_nn :: RmsNorm ,
966966 usize ,
967967 ) > {
968968 use candle_core:: quantized:: gguf_file;
@@ -1027,12 +1027,14 @@ impl CandleEmbed {
10271027 . map_err ( |e| anyhow:: anyhow!( "dequantizing token_embd: {e}" ) ) ?;
10281028 let tok_embeddings = Embedding :: new ( tok_embd_deq, embedding_length) ;
10291029
1030- // Final norm.
1030+ // Final norm (dequantize to f32 for Metal compatibility) .
10311031 let norm_qt = ct
10321032 . tensor ( & mut file, "output_norm.weight" , device)
10331033 . map_err ( |e| anyhow:: anyhow!( "loading output_norm.weight: {e}" ) ) ?;
1034- let norm = candle_transformers:: quantized_nn:: RmsNorm :: from_qtensor ( norm_qt, rms_norm_eps)
1035- . map_err ( |e| anyhow:: anyhow!( "creating RmsNorm: {e}" ) ) ?;
1034+ let norm_weight = norm_qt
1035+ . dequantize ( device)
1036+ . map_err ( |e| anyhow:: anyhow!( "dequantizing output_norm.weight: {e}" ) ) ?;
1037+ let norm = candle_nn:: RmsNorm :: new ( norm_weight, rms_norm_eps) ;
10361038
10371039 // Load transformer layers.
10381040 let mut layers = Vec :: with_capacity ( block_count) ;
@@ -1051,15 +1053,17 @@ impl CandleEmbed {
10511053 } } ;
10521054 }
10531055
1054- // Helper: load a norm weight tensor as RmsNorm.
1056+ // Helper: load a norm weight tensor as RmsNorm (dequantize for Metal) .
10551057 macro_rules! load_norm {
10561058 ( $name: expr) => { {
10571059 let full = format!( "{}.{}" , p, $name) ;
10581060 let qt = ct
10591061 . tensor( & mut file, & full, device)
10601062 . map_err( |e| anyhow:: anyhow!( "loading {full}: {e}" ) ) ?;
1061- candle_transformers:: quantized_nn:: RmsNorm :: from_qtensor( qt, rms_norm_eps)
1062- . map_err( |e| anyhow:: anyhow!( "RmsNorm for {full}: {e}" ) ) ?
1063+ let weight = qt
1064+ . dequantize( device)
1065+ . map_err( |e| anyhow:: anyhow!( "dequantizing {full}: {e}" ) ) ?;
1066+ candle_nn:: RmsNorm :: new( weight, rms_norm_eps)
10631067 } } ;
10641068 }
10651069
@@ -1991,10 +1995,10 @@ mod tests {
19911995 fn test_model_defaults ( ) {
19921996 let defaults = ModelDefaults :: default ( ) ;
19931997 assert ! ( defaults. embed_uri. starts_with( "hf:" ) ) ;
1994- assert_eq ! ( defaults. embed_dim, 384 ) ;
1998+ assert_eq ! ( defaults. embed_dim, 256 ) ;
19951999 assert ! (
1996- defaults. embed_uri. contains( "all-MiniLM-L6-v2 " ) ,
1997- "default embed model should be all-MiniLM-L6-v2-GGUF "
2000+ defaults. embed_uri. contains( "embeddinggemma " ) ,
2001+ "default embed model should be embeddinggemma "
19982002 ) ;
19992003 }
20002004
0 commit comments