Skip to content

add Jetson Orin support#467

Closed
thomas-hiddenpeak wants to merge 3 commits into
huggingface:mainfrom
thomas-hiddenpeak:Jetson-support
Closed

add Jetson Orin support#467
thomas-hiddenpeak wants to merge 3 commits into
huggingface:mainfrom
thomas-hiddenpeak:Jetson-support

Conversation

@thomas-hiddenpeak

@thomas-hiddenpeak thomas-hiddenpeak commented Jan 4, 2025

Copy link
Copy Markdown

Motivation and Context

NVIDIA Jetson Orin devices have a compute capability of 8.7, which is not currently supported in the compute_cap_matching function. This PR ensures that these devices can be used with the library by adding the necessary support.

What does this PR do?

This PR adds support for NVIDIA Jetson Orin devices by including the compute capability 8.7 in the compute_cap_matching function and updating the tests to ensure the new capability is correctly supported.

Fixes #466

Checklist

  • I have read the contributor guidelines.
  • I have added tests to verify my changes.
  • I have tagged the appropriate reviewers.

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

@OlivierDehaene OR @Narsil

@r0kk

r0kk commented Jan 16, 2025

Copy link
Copy Markdown

@HiddenPeak I am wondering if you could share reproducible steps how were you able to run text-embeddings-inference on Jetson AGX Orin. It would be greatly appreciated🙏.

Unfortunately I don't have deep enough knowledge to review you PR.

@thomas-hiddenpeak

thomas-hiddenpeak commented Jan 20, 2025

Copy link
Copy Markdown
Author

@r0kk
The Jetson Orin series uses the CUDA architecture SM8.7, which is part of the Ampere architecture. Theoretically, it should be compatible with TEI. However, in practical applications, there are many incompatibilities, making direct support generally impossible. During the process of attempting to use it, I encountered the following issues:

  1. The compute_cap_matching() function does not support the SM87 architecture, so I modified the source code and recompiled it.
  2. It is necessary to ensure that the GPU driver, CUDA runtime, and CUDA compiler are correctly installed and available in the environment variable paths.(on Jetpack 6.1 with cuda 12.6)
  3. The compilation process is extremely long, and memory usage exceeds 90% (60GB).

I attempted to compile and deploy TEI on a Jetson AGX Orin 64G and found that it could not recognize SM87. Therefore, I modified the compute_cap_matching() function in backends/candle/src/compute_cap.rs to add support for the SM87 environment and architecture. Such modifications may not be effective in many cases, but fortunately, after making these changes, I was able to achieve support on the Jetson AGX Orin 64G. Not only did it not produce any errors, but it also showed excellent performance.

curl 127.0.0.1:8080/rerank \
    -X POST \
    -d '{"query": "What is Deep Learning?", "texts": ["Deep Learning is not...", "Deep learning is..."]}' \
    -H 'Content-Type: application/json'

logs

2025-01-04T20:38:18.706787Z  INFO text_embeddings_backend_candle: backends/candle/src/lib.rs:292: Starting FlashBert model on Cuda(CudaDevice(DeviceId(1)))
2025-01-04T20:38:31.539445Z  INFO text_embeddings_router: router/src/lib.rs:248: Warming up model
2025-01-04T20:38:32.189069Z  INFO text_embeddings_router::http::server: router/src/http/server.rs:1812: Starting HTTP server: 0.0.0.0:8080
2025-01-04T20:38:32.189098Z  INFO text_embeddings_router::http::server: router/src/http/server.rs:1813: Ready
2025-01-04T20:44:11.047170Z  INFO rerank{total_time="177.15121ms" tokenization_time="727.783µs" queue_time="79.024583ms" inference_time="87.618256ms"}: text_embeddings_router::http::server: router/src/http/server.rs:459: Success

More screenshots are as follows:
74f61043ee2b7a9da75e93547f397b7
6519f72d3f88800fd8e89705b2d5fe1
fc2f46532bbe2545f4f7f2c55388190

Therefore, I created a branch and added test code. After testing it in my application, I submitted a merge request.
Additionally, I also tried other embedding and rerank models, which ran well.

@r0kk

