Skip to content

Commit 4747a92

Browse files
authored
[Python][Relax] Fix YaRN correction dim calculation (#18661)
Precompute ```inv_theta_log_scale```
1 parent 4b1bd6d commit 4747a92

2 files changed

Lines changed: 40 additions & 9 deletions

File tree

python/tvm/relax/frontend/nn/llm/kv_cache.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,23 @@ def get_query_positions(self, total_length: tir.PrimExpr) -> Tensor:
297297
# pylint: enable=protected-access
298298

299299

300+
def _prepare_yarn_rope_scaling(
301+
rope_scaling: Optional[Dict[str, Any]],
302+
rope_theta: Optional[float],
303+
) -> Optional[Dict[str, Any]]:
304+
"""Ensure Yarn-specific scaling configs include the theta metadata."""
305+
if rope_scaling is None:
306+
return None
307+
if rope_scaling.get("rope_type") != "yarn":
308+
return rope_scaling
309+
310+
rope_scaling_updated = dict(rope_scaling)
311+
if "inv_theta_log_scale" not in rope_scaling_updated and rope_theta is not None:
312+
theta_value = float(rope_theta)
313+
rope_scaling_updated["inv_theta_log_scale"] = 1.0 / (2 * math.log(theta_value))
314+
return rope_scaling_updated
315+
316+
300317
class FlashInferPagedKVCache(PagedKVCache): # pylint: disable=too-few-public-methods
301318
"""Paged KV cache using FlashInfer (CUDA) kernels."""
302319

@@ -372,6 +389,7 @@ def __init__( # pylint: disable=too-many-locals
372389
Whether to enable disaggregation in the KV cache.
373390
"""
374391
assert rope_mode != RopeMode.INLINE, "FlashInfer RoPE does not support inline mode."
392+
rope_scaling = _prepare_yarn_rope_scaling(rope_scaling, rope_theta)
375393

376394
attn_kind_single = attn_kind[0] if isinstance(attn_kind, List) else attn_kind
377395
if attn_kind_single == "mha_sliding":
@@ -561,6 +579,7 @@ def __init__( # pylint: disable=too-many-locals
561579
target : Target
562580
The target to build the model to.
563581
"""
582+
rope_scaling = _prepare_yarn_rope_scaling(rope_scaling, rope_theta)
564583
attn_kind_single = attn_kind[0] if isinstance(attn_kind, List) else attn_kind
565584
if attn_kind_single == "mha_sliding":
566585
attn_kind_single = "mha"

python/tvm/relax/frontend/nn/llm/position_embedding.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import math
2121
from functools import partial
22-
from typing import Any, Callable, Dict, Optional, Tuple
22+
from typing import Any, Callable, Dict, Optional, Tuple, Union
2323

2424
from tvm import tir
2525
from tvm.relax.frontend.nn import Tensor, op
@@ -180,38 +180,43 @@ def rope_freq_longrope( # pylint: disable=too-many-arguments
180180
def yarn_find_correction_dim(
181181
num_rotations: int,
182182
d: tir.Var,
183-
theta: float,
184183
max_position_embeddings: int,
184+
inv_theta_log_scale: Optional[Union[float, tir.PrimExpr]] = None,
185185
):
186186
"""Inverse dim formula to find dim based on number of rotations"""
187-
return (d * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
188-
2 * math.log(theta)
187+
return (
188+
d * math.log(max_position_embeddings / (num_rotations * 2 * math.pi)) * inv_theta_log_scale
189189
)
190190

191191

192192
def yarn_find_correction_range(
193193
low_rot: int,
194194
high_rot: int,
195195
d: tir.Var,
196-
theta: float,
197196
max_position_embeddings: int,
197+
inv_theta_log_scale: Optional[Union[float, tir.PrimExpr]] = None,
198198
):
199199
"""Find the correction range based on the number of rotations"""
200-
low = yarn_find_correction_dim(low_rot, d, theta, max_position_embeddings)
201-
high = yarn_find_correction_dim(high_rot, d, theta, max_position_embeddings)
200+
low = yarn_find_correction_dim(
201+
low_rot, d, max_position_embeddings, inv_theta_log_scale=inv_theta_log_scale
202+
)
203+
high = yarn_find_correction_dim(
204+
high_rot, d, max_position_embeddings, inv_theta_log_scale=inv_theta_log_scale
205+
)
202206
return tir.max(low, 0), tir.min(high, d - 1)
203207

204208

205209
def rope_freq_yarn(
206210
s: tir.Var,
207211
d: tir.Var,
208212
d_range: int,
209-
theta: float,
213+
theta: Union[float, tir.PrimExpr],
210214
dtype: str,
211215
original_max_position_embeddings: int,
212216
scaling_factor: float,
213217
beta_fast: int,
214218
beta_slow: int,
219+
inv_theta_log_scale: Optional[Union[float, tir.PrimExpr]] = None,
215220
): # pylint: disable=too-many-arguments, too-many-locals
216221
"""Compute the inverse frequency of RoPE for yarn RoPE scaling."""
217222

@@ -221,7 +226,11 @@ def rope_freq_yarn(
221226
freq_inter = tir.const(1, "float32") / (scaling_factor * freq_power)
222227

223228
low, high = yarn_find_correction_range(
224-
beta_fast, beta_slow, d_range, theta, original_max_position_embeddings
229+
beta_fast,
230+
beta_slow,
231+
d_range,
232+
original_max_position_embeddings,
233+
inv_theta_log_scale=inv_theta_log_scale,
225234
)
226235
high = tir.if_then_else(low == high, high + 0.001, high)
227236
inv_freq_mask = tir.const(1, "float32") - tir.max(
@@ -266,12 +275,15 @@ def switch_rope_freq_func(rope_scaling: Dict[str, Any]) -> Callable:
266275
original_max_position_embeddings=rope_scaling["original_max_position_embeddings"],
267276
)
268277
if rope_scaling["rope_type"] == "yarn":
278+
inv_theta_log_scale = rope_scaling.get("inv_theta_log_scale")
279+
assert inv_theta_log_scale is not None, "inv_theta_log_scale must be precomputed for YaRN"
269280
return partial(
270281
rope_freq_yarn,
271282
original_max_position_embeddings=rope_scaling["original_max_position_embeddings"],
272283
scaling_factor=rope_scaling["factor"],
273284
beta_fast=rope_scaling["beta_fast"],
274285
beta_slow=rope_scaling["beta_slow"],
286+
inv_theta_log_scale=inv_theta_log_scale,
275287
)
276288
raise ValueError(f'Unsupported RoPE scaling type: {rope_scaling["rope_type"]}')
277289

0 commit comments

Comments
 (0)