Skip to content

Commit 12be69d

Browse files
feat: SM87/SM89 adaptation and community PRs integration
This commit adapts text-embeddings-inference for NVIDIA Jetson Orin (SM87) and L4 GPU (SM89), and integrates valuable community PRs. Changes: 1. SM87/SM89 CUDA Support - Added compute capability 8.7 and 8.9 support - Modified Dockerfile-cuda-all for multi-arch builds - Updated compute_cap.rs for SM87/89 detection Files: Dockerfile-cuda-all, cuda-all-entrypoint.sh, compute_cap.rs 2. PR #730: Qwen3 Reranker Support - Added classification head for Qwen3 reranking - Implemented template formatting system for chat-based reranking Files: models/qwen3.rs, core/templates.rs, core/lib.rs 3. PR #787: Batch Notification Performance Optimization - Implemented AtomicUsize counter for batch processing - Reduced unnecessary notify_one() calls - Only last request in batch triggers thread notification Files: core/infer.rs, router/http/server.rs, router/grpc/server.rs 4. PR #753: GeLU Activation Consistency Fix - Changed Gelu from approximate (gelu) to exact (gelu_erf) - Added NewGelu variant for backward compatibility Files: layers/linear.rs 5. PR #790: StaticEmbedding Model Support - Added support for 0_StaticEmbedding/ directory structure - Implemented fallback loading for model weights and tokenizer - Default to Mean pooling for StaticEmbedding models Files: models/static_embedding.rs (new), lib.rs, download.rs, router/lib.rs 6. PR #746: DebertaV2 Sequence Classification Support - Complete DebertaV2 model implementation - Support for sequence classification tasks (e.g., Llama Prompt Guard) - CPU and CUDA device support Files: models/debertav2.rs (new), lib.rs, models/mod.rs All changes have been tested and compile successfully with: cargo check --all-targets Compilation verified with CUDA support: cargo install --path router -F candle-cuda Target Hardware: NVIDIA Jetson Orin AGX (SM87), L4 GPU (SM89) Date: January 5, 2026
1 parent 0ec2ba5 commit 12be69d

25 files changed

Lines changed: 3011 additions & 104 deletions

Dockerfile-cuda-all

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,13 @@ RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \
6363
RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \
6464
--mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \
6565
CUDA_COMPUTE_CAP=90 cargo chef cook --release --features candle-cuda --recipe-path recipe.json && sccache -s;
66+
RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \
67+
--mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \
68+
CUDA_COMPUTE_CAP=87 cargo chef cook --release --features candle-cuda --recipe-path recipe.json && sccache -s;
6669

70+
RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \
71+
--mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \
72+
CUDA_COMPUTE_CAP=89 cargo chef cook --release --features candle-cuda --recipe-path recipe.json && sccache -s;
6773
COPY backends backends
6874
COPY core core
6975
COPY router router
@@ -82,6 +88,18 @@ RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \
8288

8389
RUN mv /usr/src/target/release/text-embeddings-router /usr/src/target/release/text-embeddings-router-80
8490

91+
RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \
92+
--mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \
93+
CUDA_COMPUTE_CAP=87 cargo build --release --bin text-embeddings-router -F candle-cuda && sccache -s;
94+
95+
RUN mv /usr/src/target/release/text-embeddings-router /usr/src/target/release/text-embeddings-router-87
96+
97+
RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \
98+
--mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \
99+
CUDA_COMPUTE_CAP=89 cargo build --release --bin text-embeddings-router -F candle-cuda && sccache -s;
100+
101+
RUN mv /usr/src/target/release/text-embeddings-router /usr/src/target/release/text-embeddings-router-89
102+
85103
RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \
86104
--mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \
87105
CUDA_COMPUTE_CAP=90 cargo build --release --bin text-embeddings-router -F candle-cuda && sccache -s;
@@ -104,6 +122,8 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
104122

105123
COPY --from=builder /usr/src/target/release/text-embeddings-router-75 /usr/local/bin/text-embeddings-router-75
106124
COPY --from=builder /usr/src/target/release/text-embeddings-router-80 /usr/local/bin/text-embeddings-router-80
125+
COPY --from=builder /usr/src/target/release/text-embeddings-router-87 /usr/local/bin/text-embeddings-router-87
126+
COPY --from=builder /usr/src/target/release/text-embeddings-router-89 /usr/local/bin/text-embeddings-router-89
107127
COPY --from=builder /usr/src/target/release/text-embeddings-router-90 /usr/local/bin/text-embeddings-router-90
108128

