-
Notifications
You must be signed in to change notification settings - Fork 752
[OPTIMIZE] remove decode_mla_write_cache in mla attention backend #7834
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
88d944f
16f8f98
1f7549b
4a7d2c2
dcbcb46
f98f1a6
3f0849e
129c510
67468be
28f0471
bb8c233
72fed03
e468263
5e5d7f0
bb51fb0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -664,7 +664,6 @@ def forward_extend( | |
| metadata.block_tables, | ||
| metadata.kv_signal_data_list[layer.layer_id], | ||
| "none", | ||
| getattr(forward_meta, "max_input_length", -1), | ||
| ) | ||
|
|
||
| fmha_out = self.flash_attn_func( | ||
|
|
@@ -720,7 +719,6 @@ def forward_decode( | |
| forward_meta.cu_seqlens_q, | ||
| metadata.block_tables, | ||
| "none", | ||
| self.max_seq_len, | ||
| speculate_decoder, | ||
| ) | ||
|
|
||
|
|
@@ -799,21 +797,23 @@ def forward_mixed( | |
|
|
||
| latent_cache = forward_meta.caches[layer.layer_id] if hasattr(forward_meta, "caches") else None | ||
|
|
||
| assert k_pe.shape[0] == compressed_kv.shape[0] | ||
| prefill_mla_write_cache( | ||
| compressed_kv, | ||
| k_pe, | ||
| latent_cache, | ||
| forward_meta.seq_lens_this_time, | ||
| forward_meta.seq_lens_decoder, | ||
| forward_meta.batch_id_per_token, | ||
| forward_meta.cu_seqlens_q, | ||
| metadata.block_tables, | ||
| forward_meta.slot_mapping, | ||
| metadata.kv_signal_data_list[layer.layer_id], | ||
| "none", | ||
| ) | ||
|
|
||
| # Prefill branch: k is not None | ||
| if k is not None: | ||
| prefill_mla_write_cache( | ||
| compressed_kv, | ||
| k_pe, | ||
| latent_cache, | ||
| forward_meta.seq_lens_encoder, | ||
| forward_meta.seq_lens_decoder, | ||
| forward_meta.batch_id_per_token, | ||
| forward_meta.cu_seqlens_q, | ||
| metadata.block_tables, | ||
| metadata.kv_signal_data_list[layer.layer_id], | ||
| "none", | ||
| self.max_seq_len, | ||
| ) | ||
|
|
||
| if self.prop.major == 10: | ||
| # TODO support FA4 | ||
|
|
@@ -845,20 +845,6 @@ def forward_mixed( | |
|
|
||
| # Decode branch: k is None | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ❓ 疑问 请确认: |
||
| if k is None: | ||
| decode_mla_write_cache( | ||
| compressed_kv, | ||
| k_pe, | ||
| latent_cache, | ||
| forward_meta.seq_lens_decoder, | ||
| forward_meta.seq_lens_encoder, | ||
| forward_meta.batch_id_per_token, | ||
| forward_meta.cu_seqlens_q, | ||
| metadata.block_tables, | ||
| "none", | ||
| self.max_seq_len, | ||
| speculate_decoder, | ||
| ) | ||
|
|
||
| if int(os.getenv("USE_FLASH_MLA", "0")) == 0 and self.prop.major == 9: | ||
| assert self.num_heads <= 64, "paddle mla attention support failed" | ||
| if self.heads_need_padding: | ||
|
|
@@ -961,6 +947,12 @@ def forward_mixed( | |
| @staticmethod | ||
| def mla_blackwell(decoder_q, latent_cache, block_table, cache_seqlens, attn_softmax_scale): | ||
|
|
||
| # decoder_q = decoder_q.cast(paddle.float8_e4m3fn) | ||
| # latent_cache = latent_cache.cast(paddle.float8_e4m3fn) | ||
|
|
||
| assert decoder_q.dtype == latent_cache.dtype | ||
This comment was marked as outdated.
Sorry, something went wrong.
This comment was marked as outdated.
Sorry, something went wrong. |
||
| q_dtype = decoder_q.dtype | ||
|
|
||
| page_size = latent_cache.shape[2] | ||
| q_num_heads = decoder_q.shape[2] | ||
| assert decoder_q.shape[1:] == [1, q_num_heads, 576] | ||
|
|
@@ -1008,6 +1000,8 @@ def mla_blackwell(decoder_q, latent_cache, block_table, cache_seqlens, attn_soft | |
|
|
||
| from mla_decode_fp16 import BlackwellMultiHeadLatentAttentionForwardFP16 | ||
|
|
||
| # from mla_decode_fp8 import BlackwellMultiHeadLatentAttentionForwardFP8 | ||
|
|
||
| mla = BlackwellMultiHeadLatentAttentionForwardFP16( | ||
| cutlass.Float32, | ||
| cutlass.Float32, | ||
|
|
@@ -1063,10 +1057,18 @@ def mla_blackwell(decoder_q, latent_cache, block_table, cache_seqlens, attn_soft | |
| stream, | ||
| ) | ||
|
|
||
| if q_dtype == paddle.float8_e4m3fn: | ||
| paddle_output = paddle_output.cast("bfloat16") | ||
| return paddle_output | ||
This comment was marked as outdated.
Sorry, something went wrong. |
||
|
|
||
| @staticmethod | ||
| def flashmla_baseline(decoder_q, latent_cache, block_table, cache_seqlens, attn_softmax_scale): | ||
|
|
||
| assert decoder_q.dtype == latent_cache.dtype | ||
|
|
||
| decoder_q = decoder_q.cast("bfloat16") | ||
| latent_cache = latent_cache.cast("bfloat16") | ||
|
|
||
| page_size = latent_cache.shape[2] | ||
| q_num_heads = decoder_q.shape[2] | ||
| assert decoder_q.shape[1:] == [1, q_num_heads, 576] | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🔴 Bug CUDA kernel 中残留调试代码:
printf+asm volatile("trap;")在生产环境会导致 GPU 崩溃。asm volatile("trap;")相当于 GPU 上的abort(),一旦触发将终止整个 CUDA context,导致服务不可用。这段代码明显是用于对齐验证slot_mapping与block_tables两种寻址路径是否一致的临时调试代码,不应合入主干。建议修复:
#ifdef DEBUG_MLA_CACHE ... #endif