Skip to content

Commit 13eaea0

Browse files
authored
supoort glm yarn rope (#7893)
1 parent 8a4ac65 commit 13eaea0

1 file changed

Lines changed: 22 additions & 3 deletions

File tree

fastdeploy/model_executor/layers/rotary_embedding.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def forward(
268268
return query, key
269269

270270

271-
class GptOssScalingRotaryEmbedding:
271+
class YarnScalingRotaryEmbedding:
272272
def __init__(
273273
self,
274274
rotary_dim,
@@ -345,10 +345,29 @@ def get_rope_impl(
345345
rotary_emb_layer = QwenRotaryEmbedding(rotary_dim, base, partial_rotary_factor)
346346
rotary_emb = rotary_emb_layer(position_ids)
347347
elif architecture.startswith("Glm"):
348-
rotary_emb_layer = GlmRotaryEmbedding(rotary_dim, base, partial_rotary_factor)
348+
rope_scaling = getattr(model_config, "rope_scaling", None)
349+
if (
350+
rope_scaling is not None
351+
and isinstance(rope_scaling, dict)
352+
and rope_scaling.get("rope_type", rope_scaling.get("type", "")) == "yarn"
353+
and "factor" in rope_scaling
354+
):
355+
yarn_rotary_dim = int(rotary_dim * partial_rotary_factor) if partial_rotary_factor < 1.0 else rotary_dim
356+
rotary_emb_layer = YarnScalingRotaryEmbedding(
357+
rotary_dim=yarn_rotary_dim,
358+
base=base,
359+
original_max_position_embeddings=rope_scaling["original_max_position_embeddings"],
360+
scale=rope_scaling["factor"],
361+
mscale=rope_scaling.get("mscale", 1.0),
362+
beta_fast=rope_scaling.get("beta_fast", 32),
363+
beta_slow=rope_scaling.get("beta_slow", 1),
364+
use_neox_rotary_style=False,
365+
)
366+
else:
367+
rotary_emb_layer = GlmRotaryEmbedding(rotary_dim, base, partial_rotary_factor)
349368
rotary_emb = rotary_emb_layer(position_ids)
350369
elif architecture.startswith("GptOss"):
351-
rotary_emb_layer = GptOssScalingRotaryEmbedding(
370+
rotary_emb_layer = YarnScalingRotaryEmbedding(
352371
rotary_dim=model_config.head_dim,
353372
base=model_config.rope_theta,
354373
original_max_position_embeddings=model_config.rope_scaling["original_max_position_embeddings"],

0 commit comments

Comments
 (0)