@@ -219,7 +219,7 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device:
219219 is_tp = True ,
220220 )
221221
222- hidden_activation = config . hidden_activation
222+ hidden_activation = getattr ( config , ' hidden_activation' , None )
223223 if hidden_activation is None :
224224 hidden_activation = 'gelu_pytorch_tanh'
225225 assert hidden_activation == 'gelu_pytorch_tanh'
@@ -381,16 +381,47 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device:
381381 self .norm = RMSNorm (config .hidden_size , config .rms_norm_eps , dtype = dtype , device = device )
382382
383383 # build rotary embedding
384- self .rotary_emb = build_rotary_embedding_from_config (config )
384+ self .build_rope_emb (config )
385385
386- if self .model_type == 'gemma3_text' :
387- rope_dim = config .head_dim
388- rope_max_pos_emb = config .max_position_embeddings
386+ def build_rope_emb (self , config : PretrainedConfig ):
387+ rope_dim = config .head_dim
388+ rope_max_pos_emb = config .max_position_embeddings
389+
390+ if self .model_type != 'gemma3_text' :
391+ self .rotary_emb = build_rotary_embedding_from_config (config )
392+ return
393+
394+ # for gemma3
395+ if hasattr (config , 'rope_local_base_freq' ):
389396 rope_base = config .rope_local_base_freq
397+ self .rotary_emb = build_rotary_embedding_from_config (config )
398+
399+ if self .model_type == 'gemma3_text' :
400+ self .rotary_emb_local = build_rotary_embedding (
401+ rope_dim ,
402+ rope_max_pos_emb ,
403+ rope_base ,
404+ emb_type = RopeType .Default ,
405+ )
406+ else :
407+ # for transformers>=5
408+ rope_dim = config .head_dim
409+ from lmdeploy .pytorch .nn .rotary_embedding import get_rope_parameters
410+ rope_parameters = get_rope_parameters (config )
411+ full_attention = rope_parameters ['full_attention' ]
412+ sliding_attention = rope_parameters ['sliding_attention' ]
413+ # note that emb type has been fixed.
414+ self .rotary_emb = build_rotary_embedding (
415+ rope_dim ,
416+ rope_max_pos_emb ,
417+ base = full_attention ['rope_theta' ],
418+ scaling_factor = full_attention ['factor' ],
419+ emb_type = RopeType .LinearScaling ,
420+ )
390421 self .rotary_emb_local = build_rotary_embedding (
391422 rope_dim ,
392423 rope_max_pos_emb ,
393- rope_base ,
424+ base = sliding_attention [ 'rope_theta' ] ,
394425 emb_type = RopeType .Default ,
395426 )
396427
0 commit comments