@@ -268,7 +268,7 @@ def forward(
268268 return query , key
269269
270270
271- class GptOssScalingRotaryEmbedding :
271+ class YarnScalingRotaryEmbedding :
272272 def __init__ (
273273 self ,
274274 rotary_dim ,
@@ -345,10 +345,29 @@ def get_rope_impl(
345345 rotary_emb_layer = QwenRotaryEmbedding (rotary_dim , base , partial_rotary_factor )
346346 rotary_emb = rotary_emb_layer (position_ids )
347347 elif architecture .startswith ("Glm" ):
348- rotary_emb_layer = GlmRotaryEmbedding (rotary_dim , base , partial_rotary_factor )
348+ rope_scaling = getattr (model_config , "rope_scaling" , None )
349+ if (
350+ rope_scaling is not None
351+ and isinstance (rope_scaling , dict )
352+ and rope_scaling .get ("rope_type" , rope_scaling .get ("type" , "" )) == "yarn"
353+ and "factor" in rope_scaling
354+ ):
355+ yarn_rotary_dim = int (rotary_dim * partial_rotary_factor ) if partial_rotary_factor < 1.0 else rotary_dim
356+ rotary_emb_layer = YarnScalingRotaryEmbedding (
357+ rotary_dim = yarn_rotary_dim ,
358+ base = base ,
359+ original_max_position_embeddings = rope_scaling ["original_max_position_embeddings" ],
360+ scale = rope_scaling ["factor" ],
361+ mscale = rope_scaling .get ("mscale" , 1.0 ),
362+ beta_fast = rope_scaling .get ("beta_fast" , 32 ),
363+ beta_slow = rope_scaling .get ("beta_slow" , 1 ),
364+ use_neox_rotary_style = False ,
365+ )
366+ else :
367+ rotary_emb_layer = GlmRotaryEmbedding (rotary_dim , base , partial_rotary_factor )
349368 rotary_emb = rotary_emb_layer (position_ids )
350369 elif architecture .startswith ("GptOss" ):
351- rotary_emb_layer = GptOssScalingRotaryEmbedding (
370+ rotary_emb_layer = YarnScalingRotaryEmbedding (
352371 rotary_dim = model_config .head_dim ,
353372 base = model_config .rope_theta ,
354373 original_max_position_embeddings = model_config .rope_scaling ["original_max_position_embeddings" ],
0 commit comments