Skip to content

Commit 2827b13

Browse files
sanikolaevdonhardman
authored andcommitted
feat(embeddings): support passthrough remote model ids
1. Allow explicit provider-prefixed passthrough model ids for remote endpoints - keep the existing slash-prefixed forms (openai/..., voyage/..., jina/...) working as before - add explicit colon-prefixed forms (openai:..., voyage:..., jina:...) - when the colon form is used, pass the model id through after stripping only the provider prefix - this allows OpenAI-compatible custom endpoints to receive full upstream model ids unchanged, for example: - openai:openai/text-embedding-ada-002 - openai:jinaai/jina-embeddings-v3 - preserve strict built-in validation for default provider endpoints while allowing passthrough mode for custom API_URL-based setups 2. Allow CMake to pass optional cargo features to the embeddings crate - add EMBEDDINGS_CARGO_FEATURE_ARGS in cmake/build_embeddings.cmake - if EMBEDDINGS_CARGO_FEATURES is set, convert it to a valid cargo CLI fragment: --features <value> - this makes it possible to configure builds such as download-ort from the CMake side without hard-coding the flag in the build script Additional remote-model adjustment: - cache inferred embedding dimensionality in remote providers so passthrough/custom models can learn their vector dimension from a successful response instead of requiring a built-in static mapping - apply that caching approach consistently across OpenAI, Voyage, and Jina
1 parent 0b37507 commit 2827b13

7 files changed

Lines changed: 305 additions & 72 deletions

File tree

cmake/build_embeddings.cmake

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -50,20 +50,27 @@ function(build_embeddings_lib)
5050
set(ENV{GIT_COMMIT_ID} "${GIT_COMMIT_ID}")
5151
set(ENV{GIT_TIMESTAMP_ID} "${GIT_TIMESTAMP_ID}")
5252

53-
# Enable platform-specific BLAS acceleration for candle when available
54-
set(EMBEDDINGS_CARGO_FEATURES "")
55-
if(APPLE)
56-
set(EMBEDDINGS_CARGO_FEATURES "--features" "accelerate")
57-
elseif(UNIX)
58-
# MKL provides multi-threaded BLAS on Linux; skip if not available
59-
execute_process(COMMAND pkg-config --exists mkl-dynamic-lp64-seq RESULT_VARIABLE MKL_FOUND OUTPUT_QUIET ERROR_QUIET)
60-
if(MKL_FOUND EQUAL 0)
61-
set(EMBEDDINGS_CARGO_FEATURES "--features" "mkl")
62-
endif()
63-
endif()
53+
# EMBEDDINGS_CARGO_FEATURES may be set externally (e.g., parent CMake) to inject
54+
# extra cargo features. If unset, default to platform-specific BLAS acceleration
55+
# for candle: accelerate on macOS, mkl on Linux when available.
56+
if (NOT DEFINED EMBEDDINGS_CARGO_FEATURES OR "${EMBEDDINGS_CARGO_FEATURES}" STREQUAL "")
57+
if (APPLE)
58+
set(EMBEDDINGS_CARGO_FEATURES "accelerate")
59+
elseif (UNIX)
60+
execute_process(COMMAND pkg-config --exists mkl-dynamic-lp64-seq RESULT_VARIABLE MKL_FOUND OUTPUT_QUIET ERROR_QUIET)
61+
if (MKL_FOUND EQUAL 0)
62+
set(EMBEDDINGS_CARGO_FEATURES "mkl")
63+
endif ()
64+
endif ()
65+
endif ()
66+
67+
set(EMBEDDINGS_CARGO_FEATURE_ARGS "")
68+
if (NOT "${EMBEDDINGS_CARGO_FEATURES}" STREQUAL "")
69+
set(EMBEDDINGS_CARGO_FEATURE_ARGS --features ${EMBEDDINGS_CARGO_FEATURES})
70+
endif ()
6471

