Skip to content

Commit b59f754

Browse files
Add bidirectional attention and projection layer support for Qwen3-based models (#808)
Co-authored-by: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com>
1 parent 45354f1 commit b59f754

10 files changed

Lines changed: 16600 additions & 14 deletions

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ Below are some examples of the currently supported models:
8989
| N/A | 475M-A305M | NomicBERT | [nomic-ai/nomic-embed-text-v2-moe](https://hf.co/nomic-ai/nomic-embed-text-v2-moe) |
9090
| N/A | 434M | Alibaba GTE | [Alibaba-NLP/gte-large-en-v1.5](https://hf.co/Alibaba-NLP/gte-large-en-v1.5) |
9191
| N/A | 396M | ModernBERT | [answerdotai/ModernBERT-large](https://hf.co/answerdotai/ModernBERT-large) |
92+
| N/A | 340M | Qwen3 | [voyageai/voyage-4-nano](https://hf.co/voyageai/voyage-4-nano) |
9293
| N/A | 137M | JinaBERT | [jinaai/jina-embeddings-v2-base-en](https://hf.co/jinaai/jina-embeddings-v2-base-en) |
9394
| N/A | 137M | JinaBERT | [jinaai/jina-embeddings-v2-base-code](https://hf.co/jinaai/jina-embeddings-v2-base-code) |
9495

backends/candle/src/models/flash_qwen3.rs

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ struct Qwen3Attention {
2020
attention_head_size: usize,
2121

2222
softmax_scale: f32,
23+
use_bidirectional_attention: bool,
2324

2425
span: tracing::Span,
2526
}
@@ -98,6 +99,7 @@ impl Qwen3Attention {
9899
num_key_value_heads,
99100
attention_head_size,
100101
softmax_scale,
102+
use_bidirectional_attention: config.use_bidirectional_attention.unwrap_or(false),
101103
span: tracing::span!(tracing::Level::TRACE, "attention"),
102104
})
103105
}
@@ -158,7 +160,7 @@ impl Qwen3Attention {
158160
max_s,
159161
max_s,
160162
self.softmax_scale,
161-
true,
163+
!self.use_bidirectional_attention,
162164
None,
163165
None,
164166
)?;
@@ -285,6 +287,7 @@ pub struct FlashQwen3Model {
285287
embeddings: Embedding,
286288
layers: Vec<Qwen3Layer>,
287289
norm: RMSNorm,
290+
projection: Option<Linear>,
288291
cos_cache: Tensor,
289292
sin_cache: Tensor,
290293
pool: Pool,
@@ -313,23 +316,42 @@ impl FlashQwen3Model {
313316

314317
// The Qwen3-Reranker models contain the `model` key
315318
// https://huggingface.co/collections/Qwen/qwen3-reranker-6841b22d0192d7ade9cdefea
316-
let vb = if vb.contains_tensor("model.embed_tokens.weight") {
317-
vb.pp("model")
319+
let model_prefix = if vb.contains_tensor("model.embed_tokens.weight") {
320+
"model."
318321
} else {
319-
vb
322+
""
320323
};
321324

322325
let embeddings = Embedding::new(
323-
vb.pp("embed_tokens")
326+
vb.pp(format!("{model_prefix}embed_tokens"))
324327
.get((config.vocab_size, config.hidden_size), "weight")?,
325328
config.hidden_size,
326329
);
327330

328331
let layers = (0..config.num_hidden_layers)
329-
.map(|index| Qwen3Layer::load(vb.pp(format!("layers.{index}")), config))
332+
.map(|index| Qwen3Layer::load(vb.pp(format!("{model_prefix}layers.{index}")), config))
330333
.collect::<Result<Vec<_>>>()?;
331334

332-
let norm = RMSNorm::load(vb.pp("norm"), config.hidden_size, config.rms_norm_eps)?;
335+
let norm = RMSNorm::load(
336+
vb.pp(format!("{model_prefix}norm")),
337+
config.hidden_size,
338+
config.rms_norm_eps,
339+
)?;
340+
341+
let projection = if let Some(num_labels) = config.num_labels {
342+
if vb.contains_tensor("linear.weight") {
343+
let projection_weight =
344+
vb.get((num_labels, config.hidden_size), "linear.weight")?;
345+
Some(Linear::new(projection_weight, None, None))
346+
} else {
347+
tracing::warn!(
348+
"num_labels is set but linear.weight not found, skipping projection layer"
349+
);
350+
None
351+
}
352+
} else {
353+
None
354+
};
333355

334356
let inv_freqs = get_inv_freqs(
335357
layers[0].attention.attention_head_size,
@@ -348,6 +370,7 @@ impl FlashQwen3Model {
348370
embeddings,
349371
layers,
350372
norm,
373+
projection,
351374
cos_cache,
352375
sin_cache,
353376
pool,
@@ -392,6 +415,13 @@ impl FlashQwen3Model {
392415

393416
let (outputs, _) = self.norm.forward(&hidden_states, residual.as_ref())?;
394417

418+
// NOTE: `projection` required by https://huggingface.co/voyageai/voyage-4-nano
419+
let outputs = if let Some(ref projection) = self.projection {
420+
projection.forward(&outputs)?
421+
} else {
422+
outputs
423+
};
424+
395425
let has_pooling_requests = !batch.pooled_indices.is_empty();
396426
let has_raw_requests = !batch.raw_indices.is_empty();
397427

backends/candle/src/models/qwen3.rs

Lines changed: 47 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@ pub struct Qwen3Config {
2424
pub sliding_window: Option<usize>,
2525
pub use_sliding_window: bool,
2626
pub eos_token_id: usize,
27+
// TODO(alvarobartt): Migrate to `is_causal` instead
28+
// https://github.com/huggingface/transformers/pull/43705
29+
#[serde(default)]
30+
pub use_bidirectional_attention: Option<bool>,
31+
#[serde(default)]
32+
pub num_labels: Option<usize>,
2733
}
2834

2935
struct Qwen3Attention {
@@ -379,11 +385,14 @@ pub struct Qwen3Model {
379385
embeddings: Embedding,
380386
layers: Vec<Qwen3Layer>,
381387
norm: RMSNorm,
388+
// TODO(alvarobartt): Eventually extend Qwen3 for Voyage instead of adding `projection` here
389+
projection: Option<Linear>,
382390
rotary_cache: (Tensor, Tensor),
383391
rotary_dim: usize,
384392
pool: Pool,
385393
num_attention_heads: usize,
386394
pad_token_id: u32,
395+
use_bidirectional_attention: bool,
387396

388397
dtype: DType,
389398
device: Device,
@@ -402,23 +411,44 @@ impl Qwen3Model {
402411

403412
// The Qwen3-Reranker models contain the `model` key
404413
// https://huggingface.co/collections/Qwen/qwen3-reranker-6841b22d0192d7ade9cdefea
405-
let vb = if vb.contains_tensor("model.embed_tokens.weight") {
406-
vb.pp("model")
414+
let model_prefix = if vb.contains_tensor("model.embed_tokens.weight") {
415+
"model."
407416
} else {
408-
vb
417+
""
409418
};
410419

411420
let embeddings = Embedding::new(
412-
vb.pp("embed_tokens")
421+
vb.pp(format!("{model_prefix}embed_tokens"))
413422
.get((config.vocab_size, config.hidden_size), "weight")?,
414423
config.hidden_size,
415424
);
416425

417426
let layers = (0..config.num_hidden_layers)
418-
.map(|index| Qwen3Layer::load(vb.pp(format!("layers.{index}")), config))
427+
.map(|index| Qwen3Layer::load(vb.pp(format!("{model_prefix}layers.{index}")), config))
419428
.collect::<Result<Vec<_>>>()?;
420429

421-
let norm = RMSNorm::load(vb.pp("norm"), config.hidden_size, config.rms_norm_eps)?;
430+
let norm = RMSNorm::load(
431+
vb.pp(format!("{model_prefix}norm")),
432+
config.hidden_size,
433+
config.rms_norm_eps,
434+
)?;
435+
436+
let projection = if let Some(num_labels) = config.num_labels {
437+
if vb.contains_tensor("linear.weight") {
438+
let projection_weight =
439+
vb.get((num_labels, config.hidden_size), "linear.weight")?;
440+
Some(Linear::new(projection_weight, None, None))
441+
} else {
442+
tracing::warn!(
443+
"num_labels is set but linear.weight not found, skipping projection layer"
444+
);
445+
None
446+
}
447+
} else {
448+
None
449+
};
450+
451+
let use_bidirectional_attention = config.use_bidirectional_attention.unwrap_or(false);
422452

423453
let rotary_dim = config
424454
.head_dim
@@ -433,11 +463,13 @@ impl Qwen3Model {
433463
embeddings,
434464
layers,
435465
norm,
466+
projection,
436467
rotary_cache,
437468
rotary_dim,
438469
pool,
439470
pad_token_id: config.eos_token_id as u32,
440471
num_attention_heads: config.num_attention_heads,
472+
use_bidirectional_attention,
441473
dtype: vb.dtype(),
442474
device: vb.device().clone(),
443475
span: tracing::span!(tracing::Level::TRACE, "model"),
@@ -555,7 +587,9 @@ impl Qwen3Model {
555587
(input_ids, position_ids, input_lengths, Some(attention_bias))
556588
};
557589

558-
let attention_bias = if let Some(attn_bias) = attention_bias {
590+
let attention_bias = if self.use_bidirectional_attention {
591+
attention_bias
592+
} else if let Some(attn_bias) = attention_bias {
559593
Some(self.get_causal_attention_bias(attn_bias)?)
560594
} else {
561595
None
@@ -581,6 +615,12 @@ impl Qwen3Model {
581615

582616
let (outputs, _) = self.norm.forward(&hidden_states, None)?;
583617

618+
let outputs = if let Some(ref projection) = self.projection {
619+
projection.forward(&outputs)?
620+
} else {
621+
outputs
622+
};
623+
584624
let has_pooling_requests = !batch.pooled_indices.is_empty();
585625
let has_raw_requests = !batch.raw_indices.is_empty();
586626

0 commit comments

Comments
 (0)