11//! Embedding-based semantic search using fastembed.
22//!
33//! Stores per-feature embeddings for each entity, enabling max-cosine similarity
4- //! search that preserves multi-role entity semantics. Uses BGE-small-en-v1.5
5- //! (384 dimensions) via the fastembed crate.
4+ //! search that preserves multi-role entity semantics. The embedding model is
5+ //! configurable via `navigation.embedding_model` (default BGE-small-en-v1.5, 384d;
6+ //! "jina-code" selects a 768d code-specialized model) — see `resolve_embedding_model`.
67
78use anyhow:: { Context , Result , ensure} ;
89use fastembed:: { EmbeddingModel , TextEmbedding } ;
@@ -14,7 +15,10 @@ use std::path::{Path, PathBuf};
1415/// Magic bytes for the binary embedding file format.
1516const MAGIC : u32 = 0x5250_4745 ; // "RPGE"
1617const FORMAT_VERSION : u32 = 1 ;
17- const DIMENSION : usize = 384 ;
18+ /// Dimension of the default model (BGE-small-en-v1.5). Other models report their
19+ /// own dimension via `resolve_embedding_model`; the active dimension is stored on
20+ /// the index and in the binary header.
21+ const DEFAULT_DIMENSION : usize = 384 ;
1822
1923/// Loaded embeddings: entity map + fingerprints from disk.
2024type LoadedEmbeddings = ( HashMap < String , EntityEmbeddings > , BTreeMap < String , String > ) ;
@@ -41,6 +45,10 @@ struct EntityEmbeddings {
4145/// In-memory embedding index for semantic search.
4246pub struct EmbeddingIndex {
4347 model : TextEmbedding ,
48+ /// Canonical name of the active embedding model (e.g. "BAAI/bge-small-en-v1.5").
49+ model_name : String ,
50+ /// Vector dimension of the active model.
51+ dimension : usize ,
4452 /// Map from entity_id → feature-level embeddings.
4553 entities : HashMap < String , EntityEmbeddings > ,
4654 /// Path to the .rpg directory for persistence.
@@ -51,6 +59,31 @@ pub struct EmbeddingIndex {
5159 fingerprints : BTreeMap < String , String > ,
5260}
5361
62+ /// Resolve a configured model name to its fastembed model, canonical name, and
63+ /// vector dimension. Unknown names fall back to the default (BGE-small-en-v1.5).
64+ fn resolve_embedding_model ( name : & str ) -> ( EmbeddingModel , & ' static str , usize ) {
65+ match name. trim ( ) . to_lowercase ( ) . as_str ( ) {
66+ "jina-code" | "jinaai/jina-embeddings-v2-base-code" => (
67+ EmbeddingModel :: JinaEmbeddingsV2BaseCode ,
68+ "jinaai/jina-embeddings-v2-base-code" ,
69+ 768 ,
70+ ) ,
71+ "bge-base-en" | "baai/bge-base-en-v1.5" => {
72+ ( EmbeddingModel :: BGEBaseENV15 , "BAAI/bge-base-en-v1.5" , 768 )
73+ }
74+ "bge-large-en" | "baai/bge-large-en-v1.5" => (
75+ EmbeddingModel :: BGELargeENV15 ,
76+ "BAAI/bge-large-en-v1.5" ,
77+ 1024 ,
78+ ) ,
79+ _ => (
80+ EmbeddingModel :: BGESmallENV15 ,
81+ "BAAI/bge-small-en-v1.5" ,
82+ DEFAULT_DIMENSION ,
83+ ) ,
84+ }
85+ }
86+
5487/// Statistics from an incremental embedding sync.
5588#[ derive( Debug , Default ) ]
5689pub struct SyncStats {
@@ -69,25 +102,27 @@ impl EmbeddingIndex {
69102 /// Fingerprints are loaded from meta for incremental sync support.
70103 pub fn load_or_init ( project_root : & Path , graph_updated_at : & str ) -> Result < Self > {
71104 let rpg_dir = project_root. join ( ".rpg" ) ;
72- let model = init_model ( & rpg_dir) ?;
105+ let ( model, model_name , dimension ) = init_model ( & rpg_dir) ?;
73106
74107 let embeddings_path = rpg_dir. join ( "embeddings.bin" ) ;
75108 let meta_path = rpg_dir. join ( "embeddings.meta.json" ) ;
76109
77110 // Try loading existing index (resilient to corruption)
78111 if embeddings_path. exists ( ) && meta_path. exists ( ) {
79- match Self :: try_load_existing ( & meta_path, & embeddings_path) {
112+ match Self :: try_load_existing ( & meta_path, & embeddings_path, & model_name , dimension ) {
80113 Ok ( Some ( ( entities, fingerprints) ) ) => {
81114 return Ok ( Self {
82115 model,
116+ model_name,
117+ dimension,
83118 entities,
84119 rpg_dir,
85120 graph_updated_at : graph_updated_at. to_string ( ) ,
86121 fingerprints,
87122 } ) ;
88123 }
89124 Ok ( None ) => {
90- // Model/dimension mismatch — start fresh
125+ // Model/dimension mismatch (e.g. embedding_model changed) — rebuild fresh
91126 }
92127 Err ( e) => {
93128 // Corrupt on-disk data — delete and start fresh
@@ -101,6 +136,8 @@ impl EmbeddingIndex {
101136 // No valid index — start fresh
102137 Ok ( Self {
103138 model,
139+ model_name,
140+ dimension,
104141 entities : HashMap :: new ( ) ,
105142 rpg_dir,
106143 graph_updated_at : graph_updated_at. to_string ( ) ,
@@ -109,22 +146,25 @@ impl EmbeddingIndex {
109146 }
110147
111148 /// Try to load existing embedding data. Returns Ok(Some((entities, fingerprints)))
112- /// if valid, Ok(None) if model/dimension mismatch, Err if corrupt.
149+ /// if valid, Ok(None) if the stored model/dimension differs from the active one
150+ /// (forcing a rebuild), Err if corrupt.
113151 fn try_load_existing (
114152 meta_path : & Path ,
115153 embeddings_path : & Path ,
154+ expected_model : & str ,
155+ expected_dim : usize ,
116156 ) -> Result < Option < LoadedEmbeddings > > {
117157 let meta_json =
118158 std:: fs:: read_to_string ( meta_path) . context ( "failed to read embeddings meta" ) ?;
119159 let meta: EmbeddingMeta =
120160 serde_json:: from_str ( & meta_json) . context ( "failed to parse embeddings meta" ) ?;
121161
122- // Only reject on model/dimension mismatch — fingerprints handle staleness
123- if meta. model != "BAAI/bge-small-en-v1.5" || meta. dimension != DIMENSION as u32 {
162+ // Reject on model/dimension mismatch — fingerprints handle within-model staleness
163+ if meta. model != expected_model || meta. dimension != expected_dim as u32 {
124164 return Ok ( None ) ;
125165 }
126166
127- let entities = load_binary ( embeddings_path) ?;
167+ let entities = load_binary ( embeddings_path, expected_dim ) ?;
128168 Ok ( Some ( ( entities, meta. entity_fingerprints ) ) )
129169 }
130170
@@ -294,11 +334,15 @@ impl EmbeddingIndex {
294334 /// Save the index to disk (binary + meta sidecar with fingerprints).
295335 pub fn save ( & self ) -> Result < ( ) > {
296336 std:: fs:: create_dir_all ( & self . rpg_dir ) ?;
297- save_binary ( & self . rpg_dir . join ( "embeddings.bin" ) , & self . entities ) ?;
337+ save_binary (
338+ & self . rpg_dir . join ( "embeddings.bin" ) ,
339+ & self . entities ,
340+ self . dimension ,
341+ ) ?;
298342
299343 let meta = EmbeddingMeta {
300- model : "BAAI/bge-small-en-v1.5" . to_string ( ) ,
301- dimension : DIMENSION as u32 ,
344+ model : self . model_name . clone ( ) ,
345+ dimension : self . dimension as u32 ,
302346 version : FORMAT_VERSION ,
303347 graph_updated_at : self . graph_updated_at . clone ( ) ,
304348 entity_fingerprints : self . fingerprints . clone ( ) ,
@@ -328,19 +372,26 @@ fn compute_fingerprint(features: &[String]) -> String {
328372 format ! ( "{:016x}" , hasher. finish( ) )
329373}
330374
331- /// Initialize the fastembed model with cache in .rpg/models/.
332- fn init_model ( rpg_dir : & Path ) -> Result < TextEmbedding > {
375+ /// Initialize the fastembed model with cache in .rpg/models/. The model is chosen
376+ /// from `navigation.embedding_model` (config.toml / RPG_EMBED_MODEL / default).
377+ /// Returns the model plus its canonical name and dimension.
378+ fn init_model ( rpg_dir : & Path ) -> Result < ( TextEmbedding , String , usize ) > {
333379 let cache_dir = rpg_dir. join ( "models" ) ;
334380 std:: fs:: create_dir_all ( & cache_dir) ?;
335381
336- let options = fastembed:: TextInitOptions :: new ( EmbeddingModel :: BGESmallENV15 )
382+ let project_root = rpg_dir. parent ( ) . unwrap_or ( rpg_dir) ;
383+ let cfg = rpg_core:: config:: RpgConfig :: load ( project_root) . unwrap_or_default ( ) ;
384+ let ( model_kind, canonical, dimension) =
385+ resolve_embedding_model ( & cfg. navigation . embedding_model ) ;
386+
387+ let options = fastembed:: TextInitOptions :: new ( model_kind)
337388 . with_show_download_progress ( true )
338389 . with_cache_dir ( cache_dir) ;
339390
340391 let model = TextEmbedding :: try_new ( options)
341- . context ( "failed to initialize embedding model (BGE-small-en-v1.5)" ) ?;
392+ . with_context ( || format ! ( "failed to initialize embedding model ({canonical})" ) ) ?;
342393
343- Ok ( model)
394+ Ok ( ( model, canonical . to_string ( ) , dimension ) )
344395}
345396
346397/// Cosine similarity between two vectors.
@@ -365,14 +416,19 @@ fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
365416 dot / denom
366417}
367418
368- /// Save entity embeddings to binary format.
369- fn save_binary ( path : & Path , entities : & HashMap < String , EntityEmbeddings > ) -> Result < ( ) > {
419+ /// Save entity embeddings to binary format. `dimension` is written to the header
420+ /// (it must match the active model's dimension).
421+ fn save_binary (
422+ path : & Path ,
423+ entities : & HashMap < String , EntityEmbeddings > ,
424+ dimension : usize ,
425+ ) -> Result < ( ) > {
370426 let mut buf: Vec < u8 > = Vec :: new ( ) ;
371427
372428 // Header (16 bytes)
373429 buf. write_all ( & MAGIC . to_le_bytes ( ) ) ?;
374430 buf. write_all ( & FORMAT_VERSION . to_le_bytes ( ) ) ?;
375- buf. write_all ( & ( DIMENSION as u32 ) . to_le_bytes ( ) ) ?;
431+ buf. write_all ( & ( dimension as u32 ) . to_le_bytes ( ) ) ?;
376432 buf. write_all ( & ( entities. len ( ) as u32 ) . to_le_bytes ( ) ) ?;
377433
378434 // Per entity
@@ -403,8 +459,9 @@ fn save_binary(path: &Path, entities: &HashMap<String, EntityEmbeddings>) -> Res
403459 Ok ( ( ) )
404460}
405461
406- /// Load entity embeddings from binary format.
407- fn load_binary ( path : & Path ) -> Result < HashMap < String , EntityEmbeddings > > {
462+ /// Load entity embeddings from binary format. `expected_dim` is the active model's
463+ /// dimension; a header mismatch means the file was written by a different model.
464+ fn load_binary ( path : & Path , expected_dim : usize ) -> Result < HashMap < String , EntityEmbeddings > > {
408465 let data = std:: fs:: read ( path) . context ( "failed to read embeddings.bin" ) ?;
409466 let mut cursor = & data[ ..] ;
410467
@@ -414,7 +471,7 @@ fn load_binary(path: &Path) -> Result<HashMap<String, EntityEmbeddings>> {
414471 let version = read_u32 ( & mut cursor) ?;
415472 anyhow:: ensure!( version == FORMAT_VERSION , "unsupported embeddings version" ) ;
416473 let dimension = read_u32 ( & mut cursor) ? as usize ;
417- anyhow:: ensure!( dimension == DIMENSION , "dimension mismatch" ) ;
474+ anyhow:: ensure!( dimension == expected_dim , "dimension mismatch" ) ;
418475 let entity_count = read_u32 ( & mut cursor) ? as usize ;
419476
420477 let mut entities = HashMap :: with_capacity ( entity_count) ;
@@ -601,22 +658,57 @@ mod tests {
601658 entities. insert (
602659 "test:func" . to_string ( ) ,
603660 EntityEmbeddings {
604- vectors : vec ! [ vec![ 0.1 ; DIMENSION ] , vec![ 0.2 ; DIMENSION ] ] ,
661+ vectors : vec ! [ vec![ 0.1 ; DEFAULT_DIMENSION ] , vec![ 0.2 ; DEFAULT_DIMENSION ] ] ,
605662 } ,
606663 ) ;
607664
608665 let dir = tempfile:: tempdir ( ) . unwrap ( ) ;
609666 let path = dir. path ( ) . join ( "test.bin" ) ;
610667
611- save_binary ( & path, & entities) . unwrap ( ) ;
612- let loaded = load_binary ( & path) . unwrap ( ) ;
668+ save_binary ( & path, & entities, DEFAULT_DIMENSION ) . unwrap ( ) ;
669+ let loaded = load_binary ( & path, DEFAULT_DIMENSION ) . unwrap ( ) ;
613670
614671 assert_eq ! ( loaded. len( ) , 1 ) ;
615672 assert ! ( loaded. contains_key( "test:func" ) ) ;
616673 assert_eq ! ( loaded[ "test:func" ] . vectors. len( ) , 2 ) ;
617674 assert ! ( ( loaded[ "test:func" ] . vectors[ 0 ] [ 0 ] - 0.1 ) . abs( ) < 1e-6 ) ;
618675 }
619676
677+ #[ test]
678+ fn test_binary_dimension_mismatch_rejected ( ) {
679+ // A file written at 768d must not be read as the default 384d.
680+ let mut entities = HashMap :: new ( ) ;
681+ entities. insert (
682+ "x" . to_string ( ) ,
683+ EntityEmbeddings {
684+ vectors : vec ! [ vec![ 0.5 ; 768 ] ] ,
685+ } ,
686+ ) ;
687+ let dir = tempfile:: tempdir ( ) . unwrap ( ) ;
688+ let path = dir. path ( ) . join ( "e.bin" ) ;
689+ save_binary ( & path, & entities, 768 ) . unwrap ( ) ;
690+ assert ! ( load_binary( & path, DEFAULT_DIMENSION ) . is_err( ) ) ;
691+ assert ! ( load_binary( & path, 768 ) . is_ok( ) ) ;
692+ }
693+
694+ #[ test]
695+ fn test_resolve_embedding_model ( ) {
696+ assert_eq ! ( resolve_embedding_model( "bge-small-en" ) . 2 , 384 ) ;
697+ assert_eq ! ( resolve_embedding_model( "jina-code" ) . 2 , 768 ) ;
698+ assert_eq ! (
699+ resolve_embedding_model( "jina-code" ) . 1 ,
700+ "jinaai/jina-embeddings-v2-base-code"
701+ ) ;
702+ assert_eq ! ( resolve_embedding_model( "bge-large-en" ) . 2 , 1024 ) ;
703+ // unknown -> default
704+ assert_eq ! (
705+ resolve_embedding_model( "nonexistent" ) . 1 ,
706+ "BAAI/bge-small-en-v1.5"
707+ ) ;
708+ // case-insensitive
709+ assert_eq ! ( resolve_embedding_model( "JINA-CODE" ) . 2 , 768 ) ;
710+ }
711+
620712 #[ test]
621713 fn test_fingerprint_deterministic ( ) {
622714 let features = vec ! [ "validate input" . to_string( ) , "return result" . to_string( ) ] ;
0 commit comments