6572
execute_process (
66-
COMMAND cargo build --manifest-path ${CMAKE_SOURCE_DIR}/embeddings/Cargo.toml --lib --release ${EMBEDDINGS_CARGO_FEATURES} --target-dir ${CMAKE_CURRENT_BINARY_DIR}/embeddings
73+
COMMAND cargo build --manifest-path ${CMAKE_SOURCE_DIR}/embeddings/Cargo.toml --lib --release ${EMBEDDINGS_CARGO_FEATURE_ARGS} --target-dir ${CMAKE_CURRENT_BINARY_DIR}/embeddings
6774
RESULT_VARIABLE CMD_RESULT
6875
)
6976

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
use super::{create_model, Model, ModelOptions};
2+
3+
#[test]
4+
fn test_create_model_allows_custom_openai_model_when_custom_api_url_is_set() {
5+
let model = create_model(ModelOptions {
6+
model_id: "openai/rubert-tiny-turbo".to_string(),
7+
cache_path: None,
8+
api_key: Some("test-key".to_string()),
9+
api_url: Some("http://localhost:8080/v1/embeddings".to_string()),
10+
api_timeout: None,
11+
use_gpu: None,
12+
});
13+
14+
assert!(model.is_ok());
15+
16+
match model.unwrap() {
17+
Model::OpenAI(model) => assert_eq!(model.model, "rubert-tiny-turbo"),
18+
_ => panic!("expected OpenAI model"),
19+
}
20+
}
21+
22+
#[test]
23+
fn test_create_model_with_custom_url_still_uses_prefixed_jina_as_remote_signal() {
24+
let model = create_model(ModelOptions {
25+
model_id: "jina/custom-model".to_string(),
26+
cache_path: None,
27+
api_key: Some("test-key".to_string()),
28+
api_url: Some("http://localhost:8080/v1/embeddings".to_string()),
29+
api_timeout: None,
30+
use_gpu: None,
31+
});
32+
33+
assert!(model.is_ok());
34+
35+
match model.unwrap() {
36+
Model::Jina(model) => assert_eq!(model.model, "custom-model"),
37+
_ => panic!("expected Jina model"),
38+
}
39+
}
40+
41+
#[test]
42+
fn test_create_model_supports_explicit_openai_colon_syntax() {
43+
let model = create_model(ModelOptions {
44+
model_id: "openai:openai/text-embedding-ada-002".to_string(),
45+
cache_path: None,
46+
api_key: Some("test-key".to_string()),
47+
api_url: Some("http://localhost:8080/v1/embeddings".to_string()),
48+
api_timeout: None,
49+
use_gpu: None,
50+
});
51+
52+
assert!(model.is_ok());
53+
54+
match model.unwrap() {
55+
Model::OpenAI(model) => assert_eq!(model.model, "openai/text-embedding-ada-002"),
56+
_ => panic!("expected OpenAI model"),
57+
}
58+
}
59+
60+
#[test]
61+
fn test_create_model_supports_explicit_openai_colon_syntax_with_simple_model() {
62+
let model = create_model(ModelOptions {
63+
model_id: "openai:text-embedding-ada-002".to_string(),
64+
cache_path: None,
65+
api_key: Some("test-key".to_string()),
66+
api_url: Some("http://localhost:8080/v1/embeddings".to_string()),
67+
api_timeout: None,
68+
use_gpu: None,
69+
});
70+
71+
assert!(model.is_ok());
72+
73+
match model.unwrap() {
74+
Model::OpenAI(model) => assert_eq!(model.model, "text-embedding-ada-002"),
75+
_ => panic!("expected OpenAI model"),
76+
}
77+
}

embeddings/src/model/jina.rs

Lines changed: 54 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
1-
use super::TextModel;
1+
use super::{ModelValidationMode, TextModel};
22
use crate::LibError;
33
use reqwest::blocking::Client;
4+
use std::sync::Mutex;
45

56
#[derive(Debug)]
67
pub struct JinaModel {
78
pub client: Client,
89
pub model: String,
910
pub api_key: String,
1011
pub api_url: Option<String>,
12+
hidden_size_cache: Mutex<Option<usize>>,
1113
}
1214

