@@ -289,7 +289,7 @@ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None
289289 v_head_dim = self .v_head_dim ,
290290 )
291291 self .rope_scaling = getattr (fd_config .model_config , "rope_scaling" , None )
292- if self .rope_scaling :
292+ if self .rope_scaling and "factor" in self . rope_scaling :
293293 mscale_all_dim = self .rope_scaling .get ("mscale_all_dim" , False )
294294 scaling_factor = self .rope_scaling ["factor" ]
295295 mscale = self .yarn_get_mscale (scaling_factor , float (mscale_all_dim ))
@@ -344,8 +344,6 @@ def forward(
344344 self ,
345345 forward_meta : ForwardMeta ,
346346 hidden_states : paddle .Tensor ,
347- position_ids : paddle .Tensor ,
348- mask_encoder_batch : paddle .Tensor ,
349347 ):
350348 """ """
351349
@@ -363,7 +361,7 @@ def forward(
363361 query_nope , query_pe = query .split ([self .qk_nope_head_dim , self .qk_rope_head_dim ], axis = - 1 )
364362
365363 key_pe .reshape_ ([- 1 , 1 , self .qk_rope_head_dim ])
366- query_pe , key_pe = self .rotary_emb (position_ids , query_pe , key_pe )
364+ query_pe , key_pe = self .rotary_emb (forward_meta . position_ids , query_pe , key_pe )
367365
368366 compressed_kv = self .kv_a_layernorm (compressed_kv )[0 ]
369367
@@ -400,7 +398,7 @@ def forward(
400398 fmha_out_prefill .reshape_ ([- 1 , self .num_attention_heads_tp , self .qk_head_dim ])
401399 fmha_out_prefill = fmha_out_prefill [:, :, : self .v_head_dim ]
402400 fmha_out_prefill .reshape_ ([- 1 , self .num_attention_heads_tp * self .v_head_dim ])
403- fmha_out_prefill = fmha_out_prefill * mask_encoder_batch .cast (fmha_out_prefill .dtype )
401+ fmha_out_prefill = fmha_out_prefill * forward_meta . mask_encoder_batch .cast (fmha_out_prefill .dtype )
404402 fmha_out = fmha_out_prefill
405403
406404 if need_do_decode : # max_dec_len_this_time
@@ -617,7 +615,7 @@ def __init__(
617615 # self.buffer = paddle.zeros([2048 * 2048], dtype=paddle.uint8)
618616
619617 def forward (
620- self , forward_meta : ForwardMeta , hidden_states : paddle .Tensor , qr : paddle .Tensor , positions , rotary_emb
618+ self , forward_meta : ForwardMeta , hidden_states : paddle .Tensor , qr : paddle .Tensor , rotary_emb
621619 ) -> paddle .Tensor :
622620 self .indexer_cache = forward_meta .caches [2 * self .layer_id + 1 ]
623621
@@ -629,7 +627,7 @@ def forward(
629627 k , _ = self .k_norm (k )
630628 k_pe , k_nope = paddle .split (k , [self .rope_dim , self .index_head_dim - self .rope_dim ], axis = - 1 )
631629
632- q_pe , k_pe = rotary_emb (positions , q_pe , k_pe .unsqueeze (1 ))
630+ q_pe , k_pe = rotary_emb (forward_meta . position_ids , q_pe , k_pe .unsqueeze (1 ))
633631 q_pe = q_pe .reshape (- 1 , self .index_n_heads , self .rope_dim )
634632 k_pe = k_pe .reshape (- 1 , 1 , self .rope_dim )
635633
@@ -853,7 +851,7 @@ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None
853851 v_head_dim = self .v_head_dim ,
854852 )
855853 self .rope_scaling = getattr (fd_config .model_config , "rope_scaling" , None )
856- if self .rope_scaling :
854+ if self .rope_scaling and "factor" in self . rope_scaling :
857855 mscale_all_dim = self .rope_scaling .get ("mscale_all_dim" , False )
858856 scaling_factor = self .rope_scaling ["factor" ]
859857 mscale = self .yarn_get_mscale (scaling_factor , float (mscale_all_dim ))
@@ -926,8 +924,6 @@ def forward(
926924 self ,
927925 forward_meta : ForwardMeta ,
928926 hidden_states : paddle .Tensor ,
929- position_ids : paddle .Tensor ,
930- mask_encoder_batch : paddle .Tensor ,
931927 ):
932928 """ """
933929 qkv_a_out = self .qkv_a_proj_with_mqa (hidden_states )
@@ -940,15 +936,13 @@ def forward(
940936 query = self .q_a_layernorm (query )[0 ]
941937
942938 # DSA indexer
943- indexer_top_k = self .indexer (
944- forward_meta , hidden_states , query , position_ids , rotary_emb = self .indexer_rotary_emb
945- )
939+ indexer_top_k = self .indexer (forward_meta , hidden_states , query , rotary_emb = self .indexer_rotary_emb )
946940
947941 query = self .q_b_proj (query )
948942 query .reshape_ ([- 1 , self .num_attention_heads_tp , self .qk_head_dim ])
949943 query_nope , query_pe = query .split ([self .qk_nope_head_dim , self .qk_rope_head_dim ], axis = - 1 )
950944
951- query_pe , key_pe = self .rotary_emb (position_ids , query_pe , key_pe )
945+ query_pe , key_pe = self .rotary_emb (forward_meta . position_ids , query_pe , key_pe )
952946 q_nope_out = self .kv_b_proj_bmm (query_nope .transpose ([1 , 0 , 2 ]).contiguous (), proj_type = "k" )
953947 q_input = paddle .concat ([q_nope_out .transpose ([1 , 0 , 2 ]).contiguous (), query_pe ], axis = - 1 )
954948
@@ -1044,16 +1038,14 @@ def forward(
10441038 forward_meta : ForwardMeta ,
10451039 hidden_states : paddle .Tensor ,
10461040 residual : paddle .Tensor ,
1047- position_ids : paddle .Tensor ,
1048- mask_encoder_batch : paddle .Tensor ,
10491041 ):
10501042 """ """
10511043 if hidden_states .shape [0 ] > 0 :
10521044 hidden_states , residual = self .input_layernorm (
10531045 hidden_states , residual_input = residual , forward_meta = forward_meta
10541046 )
10551047
1056- hidden_states = self .self_attn (forward_meta , hidden_states , position_ids , mask_encoder_batch )
1048+ hidden_states = self .self_attn (forward_meta , hidden_states )
10571049
10581050 hidden_states , residual = self .post_attention_layernorm (hidden_states , residual )
10591051 else :
@@ -1108,8 +1100,6 @@ def forward(
11081100 self ,
11091101 ids_remove_padding : paddle .Tensor ,
11101102 forward_meta : ForwardMeta ,
1111- position_ids : paddle .Tensor ,
1112- mask_encoder_batch : paddle .Tensor ,
11131103 ):
11141104 """ """
11151105 hidden_states = self .embed_tokens (ids_remove_padding = ids_remove_padding , forward_meta = forward_meta )
@@ -1120,8 +1110,6 @@ def forward(
11201110 forward_meta ,
11211111 hidden_states ,
11221112 residual ,
1123- position_ids ,
1124- mask_encoder_batch ,
11251113 )
11261114 out = self .norm (hidden_states , residual , forward_meta = forward_meta )[0 ]
11271115
@@ -1297,12 +1285,10 @@ def forward(
12971285 forward_meta : ForwardMeta ,
12981286 ):
12991287 ids_remove_padding = inputs ["ids_remove_padding" ]
1300- forward_meta .position_ids , mask_encoder_batch = self .pre_process (forward_meta )
1288+ forward_meta .position_ids , forward_meta . mask_encoder_batch = self .pre_process (forward_meta )
13011289 hidden_states = self .model (
13021290 ids_remove_padding = ids_remove_padding ,
13031291 forward_meta = forward_meta ,
1304- position_ids = forward_meta .position_ids ,
1305- mask_encoder_batch = mask_encoder_batch ,
13061292 )
13071293 return hidden_states
13081294
@@ -1353,3 +1339,29 @@ class DeepSeekV32PretrainedModel(DeepSeekV3PretrainedModel):
13531339 @classmethod
13541340 def arch_name (self ):
13551341 return "DeepseekV32ForCausalLM"
1342+
1343+
1344+ @ModelRegistry .register_model_class (
1345+ architecture = "Glm4MoeLiteForCausalLM" ,
1346+ module_name = "deepseek_v3" ,
1347+ category = ModelCategory .TEXT_GENERATION ,
1348+ primary_use = ModelCategory .TEXT_GENERATION ,
1349+ )
1350+ class Glm4MoeLiteForCausalLM (DeepseekV3ForCausalLM ):
1351+ """
1352+ Glm4MoeLiteForCausalLM
1353+ """
1354+
1355+ @classmethod
1356+ def name (cls ):
1357+ return "Glm4MoeLiteForCausalLM"
1358+
1359+
1360+ class Glm4MoeLitePretrainedModel (DeepSeekV3PretrainedModel ):
1361+ """
1362+ Glm4MoeLite
1363+ """
1364+
1365+ @classmethod
1366+ def arch_name (self ):
1367+ return "Glm4MoeLiteForCausalLM"
0 commit comments