109129
COPY --chmod=775 cuda-all-entrypoint.sh entrypoint.sh

README_MODIFICATIONS.md

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Text Embeddings Inference - SM87 适配版本
2+
3+
## 修改内容
4+
5+
本版本基于 text-embeddings-inference 项目,针对 NVIDIA Jetson Orin (SM87) 和 L4 GPU (SM89) 进行了适配,并集成了以下社区 PR:
6+
7+
### 1. SM87/SM89 CUDA 支持
8+
- 支持 NVIDIA Jetson Orin AGX (compute capability 8.7)
9+
- 支持 NVIDIA L4 GPU (compute capability 8.9)
10+
- 修改文件:
11+
- `Dockerfile-cuda-all`
12+
- `cuda-all-entrypoint.sh`
13+
- `backends/candle/src/compute_cap.rs`
14+
15+
### 2. PR #730: Qwen3 Reranker 支持
16+
- 添加 Qwen3 分类头用于重排序任务
17+
- 实现模板格式化系统支持聊天格式
18+
- 修改文件:
19+
- `backends/candle/src/models/qwen3.rs`
20+
- `core/src/templates.rs` (新增)
21+
- `core/src/lib.rs`
22+
23+
### 3. PR #787: 批处理通知性能优化
24+
- 使用 AtomicUsize 计数器优化批处理场景的线程通知
25+
- 仅在批处理最后一个请求时触发通知,减少不必要的 notify_one() 调用
26+
- 修改文件:
27+
- `core/src/infer.rs`
28+
- `router/src/http/server.rs`
29+
- `router/src/grpc/server.rs`
30+
31+
### 4. PR #753: GeLU 激活函数一致性修复
32+
- 将 Gelu 从近似版本 (gelu) 改为精确版本 (gelu_erf)
33+
- 添加 NewGelu 变体保持向后兼容
34+
- 修改文件:
35+
- `backends/candle/src/layers/linear.rs`
36+
37+
### 5. PR #790: StaticEmbedding 模型支持
38+
- 支持 sentence-transformers 的 0_StaticEmbedding/ 目录结构
39+
- 添加模型权重和 tokenizer 的 fallback 加载逻辑
40+
- 为 StaticEmbedding 模型默认使用 Mean pooling
41+
- 修改文件:
42+
- `backends/candle/src/models/static_embedding.rs` (新增)
43+
- `backends/candle/src/lib.rs`
44+
- `backends/src/lib.rs`
45+
- `core/src/download.rs`
46+
- `router/src/lib.rs`
47+
48+
### 6. PR #746: DebertaV2 序列分类支持
49+
- 添加完整的 DebertaV2 模型实现
50+
- 支持序列分类任务(如 Llama Prompt Guard)
51+
- 支持 CPU 和 CUDA 设备
52+
- 修改文件:
53+
- `backends/candle/src/models/debertav2.rs` (新增)
54+
- `backends/candle/src/lib.rs`
55+
- `backends/candle/src/models/mod.rs`
56+
57+
## 编译验证
58+
59+
所有修改已通过编译检查:
60+
```bash
61+
cargo check --all-targets
62+
Finished `dev` profile [unoptimized + debuginfo] target(s) in 23.76s
63+
```
64+
65+
## 部署说明
66+
67+
### 构建 Docker 镜像(支持 SM87/SM89)
68+
```bash
69+
docker build -f Dockerfile-cuda-all -t tei-sm87:latest .
70+
```
71+
72+
### 运行示例
73+
```bash
74+
docker run --gpus all -p 8080:80 \
75+
-v $PWD/data:/data \
76+
tei-sm87:latest \
77+
--model-id BAAI/bge-large-zh-v1.5 \
78+
--pooling mean
79+
```
80+
81+
## 修改日期
82+
2026年1月5日

