Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ Below are some examples of the currently supported models:
| N/A | 475M-A305M | NomicBERT | [nomic-ai/nomic-embed-text-v2-moe](https://hf.co/nomic-ai/nomic-embed-text-v2-moe) |
| N/A | 434M | Alibaba GTE | [Alibaba-NLP/gte-large-en-v1.5](https://hf.co/Alibaba-NLP/gte-large-en-v1.5) |
| N/A | 396M | ModernBERT | [answerdotai/ModernBERT-large](https://hf.co/answerdotai/ModernBERT-large) |
| N/A | 340M | Qwen3 | [voyageai/voyage-4-nano](https://hf.co/voyageai/voyage-4-nano) |
| N/A | 137M | JinaBERT | [jinaai/jina-embeddings-v2-base-en](https://hf.co/jinaai/jina-embeddings-v2-base-en) |
| N/A | 137M | JinaBERT | [jinaai/jina-embeddings-v2-base-code](https://hf.co/jinaai/jina-embeddings-v2-base-code) |

Expand Down
44 changes: 37 additions & 7 deletions backends/candle/src/models/flash_qwen3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ struct Qwen3Attention {
attention_head_size: usize,

softmax_scale: f32,
use_bidirectional_attention: bool,

span: tracing::Span,
}
Expand Down Expand Up @@ -98,6 +99,7 @@ impl Qwen3Attention {
num_key_value_heads,
attention_head_size,
softmax_scale,
use_bidirectional_attention: config.use_bidirectional_attention.unwrap_or(false),
span: tracing::span!(tracing::Level::TRACE, "attention"),
})
}
Expand Down Expand Up @@ -158,7 +160,7 @@ impl Qwen3Attention {
max_s,
max_s,
self.softmax_scale,
true,
!self.use_bidirectional_attention,
None,
None,
)?;
Expand Down Expand Up @@ -285,6 +287,7 @@ pub struct FlashQwen3Model {
embeddings: Embedding,
layers: Vec<Qwen3Layer>,
norm: RMSNorm,
projection: Option<Linear>,
cos_cache: Tensor,
sin_cache: Tensor,
pool: Pool,
Expand Down Expand Up @@ -313,23 +316,42 @@ impl FlashQwen3Model {

// The Qwen3-Reranker models contain the `model` key
// https://huggingface.co/collections/Qwen/qwen3-reranker-6841b22d0192d7ade9cdefea
let vb = if vb.contains_tensor("model.embed_tokens.weight") {
vb.pp("model")
let model_prefix = if vb.contains_tensor("model.embed_tokens.weight") {
"model."
} else {
vb
""
};

let embeddings = Embedding::new(
vb.pp("embed_tokens")
vb.pp(format!("{model_prefix}embed_tokens"))
.get((config.vocab_size, config.hidden_size), "weight")?,
config.hidden_size,
);

let layers = (0..config.num_hidden_layers)
.map(|index| Qwen3Layer::load(vb.pp(format!("layers.{index}")), config))
.map(|index| Qwen3Layer::load(vb.pp(format!("{model_prefix}layers.{index}")), config))
.collect::<Result<Vec<_>>>()?;

let norm = RMSNorm::load(vb.pp("norm"), config.hidden_size, config.rms_norm_eps)?;
let norm = RMSNorm::load(
vb.pp(format!("{model_prefix}norm")),
config.hidden_size,
config.rms_norm_eps,
)?;

let projection = if let Some(num_labels) = config.num_labels {
if vb.contains_tensor("linear.weight") {
let projection_weight =
vb.get((num_labels, config.hidden_size), "linear.weight")?;
Some(Linear::new(projection_weight, None, None))
} else {
tracing::warn!(
"num_labels is set but linear.weight not found, skipping projection layer"
);
None
}
} else {
None
};

let inv_freqs = get_inv_freqs(
layers[0].attention.attention_head_size,
Expand All @@ -348,6 +370,7 @@ impl FlashQwen3Model {
embeddings,
layers,
norm,
projection,
cos_cache,
sin_cache,
pool,
Expand Down Expand Up @@ -392,6 +415,13 @@ impl FlashQwen3Model {

let (outputs, _) = self.norm.forward(&hidden_states, residual.as_ref())?;

// NOTE: `projection` required by https://huggingface.co/voyageai/voyage-4-nano
let outputs = if let Some(ref projection) = self.projection {
Comment thread
alvarobartt marked this conversation as resolved.
projection.forward(&outputs)?
} else {
outputs
};

let has_pooling_requests = !batch.pooled_indices.is_empty();
let has_raw_requests = !batch.raw_indices.is_empty();

Expand Down
54 changes: 47 additions & 7 deletions backends/candle/src/models/qwen3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ pub struct Qwen3Config {
pub sliding_window: Option<usize>,
pub use_sliding_window: bool,
pub eos_token_id: usize,
// TODO(alvarobartt): Migrate to `is_causal` instead
// https://github.com/huggingface/transformers/pull/43705
#[serde(default)]
pub use_bidirectional_attention: Option<bool>,
Comment thread
alvarobartt marked this conversation as resolved.
#[serde(default)]
pub num_labels: Option<usize>,
}

struct Qwen3Attention {
Expand Down Expand Up @@ -379,11 +385,14 @@ pub struct Qwen3Model {
embeddings: Embedding,
layers: Vec<Qwen3Layer>,
norm: RMSNorm,
// TODO(alvarobartt): Eventually extend Qwen3 for Voyage instead of adding `projection` here
projection: Option<Linear>,
Comment thread
alvarobartt marked this conversation as resolved.
rotary_cache: (Tensor, Tensor),
rotary_dim: usize,
pool: Pool,
num_attention_heads: usize,
pad_token_id: u32,
use_bidirectional_attention: bool,

dtype: DType,
device: Device,
Expand All @@ -402,23 +411,44 @@ impl Qwen3Model {

// The Qwen3-Reranker models contain the `model` key
// https://huggingface.co/collections/Qwen/qwen3-reranker-6841b22d0192d7ade9cdefea
let vb = if vb.contains_tensor("model.embed_tokens.weight") {
vb.pp("model")
let model_prefix = if vb.contains_tensor("model.embed_tokens.weight") {
"model."
} else {
vb
""
};

let embeddings = Embedding::new(
vb.pp("embed_tokens")
vb.pp(format!("{model_prefix}embed_tokens"))
.get((config.vocab_size, config.hidden_size), "weight")?,
config.hidden_size,
);

let layers = (0..config.num_hidden_layers)
.map(|index| Qwen3Layer::load(vb.pp(format!("layers.{index}")), config))
.map(|index| Qwen3Layer::load(vb.pp(format!("{model_prefix}layers.{index}")), config))
.collect::<Result<Vec<_>>>()?;

let norm = RMSNorm::load(vb.pp("norm"), config.hidden_size, config.rms_norm_eps)?;
let norm = RMSNorm::load(
vb.pp(format!("{model_prefix}norm")),
config.hidden_size,
config.rms_norm_eps,
)?;

let projection = if let Some(num_labels) = config.num_labels {
if vb.contains_tensor("linear.weight") {
let projection_weight =
vb.get((num_labels, config.hidden_size), "linear.weight")?;
Some(Linear::new(projection_weight, None, None))
} else {
tracing::warn!(
"num_labels is set but linear.weight not found, skipping projection layer"
);
None
}
} else {
None
};

let use_bidirectional_attention = config.use_bidirectional_attention.unwrap_or(false);

let rotary_dim = config
.head_dim
Expand All @@ -433,11 +463,13 @@ impl Qwen3Model {
embeddings,
layers,
norm,
projection,
rotary_cache,
rotary_dim,
pool,
pad_token_id: config.eos_token_id as u32,
num_attention_heads: config.num_attention_heads,
use_bidirectional_attention,
dtype: vb.dtype(),
device: vb.device().clone(),
span: tracing::span!(tracing::Level::TRACE, "model"),
Expand Down Expand Up @@ -555,7 +587,9 @@ impl Qwen3Model {
(input_ids, position_ids, input_lengths, Some(attention_bias))
};

let attention_bias = if let Some(attn_bias) = attention_bias {
let attention_bias = if self.use_bidirectional_attention {
attention_bias
} else if let Some(attn_bias) = attention_bias {
Some(self.get_causal_attention_bias(attn_bias)?)
} else {
None
Expand All @@ -581,6 +615,12 @@ impl Qwen3Model {

let (outputs, _) = self.norm.forward(&hidden_states, None)?;

let outputs = if let Some(ref projection) = self.projection {
projection.forward(&outputs)?
} else {
outputs
};

let has_pooling_requests = !batch.pooled_indices.is_empty();
let has_raw_requests = !batch.raw_indices.is_empty();

Expand Down
Loading
Loading