add Jetson Orin support#467
Conversation
|
@HiddenPeak I am wondering if you could share reproducible steps how were you able to run text-embeddings-inference on Jetson AGX Orin. It would be greatly appreciated🙏. Unfortunately I don't have deep enough knowledge to review you PR. |
|
@r0kk
I attempted to compile and deploy TEI on a Jetson AGX Orin 64G and found that it could not recognize SM87. Therefore, I modified the compute_cap_matching() function in backends/candle/src/compute_cap.rs to add support for the SM87 environment and architecture. Such modifications may not be effective in many cases, but fortunately, after making these changes, I was able to achieve support on the Jetson AGX Orin 64G. Not only did it not produce any errors, but it also showed excellent performance. curl 127.0.0.1:8080/rerank \
-X POST \
-d '{"query": "What is Deep Learning?", "texts": ["Deep Learning is not...", "Deep learning is..."]}' \
-H 'Content-Type: application/json'logs 2025-01-04T20:38:18.706787Z INFO text_embeddings_backend_candle: backends/candle/src/lib.rs:292: Starting FlashBert model on Cuda(CudaDevice(DeviceId(1)))
2025-01-04T20:38:31.539445Z INFO text_embeddings_router: router/src/lib.rs:248: Warming up model
2025-01-04T20:38:32.189069Z INFO text_embeddings_router::http::server: router/src/http/server.rs:1812: Starting HTTP server: 0.0.0.0:8080
2025-01-04T20:38:32.189098Z INFO text_embeddings_router::http::server: router/src/http/server.rs:1813: Ready
2025-01-04T20:44:11.047170Z INFO rerank{total_time="177.15121ms" tokenization_time="727.783µs" queue_time="79.024583ms" inference_time="87.618256ms"}: text_embeddings_router::http::server: router/src/http/server.rs:459: SuccessMore screenshots are as follows: Therefore, I created a branch and added test code. After testing it in my application, I submitted a merge request. |
|
@HiddenPeak |
It's very cool~ |
|
Hi, great work. How did you get it running on jetson orin? When I try to compile it, it throws this error these are the steps I followed: curl https://sh.rustup.rs/ -sSf | sh Also if cuda inference is not possible, I would like to test using cpu only. What are the steps to run this library on an arm cpu? I looked into dockerfiles, they are all dependent on intel mkl libs |
|
add -F dynamic-linking |
|
Following process worked for me: Add
|
This commit adapts text-embeddings-inference for NVIDIA Jetson Orin (SM87) and L4 GPU (SM89), and integrates valuable community PRs. Changes: 1. SM87/SM89 CUDA Support - Added compute capability 8.7 and 8.9 support - Modified Dockerfile-cuda-all for multi-arch builds - Updated compute_cap.rs for SM87/89 detection Files: Dockerfile-cuda-all, cuda-all-entrypoint.sh, compute_cap.rs 2. PR huggingface#730: Qwen3 Reranker Support - Added classification head for Qwen3 reranking - Implemented template formatting system for chat-based reranking Files: models/qwen3.rs, core/templates.rs, core/lib.rs 3. PR huggingface#787: Batch Notification Performance Optimization - Implemented AtomicUsize counter for batch processing - Reduced unnecessary notify_one() calls - Only last request in batch triggers thread notification Files: core/infer.rs, router/http/server.rs, router/grpc/server.rs 4. PR huggingface#753: GeLU Activation Consistency Fix - Changed Gelu from approximate (gelu) to exact (gelu_erf) - Added NewGelu variant for backward compatibility Files: layers/linear.rs 5. PR huggingface#790: StaticEmbedding Model Support - Added support for 0_StaticEmbedding/ directory structure - Implemented fallback loading for model weights and tokenizer - Default to Mean pooling for StaticEmbedding models Files: models/static_embedding.rs (new), lib.rs, download.rs, router/lib.rs 6. PR huggingface#746: DebertaV2 Sequence Classification Support - Complete DebertaV2 model implementation - Support for sequence classification tasks (e.g., Llama Prompt Guard) - CPU and CUDA device support Files: models/debertav2.rs (new), lib.rs, models/mod.rs All changes have been tested and compile successfully with: cargo check --all-targets Compilation verified with CUDA support: cargo install --path router -F candle-cuda Target Hardware: NVIDIA Jetson Orin AGX (SM87), L4 GPU (SM89) Date: January 5, 2026
119e0e5 to
12be69d
Compare
|
any new ? |
|
Hey @thomas-hiddenpeak thanks for this PR, indeed it'd be cool to support CUDA compute capability 8.7 as we don't support any of the Jetson devices yet, given that this PR is now a mix of different PRs I guess to make this PR usable for your use case, would you be willing to create a new PR with just the required lines to support such compute capability + changes on P.S. If you're short on time or have other priorities, please let me know and I'll include this compute cap and add you as co-author due to earlier work, thanks! |
|
Thanks for the clarification. You’re right, the current PR ended up mixing several unrelated changes. I’ll prepare a new focused PR with only the CUDA compute capability 8.7 support needed for Jetson devices, plus the required CUDA Docker changes. |
|
Hi @alvarobartt, I've created the focused PR as requested: #844 (SM87 + SM110 compute capability support only). Looking forward to your review, thanks! |
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
This PR expands model/runtime support across the router/core/backends to add StaticEmbedding + Qwen3 reranker templating, introduce batch-notify optimizations, and extend CUDA compute-cap targeting (SM87/SM89).
Changes:
- Add Qwen3 reranker template formatting with new HTTP/Proto request fields (
instruction,use_template). - Add StaticEmbedding model + artifact fallbacks, and update default pooling behavior.
- Optimize batching notifications via an
AtomicUsizecounter and add SM87/SM89 build/entrypoint routing.
Reviewed changes
Copilot reviewed 23 out of 25 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
| router/src/lib.rs | Adjust pooling fallback logic for BERT/static-embedding types |
| router/src/http/types.rs | Add rerank instruction and use_template request fields |
| router/src/http/server.rs | Thread batch counters into infer calls; add Qwen3 rerank templating path |
| router/src/grpc/server.rs | Extend infer API calls with new batch_counter parameter |
| proto/tei.proto | Add optional instruction and use_template fields to rerank requests |
| cuda-all-entrypoint.sh | Route to SM87/SM89-specific router binaries |
| core/src/templates.rs | Introduce template formatting abstraction + Qwen3 reranker formatter |
| core/src/lib.rs | Export new templates module |
| core/src/infer.rs | Add batch-notify optimization using AtomicUsize gating notify_one() |
| core/src/download.rs | Add tokenizer fallback to 0_StaticEmbedding/tokenizer.json |
| backends/src/lib.rs | Add StaticEmbedding safetensors fallback download |
| backends/candle/tests/test_qwen3.rs | Add snapshot test for Qwen3 reranker classifier output |
| backends/candle/tests/snapshots/test_qwen3__qwen3_reranker_single.snap | Store reranker snapshot output |
| backends/candle/src/models/static_embedding.rs | Add Candle StaticEmbedding model implementation |
| backends/candle/src/models/qwen3.rs | Add Qwen3 classification head + predict() implementation |
| backends/candle/src/models/mod.rs | Register StaticEmbedding model module/exports |
| backends/candle/src/models/flash_qwen3.rs | Add classification head support for FlashQwen3 + pooling index fix |
| backends/candle/src/models/debertav2.rs | Refactor/extend DebertaV2 modeling and activation deserialization |
| backends/candle/src/lib.rs | Wire StaticEmbedding + DebertaV2 config/device handling and weight fallbacks |
| backends/candle/src/layers/linear.rs | Change GeLU behavior + introduce NewGelu variant |
| backends/candle/src/compute_cap.rs | Adjust runtime/compile compute-cap matching; add SM87/SM89/SM110 cases |
| README_MODIFICATIONS.md | Document integrated changes and SM87/SM89 adaptation notes |
| Dockerfile-cuda-all | Build additional CUDA variants (87/89) and copy binaries |
| build_cuda.log | Add build log artifact (likely unintended) |
| build.log | Add build log artifact (likely unintended) |
Comments suppressed due to low confidence (2)
backends/candle/src/models/debertav2.rs:1
- This block appears to contain a partial refactor/merge artifact (
bail!(immediately followed bycandle::bail!() that will not compile. Please remove the duplicated/leftover code and keep a single error path (eitherbail!orcandle::bail!) with correct parentheses and braces.
use std::collections::HashMap;
backends/candle/src/layers/linear.rs:1
- Both
GeluandNewGeludeclare the same serde aliases, which makes deserialization ambiguous (and may fail or pick an unintended variant). Make the serialized names/aliases unique (e.g., map\"gelu\"to the exact behavior and\"gelu_new\"/\"gelu_pytorch_tanh\"toNewGelu), and update the doc comment becauseGelunow usesgelu_erf()(the comment currently claims the opposite).
use crate::layers::cublaslt::get_cublas_lt_wrapper;
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| Bert(BertConfigWrapper), | ||
| DebertaV2(DebertaV2Config), | ||
| Camembert(BertConfig), | ||
| #[serde(rename(deserialize = "deberta-v2"))] | ||
| DebertaV2(DebertaV2Config), |
| pub fn load(vb: VarBuilder, config: &DebertaV2Config) -> Result<Self> { | ||
| let config = config.clone(); | ||
|
|
||
| let embedding_size = config.embedding_size.unwrap_or(config.hidden_size); | ||
|
|
||
| let word_embeddings = | ||
| embedding(config.vocab_size, embedding_size, vb.pp("word_embeddings"))?; | ||
|
|
||
| let position_embeddings = if config.position_biased_input { | ||
| Some(embedding( | ||
| config.max_position_embeddings, | ||
| embedding_size, | ||
| vb.pp("position_embeddings"), | ||
| )?) | ||
| let embedding_size = config.embedding_size.unwrap_or(config.hidden_size); |
| pub fn forward( | ||
| &self, | ||
| _hidden_states: &Tensor, | ||
| _residual_states: &Tensor, | ||
| _input_mask: &Tensor, | ||
| ) -> Result<Tensor> { | ||
| todo!("Need a model that contains a conv layer to test against.") | ||
| } |
| permit: Option<OwnedSemaphorePermit>, | ||
| _batch_counter: Option<Arc<AtomicUsize>>| async move { |
| .predict( | ||
| inputs, | ||
| truncate, | ||
| req.truncation_direction.into(), | ||
| req.raw_scores, | ||
| permit, | ||
| None, | ||
| ) |
| let batch_counter = if batch_size == 1 { | ||
| None | ||
| } else { | ||
| Some(Arc::new(AtomicUsize::new(batch_size))) | ||
| }; |
| futures.push(predict_inner( | ||
| input, | ||
| truncate, | ||
| local_infer.0, | ||
| local_info.0, | ||
| None, | ||
| local_batch_counter, | ||
| )) |
| RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \ | ||
| --mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \ | ||
| CUDA_COMPUTE_CAP=89 cargo chef cook --release --features candle-cuda --recipe-path recipe.json && sccache -s; | ||
| CUDA_COMPUTE_CAP=100 cargo chef cook --release --features candle-cuda --recipe-path recipe.json && sccache -s; |
| fn predict(&self, batch: Batch) -> Result<Tensor> { | ||
| match &self.classification_head { | ||
| None => candle::bail!("`predict` is not implemented for this model"), | ||
| Some(classification_head) => { | ||
| let (pooled_embeddings, _raw_embeddings) = self.forward(batch)?; | ||
| let pooled_embeddings = | ||
| pooled_embeddings.expect("pooled_embeddings is empty. This is a bug."); | ||
| classification_head.forward(&pooled_embeddings) | ||
| } | ||
| } | ||
| } |




Motivation and Context
NVIDIA Jetson Orin devices have a compute capability of 8.7, which is not currently supported in the compute_cap_matching function. This PR ensures that these devices can be used with the library by adding the necessary support.
What does this PR do?
This PR adds support for NVIDIA Jetson Orin devices by including the compute capability 8.7 in the
compute_cap_matchingfunction and updating the tests to ensure the new capability is correctly supported.Fixes #466
Checklist
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
@OlivierDehaene OR @Narsil