diff --git a/backends/candle/src/lib.rs b/backends/candle/src/lib.rs index ff824f555..168b0d630 100644 --- a/backends/candle/src/lib.rs +++ b/backends/candle/src/lib.rs @@ -127,6 +127,7 @@ impl CandleBackend { dtype: String, model_type: ModelType, dense_paths: Option>, + device_id: usize, ) -> Result { // Default files let default_safetensors = model_path.join("model.safetensors"); @@ -188,7 +189,7 @@ impl CandleBackend { let device = if candle::utils::cuda_is_available() { #[cfg(feature = "cuda")] match compatible_compute_cap() { - Ok(true) => Device::new_cuda(0), + Ok(true) => Device::new_cuda(device_id), Ok(false) => { return Err(BackendError::Start(format!( "Runtime compute cap {} is not compatible with compile time compute cap {}", diff --git a/backends/candle/tests/test_bert.rs b/backends/candle/tests/test_bert.rs index c18479e4b..d496b146c 100644 --- a/backends/candle/tests/test_bert.rs +++ b/backends/candle/tests/test_bert.rs @@ -17,6 +17,7 @@ fn test_bert() -> Result<()> { "float32".to_string(), ModelType::Embedding(Pool::Mean), None, + 0, )?; let input_batch = batch( @@ -78,6 +79,7 @@ fn test_bert_pooled_raw() -> Result<()> { "float32".to_string(), ModelType::Embedding(Pool::Cls), None, + 0, )?; let input_batch = batch( @@ -149,6 +151,7 @@ fn test_emotions() -> Result<()> { "float32".to_string(), ModelType::Classifier, None, + 0, )?; let input_batch = batch( @@ -205,6 +208,7 @@ fn test_bert_classification() -> Result<()> { "float32".to_string(), ModelType::Classifier, None, + 0, )?; let input_single = batch( diff --git a/backends/candle/tests/test_dense.rs b/backends/candle/tests/test_dense.rs index b523708d8..2173e18a0 100644 --- a/backends/candle/tests/test_dense.rs +++ b/backends/candle/tests/test_dense.rs @@ -17,6 +17,7 @@ fn test_stella_en_400m_v5_default_dense() -> Result<()> { "float32".to_string(), ModelType::Embedding(Pool::Mean), dense_paths, // This will default to `2_Dense_1024/` as defined in `modules.json` + 0, )?; let input_batch = batch( @@ -74,6 +75,7 @@ fn test_stella_en_400m_v5_dense_768() -> Result<()> { "float32".to_string(), ModelType::Embedding(Pool::Mean), dense_paths, + 0, )?; let input_batch = batch( diff --git a/backends/candle/tests/test_flash_bert.rs b/backends/candle/tests/test_flash_bert.rs index 96e77ce80..54532d3d9 100644 --- a/backends/candle/tests/test_flash_bert.rs +++ b/backends/candle/tests/test_flash_bert.rs @@ -23,6 +23,7 @@ fn test_flash_mini() -> Result<()> { "float16".to_string(), ModelType::Embedding(Pool::Mean), None, + 0, )?; let input_batch = batch( @@ -88,6 +89,7 @@ fn test_flash_mini_pooled_raw() -> Result<()> { "float16".to_string(), ModelType::Embedding(Pool::Cls), None, + 0, )?; let input_batch = batch( @@ -163,6 +165,7 @@ fn test_flash_emotions() -> Result<()> { "float16".to_string(), ModelType::Classifier, None, + 0, )?; let input_batch = batch( @@ -223,6 +226,7 @@ fn test_flash_bert_classification() -> Result<()> { "float16".to_string(), ModelType::Classifier, None, + 0, )?; let input_single = batch( diff --git a/backends/candle/tests/test_flash_gte.rs b/backends/candle/tests/test_flash_gte.rs index bc4f40502..da8982434 100644 --- a/backends/candle/tests/test_flash_gte.rs +++ b/backends/candle/tests/test_flash_gte.rs @@ -19,6 +19,7 @@ fn test_flash_gte() -> Result<()> { "float16".to_string(), ModelType::Embedding(Pool::Cls), None, + 0, )?; let input_batch = batch( @@ -69,6 +70,7 @@ fn test_flash_gte_classification() -> Result<()> { "float16".to_string(), ModelType::Classifier, None, + 0, )?; let input_single = batch( diff --git a/backends/candle/tests/test_flash_jina.rs b/backends/candle/tests/test_flash_jina.rs index e0046bc50..21a884484 100644 --- a/backends/candle/tests/test_flash_jina.rs +++ b/backends/candle/tests/test_flash_jina.rs @@ -19,6 +19,7 @@ fn test_flash_jina_small() -> Result<()> { "float16".to_string(), ModelType::Embedding(Pool::Mean), None, + 0, )?; let input_batch = batch( diff --git a/backends/candle/tests/test_flash_jina_code.rs b/backends/candle/tests/test_flash_jina_code.rs index 29941c2b0..57b105f9b 100644 --- a/backends/candle/tests/test_flash_jina_code.rs +++ b/backends/candle/tests/test_flash_jina_code.rs @@ -19,6 +19,7 @@ fn test_flash_jina_code_base() -> Result<()> { "float16".to_string(), ModelType::Embedding(Pool::Mean), None, + 0, )?; let input_batch = batch( diff --git a/backends/candle/tests/test_flash_mistral.rs b/backends/candle/tests/test_flash_mistral.rs index 5e521a63a..364539295 100644 --- a/backends/candle/tests/test_flash_mistral.rs +++ b/backends/candle/tests/test_flash_mistral.rs @@ -19,6 +19,7 @@ fn test_flash_mistral() -> Result<()> { "float16".to_string(), ModelType::Embedding(Pool::Mean), None, + 0, )?; let input_batch = batch( diff --git a/backends/candle/tests/test_flash_nomic.rs b/backends/candle/tests/test_flash_nomic.rs index 74a36c76d..4db1dccb1 100644 --- a/backends/candle/tests/test_flash_nomic.rs +++ b/backends/candle/tests/test_flash_nomic.rs @@ -19,6 +19,7 @@ fn test_flash_nomic_small() -> Result<()> { "float16".to_string(), ModelType::Embedding(Pool::Mean), None, + 0, )?; let input_batch = batch( @@ -65,6 +66,7 @@ fn test_flash_nomic_moe() -> Result<()> { "float16".to_string(), ModelType::Embedding(Pool::Mean), None, + 0, )?; let input_batch = batch( diff --git a/backends/candle/tests/test_flash_qwen2.rs b/backends/candle/tests/test_flash_qwen2.rs index fc47b4b36..2658b2b1a 100644 --- a/backends/candle/tests/test_flash_qwen2.rs +++ b/backends/candle/tests/test_flash_qwen2.rs @@ -44,6 +44,7 @@ fn test_flash_qwen2() -> Result<()> { "float16".to_string(), ModelType::Embedding(Pool::LastToken), None, + 0, )?; let input_batch = batch( diff --git a/backends/candle/tests/test_flash_qwen3.rs b/backends/candle/tests/test_flash_qwen3.rs index 51f8031ba..e3378ebc3 100644 --- a/backends/candle/tests/test_flash_qwen3.rs +++ b/backends/candle/tests/test_flash_qwen3.rs @@ -19,6 +19,7 @@ fn test_flash_qwen3() -> Result<()> { "float16".to_string(), ModelType::Embedding(Pool::LastToken), None, + 0, )?; let input_batch = batch( diff --git a/backends/candle/tests/test_gemma3.rs b/backends/candle/tests/test_gemma3.rs index bff702ce1..e292e6c4f 100644 --- a/backends/candle/tests/test_gemma3.rs +++ b/backends/candle/tests/test_gemma3.rs @@ -17,6 +17,7 @@ fn test_gemma3() -> Result<()> { "float32".to_string(), ModelType::Embedding(Pool::Mean), dense_paths, + 0, )?; let input_batch = batch( diff --git a/backends/candle/tests/test_gte.rs b/backends/candle/tests/test_gte.rs index ccfad4303..452bf66db 100644 --- a/backends/candle/tests/test_gte.rs +++ b/backends/candle/tests/test_gte.rs @@ -17,6 +17,7 @@ fn test_alibaba_gte() -> Result<()> { "float32".to_string(), ModelType::Embedding(Pool::Cls), None, + 0, )?; let input_batch = batch( @@ -62,6 +63,7 @@ fn test_alibaba_gte_new() -> Result<()> { "float32".to_string(), ModelType::Embedding(Pool::Cls), None, + 0, )?; let input_batch = batch( @@ -108,6 +110,7 @@ fn test_snowflake_gte() -> Result<()> { "float32".to_string(), ModelType::Embedding(Pool::Cls), None, + 0, )?; let input_batch = batch( @@ -154,6 +157,7 @@ fn test_gte_classification() -> Result<()> { "float32".to_string(), ModelType::Classifier, None, + 0, )?; let input_single = batch( diff --git a/backends/candle/tests/test_jina.rs b/backends/candle/tests/test_jina.rs index 548c51cac..0da43f529 100644 --- a/backends/candle/tests/test_jina.rs +++ b/backends/candle/tests/test_jina.rs @@ -16,6 +16,7 @@ fn test_jina_small() -> Result<()> { "float32".to_string(), ModelType::Embedding(Pool::Mean), None, + 0, )?; let input_batch = batch( @@ -62,6 +63,7 @@ fn test_jina_rerank() -> Result<()> { "float32".to_string(), ModelType::Classifier, None, + 0, )?; let input_single = batch( diff --git a/backends/candle/tests/test_jina_code.rs b/backends/candle/tests/test_jina_code.rs index 09b751c69..6035bef55 100644 --- a/backends/candle/tests/test_jina_code.rs +++ b/backends/candle/tests/test_jina_code.rs @@ -16,6 +16,7 @@ fn test_jina_code_base() -> Result<()> { "float32".to_string(), ModelType::Embedding(Pool::Mean), None, + 0, )?; let input_batch = batch( diff --git a/backends/candle/tests/test_modernbert.rs b/backends/candle/tests/test_modernbert.rs index 77f05aef3..3497a0343 100644 --- a/backends/candle/tests/test_modernbert.rs +++ b/backends/candle/tests/test_modernbert.rs @@ -19,6 +19,7 @@ fn test_modernbert() -> Result<()> { "float32".to_string(), ModelType::Embedding(Pool::Mean), None, + 0, )?; let input_batch = batch( @@ -88,6 +89,7 @@ fn test_modernbert_pooled_raw() -> Result<()> { "float32".to_string(), ModelType::Embedding(Pool::Cls), None, + 0, )?; let input_batch = batch( @@ -184,6 +186,7 @@ fn test_modernbert_classification() -> Result<()> { "float32".to_string(), ModelType::Classifier, None, + 0, )?; let input_single = batch( @@ -223,6 +226,7 @@ fn test_modernbert_classification_mean_pooling() -> Result<()> { "float32".to_string(), ModelType::Classifier, None, + 0, )?; let input_single = batch( diff --git a/backends/candle/tests/test_mpnet.rs b/backends/candle/tests/test_mpnet.rs index ebfd3e675..d5be18e9b 100644 --- a/backends/candle/tests/test_mpnet.rs +++ b/backends/candle/tests/test_mpnet.rs @@ -18,6 +18,7 @@ fn test_mpnet() -> Result<()> { "float32".to_string(), ModelType::Embedding(Pool::Mean), None, + 0, )?; let input_batch = batch( @@ -80,6 +81,7 @@ fn test_mpnet_pooled_raw() -> Result<()> { "float32".to_string(), ModelType::Embedding(Pool::Cls), None, + 0, )?; let input_batch = batch( diff --git a/backends/candle/tests/test_nomic.rs b/backends/candle/tests/test_nomic.rs index cd5c07bd3..32effddab 100644 --- a/backends/candle/tests/test_nomic.rs +++ b/backends/candle/tests/test_nomic.rs @@ -16,6 +16,7 @@ fn test_nomic_small() -> Result<()> { "float32".to_string(), ModelType::Embedding(Pool::Mean), None, + 0, )?; let input_batch = batch( @@ -60,6 +61,7 @@ fn test_nomic_moe() -> Result<()> { "float32".to_string(), ModelType::Embedding(Pool::Mean), None, + 0, )?; let input_batch = batch( diff --git a/backends/candle/tests/test_qwen3.rs b/backends/candle/tests/test_qwen3.rs index 0e6f01155..378ff52a6 100644 --- a/backends/candle/tests/test_qwen3.rs +++ b/backends/candle/tests/test_qwen3.rs @@ -17,6 +17,7 @@ fn test_qwen3() -> Result<()> { "float32".to_string(), ModelType::Embedding(Pool::LastToken), None, + 0, )?; let input_batch = batch( diff --git a/backends/src/lib.rs b/backends/src/lib.rs index 7bf20f163..1f0648873 100644 --- a/backends/src/lib.rs +++ b/backends/src/lib.rs @@ -90,6 +90,7 @@ impl Backend { uds_path: String, otlp_endpoint: Option, otlp_service_name: String, + device_id: usize, ) -> Result { let (backend_sender, backend_receiver) = mpsc::channel(8); @@ -102,6 +103,7 @@ impl Backend { uds_path, otlp_endpoint, otlp_service_name, + device_id, ) .await?; let padded_model = backend.is_padded(); @@ -362,6 +364,7 @@ async fn init_backend( uds_path: String, otlp_endpoint: Option, otlp_service_name: String, + device_id: usize, ) -> Result, BackendError> { let mut backend_start_failed = false; let api_repo = api_repo.map(Arc::new); @@ -475,6 +478,7 @@ async fn init_backend( dtype.to_string(), model_type.clone(), dense_paths, + device_id, ); match backend { Ok(b) => return Ok(Box::new(b)), diff --git a/router/src/lib.rs b/router/src/lib.rs index 1ad777ac0..d1bf7ce51 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -66,6 +66,7 @@ pub async fn run( otlp_service_name: String, prometheus_port: u16, cors_allow_origin: Option>, + device_id: usize, ) -> Result<()> { let model_id_path = Path::new(&model_id); let (model_root, api_repo) = if model_id_path.exists() && model_id_path.is_dir() { @@ -280,6 +281,7 @@ pub async fn run( uds_path.unwrap_or("/tmp/text-embeddings-inference-server".to_string()), otlp_endpoint.clone(), otlp_service_name.clone(), + device_id, ) .await .context("Could not create backend")?; diff --git a/router/src/main.rs b/router/src/main.rs index afee836e6..37d9e9de5 100644 --- a/router/src/main.rs +++ b/router/src/main.rs @@ -122,6 +122,11 @@ struct Args { #[clap(long, env)] dense_path: Option, + /// The CUDA device ID where the model will be loaded. Defaults to 0 i.e., the first available device. + /// Only used with the candle backend. + #[clap(long, env, default_value = "0")] + device_id: usize, + /// [DEPRECATED IN FAVOR OF `--hf-token`] Your Hugging Face Hub token #[clap(long, env, hide = true)] #[redact(partial)] @@ -250,6 +255,7 @@ async fn main() -> Result<()> { args.otlp_service_name, args.prometheus_port, args.cors_allow_origin, + args.device_id, ) .await?; diff --git a/router/tests/common.rs b/router/tests/common.rs index 476211764..40a9ab182 100644 --- a/router/tests/common.rs +++ b/router/tests/common.rs @@ -70,6 +70,7 @@ pub async fn start_server(model_id: String, revision: Option, dtype: DTy "text-embeddings-inference.server".to_owned(), 9000, None, + 0, // device_id ) });