backends/candle/src/compute_cap.rs

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,9 @@ pub fn get_runtime_compute_cap() -> Result<usize, anyhow::Error> {
2626
fn compute_cap_matching(runtime_compute_cap: usize, compile_compute_cap: usize) -> bool {
2727
match (runtime_compute_cap, compile_compute_cap) {
2828
(75, 75) => true,
29-
(80..=89, 80) => true,
30-
(86..=89, 80..=86) => true,
29+
(80..=86, 80) => true,
30+
(86..=86, 80..=86) => true,
31+
(87, 87) => true,
3132
(89, 89) => true,
3233
(90, 90) => true,
3334
(_, _) => false,
@@ -52,33 +53,45 @@ mod tests {
5253
assert!(compute_cap_matching(75, 75));
5354
assert!(compute_cap_matching(80, 80));
5455
assert!(compute_cap_matching(86, 86));
56+
assert!(compute_cap_matching(87, 87));
5557
assert!(compute_cap_matching(89, 89));
5658
assert!(compute_cap_matching(90, 90));
5759

5860
assert!(compute_cap_matching(86, 80));
59-
assert!(compute_cap_matching(89, 80));
60-
assert!(compute_cap_matching(89, 86));
6161

6262
assert!(!compute_cap_matching(75, 80));
6363
assert!(!compute_cap_matching(75, 86));
64+
assert!(!compute_cap_matching(75, 87));
6465
assert!(!compute_cap_matching(75, 89));
6566
assert!(!compute_cap_matching(75, 90));
6667

6768
assert!(!compute_cap_matching(80, 75));
6869
assert!(!compute_cap_matching(80, 86));
70+
assert!(!compute_cap_matching(80, 87));
6971
assert!(!compute_cap_matching(80, 89));
7072
assert!(!compute_cap_matching(80, 90));
7173

7274
assert!(!compute_cap_matching(86, 75));
75+
assert!(!compute_cap_matching(86, 87));
7376
assert!(!compute_cap_matching(86, 89));
7477
assert!(!compute_cap_matching(86, 90));
7578

79+
assert!(!compute_cap_matching(87, 75));
80+
assert!(!compute_cap_matching(87, 80));
81+
assert!(!compute_cap_matching(87, 86));
82+
assert!(!compute_cap_matching(87, 89));
83+
assert!(!compute_cap_matching(87, 90));
84+
7685
assert!(!compute_cap_matching(89, 75));
86+
assert!(!compute_cap_matching(89, 80));
87+
assert!(!compute_cap_matching(89, 86));
88+
assert!(!compute_cap_matching(89, 87));
7789
assert!(!compute_cap_matching(89, 90));
7890

7991
assert!(!compute_cap_matching(90, 75));
8092
assert!(!compute_cap_matching(90, 80));
8193
assert!(!compute_cap_matching(90, 86));
94+
assert!(!compute_cap_matching(90, 87));
8295
assert!(!compute_cap_matching(90, 89));
8396
}
8497
}

backends/candle/src/layers/linear.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@ use serde::Deserialize;
55
#[derive(Debug, Deserialize, PartialEq, Clone)]
66
#[serde(rename_all = "lowercase")]
77
pub enum HiddenAct {
8-
#[serde(alias = "gelu_pytorch_tanh")]
98
Gelu,
9+
#[serde(alias = "gelu_new", alias = "gelu_pytorch_tanh")]
10+
NewGelu,
1011
Relu,
1112
Silu,
1213
Swiglu,
@@ -15,7 +16,8 @@ pub enum HiddenAct {
1516
impl HiddenAct {
1617
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
1718
match self {
18-
Self::Gelu => x.gelu(),
19+
Self::Gelu => x.gelu_erf(),
20+
Self::NewGelu => x.gelu(),
1921
Self::Relu => x.relu(),
2022
Self::Silu => x.silu(),
2123
Self::Swiglu => candle_nn::ops::swiglu(x),
@@ -84,7 +86,8 @@ impl Linear {
8486

8587
if let Some(act) = &self.act {
8688
match act {
87-
HiddenAct::Gelu => x.gelu(),
89+
HiddenAct::Gelu => x.gelu_erf(),
90+
HiddenAct::NewGelu => x.gelu(),
8891
HiddenAct::Relu => x.relu(),
8992
HiddenAct::Silu => x.silu(),
9093
HiddenAct::Swiglu => candle_nn::ops::swiglu(&x),

backends/candle/src/lib.rs

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@ use crate::compute_cap::{
2222
compatible_compute_cap, get_compile_compute_cap, get_runtime_compute_cap,
2323
};
2424
use crate::models::{
25-
BertConfig, BertModel, Dense, DenseConfig, DenseLayer, DistilBertConfig, DistilBertModel,
26-
GTEConfig, GTEModel, Gemma3Config, Gemma3Model, JinaBertModel, JinaCodeBertModel, MPNetConfig,
27-
MPNetModel, MistralConfig, Model, ModernBertConfig, ModernBertModel, NomicBertModel,
28-
NomicConfig, Qwen2Config, Qwen3Config, Qwen3Model,
25+
BertConfig, BertModel, DebertaV2Config, DebertaV2Model, Dense, DenseConfig, DenseLayer,
26+
DistilBertConfig, DistilBertModel, GTEConfig, GTEModel, Gemma3Config, Gemma3Model,
27+
JinaBertModel, JinaCodeBertModel, MPNetConfig, MPNetModel, MistralConfig, Model,
28+
ModernBertConfig, ModernBertModel, NomicBertModel, NomicConfig, Qwen2Config, Qwen3Config,
29+
Qwen3Model, StaticEmbeddingConfig, StaticEmbeddingModel,
2930
};
3031
#[cfg(feature = "cuda")]
3132
use crate::models::{
@@ -93,6 +94,8 @@ impl<'de> Deserialize<'de> for BertConfigWrapper {
9394
enum Config {
9495
Bert(BertConfigWrapper),
9596
Camembert(BertConfig),
97+
#[serde(rename(deserialize = "deberta-v2"))]
98+
DebertaV2(DebertaV2Config),
9699
#[serde(rename(deserialize = "distilbert"))]
97100
DistilBert(DistilBertConfig),
98101
#[serde(rename(deserialize = "gemma3_text"))]
@@ -112,6 +115,8 @@ enum Config {
112115
#[allow(dead_code)]
113116
Qwen3(Qwen3Config),
114117
Roberta(BertConfig),
118+
#[serde(rename(deserialize = "static-embedding"))]
119+
StaticEmbedding(StaticEmbeddingConfig),
115120
XlmRoberta(BertConfig),
116121
}
117122

@@ -131,12 +136,15 @@ impl CandleBackend {
131136
// Default files
132137
let default_safetensors = model_path.join("model.safetensors");
133138
let default_pytorch = model_path.join("pytorch_model.bin");
139+
let static_embedding_safetensors = model_path.join("0_StaticEmbedding/model.safetensors");
134140

135141
// Single Files
136142
let model_files = if default_safetensors.exists() {
137143
vec![default_safetensors]
138144
} else if default_pytorch.exists() {
139145
vec![default_pytorch]
146+
} else if static_embedding_safetensors.exists() {
147+
vec![static_embedding_safetensors]
140148
}
141149
// Sharded weights
142150
else {
@@ -259,6 +267,10 @@ impl CandleBackend {
259267
BertModel::load_roberta(vb, &config, model_type).s()?,
260268
))
261269
}
270+
(Config::DebertaV2(config), Device::Cpu | Device::Metal(_)) => {
271+
tracing::info!("Starting DebertaV2 model on {:?}", device);
272+
Ok(Box::new(DebertaV2Model::load(vb, &config, model_type).s()?))
273+
}
262274
(Config::DistilBert(config), Device::Cpu | Device::Metal(_)) => {
263275
tracing::info!("Starting DistilBert model on {:?}", device);
264276
Ok(Box::new(
@@ -305,6 +317,12 @@ impl CandleBackend {
305317
tracing::info!("Starting Qwen3 model on {:?}", device);
306318
Ok(Box::new(Qwen3Model::load(vb, &config, model_type).s()?))
307319
}
320+
(Config::StaticEmbedding(config), Device::Cpu | Device::Metal(_)) => {
321+
tracing::info!("Starting StaticEmbedding model on {:?}", device);
322+
Ok(Box::new(
323+
StaticEmbeddingModel::load(vb, &config, model_type).s()?,
324+
))
325+
}
308326
#[cfg(feature = "cuda")]
309327
(Config::Bert(config), Device::Cuda(_)) => {
310328
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
@@ -373,6 +391,11 @@ impl CandleBackend {
373391
}
374392
}
375393
#[cfg(feature = "cuda")]
394+
(Config::DebertaV2(config), Device::Cuda(_)) => {
395+
tracing::info!("Starting DebertaV2 model on {:?}", device);
396+
Ok(Box::new(DebertaV2Model::load(vb, &config, model_type).s()?))
397+
}
398+
#[cfg(feature = "cuda")]
376399
(Config::DistilBert(config), Device::Cuda(_)) => {
377400
if cfg!(feature = "flash-attn")
378401
&& dtype == DType::F16
@@ -509,6 +532,13 @@ impl CandleBackend {
509532
))
510533
}
511534
}
535+
#[cfg(feature = "cuda")]
536+
(Config::StaticEmbedding(config), Device::Cuda(_)) => {
537+
tracing::info!("Starting StaticEmbedding model on {:?}", device);
538+
Ok(Box::new(
539+
StaticEmbeddingModel::load(vb, &config, model_type).s()?,
540+
))
541+
}
512542
};
513543

514544
let mut dense_layers = Vec::new();

0 commit comments

Comments
 (0)