1315
pub fn validate_model(model: &str) -> Result<(), String> {
@@ -50,8 +52,32 @@ impl JinaModel {
5052
api_url: Option<&str>,
5153
api_timeout: Option<u64>,
5254
) -> Result<Self, Box<dyn std::error::Error>> {
53-
let model = model_id.trim_start_matches("jina/").to_string();
54-
validate_model(&model).map_err(|_| LibError::RemoteUnsupportedModel { status: None })?;
55+
let validation_mode = if api_url.is_some() {
56+
ModelValidationMode::Passthrough
57+
} else {
58+
ModelValidationMode::StrictBuiltInList
59+
};
60+
61+
Self::new_with_validation_mode(model_id, api_key, api_url, api_timeout, validation_mode)
62+
}
63+
64+
pub fn new_with_validation_mode(
65+
model_id: &str,
66+
api_key: &str,
67+
api_url: Option<&str>,
68+
api_timeout: Option<u64>,
69+
validation_mode: ModelValidationMode,
70+
) -> Result<Self, Box<dyn std::error::Error>> {
71+
let model = if let Some(model) = model_id.strip_prefix("jina:") {
72+
model.to_string()
73+
} else {
74+
model_id.trim_start_matches("jina/").to_string()
75+
};
76+
77+
if validation_mode == ModelValidationMode::StrictBuiltInList {
78+
validate_model(&model)
79+
.map_err(|_| LibError::RemoteUnsupportedModel { status: None })?;
80+
}
5581
// Only validate basic requirements (non-empty, no whitespace)
5682
// Real validation happens via actual API request in validate_api_key()
5783
validate_api_key_basic(api_key)
@@ -62,8 +88,26 @@ impl JinaModel {
6288
model,
6389
api_key: api_key.to_string(),
6490
api_url: api_url.map(|s| s.to_string()),
91+
hidden_size_cache: Mutex::new(None),
6592
})
6693
}
94+
95+
fn known_hidden_size(&self) -> Option<usize> {
96+
match self.model.as_str() {
97+
"jina-embeddings-v4" => Some(2048), // 32K context, 2048 dimensions
98+
"jina-clip-v2" => Some(1024), // 8K context, 1024 dimensions, multimodal
99+
"jina-embeddings-v3" => Some(1024), // 8K context, 1024 dimensions
100+
"jina-colbert-v2" => Some(128), // Multi-vector model, 8K context
101+
"jina-clip-v1" => Some(768), // 8K context, 768 dimensions, multimodal
102+
"jina-colbert-v1-en" => Some(128), // Multi-vector model, 8K context
103+
"jina-embeddings-v2-base-es" => Some(768), // 8K context, 768 dimensions
104+
"jina-embeddings-v2-base-code" => Some(768), // 8K context, 768 dimensions
105+
"jina-embeddings-v2-base-de" => Some(768), // 8K context, 768 dimensions
106+
"jina-embeddings-v2-base-zh" => Some(768), // 8K context, 768 dimensions
107+
"jina-embeddings-v2-base-en" => Some(768), // 8K context, 768 dimensions
108+
_ => None,
109+
}
110+
}
67111
}
68112

69113
impl TextModel for JinaModel {
@@ -254,15 +298,17 @@ impl TextModel for JinaModel {
254298
}));
255299
}
256300

