diff --git a/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu b/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu index b582c862c38..5da32e89359 100644 --- a/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu +++ b/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cu @@ -27,8 +27,8 @@ std::vector PrefillMLAWriteCache( const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, + const paddle::Tensor& slot_mapping, const paddle::optional& kv_signal_data, - const int max_seq_len, cudaStream_t& stream, paddle::Tensor* kv_cache) { typedef PDTraits traits_; @@ -50,7 +50,9 @@ std::vector PrefillMLAWriteCache( int grid_size = 1; GetNumBlocks<128>(pack_num, &grid_size); - prefill_absorb_cache_kernel + using CT = DataType_; + + prefill_absorb_cache_kernel <<>>( reinterpret_cast( const_cast(kv_nope.data())), @@ -58,11 +60,11 @@ std::vector PrefillMLAWriteCache( const_cast(kv_pe.data())), reinterpret_cast(kv_cache->data()), block_tables.data(), + slot_mapping.data(), batch_id_per_token.data(), cu_seqlens_q.data(), seq_lens.data(), seq_lens_decoder.data(), - max_seq_len, max_blocks_per_seq, kv_num_heads, nope_size, @@ -108,9 +110,9 @@ std::vector PrefillMLAWriteCacheKernel( const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, + const paddle::Tensor& slot_mapping, const paddle::optional& kv_signal_data, - const std::string& cache_quant_type_str, - const int max_seq_len) { + const std::string& cache_quant_type_str) { cudaStream_t stream = kv_pe.stream(); AppendAttnMetaData meta_data; const auto& kv_nope_dims = kv_nope.dims(); @@ -137,8 +139,8 @@ std::vector PrefillMLAWriteCacheKernel( batch_id_per_token, cu_seqlens_q, block_tables, + slot_mapping, kv_signal_data, - max_seq_len, stream, const_cast(&kv_cache)); } @@ -152,8 +154,8 @@ std::vector PrefillMLAWriteCacheKernel( batch_id_per_token, cu_seqlens_q, block_tables, + slot_mapping, kv_signal_data, - max_seq_len, stream, const_cast(&kv_cache)); } @@ -171,7 +173,6 @@ std::vector DecodeMLAWriteCache( const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, - const int max_seq_len, const bool speculate_decoder, cudaStream_t& stream, paddle::Tensor* kv_cache) { @@ -207,7 +208,6 @@ std::vector DecodeMLAWriteCache( cu_seqlens_q.data(), seq_lens.data(), seq_lens_encoder.data(), - max_seq_len, max_blocks_per_seq, kv_num_heads, nope_size, @@ -229,7 +229,6 @@ std::vector DecodeMLAWriteCache( cu_seqlens_q.data(), seq_lens.data(), seq_lens_encoder.data(), - max_seq_len, max_blocks_per_seq, kv_num_heads, nope_size, @@ -250,7 +249,6 @@ std::vector DecodeMLAWriteCacheKernel( const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, const std::string& cache_quant_type_str, - const int max_seq_len, const bool speculate_decoder) { cudaStream_t stream = kv_pe.stream(); AppendAttnMetaData meta_data; @@ -278,7 +276,6 @@ std::vector DecodeMLAWriteCacheKernel( batch_id_per_token, cu_seqlens_q, block_tables, - max_seq_len, speculate_decoder, stream, const_cast(&kv_cache)); @@ -293,7 +290,6 @@ std::vector DecodeMLAWriteCacheKernel( batch_id_per_token, cu_seqlens_q, block_tables, - max_seq_len, speculate_decoder, stream, const_cast(&kv_cache)); @@ -311,10 +307,11 @@ PD_BUILD_STATIC_OP(prefill_mla_write_cache) "batch_id_per_token", "cu_seqlens_q", "block_tables", + "slot_mapping", paddle::Optional("kv_signal_data")}) .Outputs({"kv_cache_out"}) .SetInplaceMap({{"kv_cache", "kv_cache_out"}}) - .Attrs({"cache_quant_type_str: std::string", "max_seq_len: int"}) + .Attrs({"cache_quant_type_str: std::string"}) .SetKernelFn(PD_KERNEL(PrefillMLAWriteCacheKernel)); PD_BUILD_STATIC_OP(decode_mla_write_cache) @@ -328,7 +325,5 @@ PD_BUILD_STATIC_OP(decode_mla_write_cache) "block_tables"}) .Outputs({"kv_cache_out"}) .SetInplaceMap({{"kv_cache", "kv_cache_out"}}) - .Attrs({"cache_quant_type_str: std::string", - "max_seq_len: int", - "speculate_decoder: bool"}) + .Attrs({"cache_quant_type_str: std::string", "speculate_decoder: bool"}) .SetKernelFn(PD_KERNEL(DecodeMLAWriteCacheKernel)); diff --git a/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cuh b/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cuh index ec5b428bda1..2f1c26236e0 100644 --- a/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cuh +++ b/custom_ops/gpu_ops/append_attn/mla_cache_kernel.cuh @@ -27,7 +27,6 @@ __global__ void decode_absorb_cache_kernel( const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] - const int max_seq_len, const int max_blocks_per_seq, const int kv_num_heads, const int nope_size, @@ -98,7 +97,6 @@ __global__ void speculate_decode_absorb_cache_kernel( const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_encoder, // [bsz] - const int max_seq_len, const int max_blocks_per_seq, const int kv_num_heads, const int nope_size, @@ -168,18 +166,18 @@ __global__ void speculate_decode_absorb_cache_kernel( } } -template +template __global__ void prefill_absorb_cache_kernel( const T* __restrict__ kv_nope, // [bsz, kv_num_heads, pe_size] 512 const T* __restrict__ kv_pe, // [bsz, kv_num_heads, nope_size] 64 - T* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size, + CT* __restrict__ kv_cache, // [num_blocks, kv_num_heads, block_size, // nope_size] const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq] + const int64_t* __restrict__ slot_mapping, const int* __restrict__ batch_id_per_token, const int* __restrict__ cu_seqlens_q, const int* __restrict__ seq_lens, // [bsz] const int* __restrict__ seq_lens_decoder, // [bsz] - const int max_seq_len, const int max_blocks_per_seq, const int kv_num_heads, const int nope_size, @@ -201,7 +199,8 @@ __global__ void prefill_absorb_cache_kernel( linear_index += step) { const uint32_t token_idx = linear_index / hidden_size; const uint32_t bias = linear_index % hidden_size; - const uint32_t ori_bi = batch_id_per_token[token_idx]; + const int32_t ori_bi = batch_id_per_token[token_idx]; + if (ori_bi == -1) continue; if (seq_lens[ori_bi] == 0) continue; const uint32_t ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi]; @@ -211,6 +210,14 @@ __global__ void prefill_absorb_cache_kernel( const uint32_t block_idx = block_table_now[ori_seq_id / block_size]; const uint32_t block_offset = ori_seq_id % block_size; + const int32_t block_idx1 = slot_mapping[token_idx] / block_size; + if (block_idx1 != block_idx) { + printf("block_idx1 %d != block_idx %d\n", block_idx1, block_idx); + printf("token_idx %d\n", token_idx); + printf("slot_mapping %d\n", slot_mapping[token_idx]); + asm volatile("trap;"); + } + if (bias < nope_hidden_size) { // pe const uint32_t inner_bias = bias; const uint32_t hi = inner_bias / nope_size; diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index ecea1eff051..bd26769db4b 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -556,7 +556,6 @@ std::vector DecodeMLAWriteCacheKernel( const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, const std::string& cache_quant_type_str, - const int max_seq_len, const bool speculate_decoder); std::vector PrefillMLAWriteCacheKernel( @@ -568,9 +567,9 @@ std::vector PrefillMLAWriteCacheKernel( const paddle::Tensor& batch_id_per_token, const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& block_tables, + const paddle::Tensor& slot_mapping, const paddle::optional& kv_signal_data, - const std::string& cache_quant_type_str, - const int max_seq_len); + const std::string& cache_quant_type_str); void FusedRotaryPositionEncoding( paddle::Tensor& query, // [num_tokens, num_heads, head_size] or diff --git a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py index 413f81c9cd9..bad44ff5588 100644 --- a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py @@ -664,7 +664,6 @@ def forward_extend( metadata.block_tables, metadata.kv_signal_data_list[layer.layer_id], "none", - getattr(forward_meta, "max_input_length", -1), ) fmha_out = self.flash_attn_func( @@ -720,7 +719,6 @@ def forward_decode( forward_meta.cu_seqlens_q, metadata.block_tables, "none", - self.max_seq_len, speculate_decoder, ) @@ -799,21 +797,23 @@ def forward_mixed( latent_cache = forward_meta.caches[layer.layer_id] if hasattr(forward_meta, "caches") else None + assert k_pe.shape[0] == compressed_kv.shape[0] + prefill_mla_write_cache( + compressed_kv, + k_pe, + latent_cache, + forward_meta.seq_lens_this_time, + forward_meta.seq_lens_decoder, + forward_meta.batch_id_per_token, + forward_meta.cu_seqlens_q, + metadata.block_tables, + forward_meta.slot_mapping, + metadata.kv_signal_data_list[layer.layer_id], + "none", + ) + # Prefill branch: k is not None if k is not None: - prefill_mla_write_cache( - compressed_kv, - k_pe, - latent_cache, - forward_meta.seq_lens_encoder, - forward_meta.seq_lens_decoder, - forward_meta.batch_id_per_token, - forward_meta.cu_seqlens_q, - metadata.block_tables, - metadata.kv_signal_data_list[layer.layer_id], - "none", - self.max_seq_len, - ) if self.prop.major == 10: # TODO support FA4 @@ -845,20 +845,6 @@ def forward_mixed( # Decode branch: k is None if k is None: - decode_mla_write_cache( - compressed_kv, - k_pe, - latent_cache, - forward_meta.seq_lens_decoder, - forward_meta.seq_lens_encoder, - forward_meta.batch_id_per_token, - forward_meta.cu_seqlens_q, - metadata.block_tables, - "none", - self.max_seq_len, - speculate_decoder, - ) - if int(os.getenv("USE_FLASH_MLA", "0")) == 0 and self.prop.major == 9: assert self.num_heads <= 64, "paddle mla attention support failed" if self.heads_need_padding: @@ -961,6 +947,12 @@ def forward_mixed( @staticmethod def mla_blackwell(decoder_q, latent_cache, block_table, cache_seqlens, attn_softmax_scale): + # decoder_q = decoder_q.cast(paddle.float8_e4m3fn) + # latent_cache = latent_cache.cast(paddle.float8_e4m3fn) + + assert decoder_q.dtype == latent_cache.dtype + q_dtype = decoder_q.dtype + page_size = latent_cache.shape[2] q_num_heads = decoder_q.shape[2] assert decoder_q.shape[1:] == [1, q_num_heads, 576] @@ -1008,6 +1000,8 @@ def mla_blackwell(decoder_q, latent_cache, block_table, cache_seqlens, attn_soft from mla_decode_fp16 import BlackwellMultiHeadLatentAttentionForwardFP16 + # from mla_decode_fp8 import BlackwellMultiHeadLatentAttentionForwardFP8 + mla = BlackwellMultiHeadLatentAttentionForwardFP16( cutlass.Float32, cutlass.Float32, @@ -1063,10 +1057,18 @@ def mla_blackwell(decoder_q, latent_cache, block_table, cache_seqlens, attn_soft stream, ) + if q_dtype == paddle.float8_e4m3fn: + paddle_output = paddle_output.cast("bfloat16") return paddle_output @staticmethod def flashmla_baseline(decoder_q, latent_cache, block_table, cache_seqlens, attn_softmax_scale): + + assert decoder_q.dtype == latent_cache.dtype + + decoder_q = decoder_q.cast("bfloat16") + latent_cache = latent_cache.cast("bfloat16") + page_size = latent_cache.shape[2] q_num_heads = decoder_q.shape[2] assert decoder_q.shape[1:] == [1, q_num_heads, 576] diff --git a/tests/operators/test_deepgemm_precision.py b/tests/operators/test_deepgemm_precision.py index 40c177342a2..2bc67a8815c 100644 --- a/tests/operators/test_deepgemm_precision.py +++ b/tests/operators/test_deepgemm_precision.py @@ -16,6 +16,7 @@ import unittest +import numpy as np import paddle paddle.enable_compat(scope={"deep_gemm"}) @@ -42,8 +43,8 @@ def __init__(self): self.num_ab_stage = 4 self.num_acc_stage = 1 self.use_2cta_instrs = True - self.cluster_shape_mnk = (2, 1, 1) if self.use_2cta_instrs else (1, 1, 1) - self.cluster_shape_mn = (2, 1) if self.use_2cta_instrs else (1, 1) + self.cluster_shape_mnk = (2, 1, 1) + self.cluster_shape_mn = self.cluster_shape_mnk[:2] self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE self.mma_tiler = (128, 128, 64) @@ -125,7 +126,7 @@ def __call__( tma_atom_b, self.cluster_layout_vmnk, ).launch( - grid=[M // self.mma_tiler[0] * self.cluster_shape_mn[0], N // self.mma_tiler[1], 1], + grid=[M // self.mma_tiler[0] * self.atom_thr_size, N // self.mma_tiler[1], 1], block=[128, 1, 1], cluster=self.cluster_shape_mnk, ) @@ -443,42 +444,50 @@ def one_invoke(self, M, N, K): baseline_out = paddle.matmul(tmp0, tmp1, False, True) deepgemm_output = paddle.zeros_like(baseline_out) - for i in range(10): - a = paddle.zeros([1024, 1024, 1024]) + 1 + test_loops = 5 + start_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(test_loops)] + end_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(test_loops)] + + for i in range(test_loops): + # 这行代码放在这里是为了让event的计时更准确! + # 太棒啦! + for j in range(100): + a = paddle.zeros([1024, 1024, 1024]) + 1 del a a = raw_x_scale.transpose([1, 0]).contiguous().transpose([1, 0]) b = raw_w_scale.transpose([1, 0]).contiguous().transpose([1, 0]) + start_events[i].record() + deep_gemm.fp8_gemm_nt( (raw_x, a), (raw_w, b), deepgemm_output, ) - print(baseline_out - deepgemm_output) + end_events[i].record() + + total_time = np.array([round(s.elapsed_time(e), 10) for s, e in zip(start_events, end_events)])[-1:] + flops = 2.0 * M * N * K / (1024**4) / (total_time / 1000.0) + print(total_time[0], "ms") + print(flops[0], "TFLOPs/s") + + # print(baseline_out - deepgemm_output) # assert (baseline_out - deepgemm_output).abs().max().item() < 0.1 def test_main(self): prop = paddle.device.cuda.get_device_properties() if prop.major != 10: return - # import paddle.profiler as profiler - # p = profiler.Profiler( - # targets=[profiler.ProfilerTarget.CPU, profiler.ProfilerTarget.GPU], - # on_trace_ready=profiler.export_chrome_tracing("./profile_log"), - # ) - # p.start() - # p.step() - - # self.one_invoke(128 * 20, 2048, 4096) + self.one_invoke(4096, 4096, 4096) + self.one_invoke(4096, 2048, 7168) + self.one_invoke(4096, 65536, 1536) # self.one_invoke(128 * 20, 2048, 2048) self.two_invoke(128 * 20, 128 * 20, 64 * 4) - # p.stop() - if __name__ == "__main__": unittest.main() diff --git a/tests/operators/test_flashmla_precision.py b/tests/operators/test_flashmla_precision.py index d94898b7d32..2e129b38054 100644 --- a/tests/operators/test_flashmla_precision.py +++ b/tests/operators/test_flashmla_precision.py @@ -16,6 +16,7 @@ import unittest +import numpy as np import paddle paddle.set_default_dtype("bfloat16") @@ -30,21 +31,17 @@ def setUp(self): pass def test_flashmla(self): - # import paddle.profiler as profiler - # p = profiler.Profiler( - # targets=[profiler.ProfilerTarget.CPU, profiler.ProfilerTarget.GPU], - # on_trace_ready=profiler.export_chrome_tracing("./profile_log"), - # ) - # p.start() - # p.step() + dtype = paddle.float8_e4m3fn + dtype = paddle.bfloat16 bsz = 128 - kv_len = 1024 * 128 + kv_len = 1024 * 8 page_size = 64 - decoder_q = paddle.randn([bsz, 1, 128, 576], dtype="bfloat16") + decoder_q = paddle.randn([bsz, 1, 128, 576], dtype="bfloat16").cast(dtype) + cache_seqlens = paddle.zeros([bsz], dtype="int32") + kv_len block_tables = paddle.arange((kv_len // page_size + 1) * bsz, dtype="int32").reshape([bsz, -1]) - latent_cache = paddle.randn([bsz * block_tables.shape[1], 1, page_size, 576], dtype="bfloat16") + latent_cache = paddle.randn([bsz * block_tables.shape[1], 1, page_size, 576], dtype="bfloat16").cast(dtype) # copy from dsv3 attn_softmax_scale = 0.1352337788608801 @@ -55,10 +52,29 @@ def test_flashmla(self): prop = paddle.device.cuda.get_device_properties() if prop.major == 10: - for i in range(10): + test_loops = 5 + start_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(test_loops)] + end_events = [paddle.device.cuda.Event(enable_timing=True) for _ in range(test_loops)] + + for i in range(test_loops): + # 这行代码放在这里是为了让event的计时更准确! + # 太棒啦! + for _ in range(10): + a = paddle.zeros([1024, 1024, 1024]) + 1 + a = a + 2 + del a + + start_events[i].record() decoder_res = MLAAttentionBackend.mla_blackwell( decoder_q, latent_cache, block_tables, cache_seqlens, attn_softmax_scale ) + end_events[i].record() + + total_time = np.array([round(s.elapsed_time(e), 10) for s, e in zip(start_events, end_events)])[-1:] + band_width = 2 * bsz * kv_len * latent_cache.shape[-1] / (1024**4) / (total_time / 1000.0) + print(total_time[0], "ms") + print(band_width[0], "TB/s") + elif prop.major == 9: paddle.enable_compat(scope={"flash_mla"}) # Enable paddle.enable_compat before importing flash_mla try: @@ -86,7 +102,7 @@ def test_flashmla(self): softmax_scale=attn_softmax_scale, causal=True, ) - # p.stop() + max_diff = (decoder_res - baseline_out).abs().max().item() print(decoder_res - baseline_out) self.assertLessEqual(max_diff, 0.1)