@@ -2,14 +2,21 @@ use candle::{DType, Device, Result, Tensor, D};
22use serde:: Deserialize ;
33
44#[ derive( Debug , Clone , PartialEq , Deserialize ) ]
5- pub struct NTKScaling {
6- pub factor : f32 ,
7- }
8-
9- #[ derive( Debug , Clone , PartialEq , Deserialize ) ]
10- #[ serde( tag = "type" , rename_all = "kebab-case" ) ]
5+ #[ serde( untagged) ]
116pub enum RopeScaling {
12- Ntk ( NTKScaling ) ,
7+ Llama3 {
8+ #[ serde( alias = "type" ) ]
9+ rope_type : String ,
10+ factor : f32 ,
11+ high_freq_factor : f32 ,
12+ low_freq_factor : f32 ,
13+ original_max_position_embeddings : usize ,
14+ } ,
15+ Ntk {
16+ #[ serde( alias = "type" ) ]
17+ rope_type : String ,
18+ factor : f32 ,
19+ } ,
1320}
1421
1522pub fn get_inv_freqs (
@@ -29,9 +36,52 @@ pub fn get_inv_freqs(
2936
3037 if let Some ( rope_scaling) = rope_scaling {
3138 match rope_scaling {
32- RopeScaling :: Ntk ( ntk_scaling) => {
33- let inv_freqs = get_inv_freqs_inner ( dim, base * ntk_scaling. factor , device) ?;
34- let s = ntk_scaling. factor . powf ( 2.0 / dim as f32 ) as f64 ;
39+ RopeScaling :: Llama3 {
40+ rope_type : _,
41+ factor,
42+ high_freq_factor,
43+ low_freq_factor,
44+ original_max_position_embeddings,
45+ } => {
46+ let old_context_len = * original_max_position_embeddings as f32 ;
47+ let low_freq_wavelen = old_context_len / low_freq_factor;
48+ let high_freq_wavelen = old_context_len / high_freq_factor;
49+
50+ let inv_freq: Vec < _ > = ( 0 ..dim)
51+ . step_by ( 2 )
52+ . map ( |i| {
53+ let freq_idx = i as f32 / dim as f32 ;
54+ // Compute base inverse frequency
55+ let inv_freq_base = 1.0 / base. powf ( freq_idx) ;
56+
57+ // Compute wavelength from inverse frequency
58+ let wavelen = 2.0 * std:: f32:: consts:: PI / inv_freq_base;
59+
60+ // Apply Llama3 scaling logic
61+ if wavelen < high_freq_wavelen {
62+ // High frequency: no scaling
63+ inv_freq_base
64+ } else if wavelen > low_freq_wavelen {
65+ // Low frequency: scale by factor
66+ inv_freq_base / factor
67+ } else {
68+ // Medium frequency: smooth interpolation
69+ let smooth_factor = ( old_context_len / wavelen - low_freq_factor)
70+ / ( high_freq_factor - low_freq_factor) ;
71+ let inv_freq_llama = inv_freq_base / factor;
72+ ( 1.0 - smooth_factor) * inv_freq_llama + smooth_factor * inv_freq_base
73+ }
74+ } )
75+ . collect ( ) ;
76+ let inv_freq_len = inv_freq. len ( ) ;
77+ return Tensor :: from_vec ( inv_freq, ( 1 , inv_freq_len) , device) ;
78+ }
79+ RopeScaling :: Ntk {
80+ rope_type : _,
81+ factor,
82+ } => {
83+ let inv_freqs = get_inv_freqs_inner ( dim, base * factor, device) ?;
84+ let s = factor. powf ( 2.0 / dim as f32 ) as f64 ;
3585 return inv_freqs / s;
3686 }
3787 }
0 commit comments