r0kk commented Feb 6, 2025

Copy link
Copy Markdown

@HiddenPeak
I can confirm that this is working on Jetson AGX 64GB. Thank you very much 🙏.

@thomas-hiddenpeak

Copy link
Copy Markdown
Author

@HiddenPeak I can confirm that this is working on Jetson AGX 64GB. Thank you very much 🙏.

It's very cool~

@taresh18-ag

taresh18-ag commented Jul 8, 2025

Copy link
Copy Markdown

Hi, great work.

How did you get it running on jetson orin? When I try to compile it, it throws this error

image

these are the steps I followed:

curl https://sh.rustup.rs/ -sSf | sh
sudo apt-get install libssl-dev gcc -y
git clone https://github.com/huggingface/text-embeddings-inference.git
cd text-embeddings-inference
cargo install --path router -F candle-cuda -F http --no-default-features # getting error here

Also if cuda inference is not possible, I would like to test using cpu only. What are the steps to run this library on an arm cpu? I looked into dockerfiles, they are all dependent on intel mkl libs

@thomas-hiddenpeak

Copy link
Copy Markdown
Author

add -F dynamic-linking

@r0kk

r0kk commented Jul 9, 2025

Copy link
Copy Markdown

Following process worked for me:

Add NVCC to the path

NVIDIA's NVCC (NVIDIA CUDA Compiler) is a compiler driver used to compile CUDA (Compute Unified Device Architecture) code, which allows developers to write programs that run on NVIDIA GPUs. It translates CUDA code into executable binaries for GPU acceleration.

  1. Check if nvcc exists
ls /usr/local/cuda/bin/nvcc
  1. Update .env variables

    • Open .bashrc

      nano ~/.bashrc
    • add nvcc paths

      export PATH=/usr/local/cuda/bin:$PATH
      export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
    • restart terminal

      source ~/.bashrc
    • check nvcc version

      nvcc --version

Build Process (you can skip if build exists)

We prepare the build and can be found in current repository. If it doesn't exist, you can follow instructions below:

  1. Install Rust

    curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
    source $HOME/.cargo/env
  2. Clone the Project
    It is the branch of the original project, because at the time of the writing official release for Jetson family didn't exist.

    git clone https://github.com/HiddenPeak/text-embeddings-inference.git
    cd text-embeddings-inference
  3. Install ssl (sometimes openssl problem might appear when building)

    sudo apt install libssl-dev
  4. Build

  • move inside router dir inside the project
  • to use less space on Jetson, set --target-dir to external disc
    cd router/
    cargo build --release --features=candle-cuda --target-dir <target dir for generated artifact>

This commit adapts text-embeddings-inference for NVIDIA Jetson Orin (SM87)
and L4 GPU (SM89), and integrates valuable community PRs.

Changes:

1. SM87/SM89 CUDA Support
   - Added compute capability 8.7 and 8.9 support
   - Modified Dockerfile-cuda-all for multi-arch builds
   - Updated compute_cap.rs for SM87/89 detection
   Files: Dockerfile-cuda-all, cuda-all-entrypoint.sh, compute_cap.rs

2. PR huggingface#730: Qwen3 Reranker Support
   - Added classification head for Qwen3 reranking
   - Implemented template formatting system for chat-based reranking
   Files: models/qwen3.rs, core/templates.rs, core/lib.rs

3. PR huggingface#787: Batch Notification Performance Optimization
   - Implemented AtomicUsize counter for batch processing
   - Reduced unnecessary notify_one() calls
   - Only last request in batch triggers thread notification
   Files: core/infer.rs, router/http/server.rs, router/grpc/server.rs

4. PR huggingface#753: GeLU Activation Consistency Fix
   - Changed Gelu from approximate (gelu) to exact (gelu_erf)
   - Added NewGelu variant for backward compatibility
   Files: layers/linear.rs

5. PR huggingface#790: StaticEmbedding Model Support
   - Added support for 0_StaticEmbedding/ directory structure
   - Implemented fallback loading for model weights and tokenizer
   - Default to Mean pooling for StaticEmbedding models
   Files: models/static_embedding.rs (new), lib.rs, download.rs, router/lib.rs

