Skip to content

Commit 882d027

Browse files
Mf/add-support-for-llama-3-and-nemotron (#805)
Co-authored-by: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com>
1 parent ac69b50 commit 882d027

4 files changed

Lines changed: 71 additions & 16 deletions

File tree

backends/candle/src/layers/rotary.rs

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,21 @@ use candle::{DType, Device, Result, Tensor, D};
22
use serde::Deserialize;
33

44
#[derive(Debug, Clone, PartialEq, Deserialize)]
5-
pub struct NTKScaling {
6-
pub factor: f32,
7-
}
8-
9-
#[derive(Debug, Clone, PartialEq, Deserialize)]
10-
#[serde(tag = "type", rename_all = "kebab-case")]
5+
#[serde(untagged)]
116
pub enum RopeScaling {
12-
Ntk(NTKScaling),
7+
Llama3 {
8+
#[serde(alias = "type")]
9+
rope_type: String,
10+
factor: f32,
11+
high_freq_factor: f32,
12+
low_freq_factor: f32,
13+
original_max_position_embeddings: usize,
14+
},
15+
Ntk {
16+
#[serde(alias = "type")]
17+
rope_type: String,
18+
factor: f32,
19+
},
1320
}
1421

1522
pub fn get_inv_freqs(
@@ -29,9 +36,52 @@ pub fn get_inv_freqs(
2936

3037
if let Some(rope_scaling) = rope_scaling {
3138
match rope_scaling {
32-
RopeScaling::Ntk(ntk_scaling) => {
33-
let inv_freqs = get_inv_freqs_inner(dim, base * ntk_scaling.factor, device)?;
34-
let s = ntk_scaling.factor.powf(2.0 / dim as f32) as f64;
39+
RopeScaling::Llama3 {
40+
rope_type: _,
41+
factor,
42+
high_freq_factor,
43+
low_freq_factor,
44+
original_max_position_embeddings,
45+
} => {
46+
let old_context_len = *original_max_position_embeddings as f32;
47+
let low_freq_wavelen = old_context_len / low_freq_factor;
48+
let high_freq_wavelen = old_context_len / high_freq_factor;
49+
50+
let inv_freq: Vec<_> = (0..dim)
51+
.step_by(2)
52+
.map(|i| {
53+
let freq_idx = i as f32 / dim as f32;
54+
// Compute base inverse frequency
55+
let inv_freq_base = 1.0 / base.powf(freq_idx);
56+
57+
// Compute wavelength from inverse frequency
58+
let wavelen = 2.0 * std::f32::consts::PI / inv_freq_base;
59+
60+
// Apply Llama3 scaling logic
61+
if wavelen < high_freq_wavelen {
62+
// High frequency: no scaling
63+
inv_freq_base
64+
} else if wavelen > low_freq_wavelen {
65+
// Low frequency: scale by factor
66+
inv_freq_base / factor
67+
} else {
68+
// Medium frequency: smooth interpolation
69+
let smooth_factor = (old_context_len / wavelen - low_freq_factor)
70+
/ (high_freq_factor - low_freq_factor);
71+
let inv_freq_llama = inv_freq_base / factor;
72+
(1.0 - smooth_factor) * inv_freq_llama + smooth_factor * inv_freq_base
73+
}
74+
})
75+
.collect();
76+
let inv_freq_len = inv_freq.len();
77+
return Tensor::from_vec(inv_freq, (1, inv_freq_len), device);
78+
}
79+
RopeScaling::Ntk {
80+
rope_type: _,
81+
factor,
82+
} => {
83+
let inv_freqs = get_inv_freqs_inner(dim, base * factor, device)?;
84+
let s = factor.powf(2.0 / dim as f32) as f64;
3585
return inv_freqs / s;
3686
}
3787
}

backends/candle/src/lib.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ impl CandleBackend {
298298
tracing::info!("Starting MPNet model on {:?}", device);
299299
Ok(Box::new(MPNetModel::load(vb, &config, model_type).s()?))
300300
}
301-
(Config::Llama(_config), Device::Cpu | Device::Metal(_)) => Err(BackendError::Start(
301+
(Config::Llama(_), Device::Cpu | Device::Metal(_)) => Err(BackendError::Start(
302302
"Llama is only supported on Cuda devices in fp16 with flash attention enabled"
303303
.to_string(),
304304
)),
@@ -531,8 +531,7 @@ impl CandleBackend {
531531
#[cfg(feature = "cuda")]
532532
(Config::Llama(config), Device::Cuda(_)) => {
533533
match config.rope_scaling {
534-
Some(ref _rope_scaling) => {
535-
// error, as no rope scaling is supported for FlashLlama yet
534+
Some(_) => {
536535
Err(BackendError::Start(
537536
"Rope scaling is not supported for FlashLlama yet".to_string(),
538537
))

backends/candle/src/models/flash_mistral.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ struct MistralAttention {
1111
o_proj: Linear,
1212

1313
window_size_left: Option<usize>,
14+
use_bidirectional_attention: bool,
1415

1516
num_attention_heads: usize,
1617
num_key_value_heads: usize,
@@ -24,6 +25,7 @@ struct MistralAttention {
2425
impl MistralAttention {
2526
pub fn load(vb: VarBuilder, config: &MistralConfig) -> Result<Self> {
2627
let window_size_left = config.sliding_window;
28+
let use_bidirectional_attention = config.use_bidirectional_attention;
2729
let num_attention_heads = config.num_attention_heads;
2830
let attention_head_size = config.hidden_size / config.num_attention_heads;
2931
let num_key_value_heads = config.num_key_value_heads;
@@ -54,6 +56,7 @@ impl MistralAttention {
5456
qkv_linear,
5557
o_proj,
5658
window_size_left,
59+
use_bidirectional_attention,
5760
num_attention_heads,
5861
num_key_value_heads,
5962
attention_head_size,
@@ -103,7 +106,7 @@ impl MistralAttention {
103106
max_s,
104107
max_s,
105108
self.softmax_scale,
106-
true,
109+
!self.use_bidirectional_attention,
107110
self.window_size_left,
108111
None,
109112
)?;
@@ -269,7 +272,7 @@ impl FlashMistralModel {
269272
layers[0].attention.attention_head_size,
270273
config.rope_theta,
271274
vb.device(),
272-
None,
275+
config.rope_scaling.as_ref(),
273276
)?;
274277
let (cos_cache, sin_cache) = get_cos_sin(
275278
config.max_position_embeddings,

backends/candle/src/models/mistral.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::layers::HiddenAct;
1+
use crate::layers::{HiddenAct, RopeScaling};
22
use serde::Deserialize;
33

44
#[derive(Debug, Clone, PartialEq, Deserialize)]
@@ -16,4 +16,7 @@ pub struct MistralConfig {
1616
pub model_type: Option<String>,
1717
pub rope_theta: f32,
1818
pub sliding_window: Option<usize>,
19+
pub rope_scaling: Option<RopeScaling>,
20+
#[serde(default)]
21+
pub use_bidirectional_attention: bool,
1922
}

0 commit comments

Comments
 (0)