@@ -20,7 +20,6 @@ def __init__(
2020 ) -> None :
2121 super ().__init__ ()
2222 self .head_size = head_size
23- assert rotary_dim == head_size
2423 inv_freq = 1.0 / (base ** (torch .arange (0 , rotary_dim , 2 , dtype = torch .float ) / rotary_dim ))
2524 if post_process is not None :
2625 inv_freq = post_process (inv_freq )
@@ -30,8 +29,8 @@ def __init__(
3029 sin = freqs .sin ()
3130 # buffer, so don't load/save
3231 self ._cos_sin_cache = torch .cat ((cos , sin ), dim = - 1 )
33- assert self .head_size in [64 , 128 , 256 , 512 ]
3432
33+ assert self .head_size in [64 , 128 , 256 , 512 ]
3534 from flashinfer import apply_rope_with_cos_sin_cache_inplace
3635
3736 self .apply_rope_with_cos_sin_cache_inplace = apply_rope_with_cos_sin_cache_inplace
@@ -97,15 +96,20 @@ def post_process(inv_freq: torch.Tensor) -> torch.Tensor:
9796 orig_max_pos : int = rope_scaling ["original_max_position_embeddings" ]
9897
9998 def _find_correction_dim (num_rotations : float ) -> float :
100- return rotary_dim * math .log (orig_max_pos / (num_rotations * 2 * math .pi )) / (2 * math .log (base ))
99+ return (
100+ rotary_dim
101+ * math .log (orig_max_pos / (num_rotations * 2 * math .pi ))
102+ / (2 * math .log (base ))
103+ )
101104
102105 low = max (math .floor (_find_correction_dim (beta_fast )), 0 )
103106 high = min (math .ceil (_find_correction_dim (beta_slow )), rotary_dim // 2 - 1 )
104107
105108 def post_process (inv_freq : torch .Tensor ) -> torch .Tensor :
106109 ramp = torch .clamp (
107110 (torch .arange (rotary_dim // 2 , dtype = torch .float32 ) - low ) / max (high - low , 1 ),
108- 0 , 1 ,
111+ 0 ,
112+ 1 ,
109113 )
110114 return (inv_freq / factor ) * ramp + inv_freq * (1 - ramp )
111115
@@ -143,4 +147,4 @@ def get_rope(
143147 return _get_rope (head_dim , rotary_dim , max_position , base , rope_map )
144148
145149
146- __all__ = ["get_rope" , "RotaryEmbedding" , "set_rope_device" ]
150+ __all__ = ["get_rope" , "RotaryEmbedding" , "set_rope_device" ]
0 commit comments