301+
let inferred_dim = embeddings[0].len();
302+
*self.hidden_size_cache.lock().unwrap() = Some(inferred_dim);
303+
257304
// Validate embedding dimensions and handle empty individual embeddings
258-
let expected_dim = self.get_hidden_size();
259305
for embedding in embeddings.iter() {
260306
if embedding.is_empty() {
261307
return Err(Box::new(LibError::RemoteHttpError {
262308
status: status_code,
263309
}));
264310
}
265-
if embedding.len() != expected_dim {
311+
if embedding.len() != inferred_dim {
266312
// Some models might return different dimensions, but we should validate
267313
// For now, we'll be lenient but could add stricter validation later
268314
}
@@ -272,20 +318,9 @@ impl TextModel for JinaModel {
272318
}
273319

274320
fn get_hidden_size(&self) -> usize {
275-
match self.model.as_str() {
276-
"jina-embeddings-v4" => 2048, // 32K context, 2048 dimensions
277-
"jina-clip-v2" => 1024, // 8K context, 1024 dimensions, multimodal
278-
"jina-embeddings-v3" => 1024, // 8K context, 1024 dimensions
279-
"jina-colbert-v2" => 128, // Multi-vector model, 8K context
280-
"jina-clip-v1" => 768, // 8K context, 768 dimensions, multimodal
281-
"jina-colbert-v1-en" => 128, // Multi-vector model, 8K context
282-
"jina-embeddings-v2-base-es" => 768, // 8K context, 768 dimensions
283-
"jina-embeddings-v2-base-code" => 768, // 8K context, 768 dimensions
284-
"jina-embeddings-v2-base-de" => 768, // 8K context, 768 dimensions
285-
"jina-embeddings-v2-base-zh" => 768, // 8K context, 768 dimensions
286-
"jina-embeddings-v2-base-en" => 768, // 8K context, 768 dimensions
287-
_ => panic!("Unknown model"),
288-
}
321+
self.known_hidden_size()
322+
.or_else(|| *self.hidden_size_cache.lock().unwrap())
323+
.unwrap_or_else(|| panic!("Unknown model"))
289324
}
290325

291326
fn get_max_input_len(&self) -> usize {

embeddings/src/model/mod.rs

Lines changed: 50 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ mod local_test;
1919
#[cfg(test)]
2020
mod ffi_test;
2121

22+
#[cfg(test)]
23+
mod create_model_test;
24+
2225
use std::error::Error;
2326
use std::path::PathBuf;
2427

@@ -41,6 +44,12 @@ pub struct ModelOptions {
4144
pub use_gpu: Option<bool>,
4245
}
4346

47+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48+
pub enum ModelValidationMode {
49+
StrictBuiltInList,
50+
Passthrough,
51+
}
52+
4453
/// Unified model enum
4554
///
4655
/// Architecture:
@@ -96,34 +105,55 @@ impl TextModel for Model {
96105

97106
pub fn create_model(options: ModelOptions) -> Result<Model, Box<dyn Error>> {
98107
let model_id = options.model_id.as_str();
108+
let api_key = options.api_key.unwrap_or_default();
109+
let api_url = options.api_url;
110+
let api_timeout = options.api_timeout;
99111

100112
// Remote providers (HTTP APIs)
101-
if model_id.starts_with("openai/") {
102-
let model = openai::OpenAIModel::new(
113+
if model_id.starts_with("openai:") {
114+
let model = openai::OpenAIModel::new_with_validation_mode(
103115
model_id,
104-
options.api_key.unwrap_or_default().as_str(),
105-
options.api_url.as_deref(),
106-
options.api_timeout,
116+
api_key.as_str(),
117+
api_url.as_deref(),
118+
api_timeout,
119+
ModelValidationMode::Passthrough,
107120
)?;
108121

109122
Ok(Model::OpenAI(Box::new(model)))
110-
} else if model_id.starts_with("voyage/") {
111-
let model = voyage::VoyageModel::new(
123+
} else if model_id.starts_with("openai/") {
124+
let model =
125+
openai::OpenAIModel::new(model_id, api_key.as_str(), api_url.as_deref(), api_timeout)?;
126+
127+
Ok(Model::OpenAI(Box::new(model)))
128+
} else if model_id.starts_with("voyage:") {
129+
let model = voyage::VoyageModel::new_with_validation_mode(
112130
model_id,
113-
options.api_key.unwrap_or_default().as_str(),
114-
options.api_url.as_deref(),
115-
options.api_timeout,
131+
api_key.as_str(),
132+
api_url.as_deref(),
133+
api_timeout,
134+
ModelValidationMode::Passthrough,
116135
)?;
117136

118137
Ok(Model::Voyage(Box::new(model)))
119-
} else if model_id.starts_with("jina/") {
120-
let model = jina::JinaModel::new(
138+
} else if model_id.starts_with("voyage/") {
139+
let model =
140+
voyage::VoyageModel::new(model_id, api_key.as_str(), api_url.as_deref(), api_timeout)?;
141+
142+
Ok(Model::Voyage(Box::new(model)))
143+
} else if model_id.starts_with("jina:") {
144+
let model = jina::JinaModel::new_with_validation_mode(
121145
model_id,
122-
options.api_key.unwrap_or_default().as_str(),
123-
options.api_url.as_deref(),
124-
options.api_timeout,
146+
api_key.as_str(),
147+
api_url.as_deref(),
148+
api_timeout,
149+
ModelValidationMode::Passthrough,
125150
)?;
126151

152+
Ok(Model::Jina(Box::new(model)))
153+
} else if model_id.starts_with("jina/") {
154+
let model =
155+
jina::JinaModel::new(model_id, api_key.as_str(), api_url.as_deref(), api_timeout)?;
156+
127157
Ok(Model::Jina(Box::new(model)))
128158
} else {
129159
// Local models - auto-detect architecture from config
@@ -135,7 +165,11 @@ pub fn create_model(options: ModelOptions) -> Result<Model, Box<dyn Error>> {
135165
.unwrap_or(String::from(".cache/manticore")),
136166
);
137167

138-
let hf_token = options.api_key.as_deref();
168+
let hf_token = if api_key.is_empty() {
169+
None
170+
} else {
171+
Some(api_key.as_str())
172+
};
139173
let model = local::LocalModel::new(
140174
model_id,
141175
cache_path,

0 commit comments

Comments
 (0)