From 964c22b70ed8b74e4dec74e819d6e37544fccf00 Mon Sep 17 00:00:00 2001 From: michaelfeil Date: Thu, 18 Dec 2025 00:35:55 +0000 Subject: [PATCH 1/5] add device id --- backends/candle/src/lib.rs | 5 +++-- backends/src/lib.rs | 4 ++++ router/src/lib.rs | 2 ++ router/src/main.rs | 6 ++++++ 4 files changed, 15 insertions(+), 2 deletions(-) diff --git a/backends/candle/src/lib.rs b/backends/candle/src/lib.rs index ff824f555..143d34c04 100644 --- a/backends/candle/src/lib.rs +++ b/backends/candle/src/lib.rs @@ -127,6 +127,7 @@ impl CandleBackend { dtype: String, model_type: ModelType, dense_paths: Option>, + device_id: usize, ) -> Result { // Default files let default_safetensors = model_path.join("model.safetensors"); @@ -188,7 +189,7 @@ impl CandleBackend { let device = if candle::utils::cuda_is_available() { #[cfg(feature = "cuda")] match compatible_compute_cap() { - Ok(true) => Device::new_cuda(0), + Ok(true) => Device::new_cuda(device_id), Ok(false) => { return Err(BackendError::Start(format!( "Runtime compute cap {} is not compatible with compile time compute cap {}", @@ -205,7 +206,7 @@ impl CandleBackend { #[cfg(not(feature = "cuda"))] Ok(Device::Cpu) } else if candle::utils::metal_is_available() { - Device::new_metal(0) + Device::new_metal(device_id) } else { Ok(Device::Cpu) } diff --git a/backends/src/lib.rs b/backends/src/lib.rs index 7bf20f163..1f0648873 100644 --- a/backends/src/lib.rs +++ b/backends/src/lib.rs @@ -90,6 +90,7 @@ impl Backend { uds_path: String, otlp_endpoint: Option, otlp_service_name: String, + device_id: usize, ) -> Result { let (backend_sender, backend_receiver) = mpsc::channel(8); @@ -102,6 +103,7 @@ impl Backend { uds_path, otlp_endpoint, otlp_service_name, + device_id, ) .await?; let padded_model = backend.is_padded(); @@ -362,6 +364,7 @@ async fn init_backend( uds_path: String, otlp_endpoint: Option, otlp_service_name: String, + device_id: usize, ) -> Result, BackendError> { let mut backend_start_failed = false; let api_repo = api_repo.map(Arc::new); @@ -475,6 +478,7 @@ async fn init_backend( dtype.to_string(), model_type.clone(), dense_paths, + device_id, ); match backend { Ok(b) => return Ok(Box::new(b)), diff --git a/router/src/lib.rs b/router/src/lib.rs index fe60a0c87..bc410744c 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -66,6 +66,7 @@ pub async fn run( otlp_service_name: String, prometheus_port: u16, cors_allow_origin: Option>, + device_id: usize, ) -> Result<()> { let model_id_path = Path::new(&model_id); let (model_root, api_repo) = if model_id_path.exists() && model_id_path.is_dir() { @@ -277,6 +278,7 @@ pub async fn run( uds_path.unwrap_or("/tmp/text-embeddings-inference-server".to_string()), otlp_endpoint.clone(), otlp_service_name.clone(), + device_id, ) .await .context("Could not create backend")?; diff --git a/router/src/main.rs b/router/src/main.rs index afee836e6..2fa793ff8 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -122,6 +122,11 @@ struct Args { #[clap(long, env)] dense_path: Option, + /// The device ID to use for CUDA/Metal devices. Defaults to 0. + /// Only used with the candle backend. + #[clap(long, env, default_value = "0")] + device_id: usize, + /// [DEPRECATED IN FAVOR OF `--hf-token`] Your Hugging Face Hub token #[clap(long, env, hide = true)] #[redact(partial)] @@ -250,6 +255,7 @@ async fn main() -> Result<()> { args.otlp_service_name, args.prometheus_port, args.cors_allow_origin, + args.device_id, ) .await?; From 24ab9aabaa6f737589bb619c6f4908d2acc4ad8d Mon Sep 17 00:00:00 2001 From: michaelfeil Date: Thu, 18 Dec 2025 00:52:20 +0000 Subject: [PATCH 2/5] add device id --- backends/candle/tests/test_bert.rs | 4 ++++ backends/candle/tests/test_dense.rs | 2 ++ backends/candle/tests/test_flash_bert.rs | 4 ++++ backends/candle/tests/test_flash_gte.rs | 2 ++ backends/candle/tests/test_flash_jina.rs | 1 + backends/candle/tests/test_flash_jina_code.rs | 1 + backends/candle/tests/test_flash_mistral.rs | 1 + backends/candle/tests/test_flash_nomic.rs | 2 ++ backends/candle/tests/test_flash_qwen2.rs | 1 + backends/candle/tests/test_flash_qwen3.rs | 1 + backends/candle/tests/test_gemma3.rs | 1 + backends/candle/tests/test_gte.rs | 4 ++++ backends/candle/tests/test_jina.rs | 2 ++ backends/candle/tests/test_jina_code.rs | 1 + backends/candle/tests/test_modernbert.rs | 4 ++++ backends/candle/tests/test_mpnet.rs | 2 ++ backends/candle/tests/test_nomic.rs | 2 ++ backends/candle/tests/test_qwen3.rs | 1 + 18 files changed, 36 insertions(+) diff --git a/backends/candle/tests/test_bert.rs b/backends/candle/tests/test_bert.rs index c18479e4b..d496b146c 100644 --- a/backends/candle/tests/test_bert.rs +++ b/backends/candle/tests/test_bert.rs @@ -17,6 +17,7 @@ fn test_bert() -> Result<()> { "float32".to_string(), ModelType::Embedding(Pool::Mean), None, + 0, )?; let input_batch = batch( @@ -78,6 +79,7 @@ fn test_bert_pooled_raw() -> Result<()> { "float32".to_string(), ModelType::Embedding(Pool::Cls), None, + 0, )?; let input_batch = batch( @@ -149,6 +151,7 @@ fn test_emotions() -> Result<()> { "float32".to_string(), ModelType::Classifier, None, + 0, )?; let input_batch = batch( @@ -205,6 +208,7 @@ fn test_bert_classification() -> Result<()> { "float32".to_string(), ModelType::Classifier, None, + 0, )?; let input_single = batch( diff --git a/backends/candle/tests/test_dense.rs b/backends/candle/tests/test_dense.rs index b523708d8..2173e18a0 100644 --- a/backends/candle/tests/test_dense.rs +++ b/backends/candle/tests/test_dense.rs @@ -17,6 +17,7 @@ fn test_stella_en_400m_v5_default_dense() -> Result<()> { "float32".to_string(), ModelType::Embedding(Pool::Mean), dense_paths, // This will default to `2_Dense_1024/` as defined in `modules.json` + 0, )?; let input_batch = batch( @@ -74,6 +75,7 @@ fn test_stella_en_400m_v5_dense_768() -> Result<()> { "float32".to_string(), ModelType::Embedding(Pool::Mean), dense_paths, + 0, )?; let input_batch = batch( diff --git a/backends/candle/tests/test_flash_bert.rs b/backends/candle/tests/test_flash_bert.rs index 96e77ce80..54532d3d9 100644 --- a/backends/candle/tests/test_flash_bert.rs +++ b/backends/candle/tests/test_flash_bert.rs @@ -23,6 +23,7 @@ fn test_flash_mini() -> Result<()> { "float16".to_string(), ModelType::Embedding(Pool::Mean), None, + 0, )?; let input_batch = batch( @@ -88,6 +89,7 @@ fn test_flash_mini_pooled_raw() -> Result<()> { "float16".to_string(), ModelType::Embedding(Pool::Cls), None, + 0, )?; let input_batch = batch( @@ -163,6 +165,7 @@ fn test_flash_emotions() -> Result<()> { "float16".to_string(), ModelType::Classifier, None, + 0, )?; let input_batch = batch( @@ -223,6 +226,7 @@ fn test_flash_bert_classification() -> Result<()> { "float16".to_string(), ModelType::Classifier, None, + 0, )?; let input_single = batch( diff --git a/backends/candle/tests/test_flash_gte.rs b/backends/candle/tests/test_flash_gte.rs index bc4f40502..da8982434 100644 --- a/backends/candle/tests/test_flash_gte.rs +++ b/backends/candle/tests/test_flash_gte.rs @@ -19,6 +19,7 @@ fn test_flash_gte() -> Result<()> { "float16".to_string(), ModelType::Embedding(Pool::Cls), None, + 0, )?; let input_batch = batch( @@ -69,6 +70,7 @@ fn test_flash_gte_classification() -> Result<()> { "float16".to_string(), ModelType::Classifier, None, + 0, )?; let input_single = batch( diff --git a/backends/candle/tests/test_flash_jina.rs b/backends/candle/tests/test_flash_jina.rs index e0046bc50..21a884484 100644 --- a/backends/candle/tests/test_flash_jina.rs +++ b/backends/candle/tests/test_flash_jina.rs @@ -19,6 +19,7 @@ fn test_flash_jina_small() -> Result<()> { "float16".to_string(), ModelType::Embedding(Pool::Mean), None, + 0, )?; let input_batch = batch( diff --git a/backends/candle/tests/test_flash_jina_code.rs b/backends/candle/tests/test_flash_jina_code.rs index 29941c2b0..57b105f9b 100644 --- a/backends/candle/tests/test_flash_jina_code.rs +++ b/backends/candle/tests/test_flash_jina_code.rs @@ -19,6 +19,7 @@ fn test_flash_jina_code_base() -> Result<()> { "float16".to_string(), ModelType::Embedding(Pool::Mean), None, + 0, )?; let input_batch = batch( diff --git a/backends/candle/tests/test_flash_mistral.rs b/backends/candle/tests/test_flash_mistral.rs index 5e521a63a..364539295 100644 --- a/backends/candle/tests/test_flash_mistral.rs +++ b/backends/candle/tests/test_flash_mistral.rs @@ -19,6 +19,7 @@ fn test_flash_mistral() -> Result<()> { "float16".to_string(), ModelType::Embedding(Pool::Mean), None, + 0, )?; let input_batch = batch( diff --git a/backends/candle/tests/test_flash_nomic.rs b/backends/candle/tests/test_flash_nomic.rs index 74a36c76d..4db1dccb1 100644 --- a/backends/candle/tests/test_flash_nomic.rs +++ b/backends/candle/tests/test_flash_nomic.rs @@ -19,6 +19,7 @@ fn test_flash_nomic_small() -> Result<()> { "float16".to_string(), ModelType::Embedding(Pool::Mean), None, + 0, )?; let input_batch = batch( @@ -65,6 +66,7 @@ fn test_flash_nomic_moe() -> Result<()> { "float16".to_string(), ModelType::Embedding(Pool::Mean), None, + 0, )?; let input_batch = batch( diff --git a/backends/candle/tests/test_flash_qwen2.rs b/backends/candle/tests/test_flash_qwen2.rs index fc47b4b36..2658b2b1a 100644 --- a/backends/candle/tests/test_flash_qwen2.rs +++ b/backends/candle/tests/test_flash_qwen2.rs @@ -44,6 +44,7 @@ fn test_flash_qwen2() -> Result<()> { "float16".to_string(), ModelType::Embedding(Pool::LastToken), None, + 0, )?; let input_batch = batch( diff --git a/backends/candle/tests/test_flash_qwen3.rs b/backends/candle/tests/test_flash_qwen3.rs index 51f8031ba..e3378ebc3 100644 --- a/backends/candle/tests/test_flash_qwen3.rs +++ b/backends/candle/tests/test_flash_qwen3.rs @@ -19,6 +19,7 @@ fn test_flash_qwen3() -> Result<()> { "float16".to_string(), ModelType::Embedding(Pool::LastToken), None, + 0, )?; let input_batch = batch( diff --git a/backends/candle/tests/test_gemma3.rs b/backends/candle/tests/test_gemma3.rs index bff702ce1..e292e6c4f 100644 --- a/backends/candle/tests/test_gemma3.rs +++ b/backends/candle/tests/test_gemma3.rs @@ -17,6 +17,7 @@ fn test_gemma3() -> Result<()> { "float32".to_string(), ModelType::Embedding(Pool::Mean), dense_paths, + 0, )?; let input_batch = batch( diff --git a/backends/candle/tests/test_gte.rs b/backends/candle/tests/test_gte.rs index ccfad4303..452bf66db 100644 --- a/backends/candle/tests/test_gte.rs +++ b/backends/candle/tests/test_gte.rs @@ -17,6 +17,7 @@ fn test_alibaba_gte() -> Result<()> { "float32".to_string(), ModelType::Embedding(Pool::Cls), None, + 0, )?; let input_batch = batch( @@ -62,6 +63,7 @@ fn test_alibaba_gte_new() -> Result<()> { "float32".to_string(), ModelType::Embedding(Pool::Cls), None, + 0, )?; let input_batch = batch( @@ -108,6 +110,7 @@ fn test_snowflake_gte() -> Result<()> { "float32".to_string(), ModelType::Embedding(Pool::Cls), None, + 0, )?; let input_batch = batch( @@ -154,6 +157,7 @@ fn test_gte_classification() -> Result<()> { "float32".to_string(), ModelType::Classifier, None, + 0, )?; let input_single = batch( diff --git a/backends/candle/tests/test_jina.rs b/backends/candle/tests/test_jina.rs index 548c51cac..0da43f529 100644 --- a/backends/candle/tests/test_jina.rs +++ b/backends/candle/tests/test_jina.rs @@ -16,6 +16,7 @@ fn test_jina_small() -> Result<()> { "float32".to_string(), ModelType::Embedding(Pool::Mean), None, + 0, )?; let input_batch = batch( @@ -62,6 +63,7 @@ fn test_jina_rerank() -> Result<()> { "float32".to_string(), ModelType::Classifier, None, + 0, )?; let input_single = batch( diff --git a/backends/candle/tests/test_jina_code.rs b/backends/candle/tests/test_jina_code.rs index 09b751c69..6035bef55 100644 --- a/backends/candle/tests/test_jina_code.rs +++ b/backends/candle/tests/test_jina_code.rs @@ -16,6 +16,7 @@ fn test_jina_code_base() -> Result<()> { "float32".to_string(), ModelType::Embedding(Pool::Mean), None, + 0, )?; let input_batch = batch( diff --git a/backends/candle/tests/test_modernbert.rs b/backends/candle/tests/test_modernbert.rs index 77f05aef3..3497a0343 100644 --- a/backends/candle/tests/test_modernbert.rs +++ b/backends/candle/tests/test_modernbert.rs @@ -19,6 +19,7 @@ fn test_modernbert() -> Result<()> { "float32".to_string(), ModelType::Embedding(Pool::Mean), None, + 0, )?; let input_batch = batch( @@ -88,6 +89,7 @@ fn test_modernbert_pooled_raw() -> Result<()> { "float32".to_string(), ModelType::Embedding(Pool::Cls), None, + 0, )?; let input_batch = batch( @@ -184,6 +186,7 @@ fn test_modernbert_classification() -> Result<()> { "float32".to_string(), ModelType::Classifier, None, + 0, )?; let input_single = batch( @@ -223,6 +226,7 @@ fn test_modernbert_classification_mean_pooling() -> Result<()> { "float32".to_string(), ModelType::Classifier, None, + 0, )?; let input_single = batch( diff --git a/backends/candle/tests/test_mpnet.rs b/backends/candle/tests/test_mpnet.rs index ebfd3e675..d5be18e9b 100644 --- a/backends/candle/tests/test_mpnet.rs +++ b/backends/candle/tests/test_mpnet.rs @@ -18,6 +18,7 @@ fn test_mpnet() -> Result<()> { "float32".to_string(), ModelType::Embedding(Pool::Mean), None, + 0, )?; let input_batch = batch( @@ -80,6 +81,7 @@ fn test_mpnet_pooled_raw() -> Result<()> { "float32".to_string(), ModelType::Embedding(Pool::Cls), None, + 0, )?; let input_batch = batch( diff --git a/backends/candle/tests/test_nomic.rs b/backends/candle/tests/test_nomic.rs index cd5c07bd3..32effddab 100644 --- a/backends/candle/tests/test_nomic.rs +++ b/backends/candle/tests/test_nomic.rs @@ -16,6 +16,7 @@ fn test_nomic_small() -> Result<()> { "float32".to_string(), ModelType::Embedding(Pool::Mean), None, + 0, )?; let input_batch = batch( @@ -60,6 +61,7 @@ fn test_nomic_moe() -> Result<()> { "float32".to_string(), ModelType::Embedding(Pool::Mean), None, + 0, )?; let input_batch = batch( diff --git a/backends/candle/tests/test_qwen3.rs b/backends/candle/tests/test_qwen3.rs index 0e6f01155..378ff52a6 100644 --- a/backends/candle/tests/test_qwen3.rs +++ b/backends/candle/tests/test_qwen3.rs @@ -17,6 +17,7 @@ fn test_qwen3() -> Result<()> { "float32".to_string(), ModelType::Embedding(Pool::LastToken), None, + 0, )?; let input_batch = batch( From e04dd6367531b87049d8e22f66a2b979ed22e83b Mon Sep 17 00:00:00 2001 From: michaelfeil Date: Thu, 18 Dec 2025 01:09:35 +0000 Subject: [PATCH 3/5] another test --- router/tests/common.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/router/tests/common.rs b/router/tests/common.rs index 476211764..40a9ab182 100644 --- a/router/tests/common.rs +++ b/router/tests/common.rs @@ -70,6 +70,7 @@ pub async fn start_server(model_id: String, revision: Option, dtype: DTy "text-embeddings-inference.server".to_owned(), 9000, None, + 0, // device_id ) }); From f320868e68237eb5f658504a696ba1e657b6caac Mon Sep 17 00:00:00 2001 From: Michael Feil <63565275+michaelfeil@users.noreply.github.com> Date: Sat, 27 Dec 2025 05:22:27 +0100 Subject: [PATCH 4/5] Update router/src/main.rs Co-authored-by: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> --- router/src/main.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/router/src/main.rs b/router/src/main.rs index 2fa793ff8..37d9e9de5 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -122,7 +122,7 @@ struct Args { #[clap(long, env)] dense_path: Option, - /// The device ID to use for CUDA/Metal devices. Defaults to 0. + /// The CUDA device ID where the model will be loaded. Defaults to 0 i.e., the first available device. /// Only used with the candle backend. #[clap(long, env, default_value = "0")] device_id: usize, From 81bfc6d1fdaacf08c7e3cc52aea6a9cae5c2eae0 Mon Sep 17 00:00:00 2001 From: Michael Feil <63565275+michaelfeil@users.noreply.github.com> Date: Sun, 28 Dec 2025 09:04:37 +0100 Subject: [PATCH 5/5] Update backends/candle/src/lib.rs Co-authored-by: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> --- backends/candle/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/candle/src/lib.rs b/backends/candle/src/lib.rs index 143d34c04..168b0d630 100644 --- a/backends/candle/src/lib.rs +++ b/backends/candle/src/lib.rs @@ -206,7 +206,7 @@ impl CandleBackend { #[cfg(not(feature = "cuda"))] Ok(Device::Cpu) } else if candle::utils::metal_is_available() { - Device::new_metal(device_id) + Device::new_metal(0) } else { Ok(Device::Cpu) }