6. PR huggingface#746: DebertaV2 Sequence Classification Support
   - Complete DebertaV2 model implementation
   - Support for sequence classification tasks (e.g., Llama Prompt Guard)
   - CPU and CUDA device support
   Files: models/debertav2.rs (new), lib.rs, models/mod.rs

All changes have been tested and compile successfully with:
  cargo check --all-targets

Compilation verified with CUDA support:
  cargo install --path router -F candle-cuda

Target Hardware: NVIDIA Jetson Orin AGX (SM87), L4 GPU (SM89)
Date: January 5, 2026
@didlawowo

Copy link
Copy Markdown

any new ?
is this working on nano super ?

@alvarobartt

Copy link
Copy Markdown
Member

Hey @thomas-hiddenpeak thanks for this PR, indeed it'd be cool to support CUDA compute capability 8.7 as we don't support any of the Jetson devices yet, given that this PR is now a mix of different PRs I guess to make this PR usable for your use case, would you be willing to create a new PR with just the required lines to support such compute capability + changes on Dockerfile-cuda too? Thanks in advance 🤗

P.S. If you're short on time or have other priorities, please let me know and I'll include this compute cap and add you as co-author due to earlier work, thanks!

@thomas-hiddenpeak

Copy link
Copy Markdown
Author

Thanks for the clarification. You’re right, the current PR ended up mixing several unrelated changes. I’ll prepare a new focused PR with only the CUDA compute capability 8.7 support needed for Jetson devices, plus the required CUDA Docker changes.

@thomas-hiddenpeak

Copy link
Copy Markdown
Author

Hi @alvarobartt, I've created the focused PR as requested: #844 (SM87 + SM110 compute capability support only). Looking forward to your review, thanks!

Copilot AI review requested due to automatic review settings June 4, 2026 10:32

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

This PR expands model/runtime support across the router/core/backends to add StaticEmbedding + Qwen3 reranker templating, introduce batch-notify optimizations, and extend CUDA compute-cap targeting (SM87/SM89).

Changes:

  • Add Qwen3 reranker template formatting with new HTTP/Proto request fields (instruction, use_template).
  • Add StaticEmbedding model + artifact fallbacks, and update default pooling behavior.
  • Optimize batching notifications via an AtomicUsize counter and add SM87/SM89 build/entrypoint routing.

Reviewed changes

Copilot reviewed 23 out of 25 changed files in this pull request and generated 9 comments.

Show a summary per file
File Description
router/src/lib.rs Adjust pooling fallback logic for BERT/static-embedding types
router/src/http/types.rs Add rerank instruction and use_template request fields
router/src/http/server.rs Thread batch counters into infer calls; add Qwen3 rerank templating path
router/src/grpc/server.rs Extend infer API calls with new batch_counter parameter
proto/tei.proto Add optional instruction and use_template fields to rerank requests
cuda-all-entrypoint.sh Route to SM87/SM89-specific router binaries
core/src/templates.rs Introduce template formatting abstraction + Qwen3 reranker formatter
core/src/lib.rs Export new templates module
core/src/infer.rs Add batch-notify optimization using AtomicUsize gating notify_one()
core/src/download.rs Add tokenizer fallback to 0_StaticEmbedding/tokenizer.json
backends/src/lib.rs Add StaticEmbedding safetensors fallback download
backends/candle/tests/test_qwen3.rs Add snapshot test for Qwen3 reranker classifier output
backends/candle/tests/snapshots/test_qwen3__qwen3_reranker_single.snap Store reranker snapshot output
backends/candle/src/models/static_embedding.rs Add Candle StaticEmbedding model implementation
backends/candle/src/models/qwen3.rs Add Qwen3 classification head + predict() implementation
backends/candle/src/models/mod.rs Register StaticEmbedding model module/exports
backends/candle/src/models/flash_qwen3.rs Add classification head support for FlashQwen3 + pooling index fix
backends/candle/src/models/debertav2.rs Refactor/extend DebertaV2 modeling and activation deserialization
backends/candle/src/lib.rs Wire StaticEmbedding + DebertaV2 config/device handling and weight fallbacks
backends/candle/src/layers/linear.rs Change GeLU behavior + introduce NewGelu variant
backends/candle/src/compute_cap.rs Adjust runtime/compile compute-cap matching; add SM87/SM89/SM110 cases
README_MODIFICATIONS.md Document integrated changes and SM87/SM89 adaptation notes
Dockerfile-cuda-all Build additional CUDA variants (87/89) and copy binaries
build_cuda.log Add build log artifact (likely unintended)
build.log Add build log artifact (likely unintended)
Comments suppressed due to low confidence (2)

