1919
2020import math
2121from functools import partial
22- from typing import Any , Callable , Dict , Optional , Tuple
22+ from typing import Any , Callable , Dict , Optional , Tuple , Union
2323
2424from tvm import tir
2525from tvm .relax .frontend .nn import Tensor , op
@@ -180,38 +180,43 @@ def rope_freq_longrope( # pylint: disable=too-many-arguments
180180def yarn_find_correction_dim (
181181 num_rotations : int ,
182182 d : tir .Var ,
183- theta : float ,
184183 max_position_embeddings : int ,
184+ inv_theta_log_scale : Optional [Union [float , tir .PrimExpr ]] = None ,
185185):
186186 """Inverse dim formula to find dim based on number of rotations"""
187- return (d * math . log ( max_position_embeddings / ( num_rotations * 2 * math . pi ))) / (
188- 2 * math .log (theta )
187+ return (
188+ d * math .log (max_position_embeddings / ( num_rotations * 2 * math . pi )) * inv_theta_log_scale
189189 )
190190
191191
192192def yarn_find_correction_range (
193193 low_rot : int ,
194194 high_rot : int ,
195195 d : tir .Var ,
196- theta : float ,
197196 max_position_embeddings : int ,
197+ inv_theta_log_scale : Optional [Union [float , tir .PrimExpr ]] = None ,
198198):
199199 """Find the correction range based on the number of rotations"""
200- low = yarn_find_correction_dim (low_rot , d , theta , max_position_embeddings )
201- high = yarn_find_correction_dim (high_rot , d , theta , max_position_embeddings )
200+ low = yarn_find_correction_dim (
201+ low_rot , d , max_position_embeddings , inv_theta_log_scale = inv_theta_log_scale
202+ )
203+ high = yarn_find_correction_dim (
204+ high_rot , d , max_position_embeddings , inv_theta_log_scale = inv_theta_log_scale
205+ )
202206 return tir .max (low , 0 ), tir .min (high , d - 1 )
203207
204208
205209def rope_freq_yarn (
206210 s : tir .Var ,
207211 d : tir .Var ,
208212 d_range : int ,
209- theta : float ,
213+ theta : Union [ float , tir . PrimExpr ] ,
210214 dtype : str ,
211215 original_max_position_embeddings : int ,
212216 scaling_factor : float ,
213217 beta_fast : int ,
214218 beta_slow : int ,
219+ inv_theta_log_scale : Optional [Union [float , tir .PrimExpr ]] = None ,
215220): # pylint: disable=too-many-arguments, too-many-locals
216221 """Compute the inverse frequency of RoPE for yarn RoPE scaling."""
217222
@@ -221,7 +226,11 @@ def rope_freq_yarn(
221226 freq_inter = tir .const (1 , "float32" ) / (scaling_factor * freq_power )
222227
223228 low , high = yarn_find_correction_range (
224- beta_fast , beta_slow , d_range , theta , original_max_position_embeddings
229+ beta_fast ,
230+ beta_slow ,
231+ d_range ,
232+ original_max_position_embeddings ,
233+ inv_theta_log_scale = inv_theta_log_scale ,
225234 )
226235 high = tir .if_then_else (low == high , high + 0.001 , high )
227236 inv_freq_mask = tir .const (1 , "float32" ) - tir .max (
@@ -266,12 +275,15 @@ def switch_rope_freq_func(rope_scaling: Dict[str, Any]) -> Callable:
266275 original_max_position_embeddings = rope_scaling ["original_max_position_embeddings" ],
267276 )
268277 if rope_scaling ["rope_type" ] == "yarn" :
278+ inv_theta_log_scale = rope_scaling .get ("inv_theta_log_scale" )
279+ assert inv_theta_log_scale is not None , "inv_theta_log_scale must be precomputed for YaRN"
269280 return partial (
270281 rope_freq_yarn ,
271282 original_max_position_embeddings = rope_scaling ["original_max_position_embeddings" ],
272283 scaling_factor = rope_scaling ["factor" ],
273284 beta_fast = rope_scaling ["beta_fast" ],
274285 beta_slow = rope_scaling ["beta_slow" ],
286+ inv_theta_log_scale = inv_theta_log_scale ,
275287 )
276288 raise ValueError (f'Unsupported RoPE scaling type: { rope_scaling ["rope_type" ]} ' )
277289
0 commit comments