Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion backends/candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ impl CandleBackend {
dtype: String,
model_type: ModelType,
dense_paths: Option<Vec<String>>,
device_id: usize,
) -> Result<Self, BackendError> {
// Default files
let default_safetensors = model_path.join("model.safetensors");
Expand Down Expand Up @@ -188,7 +189,7 @@ impl CandleBackend {
let device = if candle::utils::cuda_is_available() {
#[cfg(feature = "cuda")]
match compatible_compute_cap() {
Ok(true) => Device::new_cuda(0),
Ok(true) => Device::new_cuda(device_id),
Comment thread
michaelfeil marked this conversation as resolved.
Ok(false) => {
return Err(BackendError::Start(format!(
"Runtime compute cap {} is not compatible with compile time compute cap {}",
Expand Down
4 changes: 4 additions & 0 deletions backends/candle/tests/test_bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ fn test_bert() -> Result<()> {
"float32".to_string(),
ModelType::Embedding(Pool::Mean),
None,
0,
)?;

let input_batch = batch(
Expand Down Expand Up @@ -78,6 +79,7 @@ fn test_bert_pooled_raw() -> Result<()> {
"float32".to_string(),
ModelType::Embedding(Pool::Cls),
None,
0,
)?;

let input_batch = batch(
Expand Down Expand Up @@ -149,6 +151,7 @@ fn test_emotions() -> Result<()> {
"float32".to_string(),
ModelType::Classifier,
None,
0,
)?;

let input_batch = batch(
Expand Down Expand Up @@ -205,6 +208,7 @@ fn test_bert_classification() -> Result<()> {
"float32".to_string(),
ModelType::Classifier,
None,
0,
)?;

let input_single = batch(
Expand Down
2 changes: 2 additions & 0 deletions backends/candle/tests/test_dense.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ fn test_stella_en_400m_v5_default_dense() -> Result<()> {
"float32".to_string(),
ModelType::Embedding(Pool::Mean),
dense_paths, // This will default to `2_Dense_1024/` as defined in `modules.json`
0,
)?;

let input_batch = batch(
Expand Down Expand Up @@ -74,6 +75,7 @@ fn test_stella_en_400m_v5_dense_768() -> Result<()> {
"float32".to_string(),
ModelType::Embedding(Pool::Mean),
dense_paths,
0,
)?;

let input_batch = batch(
Expand Down
4 changes: 4 additions & 0 deletions backends/candle/tests/test_flash_bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ fn test_flash_mini() -> Result<()> {
"float16".to_string(),
ModelType::Embedding(Pool::Mean),
None,
0,
)?;

let input_batch = batch(
Expand Down Expand Up @@ -88,6 +89,7 @@ fn test_flash_mini_pooled_raw() -> Result<()> {
"float16".to_string(),
ModelType::Embedding(Pool::Cls),
None,
0,
)?;

let input_batch = batch(
Expand Down Expand Up @@ -163,6 +165,7 @@ fn test_flash_emotions() -> Result<()> {
"float16".to_string(),
ModelType::Classifier,
None,
0,
)?;

let input_batch = batch(
Expand Down Expand Up @@ -223,6 +226,7 @@ fn test_flash_bert_classification() -> Result<()> {
"float16".to_string(),
ModelType::Classifier,
None,
0,
)?;

let input_single = batch(
Expand Down
2 changes: 2 additions & 0 deletions backends/candle/tests/test_flash_gte.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ fn test_flash_gte() -> Result<()> {
"float16".to_string(),
ModelType::Embedding(Pool::Cls),
None,
0,
)?;

let input_batch = batch(
Expand Down Expand Up @@ -69,6 +70,7 @@ fn test_flash_gte_classification() -> Result<()> {
"float16".to_string(),
ModelType::Classifier,
None,
0,
)?;

let input_single = batch(
Expand Down
1 change: 1 addition & 0 deletions backends/candle/tests/test_flash_jina.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ fn test_flash_jina_small() -> Result<()> {
"float16".to_string(),
ModelType::Embedding(Pool::Mean),
None,
0,
)?;

let input_batch = batch(
Expand Down
1 change: 1 addition & 0 deletions backends/candle/tests/test_flash_jina_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ fn test_flash_jina_code_base() -> Result<()> {
"float16".to_string(),
ModelType::Embedding(Pool::Mean),
None,
0,
)?;

let input_batch = batch(
Expand Down
1 change: 1 addition & 0 deletions backends/candle/tests/test_flash_mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ fn test_flash_mistral() -> Result<()> {
"float16".to_string(),
ModelType::Embedding(Pool::Mean),
None,
0,
)?;

let input_batch = batch(
Expand Down
2 changes: 2 additions & 0 deletions backends/candle/tests/test_flash_nomic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ fn test_flash_nomic_small() -> Result<()> {
"float16".to_string(),
ModelType::Embedding(Pool::Mean),
None,
0,
)?;

let input_batch = batch(
Expand Down Expand Up @@ -65,6 +66,7 @@ fn test_flash_nomic_moe() -> Result<()> {
"float16".to_string(),
ModelType::Embedding(Pool::Mean),
None,
0,
)?;

let input_batch = batch(
Expand Down
1 change: 1 addition & 0 deletions backends/candle/tests/test_flash_qwen2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ fn test_flash_qwen2() -> Result<()> {
"float16".to_string(),
ModelType::Embedding(Pool::LastToken),
None,
0,
)?;

let input_batch = batch(
Expand Down
1 change: 1 addition & 0 deletions backends/candle/tests/test_flash_qwen3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ fn test_flash_qwen3() -> Result<()> {
"float16".to_string(),
ModelType::Embedding(Pool::LastToken),
None,
0,
)?;

let input_batch = batch(
Expand Down
1 change: 1 addition & 0 deletions backends/candle/tests/test_gemma3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ fn test_gemma3() -> Result<()> {
"float32".to_string(),
ModelType::Embedding(Pool::Mean),
dense_paths,
0,
)?;

let input_batch = batch(
Expand Down
4 changes: 4 additions & 0 deletions backends/candle/tests/test_gte.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ fn test_alibaba_gte() -> Result<()> {
"float32".to_string(),
ModelType::Embedding(Pool::Cls),
None,
0,
)?;

let input_batch = batch(
Expand Down Expand Up @@ -62,6 +63,7 @@ fn test_alibaba_gte_new() -> Result<()> {
"float32".to_string(),
ModelType::Embedding(Pool::Cls),
None,
0,
)?;

let input_batch = batch(
Expand Down Expand Up @@ -108,6 +110,7 @@ fn test_snowflake_gte() -> Result<()> {
"float32".to_string(),
ModelType::Embedding(Pool::Cls),
None,
0,
)?;

let input_batch = batch(
Expand Down Expand Up @@ -154,6 +157,7 @@ fn test_gte_classification() -> Result<()> {
"float32".to_string(),
ModelType::Classifier,
None,
0,
)?;

let input_single = batch(
Expand Down
2 changes: 2 additions & 0 deletions backends/candle/tests/test_jina.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ fn test_jina_small() -> Result<()> {
"float32".to_string(),
ModelType::Embedding(Pool::Mean),
None,
0,
)?;

let input_batch = batch(
Expand Down Expand Up @@ -62,6 +63,7 @@ fn test_jina_rerank() -> Result<()> {
"float32".to_string(),
ModelType::Classifier,
None,
0,
)?;

let input_single = batch(
Expand Down
1 change: 1 addition & 0 deletions backends/candle/tests/test_jina_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ fn test_jina_code_base() -> Result<()> {
"float32".to_string(),
ModelType::Embedding(Pool::Mean),
None,
0,
)?;

let input_batch = batch(
Expand Down
4 changes: 4 additions & 0 deletions backends/candle/tests/test_modernbert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ fn test_modernbert() -> Result<()> {
"float32".to_string(),
ModelType::Embedding(Pool::Mean),
None,
0,
)?;

let input_batch = batch(
Expand Down Expand Up @@ -88,6 +89,7 @@ fn test_modernbert_pooled_raw() -> Result<()> {
"float32".to_string(),
ModelType::Embedding(Pool::Cls),
None,
0,
)?;

let input_batch = batch(
Expand Down Expand Up @@ -184,6 +186,7 @@ fn test_modernbert_classification() -> Result<()> {
"float32".to_string(),
ModelType::Classifier,
None,
0,
)?;

let input_single = batch(
Expand Down Expand Up @@ -223,6 +226,7 @@ fn test_modernbert_classification_mean_pooling() -> Result<()> {
"float32".to_string(),
ModelType::Classifier,
None,
0,
)?;

let input_single = batch(
Expand Down
2 changes: 2 additions & 0 deletions backends/candle/tests/test_mpnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ fn test_mpnet() -> Result<()> {
"float32".to_string(),
ModelType::Embedding(Pool::Mean),
None,
0,
)?;

let input_batch = batch(
Expand Down Expand Up @@ -80,6 +81,7 @@ fn test_mpnet_pooled_raw() -> Result<()> {
"float32".to_string(),
ModelType::Embedding(Pool::Cls),
None,
0,
)?;

let input_batch = batch(
Expand Down
2 changes: 2 additions & 0 deletions backends/candle/tests/test_nomic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ fn test_nomic_small() -> Result<()> {
"float32".to_string(),
ModelType::Embedding(Pool::Mean),
None,
0,
)?;

let input_batch = batch(
Expand Down Expand Up @@ -60,6 +61,7 @@ fn test_nomic_moe() -> Result<()> {
"float32".to_string(),
ModelType::Embedding(Pool::Mean),
None,
0,
)?;

let input_batch = batch(
Expand Down
1 change: 1 addition & 0 deletions backends/candle/tests/test_qwen3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ fn test_qwen3() -> Result<()> {
"float32".to_string(),
ModelType::Embedding(Pool::LastToken),
None,
0,
)?;

let input_batch = batch(
Expand Down
4 changes: 4 additions & 0 deletions backends/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ impl Backend {
uds_path: String,
otlp_endpoint: Option<String>,
otlp_service_name: String,
device_id: usize,
) -> Result<Self, BackendError> {
let (backend_sender, backend_receiver) = mpsc::channel(8);

Expand All @@ -102,6 +103,7 @@ impl Backend {
uds_path,
otlp_endpoint,
otlp_service_name,
device_id,
)
.await?;
let padded_model = backend.is_padded();
Expand Down Expand Up @@ -362,6 +364,7 @@ async fn init_backend(
uds_path: String,
otlp_endpoint: Option<String>,
otlp_service_name: String,
device_id: usize,
) -> Result<Box<dyn CoreBackend + Send>, BackendError> {
let mut backend_start_failed = false;
let api_repo = api_repo.map(Arc::new);
Expand Down Expand Up @@ -475,6 +478,7 @@ async fn init_backend(
dtype.to_string(),
model_type.clone(),
dense_paths,
device_id,
);
match backend {
Ok(b) => return Ok(Box::new(b)),
Expand Down
2 changes: 2 additions & 0 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ pub async fn run(
otlp_service_name: String,
prometheus_port: u16,
cors_allow_origin: Option<Vec<String>>,
device_id: usize,
) -> Result<()> {
let model_id_path = Path::new(&model_id);
let (model_root, api_repo) = if model_id_path.exists() && model_id_path.is_dir() {
Expand Down Expand Up @@ -280,6 +281,7 @@ pub async fn run(
uds_path.unwrap_or("/tmp/text-embeddings-inference-server".to_string()),
otlp_endpoint.clone(),
otlp_service_name.clone(),
device_id,
)
.await
.context("Could not create backend")?;
Expand Down
6 changes: 6 additions & 0 deletions router/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ struct Args {
#[clap(long, env)]
dense_path: Option<String>,

/// The CUDA device ID where the model will be loaded. Defaults to 0 i.e., the first available device.
/// Only used with the candle backend.
#[clap(long, env, default_value = "0")]
device_id: usize,

/// [DEPRECATED IN FAVOR OF `--hf-token`] Your Hugging Face Hub token
#[clap(long, env, hide = true)]
#[redact(partial)]
Expand Down Expand Up @@ -250,6 +255,7 @@ async fn main() -> Result<()> {
args.otlp_service_name,
args.prometheus_port,
args.cors_allow_origin,
args.device_id,
)
.await?;

Expand Down
1 change: 1 addition & 0 deletions router/tests/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ pub async fn start_server(model_id: String, revision: Option<String>, dtype: DTy
"text-embeddings-inference.server".to_owned(),
9000,
None,
0, // device_id
)
});

Expand Down