Skip to content

Commit 1090f8b

Browse files
authored
[Models]support GLM4.7 Flash && Ernie_MLA (#7139)
* support GLM4.7 Flash && Ernie_MLA
1 parent 5f612a3 commit 1090f8b

3 files changed

Lines changed: 65 additions & 25 deletions

File tree

fastdeploy/model_executor/forward_meta.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,9 @@ class ForwardMeta:
158158
# for prefill
159159
exist_prefill: bool = False
160160

161+
# for mla & dsa
161162
position_ids: Optional[paddle.Tensor] = None
163+
mask_encoder_batch: Optional[paddle.Tensor] = None
162164

163165
real_bsz: int = 0
164166

fastdeploy/model_executor/layers/attention/mla_attention_backend.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,16 @@ def __init__(
272272
self.num_layers_draft_model: int = int(fd_config.speculative_config.method == SpecMethod.MTP)
273273

274274
self.num_heads: int = num_heads
275+
self.heads_need_padding = False
276+
if self.num_heads < 64 and fd_config.parallel_config.tensor_parallel_size > 1:
277+
self.padding_num_heads = 64 - self.num_heads
278+
self.heads_need_padding = True
279+
logger.warning(
280+
"MLA num_attention_heads is less than 64, force to use 64 num_heads. "
281+
"current num_heads=%d, tp_size=%d",
282+
self.num_heads,
283+
fd_config.parallel_config.tensor_parallel_size,
284+
)
275285
self.head_dim: int = fd_config.model_config.head_dim
276286
self.num_layers: int = fd_config.model_config.num_hidden_layers
277287

@@ -280,7 +290,9 @@ def __init__(
280290
self.qk_rope_head_dim: int = fd_config.model_config.qk_rope_head_dim
281291
self.qk_head_dim: int = fd_config.model_config.qk_nope_head_dim + fd_config.model_config.qk_rope_head_dim
282292
self.attn_softmax_scale: float = self.qk_head_dim**-0.5
283-
if fd_config.model_config.rope_scaling:
293+
self.rope_scaling = getattr(fd_config.model_config, "rope_scaling", None)
294+
if self.rope_scaling and "factor" in self.rope_scaling:
295+
# if fd_config.model_config.rope_scaling:
284296
mscale_all_dim = fd_config.model_config.rope_scaling.get("mscale_all_dim", False) # 1.0
285297
scaling_factor = fd_config.model_config.rope_scaling["factor"] # 40
286298
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
@@ -604,6 +616,10 @@ def forward_mixed(
604616

605617
if int(os.getenv("USE_FLASH_MLA", "0")) == 0:
606618
assert self.num_heads <= 64, "paddle mla attention support failed"
619+
if self.heads_need_padding:
620+
q = paddle.nn.functional.pad(
621+
q, [0, (self.padding_num_heads) * (self.kv_lora_rank + self.qk_rope_head_dim)], value=0.0
622+
).contiguous()
607623
# 多头潜在注意力计算
608624
fmha_out = multi_head_latent_attention(
609625
q,
@@ -646,6 +662,8 @@ def forward_mixed(
646662
True, # causal
647663
speculate_decoder,
648664
)
665+
if self.heads_need_padding:
666+
fmha_out = fmha_out[:, : self.num_heads * self.kv_lora_rank].contiguous()
649667

650668
return fmha_out
651669
else:
@@ -661,6 +679,12 @@ def forward_mixed(
661679
tile_scheduler_metadata, num_splits = flash_mla.get_mla_metadata()
662680
token_num = q.shape[0]
663681
decoder_q.reshape_([-1, 1, self.num_heads, 576])
682+
if self.heads_need_padding:
683+
padded_q = paddle.zeros(
684+
[decoder_q.shape[0], decoder_q.shape[1], 64, decoder_q.shape[3]], dtype=decoder_q.dtype
685+
)
686+
padded_q[:, :, : self.num_heads, :] = decoder_q
687+
decoder_q = padded_q
664688

665689
new_cache_shape = latent_cache.shape
666690
assert new_cache_shape[1] == 1
@@ -679,6 +703,8 @@ def forward_mixed(
679703
softmax_scale=self.attn_softmax_scale,
680704
causal=True,
681705
)
706+
if self.heads_need_padding:
707+
decoder_res = decoder_res[:, :, : self.num_heads, :].contiguous()
682708

683709
final_res = insert_decoder_result_back(
684710
decoder_res,

fastdeploy/model_executor/models/deepseek_v3.py

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None
289289
v_head_dim=self.v_head_dim,
290290
)
291291
self.rope_scaling = getattr(fd_config.model_config, "rope_scaling", None)
292-
if self.rope_scaling:
292+
if self.rope_scaling and "factor" in self.rope_scaling:
293293
mscale_all_dim = self.rope_scaling.get("mscale_all_dim", False)
294294
scaling_factor = self.rope_scaling["factor"]
295295
mscale = self.yarn_get_mscale(scaling_factor, float(mscale_all_dim))
@@ -344,8 +344,6 @@ def forward(
344344
self,
345345
forward_meta: ForwardMeta,
346346
hidden_states: paddle.Tensor,
347-
position_ids: paddle.Tensor,
348-
mask_encoder_batch: paddle.Tensor,
349347
):
350348
""" """
351349

@@ -363,7 +361,7 @@ def forward(
363361
query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1)
364362

365363
key_pe.reshape_([-1, 1, self.qk_rope_head_dim])
366-
query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe)
364+
query_pe, key_pe = self.rotary_emb(forward_meta.position_ids, query_pe, key_pe)
367365

368366
compressed_kv = self.kv_a_layernorm(compressed_kv)[0]
369367

@@ -400,7 +398,7 @@ def forward(
400398
fmha_out_prefill.reshape_([-1, self.num_attention_heads_tp, self.qk_head_dim])
401399
fmha_out_prefill = fmha_out_prefill[:, :, : self.v_head_dim]
402400
fmha_out_prefill.reshape_([-1, self.num_attention_heads_tp * self.v_head_dim])
403-
fmha_out_prefill = fmha_out_prefill * mask_encoder_batch.cast(fmha_out_prefill.dtype)
401+
fmha_out_prefill = fmha_out_prefill * forward_meta.mask_encoder_batch.cast(fmha_out_prefill.dtype)
404402
fmha_out = fmha_out_prefill
405403

406404
if need_do_decode: # max_dec_len_this_time
@@ -617,7 +615,7 @@ def __init__(
617615
# self.buffer = paddle.zeros([2048 * 2048], dtype=paddle.uint8)
618616

619617
def forward(
620-
self, forward_meta: ForwardMeta, hidden_states: paddle.Tensor, qr: paddle.Tensor, positions, rotary_emb
618+
self, forward_meta: ForwardMeta, hidden_states: paddle.Tensor, qr: paddle.Tensor, rotary_emb
621619
) -> paddle.Tensor:
622620
self.indexer_cache = forward_meta.caches[2 * self.layer_id + 1]
623621

@@ -629,7 +627,7 @@ def forward(
629627
k, _ = self.k_norm(k)
630628
k_pe, k_nope = paddle.split(k, [self.rope_dim, self.index_head_dim - self.rope_dim], axis=-1)
631629

632-
q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1))
630+
q_pe, k_pe = rotary_emb(forward_meta.position_ids, q_pe, k_pe.unsqueeze(1))
633631
q_pe = q_pe.reshape(-1, self.index_n_heads, self.rope_dim)
634632
k_pe = k_pe.reshape(-1, 1, self.rope_dim)
635633

@@ -853,7 +851,7 @@ def __init__(self, fd_config: FDConfig, layer_id: int, prefix: str = "") -> None
853851
v_head_dim=self.v_head_dim,
854852
)
855853
self.rope_scaling = getattr(fd_config.model_config, "rope_scaling", None)
856-
if self.rope_scaling:
854+
if self.rope_scaling and "factor" in self.rope_scaling:
857855
mscale_all_dim = self.rope_scaling.get("mscale_all_dim", False)
858856
scaling_factor = self.rope_scaling["factor"]
859857
mscale = self.yarn_get_mscale(scaling_factor, float(mscale_all_dim))
@@ -926,8 +924,6 @@ def forward(
926924
self,
927925
forward_meta: ForwardMeta,
928926
hidden_states: paddle.Tensor,
929-
position_ids: paddle.Tensor,
930-
mask_encoder_batch: paddle.Tensor,
931927
):
932928
""" """
933929
qkv_a_out = self.qkv_a_proj_with_mqa(hidden_states)
@@ -940,15 +936,13 @@ def forward(
940936
query = self.q_a_layernorm(query)[0]
941937

942938
# DSA indexer
943-
indexer_top_k = self.indexer(
944-
forward_meta, hidden_states, query, position_ids, rotary_emb=self.indexer_rotary_emb
945-
)
939+
indexer_top_k = self.indexer(forward_meta, hidden_states, query, rotary_emb=self.indexer_rotary_emb)
946940

947941
query = self.q_b_proj(query)
948942
query.reshape_([-1, self.num_attention_heads_tp, self.qk_head_dim])
949943
query_nope, query_pe = query.split([self.qk_nope_head_dim, self.qk_rope_head_dim], axis=-1)
950944

951-
query_pe, key_pe = self.rotary_emb(position_ids, query_pe, key_pe)
945+
query_pe, key_pe = self.rotary_emb(forward_meta.position_ids, query_pe, key_pe)
952946
q_nope_out = self.kv_b_proj_bmm(query_nope.transpose([1, 0, 2]).contiguous(), proj_type="k")
953947
q_input = paddle.concat([q_nope_out.transpose([1, 0, 2]).contiguous(), query_pe], axis=-1)
954948

@@ -1044,16 +1038,14 @@ def forward(
10441038
forward_meta: ForwardMeta,
10451039
hidden_states: paddle.Tensor,
10461040
residual: paddle.Tensor,
1047-
position_ids: paddle.Tensor,
1048-
mask_encoder_batch: paddle.Tensor,
10491041
):
10501042
""" """
10511043
if hidden_states.shape[0] > 0:
10521044
hidden_states, residual = self.input_layernorm(
10531045
hidden_states, residual_input=residual, forward_meta=forward_meta
10541046
)
10551047

1056-
hidden_states = self.self_attn(forward_meta, hidden_states, position_ids, mask_encoder_batch)
1048+
hidden_states = self.self_attn(forward_meta, hidden_states)
10571049

10581050
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
10591051
else:
@@ -1108,8 +1100,6 @@ def forward(
11081100
self,
11091101
ids_remove_padding: paddle.Tensor,
11101102
forward_meta: ForwardMeta,
1111-
position_ids: paddle.Tensor,
1112-
mask_encoder_batch: paddle.Tensor,
11131103
):
11141104
""" """
11151105
hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding, forward_meta=forward_meta)
@@ -1120,8 +1110,6 @@ def forward(
11201110
forward_meta,
11211111
hidden_states,
11221112
residual,
1123-
position_ids,
1124-
mask_encoder_batch,
11251113
)
11261114
out = self.norm(hidden_states, residual, forward_meta=forward_meta)[0]
11271115

@@ -1297,12 +1285,10 @@ def forward(
12971285
forward_meta: ForwardMeta,
12981286
):
12991287
ids_remove_padding = inputs["ids_remove_padding"]
1300-
forward_meta.position_ids, mask_encoder_batch = self.pre_process(forward_meta)
1288+
forward_meta.position_ids, forward_meta.mask_encoder_batch = self.pre_process(forward_meta)
13011289
hidden_states = self.model(
13021290
ids_remove_padding=ids_remove_padding,
13031291
forward_meta=forward_meta,
1304-
position_ids=forward_meta.position_ids,
1305-
mask_encoder_batch=mask_encoder_batch,
13061292
)
13071293
return hidden_states
13081294

@@ -1353,3 +1339,29 @@ class DeepSeekV32PretrainedModel(DeepSeekV3PretrainedModel):
13531339
@classmethod
13541340
def arch_name(self):
13551341
return "DeepseekV32ForCausalLM"
1342+
1343+
1344+
@ModelRegistry.register_model_class(
1345+
architecture="Glm4MoeLiteForCausalLM",
1346+
module_name="deepseek_v3",
1347+
category=ModelCategory.TEXT_GENERATION,
1348+
primary_use=ModelCategory.TEXT_GENERATION,
1349+
)
1350+
class Glm4MoeLiteForCausalLM(DeepseekV3ForCausalLM):
1351+
"""
1352+
Glm4MoeLiteForCausalLM
1353+
"""
1354+
1355+
@classmethod
1356+
def name(cls):
1357+
return "Glm4MoeLiteForCausalLM"
1358+
1359+
1360+
class Glm4MoeLitePretrainedModel(DeepSeekV3PretrainedModel):
1361+
"""
1362+
Glm4MoeLite
1363+
"""
1364+
1365+
@classmethod
1366+
def arch_name(self):
1367+
return "Glm4MoeLiteForCausalLM"

0 commit comments

Comments
 (0)