Skip to content

Commit 37dfad2

Browse files
authored
feat(nav): selectable embedding model + dynamic dimension (#99)
* feat(nav): selectable embedding model + dynamic dimension The embedding model was hardcoded to BGE-small-en-v1.5 (384d, general English text — not code-specialized). Make it configurable via navigation.embedding_model (config.toml / RPG_EMBED_MODEL), resolved to a fastembed model + dimension by resolve_embedding_model. Known: bge-small-en (default, 384), jina-code (768, code-specialized), bge-base-en (768), bge-large-en (1024); unknown falls back to default. Dimension is now carried on the index + binary header (no longer a const), so switching models rebuilds the index (model/dimension mismatch in meta forces a fresh embed). Zero-config default unchanged; no API key required. * style(nav): rustfmt resolve_embedding_model + test (1.96 formatting)
1 parent 397d0e5 commit 37dfad2

2 files changed

Lines changed: 126 additions & 27 deletions

File tree

crates/rpg-core/src/config.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ pub struct EncodingConfig {
5757
pub struct NavigationConfig {
5858
/// Maximum number of search results returned.
5959
pub search_result_limit: usize,
60+
/// Embedding model used for semantic search. Resolved by the navigation layer
61+
/// to a concrete model + dimension. Known values: "bge-small-en" (default, 384d,
62+
/// general-text), "jina-code" (768d, code-specialized), "bge-base-en" (768d),
63+
/// "bge-large-en" (1024d). Unknown values fall back to the default.
64+
pub embedding_model: String,
6065
}
6166

6267
impl Default for EncodingConfig {
@@ -78,6 +83,7 @@ impl Default for NavigationConfig {
7883
fn default() -> Self {
7984
Self {
8085
search_result_limit: 10,
86+
embedding_model: "bge-small-en".to_string(),
8187
}
8288
}
8389
}
@@ -127,6 +133,7 @@ impl RpgConfig {
127133
"RPG_SEARCH_LIMIT",
128134
&mut config.navigation.search_result_limit,
129135
);
136+
env_override("RPG_EMBED_MODEL", &mut config.navigation.embedding_model);
130137

131138
// Validate drift thresholds
132139
if config.encoding.drift_ignore_threshold >= config.encoding.drift_auto_threshold {

crates/rpg-nav/src/embeddings.rs

Lines changed: 119 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
//! Embedding-based semantic search using fastembed.
22
//!
33
//! Stores per-feature embeddings for each entity, enabling max-cosine similarity
4-
//! search that preserves multi-role entity semantics. Uses BGE-small-en-v1.5
5-
//! (384 dimensions) via the fastembed crate.
4+
//! search that preserves multi-role entity semantics. The embedding model is
5+
//! configurable via `navigation.embedding_model` (default BGE-small-en-v1.5, 384d;
6+
//! "jina-code" selects a 768d code-specialized model) — see `resolve_embedding_model`.
67
78
use anyhow::{Context, Result, ensure};
89
use fastembed::{EmbeddingModel, TextEmbedding};
@@ -14,7 +15,10 @@ use std::path::{Path, PathBuf};
1415
/// Magic bytes for the binary embedding file format.
1516
const MAGIC: u32 = 0x5250_4745; // "RPGE"
1617
const FORMAT_VERSION: u32 = 1;
17-
const DIMENSION: usize = 384;
18+
/// Dimension of the default model (BGE-small-en-v1.5). Other models report their
19+
/// own dimension via `resolve_embedding_model`; the active dimension is stored on
20+
/// the index and in the binary header.
21+
const DEFAULT_DIMENSION: usize = 384;
1822

1923
/// Loaded embeddings: entity map + fingerprints from disk.
2024
type LoadedEmbeddings = (HashMap<String, EntityEmbeddings>, BTreeMap<String, String>);
@@ -41,6 +45,10 @@ struct EntityEmbeddings {
4145
/// In-memory embedding index for semantic search.
4246
pub struct EmbeddingIndex {
4347
model: TextEmbedding,
48+
/// Canonical name of the active embedding model (e.g. "BAAI/bge-small-en-v1.5").
49+
model_name: String,
50+
/// Vector dimension of the active model.
51+
dimension: usize,
4452
/// Map from entity_id → feature-level embeddings.
4553
entities: HashMap<String, EntityEmbeddings>,
4654
/// Path to the .rpg directory for persistence.
@@ -51,6 +59,31 @@ pub struct EmbeddingIndex {
5159
fingerprints: BTreeMap<String, String>,
5260
}
5361

62+
/// Resolve a configured model name to its fastembed model, canonical name, and
63+
/// vector dimension. Unknown names fall back to the default (BGE-small-en-v1.5).
64+
fn resolve_embedding_model(name: &str) -> (EmbeddingModel, &'static str, usize) {
65+
match name.trim().to_lowercase().as_str() {
66+
"jina-code" | "jinaai/jina-embeddings-v2-base-code" => (
67+
EmbeddingModel::JinaEmbeddingsV2BaseCode,
68+
"jinaai/jina-embeddings-v2-base-code",
69+
768,
70+
),
71+
"bge-base-en" | "baai/bge-base-en-v1.5" => {
72+
(EmbeddingModel::BGEBaseENV15, "BAAI/bge-base-en-v1.5", 768)
73+
}
74+
"bge-large-en" | "baai/bge-large-en-v1.5" => (
75+
EmbeddingModel::BGELargeENV15,
76+
"BAAI/bge-large-en-v1.5",
77+
1024,
78+
),
79+
_ => (
80+
EmbeddingModel::BGESmallENV15,
81+
"BAAI/bge-small-en-v1.5",
82+
DEFAULT_DIMENSION,
83+
),
84+
}
85+
}
86+
5487
/// Statistics from an incremental embedding sync.
5588
#[derive(Debug, Default)]
5689
pub struct SyncStats {
@@ -69,25 +102,27 @@ impl EmbeddingIndex {
69102
/// Fingerprints are loaded from meta for incremental sync support.
70103
pub fn load_or_init(project_root: &Path, graph_updated_at: &str) -> Result<Self> {
71104
let rpg_dir = project_root.join(".rpg");
72-
let model = init_model(&rpg_dir)?;
105+
let (model, model_name, dimension) = init_model(&rpg_dir)?;
73106

74107
let embeddings_path = rpg_dir.join("embeddings.bin");
75108
let meta_path = rpg_dir.join("embeddings.meta.json");
76109

77110
// Try loading existing index (resilient to corruption)
78111
if embeddings_path.exists() && meta_path.exists() {
79-
match Self::try_load_existing(&meta_path, &embeddings_path) {
112+
match Self::try_load_existing(&meta_path, &embeddings_path, &model_name, dimension) {
80113
Ok(Some((entities, fingerprints))) => {
81114
return Ok(Self {
82115
model,
116+
model_name,
117+
dimension,
83118
entities,
84119
rpg_dir,
85120
graph_updated_at: graph_updated_at.to_string(),
86121
fingerprints,
87122
});
88123
}
89124
Ok(None) => {
90-
// Model/dimension mismatch — start fresh
125+
// Model/dimension mismatch (e.g. embedding_model changed) — rebuild fresh
91126
}
92127
Err(e) => {
93128
// Corrupt on-disk data — delete and start fresh
@@ -101,6 +136,8 @@ impl EmbeddingIndex {
101136
// No valid index — start fresh
102137
Ok(Self {
103138
model,
139+
model_name,
140+
dimension,
104141
entities: HashMap::new(),
105142
rpg_dir,
106143
graph_updated_at: graph_updated_at.to_string(),
@@ -109,22 +146,25 @@ impl EmbeddingIndex {
109146
}
110147

111148
/// Try to load existing embedding data. Returns Ok(Some((entities, fingerprints)))
112-
/// if valid, Ok(None) if model/dimension mismatch, Err if corrupt.
149+
/// if valid, Ok(None) if the stored model/dimension differs from the active one
150+
/// (forcing a rebuild), Err if corrupt.
113151
fn try_load_existing(
114152
meta_path: &Path,
115153
embeddings_path: &Path,
154+
expected_model: &str,
155+
expected_dim: usize,
116156
) -> Result<Option<LoadedEmbeddings>> {
117157
let meta_json =
118158
std::fs::read_to_string(meta_path).context("failed to read embeddings meta")?;
119159
let meta: EmbeddingMeta =
120160
serde_json::from_str(&meta_json).context("failed to parse embeddings meta")?;
121161

122-
// Only reject on model/dimension mismatch — fingerprints handle staleness
123-
if meta.model != "BAAI/bge-small-en-v1.5" || meta.dimension != DIMENSION as u32 {
162+
// Reject on model/dimension mismatch — fingerprints handle within-model staleness
163+
if meta.model != expected_model || meta.dimension != expected_dim as u32 {
124164
return Ok(None);
125165
}
126166

127-
let entities = load_binary(embeddings_path)?;
167+
let entities = load_binary(embeddings_path, expected_dim)?;
128168
Ok(Some((entities, meta.entity_fingerprints)))
129169
}
130170

@@ -294,11 +334,15 @@ impl EmbeddingIndex {
294334
/// Save the index to disk (binary + meta sidecar with fingerprints).
295335
pub fn save(&self) -> Result<()> {
296336
std::fs::create_dir_all(&self.rpg_dir)?;
297-
save_binary(&self.rpg_dir.join("embeddings.bin"), &self.entities)?;
337+
save_binary(
338+
&self.rpg_dir.join("embeddings.bin"),
339+
&self.entities,
340+
self.dimension,
341+
)?;
298342

299343
let meta = EmbeddingMeta {
300-
model: "BAAI/bge-small-en-v1.5".to_string(),
301-
dimension: DIMENSION as u32,
344+
model: self.model_name.clone(),
345+
dimension: self.dimension as u32,
302346
version: FORMAT_VERSION,
303347
graph_updated_at: self.graph_updated_at.clone(),
304348
entity_fingerprints: self.fingerprints.clone(),
@@ -328,19 +372,26 @@ fn compute_fingerprint(features: &[String]) -> String {
328372
format!("{:016x}", hasher.finish())
329373
}
330374

331-
/// Initialize the fastembed model with cache in .rpg/models/.
332-
fn init_model(rpg_dir: &Path) -> Result<TextEmbedding> {
375+
/// Initialize the fastembed model with cache in .rpg/models/. The model is chosen
376+
/// from `navigation.embedding_model` (config.toml / RPG_EMBED_MODEL / default).
377+
/// Returns the model plus its canonical name and dimension.
378+
fn init_model(rpg_dir: &Path) -> Result<(TextEmbedding, String, usize)> {
333379
let cache_dir = rpg_dir.join("models");
334380
std::fs::create_dir_all(&cache_dir)?;
335381

336-
let options = fastembed::TextInitOptions::new(EmbeddingModel::BGESmallENV15)
382+
let project_root = rpg_dir.parent().unwrap_or(rpg_dir);
383+
let cfg = rpg_core::config::RpgConfig::load(project_root).unwrap_or_default();
384+
let (model_kind, canonical, dimension) =
385+
resolve_embedding_model(&cfg.navigation.embedding_model);
386+
387+
let options = fastembed::TextInitOptions::new(model_kind)
337388
.with_show_download_progress(true)
338389
.with_cache_dir(cache_dir);
339390

340391
let model = TextEmbedding::try_new(options)
341-
.context("failed to initialize embedding model (BGE-small-en-v1.5)")?;
392+
.with_context(|| format!("failed to initialize embedding model ({canonical})"))?;
342393

343-
Ok(model)
394+
Ok((model, canonical.to_string(), dimension))
344395
}
345396

346397
/// Cosine similarity between two vectors.
@@ -365,14 +416,19 @@ fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
365416
dot / denom
366417
}
367418

368-
/// Save entity embeddings to binary format.
369-
fn save_binary(path: &Path, entities: &HashMap<String, EntityEmbeddings>) -> Result<()> {
419+
/// Save entity embeddings to binary format. `dimension` is written to the header
420+
/// (it must match the active model's dimension).
421+
fn save_binary(
422+
path: &Path,
423+
entities: &HashMap<String, EntityEmbeddings>,
424+
dimension: usize,
425+
) -> Result<()> {
370426
let mut buf: Vec<u8> = Vec::new();
371427

372428
// Header (16 bytes)
373429
buf.write_all(&MAGIC.to_le_bytes())?;
374430
buf.write_all(&FORMAT_VERSION.to_le_bytes())?;
375-
buf.write_all(&(DIMENSION as u32).to_le_bytes())?;
431+
buf.write_all(&(dimension as u32).to_le_bytes())?;
376432
buf.write_all(&(entities.len() as u32).to_le_bytes())?;
377433

378434
// Per entity
@@ -403,8 +459,9 @@ fn save_binary(path: &Path, entities: &HashMap<String, EntityEmbeddings>) -> Res
403459
Ok(())
404460
}
405461

406-
/// Load entity embeddings from binary format.
407-
fn load_binary(path: &Path) -> Result<HashMap<String, EntityEmbeddings>> {
462+
/// Load entity embeddings from binary format. `expected_dim` is the active model's
463+
/// dimension; a header mismatch means the file was written by a different model.
464+
fn load_binary(path: &Path, expected_dim: usize) -> Result<HashMap<String, EntityEmbeddings>> {
408465
let data = std::fs::read(path).context("failed to read embeddings.bin")?;
409466
let mut cursor = &data[..];
410467

@@ -414,7 +471,7 @@ fn load_binary(path: &Path) -> Result<HashMap<String, EntityEmbeddings>> {
414471
let version = read_u32(&mut cursor)?;
415472
anyhow::ensure!(version == FORMAT_VERSION, "unsupported embeddings version");
416473
let dimension = read_u32(&mut cursor)? as usize;
417-
anyhow::ensure!(dimension == DIMENSION, "dimension mismatch");
474+
anyhow::ensure!(dimension == expected_dim, "dimension mismatch");
418475
let entity_count = read_u32(&mut cursor)? as usize;
419476

420477
let mut entities = HashMap::with_capacity(entity_count);
@@ -601,22 +658,57 @@ mod tests {
601658
entities.insert(
602659
"test:func".to_string(),
603660
EntityEmbeddings {
604-
vectors: vec![vec![0.1; DIMENSION], vec![0.2; DIMENSION]],
661+
vectors: vec![vec![0.1; DEFAULT_DIMENSION], vec![0.2; DEFAULT_DIMENSION]],
605662
},
606663
);
607664

608665
let dir = tempfile::tempdir().unwrap();
609666
let path = dir.path().join("test.bin");
610667

611-
save_binary(&path, &entities).unwrap();
612-
let loaded = load_binary(&path).unwrap();
668+
save_binary(&path, &entities, DEFAULT_DIMENSION).unwrap();
669+
let loaded = load_binary(&path, DEFAULT_DIMENSION).unwrap();
613670

614671
assert_eq!(loaded.len(), 1);
615672
assert!(loaded.contains_key("test:func"));
616673
assert_eq!(loaded["test:func"].vectors.len(), 2);
617674
assert!((loaded["test:func"].vectors[0][0] - 0.1).abs() < 1e-6);
618675
}
619676

677+
#[test]
678+
fn test_binary_dimension_mismatch_rejected() {
679+
// A file written at 768d must not be read as the default 384d.
680+
let mut entities = HashMap::new();
681+
entities.insert(
682+
"x".to_string(),
683+
EntityEmbeddings {
684+
vectors: vec![vec![0.5; 768]],
685+
},
686+
);
687+
let dir = tempfile::tempdir().unwrap();
688+
let path = dir.path().join("e.bin");
689+
save_binary(&path, &entities, 768).unwrap();
690+
assert!(load_binary(&path, DEFAULT_DIMENSION).is_err());
691+
assert!(load_binary(&path, 768).is_ok());
692+
}
693+
694+
#[test]
695+
fn test_resolve_embedding_model() {
696+
assert_eq!(resolve_embedding_model("bge-small-en").2, 384);
697+
assert_eq!(resolve_embedding_model("jina-code").2, 768);
698+
assert_eq!(
699+
resolve_embedding_model("jina-code").1,
700+
"jinaai/jina-embeddings-v2-base-code"
701+
);
702+
assert_eq!(resolve_embedding_model("bge-large-en").2, 1024);
703+
// unknown -> default
704+
assert_eq!(
705+
resolve_embedding_model("nonexistent").1,
706+
"BAAI/bge-small-en-v1.5"
707+
);
708+
// case-insensitive
709+
assert_eq!(resolve_embedding_model("JINA-CODE").2, 768);
710+
}
711+
620712
#[test]
621713
fn test_fingerprint_deterministic() {
622714
let features = vec!["validate input".to_string(), "return result".to_string()];

0 commit comments

Comments
 (0)