backends/candle/src/models/debertav2.rs:1

  • This block appears to contain a partial refactor/merge artifact (bail!( immediately followed by candle::bail!() that will not compile. Please remove the duplicated/leftover code and keep a single error path (either bail! or candle::bail!) with correct parentheses and braces.
use std::collections::HashMap;

backends/candle/src/layers/linear.rs:1

  • Both Gelu and NewGelu declare the same serde aliases, which makes deserialization ambiguous (and may fail or pick an unintended variant). Make the serialized names/aliases unique (e.g., map \"gelu\" to the exact behavior and \"gelu_new\"/\"gelu_pytorch_tanh\" to NewGelu), and update the doc comment because Gelu now uses gelu_erf() (the comment currently claims the opposite).
use crate::layers::cublaslt::get_cublas_lt_wrapper;

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 130 to +134
Bert(BertConfigWrapper),
DebertaV2(DebertaV2Config),
Camembert(BertConfig),
#[serde(rename(deserialize = "deberta-v2"))]
DebertaV2(DebertaV2Config),
Comment on lines 130 to 144
pub fn load(vb: VarBuilder, config: &DebertaV2Config) -> Result<Self> {
let config = config.clone();

let embedding_size = config.embedding_size.unwrap_or(config.hidden_size);

let word_embeddings =
embedding(config.vocab_size, embedding_size, vb.pp("word_embeddings"))?;

let position_embeddings = if config.position_biased_input {
Some(embedding(
config.max_position_embeddings,
embedding_size,
vb.pp("position_embeddings"),
)?)
let embedding_size = config.embedding_size.unwrap_or(config.hidden_size);
Comment on lines 1034 to +1041
pub fn forward(
&self,
_hidden_states: &Tensor,
_residual_states: &Tensor,
_input_mask: &Tensor,
) -> Result<Tensor> {
todo!("Need a model that contains a conv layer to test against.")
}
Comment thread router/src/http/server.rs
Comment on lines +123 to +124
permit: Option<OwnedSemaphorePermit>,
_batch_counter: Option<Arc<AtomicUsize>>| async move {
Comment thread router/src/http/server.rs
Comment on lines 131 to 138
.predict(
inputs,
truncate,
req.truncation_direction.into(),
req.raw_scores,
permit,
None,
)
Comment thread router/src/http/server.rs
Comment on lines +223 to +227
let batch_counter = if batch_size == 1 {
None
} else {
Some(Arc::new(AtomicUsize::new(batch_size)))
};
Comment thread router/src/http/server.rs
Comment on lines 236 to 243
futures.push(predict_inner(
input,
truncate,
local_infer.0,
local_info.0,
None,
local_batch_counter,
))
Comment thread Dockerfile-cuda-all
Comment on lines 71 to 74
RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \
--mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \
CUDA_COMPUTE_CAP=89 cargo chef cook --release --features candle-cuda --recipe-path recipe.json && sccache -s;
CUDA_COMPUTE_CAP=100 cargo chef cook --release --features candle-cuda --recipe-path recipe.json && sccache -s;
Comment on lines +809 to +819
fn predict(&self, batch: Batch) -> Result<Tensor> {
match &self.classification_head {
None => candle::bail!("`predict` is not implemented for this model"),
Some(classification_head) => {
let (pooled_embeddings, _raw_embeddings) = self.forward(batch)?;
let pooled_embeddings =
pooled_embeddings.expect("pooled_embeddings is empty. This is a bug.");
classification_head.forward(&pooled_embeddings)
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Could not start backend on Jetson AGX Orin

6 participants