diff --git a/custom_ops/gpu_ops/append_attn/utils.cuh b/custom_ops/gpu_ops/append_attn/utils.cuh index de73a331a38..edc366e69d7 100644 --- a/custom_ops/gpu_ops/append_attn/utils.cuh +++ b/custom_ops/gpu_ops/append_attn/utils.cuh @@ -31,7 +31,7 @@ struct AppendAttnMetaData { }; __forceinline__ __host__ __device__ int div_up(int a, int b) { - return (a + b - 1) / b; + return a / b + (a % b != 0); } enum PosEncMode { kNonePos, kRoPE, kAliBi }; diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 7cca2d28d3b..a4a7941a7e6 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -189,6 +189,84 @@ std::vector AppendAttentionWithOutput( const int sliding_window, const int sink_size); +std::vector DecoderWriteCacheWithRoPE( + const paddle::Tensor& qkv, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_tables, + const paddle::Tensor& set_max_lengths, + const paddle::optional& rotary_embs, + const paddle::optional& qkv_bias, + const paddle::optional& cache_k_quant_scales, + const paddle::optional& cache_v_quant_scales, + const paddle::optional& cache_k_dequant_scales, + const paddle::optional& cache_v_dequant_scales, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& kv_signal_data, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const bool speculate_decoder); + +std::vector DecodeUnifiedAttention( + const paddle::Tensor& qkv, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& tmp_workspace, + const paddle::Tensor& tmp_m, + const paddle::Tensor& tmp_d, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_tables, + const paddle::Tensor& block_indices, + const paddle::Tensor& num_blocks, + const paddle::Tensor& chunk_size, + const paddle::Tensor& set_max_lengths, + const paddle::optional& attn_mask, + const paddle::optional& cache_k_quant_scales, + const paddle::optional& cache_v_quant_scales, + const paddle::optional& cache_k_dequant_scales, + const paddle::optional& cache_v_dequant_scales, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& mask_offset, + const paddle::optional& sinks, + paddle::Tensor& fmha_out, + const std::string& cache_quant_type, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const int max_tokens_per_batch, + const bool causal, + const int sliding_window); + +void ConfigForAttention(const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + paddle::Tensor& block_indices, // Inplace + paddle::Tensor& num_blocks, // Inplace + paddle::Tensor& chunk_size, // Inplace + paddle::Tensor& max_len_tensor_cpu, // Inplace, CPU + const std::string cache_quant_type, + const int group_size, + const int kv_num_heads, + const int max_tokens_per_batch); + std::vector GQARopeWriteCacheKernel( const paddle::Tensor& qkv, const paddle::Tensor& key_cache, @@ -1963,4 +2041,28 @@ PYBIND11_MODULE(fastdeploy_ops, m) { m.def("per_token_group_fp8_quant", &PerTokenGroupQuantFp8, "per_token_group_quant_fp8"); + + /** + * decoder_write_cache_with_rope.cu + * decoder_write_cache_with_rope + */ + m.def("decoder_write_cache_with_rope", + &DecoderWriteCacheWithRoPE, + "decoder write cache with RoPE function"); + + /** + * decode_unified_attention.cu + * decode_unified_attention + */ + m.def("decode_unified_attention", + &DecodeUnifiedAttention, + "decoder append attention function"); + + /** + * config_for_attention.cu + * config_for_attention + */ + m.def("config_for_attention", + &ConfigForAttention, + "config for attention function"); } diff --git a/custom_ops/gpu_ops/decode_unified_attention.cu b/custom_ops/gpu_ops/decode_unified_attention.cu new file mode 100644 index 00000000000..257134d1e95 --- /dev/null +++ b/custom_ops/gpu_ops/decode_unified_attention.cu @@ -0,0 +1,428 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "decode_unified_attention/decode_unified_attention_c8_impl.cuh" +#include "decode_unified_attention/decode_unified_attention_c16_impl.cuh" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +template +class type2value; + +template <> +class type2value { + public: + static constexpr paddle::DataType value = paddle::DataType::BFLOAT16; +}; + +template <> +class type2value { + public: + static constexpr paddle::DataType value = paddle::DataType::FLOAT16; +}; + +std::vector DecodeUnifiedAttention( + const paddle::Tensor& qkv, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& tmp_workspace, + const paddle::Tensor& tmp_m, + const paddle::Tensor& tmp_d, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_tables, + const paddle::Tensor& block_indices, + const paddle::Tensor& num_blocks, + const paddle::Tensor& chunk_size, + const paddle::Tensor& set_max_lengths, + const paddle::optional& attn_mask, + const paddle::optional& cache_k_quant_scales, + const paddle::optional& cache_v_quant_scales, + const paddle::optional& cache_k_dequant_scales, + const paddle::optional& cache_v_dequant_scales, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& mask_offset, + const paddle::optional& sinks, + paddle::Tensor& fmha_out, + const std::string& cache_quant_type, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const int max_tokens_per_batch, + const bool causal, + const int sliding_window) { + AppendAttnMetaData meta_data; + + const auto& qkv_dims = qkv.dims(); + const auto& key_cache_dims = key_cache.dims(); + meta_data.token_num = qkv_dims[0]; + meta_data.kv_num_heads = key_cache_dims[1]; + meta_data.head_dims = key_cache_dims[3]; + // TODO: trick method support c4, add attr head_dims in the future + if (cache_quant_type == "cache_int4_zp") { + meta_data.head_dims *= 2; + } + const int total_num_head = + qkv_dims[qkv_dims.size() - 1] / meta_data.head_dims; + meta_data.q_num_heads = total_num_head - 2 * meta_data.kv_num_heads; + const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads; + + meta_data.max_blocks_per_seq = block_tables.dims()[1]; + meta_data.block_size = key_cache.dims()[2]; + meta_data.batch_size = seq_lens_this_time.dims()[0]; + + if (mask_offset) { + meta_data.mask_offset = mask_offset.get().data(); + } + + const int max_just_dec_len_this_time = set_max_lengths.data()[4]; + const int max_kv_len_this_time = set_max_lengths.data()[5]; + + auto stream = qkv.stream(); + bool is_fp8 = + cache_quant_type == "cache_fp8" || cache_quant_type == "block_wise_fp8"; + bool is_dynamic_cfp8 = cache_quant_type == "block_wise_fp8"; + bool is_c16 = cache_quant_type == "none"; + + if (max_just_dec_len_this_time > 0) { + if (is_c16) { + DISPATCH_CAUSAL( + causal, + CAUSAL, + {DISPATCH_GQA_GROUP_SIZE( + group_size, + GROUP_SIZE, + {DISPATCH_HEAD_DIM( + meta_data.head_dims, + HEAD_DIM, + {DISPATCH_BLOCK_SIZE( + meta_data.block_size, + BLOCK_SIZE, + {DISPATCH_Q_TILE_SIZE( + group_size, max_tokens_per_batch, Q_TILE_SIZE, { + switch (qkv.dtype()) { + case paddle::DataType::BFLOAT16: { + DecodeUnifiedC16Attention( + meta_data, + qkv, + key_cache, + value_cache, + tmp_workspace, + tmp_m, + tmp_d, + attn_mask, + sinks, + seq_lens_this_time, + seq_lens_decoder, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + block_indices, + num_blocks, + chunk_size, + max_input_length, + max_kv_len_this_time, + max_tokens_per_batch, + stream, + &fmha_out, + sliding_window); + break; + } + case paddle::DataType::FLOAT16: { + DecodeUnifiedC16Attention( + meta_data, + qkv, + key_cache, + value_cache, + tmp_workspace, + tmp_m, + tmp_d, + attn_mask, + sinks, + seq_lens_this_time, + seq_lens_decoder, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + block_indices, + num_blocks, + chunk_size, + max_input_length, + max_kv_len_this_time, + max_tokens_per_batch, + stream, + &fmha_out, + sliding_window); + break; + } + default: + PD_THROW( + "NOT supported data type. " + "Only bfloat16 and float16 are " + "supported. "); + } + })})})})}) + } else { + DISPATCH_CAUSAL( + causal, + CAUSAL, + {DISPATCH_GQA_GROUP_SIZE( + group_size, + GROUP_SIZE, + {DISPATCH_HEAD_DIM( + meta_data.head_dims, + HEAD_DIM, + {DISPATCH_BLOCK_SIZE( + meta_data.block_size, + BLOCK_SIZE, + {DISPATCH_Q_TILE_SIZE( + group_size, + max_tokens_per_batch, + Q_TILE_SIZE, + {DISPATCH_DyCfp8( + is_dynamic_cfp8, + IsDynamicC8, + {DISPATCH_IS_FP8(is_fp8, IsFP8, { + switch (qkv.dtype()) { + case paddle::DataType::BFLOAT16: { + DecodeUnifiedC8Attention( + meta_data, + qkv, + key_cache, + value_cache, + tmp_workspace, + tmp_m, + tmp_d, + attn_mask, + cache_quant_type == "block_wise_fp8" + ? cache_k_quant_scales.get() + : cache_k_dequant_scales.get(), + cache_quant_type == "block_wise_fp8" + ? cache_v_quant_scales.get() + : cache_v_dequant_scales.get(), + sinks, + seq_lens_this_time, + seq_lens_decoder, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + block_indices, + num_blocks, + chunk_size, + max_input_length, + max_kv_len_this_time, + quant_max_bound, + quant_min_bound, + max_tokens_per_batch, + stream, + &fmha_out, + sliding_window); + break; + } + case paddle::DataType::FLOAT16: { + DecodeUnifiedC8Attention( + meta_data, + qkv, + key_cache, + value_cache, + tmp_workspace, + tmp_m, + tmp_d, + attn_mask, + cache_quant_type == "block_wise_fp8" + ? cache_k_quant_scales.get() + : cache_k_dequant_scales.get(), + cache_quant_type == "block_wise_fp8" + ? cache_v_quant_scales.get() + : cache_v_dequant_scales.get(), + sinks, + seq_lens_this_time, + seq_lens_decoder, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + block_indices, + num_blocks, + chunk_size, + max_input_length, + max_kv_len_this_time, + quant_max_bound, + quant_min_bound, + max_tokens_per_batch, + stream, + &fmha_out, + sliding_window); + break; + } + default: + PD_THROW( + "NOT supported data type. " + "Only bfloat16 and float16 are " + "supported. "); + } + })})})})})})}) + } + } + return {fmha_out}; +} + +std::vector> DecodeUnifiedAttentionInferShape( + const std::vector& qkv_shape, + const std::vector& key_cache_shape, + const std::vector& value_cache_shape, + const std::vector& tmp_workspace_shape, + const std::vector& tmp_m_shape, + const std::vector& tmp_d_shape, + const std::vector& seq_lens_encoder_shape, + const std::vector& seq_lens_decoder_shape, + const std::vector& seq_lens_this_time_shape, + const std::vector& batch_id_per_token_shape, + const std::vector& cu_seqlens_q_shape, + const std::vector& block_tables_shape, + const std::vector& block_indices_shape, + const std::vector& num_blocks_shape, + const std::vector& chunk_size_shape, + const std::vector& set_max_lengths_shape, + const paddle::optional>& attn_mask_shape, + const paddle::optional>& cache_k_quant_scales_shape, + const paddle::optional>& cache_v_quant_scales_shape, + const paddle::optional>& cache_k_dequant_scales_shape, + const paddle::optional>& cache_v_dequant_scales_shape, + const paddle::optional>& cache_k_zp_shape, + const paddle::optional>& cache_v_zp_shape, + const paddle::optional>& mask_offset_shape, + const paddle::optional>& sinks_shape, + const std::vector& fmha_out_shape, + const std::string& cache_quant_type, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const int max_tokens_per_batch, + const bool causal, + const int sliding_window) { + return {fmha_out_shape}; +} + +std::vector DecodeUnifiedAttentionInferDtype( + const paddle::DataType& qkv_dtype, + const paddle::DataType& key_cache_dtype, + const paddle::DataType& value_cache_dtype, + const paddle::DataType& tmp_workspace_dtype, + const paddle::DataType& tmp_m_dtype, + const paddle::DataType& tmp_d_dtype, + const paddle::DataType& seq_lens_encoder_dtype, + const paddle::DataType& seq_lens_decoder_dtype, + const paddle::DataType& seq_lens_this_time_dtype, + const paddle::DataType& batch_id_per_token_dtype, + const paddle::DataType& cu_seqlens_q_dtype, + const paddle::DataType& block_tables_dtype, + const paddle::DataType& block_indices_dtype, + const paddle::DataType& num_blocks_dtype, + const paddle::DataType& chunk_size_dtype, + const paddle::DataType& set_max_lengths_dtype, + const paddle::optional& attn_mask_dtype, + const paddle::optional& cache_k_quant_scales_dtype, + const paddle::optional& cache_v_quant_scales_dtype, + const paddle::optional& cache_k_dequant_scales_dtype, + const paddle::optional& cache_v_dequant_scales_dtype, + const paddle::optional& cache_k_zp_dtype, + const paddle::optional& cache_v_zp_dtype, + const paddle::optional& mask_offset_dtype, + const paddle::optional& sinks_dtype, + const paddle::DataType& fmha_out_dtype, + const std::string& cache_quant_type, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const int max_tokens_per_batch, + const bool causal, + const int sliding_window) { + return {fmha_out_dtype}; +} + +PD_BUILD_STATIC_OP(decode_unified_attention) + .Inputs({"qkv", + "key_cache", + "value_cache", + "tmp_workspace", + "tmp_m", + "tmp_d", + "seq_lens_encoder", + "seq_lens_decoder", + "seq_lens_this_time", + "batch_id_per_token", + "cu_seqlens_q", + "block_tables", + "block_indices", + "num_blocks", + "chunk_size", + "set_max_lengths", + paddle::Optional("attn_mask"), + paddle::Optional("cache_k_quant_scales"), + paddle::Optional("cache_v_quant_scales"), + paddle::Optional("cache_k_dequant_scales"), + paddle::Optional("cache_v_dequant_scales"), + paddle::Optional("cache_k_zp"), + paddle::Optional("cache_v_zp"), + paddle::Optional("mask_offset"), + paddle::Optional("sinks"), + "fmha_out"}) + .Outputs({"fmha_out_out"}) + .SetInplaceMap({{"fmha_out", "fmha_out_out"}}) + .Attrs({ + "cache_quant_type: std::string", + "max_input_length: int", + "quant_max_bound: float", + "quant_min_bound: float", + "max_tokens_per_batch: int", + "causal: bool", + "sliding_window: int", + }) + .SetKernelFn(PD_KERNEL(DecodeUnifiedAttention)) + .SetInferShapeFn(PD_INFER_SHAPE(DecodeUnifiedAttentionInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(DecodeUnifiedAttentionInferDtype)); diff --git a/custom_ops/gpu_ops/decode_unified_attention/attention_func.cuh b/custom_ops/gpu_ops/decode_unified_attention/attention_func.cuh new file mode 100644 index 00000000000..ee74570e5d8 --- /dev/null +++ b/custom_ops/gpu_ops/decode_unified_attention/attention_func.cuh @@ -0,0 +1,1231 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "mma_tensor_op.cuh" +#include "utils.cuh" + +template +__device__ __forceinline__ void init_states(float (*o_frag)[num_frags_y][8], + float (*m)[2], + float (*d)[2]) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + o_frag[fx][fy][reg_id] = 0.f; + } + } + } +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + if constexpr (std::is_same::value) { + m[fx][j] = -5e4f; + } else if constexpr (std::is_same::value) { + m[fx][j] = -3.0e+30f; + } + d[fx][j] = 1.f; + } + } +} + +template +__device__ __forceinline__ void load_block_table_per_chunk( + const int32_t* block_table_chunk_start, + int32_t* block_table_smem, + uint32_t chunk_start, + uint32_t chunk_end, + uint32_t tid, + uint32_t wid) { + uint32_t len = chunk_end / BLOCK_SIZE - chunk_start / BLOCK_SIZE; + for (uint32_t i = 0; i < div_up(len, 128); i++) { + uint32_t offset = wid * kWarpSize + tid + i * 128; + if (offset < len) { + block_table_smem[offset] = block_table_chunk_start[offset]; + } + } +} + +// load q from global memory to shared memory +template +__device__ __forceinline__ void load_q_global_smem_multi_warps( + T* q_ptr_base, + smem_t* q_smem, + uint32_t q_idx_base, + const uint32_t qo_upper_bound, + const uint32_t qo_n_stride, + const uint32_t qo_h_stride) { + constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); + + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + uint32_t q_smem_offset_w = // [NUM_WARP_Q, num_frags_x, 16, head_dim] + smem_t::get_permuted_offset(ty * 4 + tx / 8, + tx % 8); // 4 * 64 + + const uint32_t tx_offset = tx / 8; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + const uint32_t base_offset = q_idx_base + fx * 16 + tx_offset; +#pragma unroll + const int j = ty; + const uint32_t offset_now = base_offset + j * 4; + const uint32_t n_offset = offset_now / group_size; + const uint32_t h_offset = offset_now % group_size; + T* q_ptr = q_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride; +#pragma unroll + for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { + q_smem->load_128b_async( + q_smem_offset_w, q_ptr, n_offset < qo_upper_bound); + q_smem_offset_w = + q_smem->advance_offset_by_column<8>(q_smem_offset_w, fyo); + q_ptr += 8 * num_elems_per_128b(); + } + q_smem_offset_w = + q_smem->advance_offset_by_row<16, num_vecs_per_head>(q_smem_offset_w) - + 2 * num_frags_y; + } +} + +template +__device__ __forceinline__ void q_smem_inplace_multiply_sm_scale_multi_warps( + smem_t* q_smem, // [num_frags_x * 16, num_frags_y * 16] + const float sm_scale) { + constexpr int vec_size = 16 / sizeof(T); + using LoadT = AlignedVector; + LoadT tmp_vec; + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); + +#pragma unroll + for (uint32_t i = 0; i < num_frags_x * 16 * head_dim / 1024; ++i) { + const int offset = i * 1024 + ty * 256 + tx * 8; + Load(reinterpret_cast(q_smem->base) + offset, &tmp_vec); +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + tmp_vec[reg_id] *= sm_scale; + } + Store(tmp_vec, reinterpret_cast(q_smem->base) + offset); + } +} + +template +__device__ __forceinline__ void produce_k_blockwise_c8( + smem_t smem, + uint32_t* smem_offset, + CacheT* cache_k, + const int* block_table_now, + const uint32_t kv_head_idx, + const uint32_t kv_n_stride, + const uint32_t kv_h_stride, + const uint32_t kv_b_stride, + const uint32_t kv_idx_base, + const uint32_t kv_len, + const uint32_t const_k_offset) { + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head = + head_dim / num_elems_per_128b(); // 8 + constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q; + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + uint32_t kv_idx = kv_idx_base + ty * 4 + tx / 8; +#pragma unroll + for (uint32_t kv_i = 0; kv_i < NUM_WARP_KV / 2; ++kv_i) { + int block_id = __ldg(&block_table_now[kv_idx / block_size]); + if (block_id < 0) block_id = 0; + CacheT* cache_k_now = cache_k + block_id * kv_n_stride + const_k_offset; +#pragma unroll + for (uint32_t i = 0; i < 2 * num_frags_z * 4 / num_warps; + ++i) { // m num_frags_z * 16 / (num_warps * 4) +#pragma unroll + for (uint32_t j = 0; j < num_frags_y / 8; ++j) { + smem.load_128b_async(*smem_offset, cache_k_now, true); + *smem_offset = smem.advance_offset_by_column<8, num_vecs_per_head>( + *smem_offset, j); + cache_k_now += 8 * num_elems_per_128b(); + } + kv_idx += num_warps * 4; + *smem_offset = + smem.advance_offset_by_row( + *smem_offset) - + num_frags_y; // num_frags_y / 4 * 4 + cache_k_now += num_warps * 4 * kv_b_stride - + num_frags_y * num_elems_per_128b(); + } + } + *smem_offset -= NUM_WARP_KV * num_frags_z * 16 * num_vecs_per_head; +} + +template +__device__ __forceinline__ void produce_v_blockwise_c8( + smem_t smem, + uint32_t* smem_offset, + CacheT* cache_v, + const int* block_table_now, + const uint32_t kv_head_idx, + const uint32_t kv_n_stride, + const uint32_t kv_h_stride, + const uint32_t kv_d_stride, + const uint32_t kv_idx_base, + const uint32_t kv_len, + const uint32_t const_v_offset) { + constexpr uint32_t num_vecs_per_blocksize = + block_size / num_elems_per_128b(); // 8 + constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q; + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + uint32_t kv_idx = kv_idx_base + tx % 4 * num_elems_per_128b(); + +#pragma unroll + for (uint32_t kv_i = 0; kv_i < NUM_WARP_KV / 2; ++kv_i) { + int block_id = __ldg(&block_table_now[kv_idx / block_size]); + if (block_id < 0) block_id = 0; + CacheT* cache_v_now = cache_v + block_id * kv_n_stride + const_v_offset; + +#pragma unroll + for (uint32_t i = 0; i < num_frags_y * 2 / num_warps; + ++i) { // m (num_frags_y * 16 / (num_warps * 8)) +#pragma unroll + for (uint32_t j = 0; j < 2 * num_frags_z / 4; ++j) { + smem.load_128b_async(*smem_offset, cache_v_now, true); + *smem_offset = smem.advance_offset_by_column<4, num_vecs_per_blocksize>( + *smem_offset, j); + cache_v_now += 4 * num_elems_per_128b(); + kv_idx += 4 * num_elems_per_128b(); + } + kv_idx -= 2 * num_frags_z * num_elems_per_128b(); + *smem_offset = + smem.advance_offset_by_row( + *smem_offset) - + 2 * num_frags_z; // num_frags_z / 4 * 4 + cache_v_now += num_warps * 8 * kv_d_stride - + 2 * num_frags_z * num_elems_per_128b(); + } + kv_idx += block_size; + } + *smem_offset -= NUM_WARP_KV / 2 * num_frags_y * 16 * num_vecs_per_blocksize; +} + +template +__device__ __forceinline__ void produce_kv_dynamic_scale_gmem2smem_async( + smem_t kv_scale_smem, + const int* block_table_now, + const T* cache_kv_scale, + const uint32_t kv_idx, + const uint32_t kv_num_heads, + const uint32_t kv_head_idx, + const uint32_t chunk_end) { + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + const uint32_t tid = ty * 32 + tx; + // 1 warp 32 tokens + if (tid < block_size / 8 * 2) { + const uint32_t kv_idx_now = kv_idx + block_size * tid / 8; + int block_id = __ldg(&block_table_now[kv_idx_now / block_size]); + if (block_id < 0) block_id = 0; + const int kv_idx_this_thread = kv_idx + tid * 8; + const T* cache_k_scale_now = cache_kv_scale + + block_id * kv_num_heads * block_size + + kv_head_idx * block_size + tid % 8 * 8; + kv_scale_smem.load_128b_async( + tid, cache_k_scale_now, kv_idx_this_thread < chunk_end); + } +} + +template +__device__ __forceinline__ void produce_k_dynamic_scale_smem2reg( + T* k_smem_scale, T* cache_k_reg) { + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + // 1 warp 32 tokens + const uint32_t row_id = tx / 4; + for (uint32_t fz = 0; fz < num_frags_z; fz++) { + const uint32_t scale_idx = ty * 32 + fz * 16 + row_id; + cache_k_reg[fz * 2] = k_smem_scale[scale_idx]; + cache_k_reg[fz * 2 + 1] = k_smem_scale[scale_idx + 8]; + } +} + +template +__device__ __forceinline__ void produce_v_dynamic_scale_smem2reg( + T* v_smem_scale, T* cache_v_reg) { + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + + // 1 warp 32 tokens + const uint32_t row_id = tx % 4 * 2; + for (uint32_t fz = 0; fz < num_frags_z; fz++) { + const uint32_t scale_idx = ty * 32 + fz * 16 + row_id; + cache_v_reg[fz * 4] = v_smem_scale[scale_idx]; + cache_v_reg[fz * 4 + 1] = v_smem_scale[scale_idx + 1]; + cache_v_reg[fz * 4 + 2] = v_smem_scale[scale_idx + 8]; + cache_v_reg[fz * 4 + 3] = v_smem_scale[scale_idx + 9]; + } +} + +template +__device__ __forceinline__ void compute_qk_c8(smem_t* q_smem, + uint32_t* q_smem_offset_r, + smem_t* k_smem, + uint32_t* k_smem_offset_r, + const T* cache_k_scale, + float (*s_frag)[num_frags_z][8]) { + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head_q = head_dim / num_elems_per_128b(); + constexpr uint32_t num_vecs_per_head_k = + head_dim / num_elems_per_128b(); + + uint32_t a_frag[num_frags_x][2][4], b_frag[4], b_frag_dq[4]; + +#pragma unroll + for (uint32_t ky = 0; ky < num_frags_y / 2; ++ky) { // k + // load q +#pragma unroll + for (uint32_t fy = 0; fy < 2; ++fy) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + q_smem->ldmatrix_m8n8x4(*q_smem_offset_r, a_frag[fx][fy]); + + *q_smem_offset_r = + q_smem->advance_offset_by_row<16, num_vecs_per_head_q>( + *q_smem_offset_r); + } + *q_smem_offset_r = + q_smem->advance_offset_by_column<2>(*q_smem_offset_r, ky * 2 + fy) - + num_frags_x * 16 * num_vecs_per_head_q; + } + +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + // load + k_smem->ldmatrix_m8n8x4(*k_smem_offset_r, b_frag); + *k_smem_offset_r = k_smem->advance_offset_by_row<16, num_vecs_per_head_k>( + *k_smem_offset_r); +#pragma unroll + for (uint32_t fy = 0; fy < 2; ++fy) { + T* b_frag_dq_T = reinterpret_cast(b_frag_dq); + convert_c8(b_frag_dq_T, b_frag[fy * 2]); + convert_c8(b_frag_dq_T + 4, b_frag[fy * 2 + 1]); + // scale zp + if constexpr (!IsDynamicC8) { + if constexpr (is_scale_channel_wise) { + const int scale_col = (ky * 2 + fy) * 4; + b_frag_dq_T[0] *= cache_k_scale[scale_col]; + b_frag_dq_T[1] *= cache_k_scale[scale_col + 1]; + b_frag_dq_T[2] *= cache_k_scale[scale_col + 2]; + b_frag_dq_T[3] *= cache_k_scale[scale_col + 3]; + b_frag_dq_T[4] *= cache_k_scale[scale_col]; + b_frag_dq_T[5] *= cache_k_scale[scale_col + 1]; + b_frag_dq_T[6] *= cache_k_scale[scale_col + 2]; + b_frag_dq_T[7] *= cache_k_scale[scale_col + 3]; + } else { +#pragma unroll + for (uint32_t b_i = 0; b_i < 8; ++b_i) { + b_frag_dq_T[b_i] *= cache_k_scale[0]; + } + } + } else { +#pragma unroll + for (uint32_t b_i = 0; b_i < 8; ++b_i) { + b_frag_dq_T[b_i] *= cache_k_scale[fz * 2 + b_i / 4]; + } + } +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + if (ky == 0 && fy == 0) { + mma_sync_m16n16k16_row_col_f16f16f32( + s_frag[fx][fz], a_frag[fx][fy], b_frag_dq); + } else { + mma_sync_m16n16k16_row_col_f16f16f32( + s_frag[fx][fz], a_frag[fx][fy], b_frag_dq); + } + } + } + } + *k_smem_offset_r = k_smem->advance_offset_by_column<2, num_vecs_per_head_k>( + *k_smem_offset_r, ky) - + num_frags_z * 16 * num_vecs_per_head_k; + } + *q_smem_offset_r -= num_frags_y * 2; + *k_smem_offset_r -= num_frags_y / 2 * 2; +} + +template +__device__ __forceinline__ void mask_s(const bool* attn_mask, + const uint32_t qo_idx_base, + const uint32_t kv_idx_base, + const uint32_t qo_len, + const uint32_t kv_len, + const uint32_t chunk_end, + const uint32_t attn_mask_len, + float (*s_frag)[num_frags_z][8], + const int* mask_offset = nullptr, + const int sliding_window = 0) { + const uint32_t tx = threadIdx.x; +#pragma unroll 1 + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + const uint32_t q_idx = (qo_idx_base + fx * 16 + tx / 4 + + 8 * ((reg_id % 4) / 2)) / + group_size, + kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) + + 8 * (reg_id / 4) + reg_id % 2; + bool out_of_boundary; + if (mask_offset) { + const int2 mo = reinterpret_cast(mask_offset)[q_idx]; + out_of_boundary = + q_idx < qo_len ? (kv_idx >= mo.y || kv_idx < mo.x) : true; + } else if (sliding_window > 0) { + bool out_of_window = int(kv_idx) <= (int)kv_len + (int)q_idx - + (int)qo_len - sliding_window; + out_of_boundary = (causal ? (kv_idx > kv_len + q_idx - qo_len || + out_of_window || (kv_idx >= chunk_end)) + : kv_idx >= chunk_end); + } else { + out_of_boundary = (causal ? (kv_idx > kv_len + q_idx - qo_len || + (kv_idx >= chunk_end)) + : kv_idx >= chunk_end); + if (attn_mask != nullptr && kv_idx > kv_len - qo_len && + kv_idx < chunk_end && q_idx < attn_mask_len) { + const int32_t mask_idx = + q_idx * attn_mask_len + kv_idx - kv_len + qo_len; + bool mask = attn_mask[mask_idx]; + out_of_boundary |= mask; + } + } + + if constexpr (std::is_same::value) { + s_frag[fx][fz][reg_id] = + out_of_boundary ? -5e4f : s_frag[fx][fz][reg_id]; + } else if constexpr (std::is_same::value) { + s_frag[fx][fz][reg_id] = + out_of_boundary ? -3.0e+30f : s_frag[fx][fz][reg_id]; + } + } + } + } +} + +template +__device__ __forceinline__ void update_mdo_states( + float (*s_frag)[num_frags_z][8], + float (*o_frag)[num_frags_y][8], + float (*m)[2], + float (*d)[2]) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + uint32_t j_id = j * 2; + float m_prev = m[fx][j]; +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + float* s_frag_tmp = s_frag[fx][fz] + j_id; + float m_local = max(max(s_frag_tmp[0], s_frag_tmp[1]), + max(s_frag_tmp[4], s_frag_tmp[5])); + m[fx][j] = max(m[fx][j], m_local); + } + m[fx][j] = max(m[fx][j], __shfl_xor_sync(-1, m[fx][j], 0x2, 32)); + m[fx][j] = max(m[fx][j], __shfl_xor_sync(-1, m[fx][j], 0x1, 32)); + float o_scale = expf(m_prev - m[fx][j]); + d[fx][j] *= o_scale; + float2 fp2_scale = make_float2(o_scale, o_scale); +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + float2* o_frag_ptr = reinterpret_cast(o_frag[fx][fy] + j_id); + o_frag_ptr[0] = fast_float2_mul(o_frag_ptr[0], fp2_scale); + o_frag_ptr[2] = fast_float2_mul(o_frag_ptr[2], fp2_scale); + } + float tmp_m = m[fx][j]; +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + float* s_frag_ptr = s_frag[fx][fz] + j_id; + s_frag_ptr[0] = __expf(s_frag_ptr[0] - tmp_m); + s_frag_ptr[1] = __expf(s_frag_ptr[1] - tmp_m); + s_frag_ptr[4] = __expf(s_frag_ptr[4] - tmp_m); + s_frag_ptr[5] = __expf(s_frag_ptr[5] - tmp_m); + } + } + } +} + +template +__device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec( + smem_t* v_smem, + uint32_t* v_smem_offset_r, + float (*s_frag)[num_frags_z][8], + float (*o_frag)[num_frags_y][8], + float (*d)[2], + T* cache_v_scale) { + constexpr uint32_t num_vecs_per_blocksize = + block_size / num_elems_per_128b(); + + T s_frag_f16[num_frags_x][num_frags_z][8]; + uint32_t b_frag[4], b_frag_dq[4]; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + vec_cast(s_frag_f16[fx][fz], s_frag[fx][fz]); + } + } + +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + rowsum_f16f16f32(d[fx], s_frag_f16[fx][fz]); + } + } + +#pragma unroll + for (uint32_t kz = 0; kz < num_frags_z / 2; ++kz) { // k +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + v_smem->ldmatrix_m8n8x4(*v_smem_offset_r, b_frag); + *v_smem_offset_r = + v_smem->advance_offset_by_row<16, num_vecs_per_blocksize>( + *v_smem_offset_r); +#pragma unroll + for (uint32_t fz = 0; fz < 2; ++fz) { + // dequant b_frag -> b_frag_dq + T* b_frag_dq_T = reinterpret_cast(b_frag_dq); + convert_c8(b_frag_dq_T, b_frag[fz * 2]); + convert_c8(b_frag_dq_T + 4, b_frag[fz * 2 + 1]); + // scale zp + if constexpr (!IsDynamicC8) { + if constexpr (is_scale_channel_wise) { +#pragma unroll + for (uint32_t b_i = 0; b_i < 8; ++b_i) { + b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2]; + } + } else { +#pragma unroll + for (uint32_t b_i = 0; b_i < 8; ++b_i) { + b_frag_dq_T[b_i] *= cache_v_scale[0]; + } + } + } else { + const int scale_col = (kz * 2 + fz) * 4; + b_frag_dq_T[0] *= cache_v_scale[scale_col]; + b_frag_dq_T[1] *= cache_v_scale[scale_col + 1]; + b_frag_dq_T[2] *= cache_v_scale[scale_col + 2]; + b_frag_dq_T[3] *= cache_v_scale[scale_col + 3]; + b_frag_dq_T[4] *= cache_v_scale[scale_col]; + b_frag_dq_T[5] *= cache_v_scale[scale_col + 1]; + b_frag_dq_T[6] *= cache_v_scale[scale_col + 2]; + b_frag_dq_T[7] *= cache_v_scale[scale_col + 3]; + } +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16 + mma_sync_m16n16k16_row_col_f16f16f32( + o_frag[fx][fy], + (uint32_t*)(s_frag_f16[fx][kz * 2 + fz]), + b_frag_dq); + } + } + } + *v_smem_offset_r -= num_frags_y * 16 * num_vecs_per_blocksize; + } +} + +template +__device__ __forceinline__ void merge_block_res(float (*o_frag)[num_frags_y][8], + float* md_smem, + float (*m)[2], + float (*d)[2], + const uint32_t wid, + const uint32_t tid, + const bool normalize = false) { + // Padded row stride (33 instead of 32) to avoid cross-row bank conflicts. + constexpr uint32_t kRowStride = 33; + // o_smem row stride in floats: kRowStride * 8 = 264 + constexpr uint32_t kORowStride = kRowStride * 8; + // md_smem base offset: after all o_smem data + // NUM_WARPS(4) * num_frags_x * num_frags_y * kORowStride floats + constexpr uint32_t kOMemFloats = 4 * num_frags_x * num_frags_y * kORowStride; + float2* smem_md = reinterpret_cast(md_smem + kOMemFloats); + + // Phase 1: Write m/d to smem only (2KB, no o data yet) +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + smem_md[((wid * num_frags_x + fx) * 2 + j) * kRowStride + tid] = + make_float2(m[fx][j], d[fx][j]); + } + } + __syncthreads(); + + // Phase 2: Compute global m/d and scale own o_frag in registers + float scale_j[2]; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + float m_new; + float d_new = 1.f; + if constexpr (std::is_same::value) { + m_new = -5e4f; + } else { + m_new = -3.0e+30f; + } +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + float2 md = + smem_md[((i * num_frags_x + fx) * 2 + j) * kRowStride + tid]; + float m_prev = m_new, d_prev = d_new; + m_new = max(m_new, md.x); + d_new = fmaf(d_prev, expf(m_prev - m_new), md.y * expf(md.x - m_new)); + } + float own_scale = expf(m[fx][j] - m_new); + m[fx][j] = m_new; + d[fx][j] = d_new; + float d_rcp = normalize ? (1.f / d_new) : 1.f; + scale_j[j] = own_scale * d_rcp; + } + // Apply scale to o_frag using WGMMA fragment layout: + // regs 0,1→j=0, 2,3→j=1, 4,5→j=0, 6,7→j=1 + // i.e., float2 index k → j = k % 2 +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { +#pragma unroll + for (uint32_t k = 0; k < 4; ++k) { + float s = scale_j[k % 2]; + o_frag[fx][fy][2 * k + 0] *= s; + o_frag[fx][fy][2 * k + 1] *= s; + } + } + } + + // Phase 3: Write pre-scaled o_frag to smem with padded stride +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + float2* o_smem_start = + (float2*)(md_smem + + ((wid * num_frags_x + fx) * num_frags_y + fy) * + kORowStride + + tid * 2); +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + o_smem_start[i * kRowStride] = ((float2*)(&o_frag[fx][fy][0]))[i]; + } + } + } + __syncthreads(); + + // Phase 4: Accumulate all warps' scaled o_frag +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + float2* o_new_fp2 = reinterpret_cast(&o_frag[fx][fy][0]); +#pragma unroll + for (uint32_t o_id = 0; o_id < 4; ++o_id) { + o_new_fp2[o_id] = make_float2(0.f, 0.f); + } +#pragma unroll + for (uint32_t i = 0; i < 4; ++i) { + AlignedVector oi_fp2; + float2* o_smem_start = + (float2*)(md_smem + + ((i * num_frags_x + fx) * num_frags_y + fy) * + kORowStride + + tid * 2); +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 4; ++reg_id) { + oi_fp2[reg_id] = o_smem_start[reg_id * kRowStride]; + } +#pragma unroll + for (uint32_t reg_fp2_id = 0; reg_fp2_id < 4; ++reg_fp2_id) { + o_new_fp2[reg_fp2_id].x += oi_fp2[reg_fp2_id].x; + o_new_fp2[reg_fp2_id].y += oi_fp2[reg_fp2_id].y; + } + } + } + } +} + +template +__device__ __forceinline__ void normalize_d(float (*o_frag)[num_frags_y][8], + float (*d)[2]) { + float d_rcp[num_frags_x][2]; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + d_rcp[fx][j] = 1.f / d[fx][j]; + } + } + +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + o_frag[fx][fy][reg_id] = + o_frag[fx][fy][reg_id] * d_rcp[fx][(reg_id % 4) / 2]; + } + } + } +} + +template +__device__ __forceinline__ void write_o_reg_gmem_multi_warps( + float (*o_frag)[num_frags_y][8], + smem_t* o_smem, + OutT* o_ptr_base, + uint32_t o_idx_base, + const uint32_t q_head_idx_base, + const uint32_t qo_upper_bound, + const uint32_t qo_n_stride, + const uint32_t qo_h_stride) { + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + constexpr int VEC_SIZE = 16 / sizeof(T); + // [num_warps * num_frags_x * 16, num_frags_y * 16] + if (ty == 0) { + // [num_frags_x * 16, num_frags_y * 16] +#pragma unroll 1 + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + uint32_t o_frag_f16[4]; + vec_cast((T*)o_frag_f16, o_frag[fx][fy]); + uint32_t o_smem_offset_w = + smem_t::get_permuted_offset(fx * 16 + tx / 4, + fy * 2); + ((uint32_t*)(o_smem->base + o_smem_offset_w))[tx % 4] = o_frag_f16[0]; + ((uint32_t*)(o_smem->base + o_smem_offset_w + + 8 * num_vecs_per_head))[tx % 4] = o_frag_f16[1]; + ((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1)))[tx % 4] = + o_frag_f16[2]; + ((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1) + + 8 * num_vecs_per_head))[tx % 4] = o_frag_f16[3]; + } + } + } + __syncthreads(); + + uint32_t o_smem_offset_w = + smem_t::get_permuted_offset(ty * 4 + tx / 8, tx % 8); + + const uint32_t tx_offset = tx / 8; +#pragma unroll 1 + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + const uint32_t base_offset = o_idx_base + fx * 16 + tx_offset; +#pragma unroll + const int j = ty; + const uint32_t offset_now = base_offset + j * 4; + const uint32_t n_offset = offset_now / group_size; + const uint32_t h_offset = offset_now % group_size; + + OutT* o_ptr = o_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride; +#pragma unroll + for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { + if (n_offset < qo_upper_bound) { + o_smem->store_128b(o_smem_offset_w, o_ptr); + } + o_ptr += 8 * num_elems_per_128b(); + o_smem_offset_w = + o_smem->advance_offset_by_column<8>(o_smem_offset_w, fyo); + } + o_smem_offset_w = + o_smem->advance_offset_by_row<16, num_vecs_per_head>(o_smem_offset_w) - + 2 * num_frags_y; + } +} + +template +struct prefill_softmax_state_t { + AlignedVector o; + float m; + float d; + + __device__ __forceinline__ void init() { + if constexpr (std::is_same::value) { +#pragma unroll + for (int i = 0; i < vec_size / 2; ++i) { + *((half2*)(&o) + i) = make_half2(0, 0); + } + } else if constexpr (std::is_same::value) { +#pragma unroll + for (int i = 0; i < vec_size / 2; ++i) { + *((nv_bfloat162*)(&o) + i) = make_bfloat162(0, 0); + } + } + d = 1.f; + if constexpr (std::is_same::value) { + m = -5e4f; + } else if constexpr (std::is_same::value) { + m = -3.38953e38f; + } + } + + __device__ __forceinline__ void merge( + const AlignedVector& other_o, float other_m, float other_d) { + float m_prev = m, d_prev = d; + m = m_prev > other_m ? m_prev : other_m; + const float scale1 = __expf(m_prev - m), scale2 = __expf(other_m - m); + const T scale1_T = static_cast(scale1), + scale2_T = static_cast(scale2); + d = d_prev * scale1 + other_d * scale2; +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + o[i] = o[i] * scale1_T + other_o[i] * scale2_T; + } + } + + __device__ __forceinline__ void normalize() { + const T d_t = static_cast(d); +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + o[i] /= d_t; + } + } + + __device__ __forceinline__ void normalize(float current_sink) { + const T d_t = static_cast(d + __expf(current_sink - m)); +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + o[i] /= d_t; + } + } +}; + +// C16 (fp16/bf16 KV cache) helper functions + +template +__device__ __forceinline__ void produce_kv_blockwise(smem_t smem, + uint32_t* smem_offset, + T** gptr, + const uint32_t kv_b_stride, + const uint32_t kv_idx_base, + const uint32_t kv_len) { + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); + constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q; + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + uint32_t kv_idx = kv_idx_base + ty * 4 + tx / 8; +#pragma unroll + for (uint32_t i = 0; i < NUM_WARP_KV * num_frags_z * 4 / num_warps; ++i) { +#pragma unroll + for (uint32_t j = 0; j < num_frags_y / 4; ++j) { + smem.load_128b_async(*smem_offset, *gptr, kv_idx < kv_len); + *smem_offset = smem.advance_offset_by_column<8>(*smem_offset, j); + *gptr += 8 * num_elems_per_128b(); + } + kv_idx += num_warps * 4; + *smem_offset = smem.advance_offset_by_row( + *smem_offset) - + 2 * num_frags_y; + *gptr += + num_warps * 4 * kv_b_stride - 2 * num_frags_y * num_elems_per_128b(); + } + *gptr -= NUM_WARP_KV * num_frags_z * 16 * kv_b_stride; + *smem_offset -= NUM_WARP_KV * num_frags_z * 16 * num_vecs_per_head; +} + +template +__device__ __forceinline__ void compute_qk(smem_t* q_smem, + uint32_t* q_smem_offset_r, + smem_t* k_smem, + uint32_t* k_smem_offset_r, + float (*s_frag)[num_frags_z][8]) { + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); + uint32_t a_frag[num_frags_x][4], b_frag[4]; +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + q_smem->ldmatrix_m8n8x4(*q_smem_offset_r, a_frag[fx]); + *q_smem_offset_r = q_smem->advance_offset_by_row<16, num_vecs_per_head>( + *q_smem_offset_r); + } + + *q_smem_offset_r = + q_smem->advance_offset_by_column<2>(*q_smem_offset_r, fy) - + num_frags_x * 16 * num_vecs_per_head; + +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + k_smem->ldmatrix_m8n8x4(*k_smem_offset_r, b_frag); + *k_smem_offset_r = k_smem->advance_offset_by_row<16, num_vecs_per_head>( + *k_smem_offset_r); +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + if (fy == 0) { + mma_sync_m16n16k16_row_col_f16f16f32( + s_frag[fx][fz], a_frag[fx], b_frag); + } else { + mma_sync_m16n16k16_row_col_f16f16f32( + s_frag[fx][fz], a_frag[fx], b_frag); + } + } + } + *k_smem_offset_r = + k_smem->advance_offset_by_column<2>(*k_smem_offset_r, fy) - + num_frags_z * 16 * num_vecs_per_head; + } + *q_smem_offset_r -= num_frags_y * 2; + *k_smem_offset_r -= num_frags_y * 2; +} + +template +__device__ __forceinline__ void compute_sfm_v(smem_t* v_smem, + uint32_t* v_smem_offset_r, + float (*s_frag)[num_frags_z][8], + float (*o_frag)[num_frags_y][8], + float (*d)[2]) { + constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); + + T s_frag_f16[num_frags_x][num_frags_z][8]; +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + vec_cast(s_frag_f16[fx][fz], s_frag[fx][fz]); + } + } + +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + rowsum_f16f16f32(d[fx], s_frag_f16[fx][fz]); + } + } + +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + uint32_t b_frag[4]; + v_smem->ldmatrix_m8n8x4_trans(*v_smem_offset_r, b_frag); +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + mma_sync_m16n16k16_row_col_f16f16f32( + o_frag[fx][fy], (uint32_t*)(s_frag_f16[fx][fz]), b_frag); + } + *v_smem_offset_r = + v_smem->advance_offset_by_column<2>(*v_smem_offset_r, fy); + } + *v_smem_offset_r = + v_smem->advance_offset_by_row<16, num_vecs_per_head>(*v_smem_offset_r) - + 2 * num_frags_y; + } + *v_smem_offset_r -= 16 * num_frags_z * num_vecs_per_head; +} + +template +__global__ void merge_chunks_kernel( + const T* __restrict__ multi_out, // [token_num, num_chunks, num_heads, + // head_dim] + const float* __restrict__ multi_m, // [token_num, num_chunks, num_heads] + const float* __restrict__ multi_d, // [token_num, num_chunks, num_heads] + const int* __restrict__ seq_lens_q, + const int* __restrict__ seq_lens_kv, + const int* __restrict__ seq_lens_encoder, + const int* __restrict__ batch_id_per_token, + const int* __restrict__ cu_seqlens_q, + const T* __restrict__ shift_bias, // [q_num_heads * HEAD_DIM] + const T* __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM] + const T* __restrict__ sinks, // [q_num_heads] + const int* __restrict__ chunk_size_ptr, + T* __restrict__ out, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int max_seq_len, + const int num_chunks, + const int num_heads, + const int head_dim, + const int token_num, + const int max_tokens_per_batch = 5) { + const int vid = threadIdx.x, ty = threadIdx.y; + const int hid = blockIdx.y; + // After intra-warp reduction, only bdy/2 results need smem storage + __shared__ T smem[(bdy / 2) * HEAD_DIM]; + __shared__ float md_smem[(bdy / 2) * 2]; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif + // Phase 1: Fast path — all ty participate independently (no smem, no + // syncthreads) Each ty handles a different qid with stride gridDim.x * bdy + using LoadT = AlignedVector; + for (int qid = blockIdx.x + ty * gridDim.x; qid < token_num; + qid += gridDim.x * bdy) { + const uint32_t bid = batch_id_per_token[qid]; + if (bid == (uint32_t)-1) continue; + if (seq_lens_encoder[bid] > 0) continue; // skip prefill batches + const uint32_t local_seq_id = qid - cu_seqlens_q[bid]; + const int seq_len_q = seq_lens_q[bid]; + if (seq_len_q == 0) continue; + int seq_len_kv = seq_lens_kv[bid]; + if (seq_len_kv == 0) continue; + seq_len_kv += seq_len_q; + const int num_chunks_this_seq = div_up(seq_len_kv, *chunk_size_ptr); + if (num_chunks_this_seq != 1) continue; // handled in Phase 2 + + LoadT load_vec; + uint32_t offset = + ((bid * max_tokens_per_batch + local_seq_id) * num_chunks * num_heads + + hid) * + head_dim + + vid * vec_size; + Load(&multi_out[offset], &load_vec); + Store( + load_vec, &out[(qid * num_heads + hid) * head_dim + vid * vec_size]); + } + + // Phase 2: Slow path — merge multi-chunk results + // Optimization: use warp-shuffle reduction within each warp, then cross-warp + // via smem. This eliminates the large smem[bdy * HEAD_DIM] buffer and reduces + // syncthreads from 2 per qid to 1 per qid. + // Block layout: (blockx=16, bdy=8) => 4 warps, each warp has 2 ty values + // Warp 0: ty=0,1 Warp 1: ty=2,3 Warp 2: ty=4,5 Warp 3: ty=6,7 + // Lane layout within warp: lanes 0-15 = (ty_low, vid), lanes 16-31 = + // (ty_high, vid) + const int lane_id = (ty * blockDim.x + vid) % 32; + + for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) { + const uint32_t bid = batch_id_per_token[qid]; + if (bid == (uint32_t)-1) continue; // uniform skip — no syncthreads needed + if (seq_lens_encoder[bid] > 0) continue; + const uint32_t local_seq_id = qid - cu_seqlens_q[bid]; + const int seq_len_q = seq_lens_q[bid]; + if (seq_len_q == 0) continue; + int seq_len_kv = seq_lens_kv[bid]; + if (seq_len_kv == 0) continue; + seq_len_kv += seq_len_q; + const int num_chunks_this_seq = div_up(seq_len_kv, *chunk_size_ptr); + if (num_chunks_this_seq == 1) continue; // handled in Phase 1 + + LoadT load_vec; + LoadT res_vec; + if constexpr (std::is_same::value) { +#pragma unroll + for (int i = 0; i < vec_size / 2; ++i) { + *((half2*)(&res_vec) + i) = make_half2(0, 0); + } + } else { +#pragma unroll + for (int i = 0; i < vec_size / 2; ++i) { + *((nv_bfloat162*)(&res_vec) + i) = make_bfloat162(0, 0); + } + } + float m; + float d = 1.f; + if constexpr (std::is_same::value) { + m = -5e4f; + } else if constexpr (std::is_same::value) { + m = -3.0e+30f; + } + + // Step 1: Each ty iterates over its chunk subset and does local online + // softmax merge +#pragma unroll 2 + for (int i = ty; i < num_chunks_this_seq; i += bdy) { + uint32_t offset; + + offset = ((bid * max_tokens_per_batch + local_seq_id) * num_chunks + i) * + num_heads + + hid; + float m_prev = m; + float d_prev = d; + const float m_now = multi_m[offset]; + const float d_now = multi_d[offset]; + m = max(m_prev, m_now); + + offset = ((bid * max_tokens_per_batch + local_seq_id) * num_chunks * + num_heads + + i * num_heads + hid) * + head_dim + + vid * vec_size; + Load(&multi_out[offset], &load_vec); + const float scale1 = __expf(m_prev - m), scale2 = __expf(m_now - m); + const T scale1_T = static_cast(scale1), + scale2_T = static_cast(scale2); + d = d * scale1 + d_now * scale2; +#pragma unroll + for (int j = 0; j < vec_size; j++) { + res_vec[j] = res_vec[j] * scale1_T + load_vec[j] * scale2_T; + } + } + + // Step 2: Intra-warp reduction via warp shuffle + // Each warp has 2 ty values: ty_low at lanes 0-15, ty_high at lanes 16-31 + // Merge ty_high into ty_low using shuffle + const int partner_lane = lane_id ^ 16; // flip bit 4 to swap low/high ty + const float m_partner = __shfl_sync(0xffffffff, m, partner_lane); + const float d_partner = __shfl_sync(0xffffffff, d, partner_lane); + // Pack adjacent 16-bit pairs into 32-bit for efficient shuffle. + // AlignedVector alignment >= 4 bytes, so uint32 reinterpret is safe + // — no OOB read, no type confusion. This halves shuffle count vs + // per-element memcpy for bf16/fp16. + constexpr int PACKED_SIZE = vec_size * sizeof(T) / sizeof(unsigned); + const unsigned* packed_res = reinterpret_cast(&res_vec); + unsigned packed_partner[PACKED_SIZE]; +#pragma unroll + for (int j = 0; j < PACKED_SIZE; j++) { + packed_partner[j] = __shfl_sync(0xffffffff, packed_res[j], partner_lane); + } + LoadT partner_vec; + memcpy(&partner_vec, packed_partner, sizeof(partner_vec)); + + // Merge partner into self (only the "low ty" keeps the result) + float m_new = max(m, m_partner); + const float scale1 = __expf(m - m_new); + const float scale2 = __expf(m_partner - m_new); + float d_new = d * scale1 + d_partner * scale2; + if ((ty & 1) == 0) { // low ty keeps merged result + m = m_new; + d = d_new; + const T scale1_T = static_cast(scale1); + const T scale2_T = static_cast(scale2); +#pragma unroll + for (int j = 0; j < vec_size; j++) { + res_vec[j] = res_vec[j] * scale1_T + partner_vec[j] * scale2_T; + } + } + + // Cross-warp: only even ty (0,2,4,6) write to smem + if ((ty & 1) == 0) { + Store(res_vec, &smem[(ty / 2) * head_dim + vid * vec_size]); + md_smem[ty] = m; + md_smem[ty + 1] = d; + } + __syncthreads(); + + if (ty == 0) { + prefill_softmax_state_t st; + st.init(); +#pragma unroll + for (int i = 0; i < bdy / 2; i++) { + Load(&smem[i * head_dim + vid * vec_size], &load_vec); + const float m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1]; + st.merge(load_vec, m_tmp, d_tmp); + } + + if (sinks) { + float current_sink = static_cast(sinks[hid]); + st.normalize(current_sink); + } else { + st.normalize(); + } + + const uint32_t shift_smooth_offset = hid * head_dim + vid * vec_size; + AlignedVector shift_bias_vec; + AlignedVector smooth_weight_vec; + AlignedVector out_vec; + if (shift_bias) { + Load(shift_bias + shift_smooth_offset, &shift_bias_vec); + Load(smooth_weight + shift_smooth_offset, + &smooth_weight_vec); + } + +#pragma unroll + for (int i = 0; i < vec_size; ++i) { + StoreFunc()(st.o, + shift_bias_vec, + smooth_weight_vec, + out_vec, + quant_max_bound, + quant_min_bound, + in_scale, + i); + } + Store( + out_vec, &out[(qid * num_heads + hid) * head_dim + vid * vec_size]); + } + __syncthreads(); + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif +} diff --git a/custom_ops/gpu_ops/decode_unified_attention/config_for_attention.cu b/custom_ops/gpu_ops/decode_unified_attention/config_for_attention.cu new file mode 100644 index 00000000000..7033cbd10bf --- /dev/null +++ b/custom_ops/gpu_ops/decode_unified_attention/config_for_attention.cu @@ -0,0 +1,409 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "cute/tensor.hpp" +#include "helper.h" +#include "paddle/extension.h" +#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU +#include "paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h" +#include "paddle/phi/core/memory/memcpy.h" +#endif +#include "utils.cuh" + +template +__global__ void GetMaxLenKernel(const int* seq_lens_decoder, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + int* max_lens, + const int batch_size) { + const int tid = threadIdx.x; + + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + int max_len_this_time_this_thread = 0; + int max_len_encoder_this_thread = 0; + int max_len_decoder_this_thread = 0; + int max_len_this_thread = 0; + int max_just_dec_len_this_thread = 0; + int max_len_kv_this_thread = 0; + for (int i = tid; i < batch_size; i += blockDim.x) { + const int seq_len_this_time = seq_lens_this_time[i]; + const int seq_len_decoder = seq_lens_decoder[i]; + max_len_this_time_this_thread = + max(seq_len_this_time, max_len_this_time_this_thread); + max_len_encoder_this_thread = + max(seq_lens_encoder[i], max_len_encoder_this_thread); + max_len_decoder_this_thread = + max(seq_len_decoder, max_len_decoder_this_thread); + if (seq_len_this_time <= 0) continue; + const int max_just_dec_len_now = + seq_lens_encoder[i] > 0 ? 0 : seq_len_decoder; + max_len_this_thread = + max(seq_len_decoder + seq_len_this_time, max_len_this_thread); + max_just_dec_len_this_thread = + max(max_just_dec_len_this_thread, max_just_dec_len_now); + + if (seq_len_decoder == 0) continue; + max_len_kv_this_thread = + max(seq_len_this_time + seq_len_decoder, max_len_kv_this_thread); + } + int total_max_len_this_time = + BlockReduce(temp_storage) + .Reduce(max_len_this_time_this_thread, MaxOp()); + int total_max_len_encoder = + BlockReduce(temp_storage) + .Reduce(max_len_encoder_this_thread, MaxOp()); + int total_max_len_decoder = + BlockReduce(temp_storage) + .Reduce(max_len_decoder_this_thread, MaxOp()); + int total = + BlockReduce(temp_storage).Reduce(max_len_this_thread, MaxOp()); + int total_just_dec = BlockReduce(temp_storage) + .Reduce(max_just_dec_len_this_thread, MaxOp()); + int total_max_len_kv = + BlockReduce(temp_storage).Reduce(max_len_kv_this_thread, MaxOp()); + if (tid == 0) { + max_lens[0] = total_max_len_this_time; + max_lens[1] = total_max_len_encoder; + max_lens[2] = total_max_len_decoder; + max_lens[3] = total; + max_lens[4] = total_just_dec; + max_lens[5] = total_max_len_kv; + } +} + +template +__global__ void config_decode_attn(const int* __restrict__ seq_lens_this_time, + const int* __restrict__ seq_lens_encoder, + const int* __restrict__ seq_lens_decoder, + int4* __restrict__ block_indices, + int* __restrict__ num_blocks, + int* __restrict__ chunk_size, + const int bsz, + const int group_size, + const int kv_num_heads, + const int q_tile_size, + const int max_tokens_per_batch, + const int config_gridx) { + const int tid = threadIdx.x, wid = threadIdx.y; + const uint32_t warp_size = blockDim.x; + __shared__ int num_block_all_shared[block_size]; + __shared__ int chunk_size_res[1]; + + const int lane_id = tid + wid * warp_size; + + // Merged Step 1+2: single bsz loop computing both Scheme E metrics and + // split-KV block counts per lane. Avoids redundant seq_lens reads and + // shared intermediate values (token_num, kv_len, q_tile_num). + const int target_blocks = config_gridx / 3; // sm_count * 3 + // Search chunk_size from 512 with step 128: {512, 640, 768, ...} + + const int cur_chunk_size = + min(min_chunk_size + lane_id * chunk_step, max_chunk_size); + int num_block_no_chunk = 0; + int max_kv_len_no_chunk = 0; + int num_block_all = 0; + for (int bid = 0; bid < bsz; bid++) { + if (seq_lens_this_time[bid] <= 0 || seq_lens_encoder[bid] > 0) { + continue; + } + const int token_num_cur_batch = seq_lens_this_time[bid]; + const int kv_len_cur_batch = seq_lens_decoder[bid] + token_num_cur_batch; + const int q_tile_num = + div_up(token_num_cur_batch * group_size, q_tile_size); + num_block_no_chunk += q_tile_num * kv_num_heads; + max_kv_len_no_chunk = max(max_kv_len_no_chunk, kv_len_cur_batch); + const int kv_chunk_num = div_up(kv_len_cur_batch, cur_chunk_size); + num_block_all += q_tile_num * kv_chunk_num * kv_num_heads; + } + num_block_all_shared[lane_id] = num_block_all; + __syncthreads(); + + // Step 3: find best chunk_size, then decide Scheme E vs split-KV + if (tid == 0 && wid == 0) { + // Strategy: + // 1. Must fill target_blocks (2*sm_count) to maintain SM concurrency + // 2. Among valid choices, prefer minimum per-SM max KV traffic + // (= waves * chunk_size, since kernel time = slowest SM) + // 3. Within 5% of minimum KV traffic, prefer larger chunk_size + int chunk_size_best = min_chunk_size; + int num_block_all_best = num_block_all_shared[0]; + // Step 1: find minimum kv_traffic among chunk_sizes that fill SMs + int64_t kv_traffic_min = INT64_MAX; + for (int i = 0; i < static_cast(block_size); i++) { + const int nb = num_block_all_shared[i]; + if (nb < target_blocks) continue; + const int cs = min(min_chunk_size + i * chunk_step, max_chunk_size); + const int w = div_up(nb, target_blocks); + const int64_t kv_traffic = static_cast(w) * cs; + if (kv_traffic < kv_traffic_min) { + kv_traffic_min = kv_traffic; + } + } + // Step 2: if no chunk_size fills SMs, fall back to smallest + if (kv_traffic_min == INT64_MAX) { + chunk_size_best = min_chunk_size; + num_block_all_best = num_block_all_shared[0]; + } else { + // Step 3: scan from largest chunk_size downward; accept the first + // one that fills SMs AND has kv_traffic within 20% of minimum + for (int i = block_size - 1; i >= 0; i--) { + const int nb = num_block_all_shared[i]; + if (nb < target_blocks) continue; + const int cs = min(min_chunk_size + i * chunk_step, max_chunk_size); + const int w = div_up(nb, target_blocks); + const int64_t kv_traffic = static_cast(w) * cs; + if (kv_traffic <= kv_traffic_min + kv_traffic_min / 4) { + chunk_size_best = cs; + num_block_all_best = nb; + break; + } + } + } + + // Decide Scheme E: prefer when blocks fill SMs AND estimated latency + // is no worse than split-KV. + // Scheme E: waves_E * max_kv_len (few heavy blocks) + // Split-KV: waves_split * chunk_size_best (many light blocks) + // When no splitting is needed (num_block_all_best == num_block_no_chunk), + // Scheme E is strictly better (saves merge overhead). + bool use_scheme_e = false; + if (num_block_no_chunk >= target_blocks) { + if (num_block_all_best == num_block_no_chunk) { + use_scheme_e = true; + } else { + // target_blocks = sm_count * 3 ≈ CTAs per wave (sm_count × occupancy). + // Using target_blocks as denominator correctly accounts for occupancy + // in wave count estimation. + const int waves_e = div_up(num_block_no_chunk, target_blocks); + const int waves_split = div_up(num_block_all_best, target_blocks); + use_scheme_e = (static_cast(waves_e) * max_kv_len_no_chunk <= + static_cast(waves_split) * chunk_size_best); + } + } + + if (use_scheme_e) { + num_blocks[0] = num_block_no_chunk; + chunk_size[0] = INT_MAX; + chunk_size_res[0] = INT_MAX; + } else { + num_blocks[0] = num_block_all_best; + chunk_size[0] = chunk_size_best; + chunk_size_res[0] = chunk_size_best; + } + } + + __syncthreads(); + if (wid == 0) { + const int chunk_size_final = chunk_size_res[0]; + + int prev_offset = 0; + for (int base = 0; base < bsz; base += warp_size) { + const int bid = base + tid; + int num_block_cur = 0; + int q_tile_num = 0; + int kv_chunk_num = 0; + + if (bid < bsz) { + int token_num_cur_batch = seq_lens_this_time[bid]; + if (seq_lens_encoder && seq_lens_encoder[bid] > 0) { + token_num_cur_batch = 0; + } + q_tile_num = div_up(token_num_cur_batch * group_size, q_tile_size); + const int kv_len_cur_batch = + seq_lens_decoder[bid] + token_num_cur_batch; + kv_chunk_num = div_up(kv_len_cur_batch, chunk_size_final); + num_block_cur = q_tile_num * kv_chunk_num * kv_num_heads; + } + + // inclusive prefix sum + int x = num_block_cur; + for (int offset = 1; offset < warp_size; offset <<= 1) { + int y = __shfl_up_sync(0xffffffff, x, offset); + if (tid >= offset) x += y; + } + int bid_offset = x - num_block_cur; + int tile_sum = __shfl_sync(0xffffffff, x, warp_size - 1); + + // Write block_indices using int4 vectorized stores. + // Each entry is exactly 4 ints (bid, kv_head_id, kv_chunk_id, q_tile_id), + // matching int4 layout. This reduces 4 scalar stores to 1 vector store. + if (bid < bsz && num_block_cur > 0) { + int4* write_ptr = block_indices + prev_offset + bid_offset; + int flat_idx = 0; + const int kv_chunk_num_x_q_tile_num = kv_chunk_num * q_tile_num; +#pragma unroll 2 + for (int kv_head_id = 0; kv_head_id < kv_num_heads; kv_head_id++) { + const int head_base = kv_head_id * kv_chunk_num_x_q_tile_num; +#pragma unroll 2 + for (int kv_chunk_id = 0; kv_chunk_id < kv_chunk_num; kv_chunk_id++) { + const int chunk_base = head_base + kv_chunk_id * q_tile_num; +#pragma unroll + for (int q_tile_id = 0; q_tile_id < q_tile_num; q_tile_id++) { + write_ptr[flat_idx] = + make_int4(bid, kv_head_id, kv_chunk_id, q_tile_id); + flat_idx++; + } + } + } + } + prev_offset += tile_sum; + } + } +} + +void ConfigForAttention( + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + paddle::Tensor& block_indices, // Inplace, shape:[block_num,4], block's + // indices with 4 dimension[batch_idx, + // kv_head_idx, kv_chunk_idx, q_tile_idx] + paddle::Tensor& num_blocks, // Inplace + paddle::Tensor& chunk_size, // Inplace + paddle::Tensor& max_len_tensor_cpu, // Inplace, CPU + const std::string cache_quant_type, + const int group_size, + const int kv_num_heads, + const int max_tokens_per_batch) { + auto stream = seq_lens_encoder.stream(); + int bsz = seq_lens_this_time.shape()[0]; + + paddle::Tensor max_len_tensor_gpu = + GetEmptyTensor({max_len_tensor_cpu.shape()[0]}, + paddle::DataType::INT32, + seq_lens_this_time.place()); + + GetMaxLenKernel<1024><<<1, 1024, 0, stream>>>(seq_lens_decoder.data(), + seq_lens_this_time.data(), + seq_lens_encoder.data(), + max_len_tensor_gpu.data(), + bsz); +#ifndef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + if (!phi::backends::gpu::IsCUDAGraphCapturing()) +#endif + max_len_tensor_cpu.copy_( + max_len_tensor_gpu, max_len_tensor_cpu.place(), false); + auto max_len_cpu_ptr = max_len_tensor_cpu.data(); + int max_just_dec_len_this_time = max_len_cpu_ptr[4]; + + const uint32_t block_indices_ele_num = block_indices.size(); + + // decoder + if (max_just_dec_len_this_time > 0) { + CUDA_CHECK(cudaMemsetAsync(block_indices.data(), + 0, + block_indices_ele_num * sizeof(int32_t), + stream)); + CUDA_CHECK( + cudaMemsetAsync(num_blocks.data(), 0, sizeof(int32_t), stream)); + CUDA_CHECK( + cudaMemsetAsync(chunk_size.data(), 0, sizeof(int32_t), stream)); + + int device; + CUDA_CHECK(cudaGetDevice(&device)); + int sm_cout; + CUDA_CHECK(cudaDeviceGetAttribute( + &sm_cout, cudaDevAttrMultiProcessorCount, device)); + const int config_gridx = sm_cout * 6; + + const int q_tile_size = 16; + dim3 blocks(32, 4); + // Cast block_indices to int4* for vectorized stores. + // Each block_indices entry is 4 ints = 16 bytes = sizeof(int4), + // and block_num * 4 ints = block_num int4s, so the reinterpret is valid. + int4* block_indices_i4 = reinterpret_cast(block_indices.data()); + if (cache_quant_type == "cache_int4_zp") { + config_decode_attn<512, 256, 128, 32768> + <<<1, blocks, 0, stream>>>(seq_lens_this_time.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + block_indices_i4, + num_blocks.data(), + chunk_size.data(), + bsz, + group_size, + kv_num_heads, + q_tile_size, + max_tokens_per_batch, + config_gridx); + } else { + config_decode_attn<512, 128, 128, 16384> + <<<1, blocks, 0, stream>>>(seq_lens_this_time.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + block_indices_i4, + num_blocks.data(), + chunk_size.data(), + bsz, + group_size, + kv_num_heads, + q_tile_size, + max_tokens_per_batch, + config_gridx); + } + } +} + +std::vector> ConfigForAttentionInferShape( + const std::vector& seq_lens_encoder_shape, + const std::vector& seq_lens_decoder_shape, + const std::vector& seq_lens_this_time_shape, + const std::vector& num_blocks_shape, + const std::vector& chunk_size_shape, + const std::vector& max_len_tensor_cpu_shape, + const std::string cache_quant_type, + const int group_size, + const int kv_num_heads, + const int max_tokens_per_batch) { + return {}; +} + +std::vector ConfigForAttentionInferDtype( + const paddle::DataType& seq_lens_encoder_dtype, + const paddle::DataType& seq_lens_decoder_dtype, + const paddle::DataType& seq_lens_this_time_dtype, + const paddle::DataType& num_blocks_dtype, + const paddle::DataType& chunk_size_dtype, + const paddle::DataType& max_len_tensor_cpu_dtype, + const std::string cache_quant_type, + const int group_size, + const int kv_num_heads, + const int max_tokens_per_batch) { + return {}; +} + +PD_BUILD_STATIC_OP(config_for_attention) + .Inputs({ + "seq_lens_encoder", + "seq_lens_decoder", + "seq_lens_this_time", + "block_indices", + "num_blocks", + "chunk_size", + "max_len_tensor_cpu", + }) + .Outputs({ + + }) + .Attrs({"cache_quant_type: std::string", + "group_size: int", + "kv_num_heads: int", + "max_tokens_per_batch: int"}) + .SetKernelFn(PD_KERNEL(ConfigForAttention)) + .SetInferShapeFn(PD_INFER_SHAPE(ConfigForAttentionInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(ConfigForAttentionInferDtype)); diff --git a/custom_ops/gpu_ops/decode_unified_attention/cu_tensor_map.cuh b/custom_ops/gpu_ops/decode_unified_attention/cu_tensor_map.cuh new file mode 100644 index 00000000000..ff84e1cd3f6 --- /dev/null +++ b/custom_ops/gpu_ops/decode_unified_attention/cu_tensor_map.cuh @@ -0,0 +1,124 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include +#include +#include +#include +#include +#include + +using barrier = cuda::barrier; +namespace cde = cuda::device::experimental; + +template +struct cu_tensor_map_type_traits { + static const CUtensorMapDataType type = + CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; +}; + +template <> +struct cu_tensor_map_type_traits { + static const CUtensorMapDataType type = + CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; +}; + +template <> +struct cu_tensor_map_type_traits { + static const CUtensorMapDataType type = + CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16; +}; + +template <> +struct cu_tensor_map_type_traits { + static const CUtensorMapDataType type = + CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8; +}; + +template <> +struct cu_tensor_map_type_traits { + static const CUtensorMapDataType type = + CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8; +}; + +template +CUtensorMap makeTensorMapForKVCache(T const* addr, + uint32_t block_num, + uint32_t kv_num_head, + uint32_t second_size, + uint32_t last_size) { + CUtensorMap tensorMap{}; + + uint32_t elem_bytes = sizeof(T); + + uint32_t const last_size_bytes = elem_bytes * last_size; + // VLLM Layout + CUtensorMapDataType data_dtype = cu_tensor_map_type_traits::type; + constexpr uint32_t rank = 4; + uint64_t global_dims[] = {last_size, second_size, kv_num_head, block_num}; + uint64_t global_strides[] = {last_size_bytes, + second_size * last_size_bytes, + kv_num_head * second_size * last_size_bytes}; + + uint32_t box_dims[] = {last_size, second_size, 1, 1}; + uint32_t elem_strides[] = {1, 1, 1, 1}; + + auto const swizzle = [&] { + switch (last_size_bytes) { + case 128: + return CU_TENSOR_MAP_SWIZZLE_128B; + case 64: + return CU_TENSOR_MAP_SWIZZLE_64B; + default: + throw std::runtime_error("unsupported cache last_size"); + } + }(); + CUresult res = cuTensorMapEncodeTiled( + &tensorMap, + data_dtype, + rank, + reinterpret_cast(const_cast(addr)), + global_dims, + global_strides, + box_dims, + elem_strides, + CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, + swizzle, + CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B, + CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + switch (res) { + case CUDA_SUCCESS: + printf("CUDA_SUCCESS!\n"); + break; + case CUDA_ERROR_INVALID_VALUE: + printf("CUDA_ERROR_INVALID_VALUE\n"); + break; + case CUDA_ERROR_OUT_OF_MEMORY: + printf("CUDA_ERROR_OUT_OF_MEMORY\n"); + break; + case CUDA_ERROR_NOT_INITIALIZED: + printf("CUDA_ERROR_NOT_INITIALIZED\n"); + break; + case CUDA_ERROR_DEINITIALIZED: + printf("CUDA_ERROR_DEINITIALIZED\n"); + break; + case CUDA_ERROR_PROFILER_DISABLED: + printf("CUDA_ERROR_PROFILER_DISABLED\n"); + break; + default: + throw std::runtime_error("unsupported res!"); + } + + return tensorMap; +} diff --git a/custom_ops/gpu_ops/decode_unified_attention/decode_unified_attention_c16_impl.cuh b/custom_ops/gpu_ops/decode_unified_attention/decode_unified_attention_c16_impl.cuh new file mode 100644 index 00000000000..e30588a01ab --- /dev/null +++ b/custom_ops/gpu_ops/decode_unified_attention/decode_unified_attention_c16_impl.cuh @@ -0,0 +1,492 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include "utils.cuh" +#include "attention_func.cuh" + +template +__global__ void decode_unified_attention_c16_kernel( + AttentionParams params) { + const uint32_t tid = threadIdx.x, wid = threadIdx.y; + + // Cache loop-invariant params fields into registers. + // Pass-by-value (no __grid_constant__) allows the compiler to cache + // struct fields, and explicit local variables guarantee no constant + // cache pressure in the grid-stride loop. + // Only cache frequently-used fields; rarely-used ones are accessed + // via params.xxx to reduce register pressure (Scheme I-A.2). + const auto qkv = params.qkv; + const auto cache_k = params.cache_k; + const auto cache_v = params.cache_v; + const auto seq_lens_q = params.seq_lens_q; + const auto seq_lens_kv = params.seq_lens_kv; + const auto block_table = params.block_table; + const auto cu_seqlens_q = params.cu_seqlens_q; + const auto block_indices = params.block_indices; + const auto mask_offset = params.mask_offset; + const auto attn_mask = params.attn_mask; + const auto tmp_o = params.tmp_o; + const auto tmp_m = params.tmp_m; + const auto tmp_d = params.tmp_d; + const float softmax_scale = params.softmax_scale; + const int q_num_heads = params.q_num_heads; + const int kv_num_heads = params.kv_num_heads; + + extern __shared__ __align__(128) uint8_t smem[]; + smem_t qo_smem(smem); + smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)), + v_smem(smem + (num_frags_x * 16 + BLOCK_SIZE) * HEAD_DIM * sizeof(T)); + + int total_block = params.num_blocks_ptr[0]; + int chunk_size = params.chunk_size_ptr[0]; + + for (int lane_idx = blockIdx.x; lane_idx < total_block; + lane_idx += gridDim.x) { + int4 indices = reinterpret_cast(block_indices)[lane_idx]; + int batch_idx = indices.x; + int kv_head_idx = indices.y; + int chunk_idx = indices.z; + int tile_idx = indices.w; + int q_head_idx = kv_head_idx * GROUP_SIZE; + + const uint32_t q_len = seq_lens_q[batch_idx]; + const int* block_table_now = + block_table + batch_idx * params.max_blocks_per_seq; + + constexpr uint32_t num_rows_per_block = num_frags_x * 16; + const uint32_t q_end = + min(q_len, div_up((tile_idx + 1) * num_rows_per_block, GROUP_SIZE)); + const uint32_t kv_len = seq_lens_kv[batch_idx] + q_len; + const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size); + + constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); + + const uint32_t q_n_stride = q_num_heads * HEAD_DIM; + const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM; + const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_b_stride = HEAD_DIM; + + float s_frag[num_frags_x][num_frags_z][8]; + float o_frag[num_frags_x][num_frags_y][8]; + float m_frag[num_frags_x][2]; + float d_frag[num_frags_x][2]; + + const uint32_t chunk_start = chunk_idx * chunk_size; + const uint32_t chunk_end = min(kv_len, chunk_start + chunk_size); + const uint32_t chunk_len = chunk_end - chunk_start; + + init_states(o_frag, m_frag, d_frag); + + const uint32_t q_start_seq_id = cu_seqlens_q[batch_idx]; + const uint32_t q_base_seq_id_this_block = tile_idx * num_frags_x * 16; + const uint32_t q_offset = q_start_seq_id * q_ori_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + T* q_base_ptr = qkv + q_offset; + + T* o_base_ptr_T = tmp_o + + batch_idx * params.max_tokens_per_batch * + params.max_num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + const int* mask_offset_this_seq = + mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; + const bool* attn_mask_this_seq = + attn_mask ? attn_mask + + batch_idx * params.attn_mask_len * params.attn_mask_len + : nullptr; + + uint32_t q_smem_offset_r = + smem_t::get_permuted_offset(tid % 16, tid / 16); + + load_q_global_smem_multi_warps(q_base_ptr, + &qo_smem, + q_base_seq_id_this_block, + q_end, + q_ori_n_stride, + HEAD_DIM); + commit_group(); + wait_group<0>(); + __syncthreads(); + + q_smem_inplace_multiply_sm_scale_multi_warps( + &qo_smem, softmax_scale); + + const uint32_t num_iterations = + div_up(CAUSAL ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + div_up((tile_idx + 1) * num_rows_per_block, + GROUP_SIZE), + chunk_start))) + : chunk_len, + BLOCK_SIZE); + const uint32_t mask_check_iteration = + (CAUSAL ? (min(chunk_len, + sub_if_greater_or_zero(kv_len - q_len, chunk_start))) + : mask_offset ? 0 + : chunk_len) / + (BLOCK_SIZE); + + uint32_t k_smem_offset_r = smem_t::get_permuted_offset( + wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + + uint32_t v_smem_offset_r = smem_t::get_permuted_offset( + wid * num_frags_z * 16 + tid % 16, tid / 16); + uint32_t kv_smem_offset_w = smem_t::get_permuted_offset( + wid * 4 + tid / 8, tid % 8); + + uint32_t kv_idx = chunk_start; + int block_table_idx = kv_idx / BLOCK_SIZE; + int block_id = __ldg(&block_table_now[block_table_idx]); + int block_id_next = __ldg(&block_table_now[block_table_idx + 1]); + if (block_id_next < 0) { + block_id_next = 0; + } + const uint32_t const_offset = kv_head_idx * kv_h_stride + + (wid * 4 + tid / 8) * kv_b_stride + + tid % 8 * num_elems_per_128b(); + T* cache_k_now = cache_k + block_id * kv_n_stride + const_offset; + T* cache_v_now = cache_v + block_id * kv_n_stride + const_offset; + + produce_kv_blockwise(k_smem, + &kv_smem_offset_w, + &cache_k_now, + kv_b_stride, + kv_idx, + chunk_end); + commit_group(); + + produce_kv_blockwise(v_smem, + &kv_smem_offset_w, + &cache_v_now, + kv_b_stride, + kv_idx, + chunk_end); + commit_group(); +#pragma unroll 1 + for (uint32_t iter = 0; iter < num_iterations; ++iter) { + if (iter + 1 < num_iterations) { + block_id_next = __ldg(&block_table_now[block_table_idx + 1]); + if (block_id_next < 0) { + block_id_next = 0; + } + } + + wait_group<1>(); + __syncthreads(); + + compute_qk( + &qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); + + if (iter >= mask_check_iteration || params.sliding_window > 0) { + mask_s(attn_mask_this_seq, + q_base_seq_id_this_block, + kv_idx + wid * num_frags_z * 16, + q_len, + kv_len, + chunk_end, + params.attn_mask_len, + s_frag, + mask_offset_this_seq, + params.sliding_window); + } + + update_mdo_states( + s_frag, o_frag, m_frag, d_frag); + __syncthreads(); + + kv_idx += BLOCK_SIZE; + block_table_idx++; + + block_id = block_id_next; + cache_k_now = cache_k + block_id * kv_n_stride + const_offset; + produce_kv_blockwise(k_smem, + &kv_smem_offset_w, + &cache_k_now, + kv_b_stride, + kv_idx, + chunk_end); + commit_group(); + wait_group<1>(); + __syncthreads(); + + compute_sfm_v( + &v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag); + __syncthreads(); + + cache_v_now = cache_v + block_id * kv_n_stride + const_offset; + produce_kv_blockwise(v_smem, + &kv_smem_offset_w, + &cache_v_now, + kv_b_stride, + kv_idx, + chunk_end); + commit_group(); + } + wait_group<0>(); + __syncthreads(); + const bool do_normalize = (num_chunks_this_seq <= 1); + merge_block_res(o_frag, + reinterpret_cast(smem), + m_frag, + d_frag, + wid, + tid, + do_normalize); + + write_o_reg_gmem_multi_warps( + o_frag, + &qo_smem, + o_base_ptr_T, + q_base_seq_id_this_block, + q_head_idx, + q_len, + q_n_stride * params.max_num_chunks, + HEAD_DIM); + + if (num_chunks_this_seq > 1) { + if (wid == 0) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + const uint32_t qo_idx_now = + q_base_seq_id_this_block + tid / 4 + j * 8 + fx * 16; + const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; + const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; + if (qo_idx - q_start_seq_id < q_len) { + uint32_t offset; + offset = ((batch_idx * params.max_tokens_per_batch + + qo_idx_now / GROUP_SIZE) * + params.max_num_chunks + + chunk_idx) * + q_num_heads + + qo_head_idx; + tmp_m[offset] = m_frag[fx][j]; + tmp_d[offset] = d_frag[fx][j]; + } + } + } + } + } + } +} + +template +void DecodeUnifiedC16Attention( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, + const paddle::Tensor& cache_k, + const paddle::Tensor& cache_v, + const paddle::Tensor& tmp_workspace, + const paddle::Tensor& tmp_m, + const paddle::Tensor& tmp_d, + const paddle::optional& attn_mask, + const paddle::optional& sinks, + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& block_indices, + const paddle::Tensor& num_blocks, + const paddle::Tensor& chunk_size, + const int max_seq_len, + const int max_dec_len, + const int max_tokens_per_batch, + cudaStream_t& stream, + paddle::Tensor* out, + const int sliding_window) { + using NV_TYPE = typename type_traits::nv_type; + + auto num_heads = meta_data.q_num_heads; + auto kv_num_heads = meta_data.kv_num_heads; + auto token_num = meta_data.token_num; + auto bsz = meta_data.batch_size; + auto max_blocks_per_seq = meta_data.max_blocks_per_seq; + + constexpr uint32_t NUM_WARP_Q = 1; + constexpr uint32_t NUM_WARP_KV = NUM_WARPS_PER_BLOCK / NUM_WARP_Q; + constexpr uint32_t num_frags_x = Q_TILE_SIZE / (16 * NUM_WARP_Q); + constexpr uint32_t num_frags_y = HEAD_DIM / 16; + + constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV; + constexpr uint32_t smem_size_0 = + (num_frags_x + NUM_WARP_KV * num_frags_z * 2) * 16 * HEAD_DIM * + sizeof(NV_TYPE); + constexpr uint32_t smem_size_1 = + NUM_WARPS_PER_BLOCK * num_frags_x * num_frags_y * 33 * 8 * sizeof(float) + + NUM_WARPS_PER_BLOCK * num_frags_x * 2 * 33 * 8; + constexpr uint32_t smem_size = + smem_size_0 > smem_size_1 ? smem_size_0 : smem_size_1; + + auto split_kv_kernel = + decode_unified_attention_c16_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(split_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + const int dev_id = 0; + int sm_count; + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + + const int max_num_chunks = div_up(max_seq_len, 512); + uint32_t attn_mask_len; + if (attn_mask) { + attn_mask_len = attn_mask.get().shape()[1]; + } else { + attn_mask_len = -1; + } + + AttentionParams params; + memset(¶ms, 0, sizeof(AttentionParams)); + + params.qkv = reinterpret_cast(const_cast(qkv.data())); + params.cache_k = + reinterpret_cast(const_cast(cache_k.data())); + params.cache_v = + reinterpret_cast(const_cast(cache_v.data())); + params.seq_lens_q = const_cast(seq_lens_q.data()); + params.seq_lens_kv = const_cast(seq_lens_kv.data()); + params.block_indices = const_cast(block_indices.data()); + params.num_blocks_ptr = const_cast(num_blocks.data()); + params.chunk_size_ptr = const_cast(chunk_size.data()); + params.cu_seqlens_q = const_cast(cu_seqlens_q.data()); + params.block_table = const_cast(block_table.data()); + params.mask_offset = const_cast(meta_data.mask_offset); + params.attn_mask = + attn_mask ? const_cast(attn_mask.get().data()) : nullptr; + params.max_model_len = max_dec_len; + params.max_kv_len = max_dec_len; + params.max_blocks_per_seq = max_blocks_per_seq; + params.softmax_scale = 1.f / sqrt(HEAD_DIM); + params.tmp_o = + reinterpret_cast(const_cast(tmp_workspace.data())); + params.tmp_m = const_cast(tmp_m.data()); + params.tmp_d = const_cast(tmp_d.data()); + params.max_tokens_per_batch = max_tokens_per_batch; + params.attn_mask_len = + attn_mask ? attn_mask_len = attn_mask.get().shape()[1] : -1; + params.sliding_window = sliding_window; + params.q_num_heads = num_heads; + params.kv_num_heads = kv_num_heads; + params.max_num_chunks = max_num_chunks; + params.batch_size = meta_data.batch_size; + + int device; + CUDA_CHECK(cudaGetDevice(&device)); + int sm_cout; + CUDA_CHECK( + cudaDeviceGetAttribute(&sm_cout, cudaDevAttrMultiProcessorCount, device)); + + dim3 grids(sm_cout * 6); + dim3 blocks(32, NUM_WARPS_PER_BLOCK); + + launchWithPdlWhenEnabled( + split_kv_kernel, grids, blocks, smem_size, stream, params); + + constexpr int vec_size = num_elems_per_128b(); + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(min(sm_count * 4, token_num), num_heads); + dim3 blocks_merge(blockx, blocky); + launchWithPdlWhenEnabled( + merge_chunks_kernel, + grids_merge, + blocks_merge, + 0, + stream, + params.tmp_o, + params.tmp_m, + params.tmp_d, + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + (NV_TYPE*)nullptr, + (NV_TYPE*)nullptr, + sinks ? reinterpret_cast(const_cast(sinks.get().data())) + : nullptr, + chunk_size.data(), + reinterpret_cast(out->data()), + 0.f, + 0.f, + -1, + max_seq_len, + max_num_chunks, + num_heads, + HEAD_DIM, + token_num, + max_tokens_per_batch); +} diff --git a/custom_ops/gpu_ops/decode_unified_attention/decode_unified_attention_c8_impl.cuh b/custom_ops/gpu_ops/decode_unified_attention/decode_unified_attention_c8_impl.cuh new file mode 100644 index 00000000000..00a20165555 --- /dev/null +++ b/custom_ops/gpu_ops/decode_unified_attention/decode_unified_attention_c8_impl.cuh @@ -0,0 +1,706 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include "utils.cuh" +// #include "cu_tensor_map.cuh" +#include "attention_func.cuh" + +template +void print_params(AttentionParams const params) { + printf("max_model_len: %d\n", params.max_model_len); + printf("max_kv_len: %d\n", params.max_kv_len); + printf("max_blocks_per_seq: %d\n", params.max_blocks_per_seq); + printf("softmax_scale: %f\n", params.softmax_scale); + printf("quant_max_bound: %f\n", params.quant_max_bound); + printf("quant_min_bound: %f\n", params.quant_min_bound); + printf("max_tokens_per_batch: %d\n", params.max_tokens_per_batch); + printf("attn_mask_len: %d\n", params.attn_mask_len); + printf("sliding_window: %d\n", params.sliding_window); + printf("q_num_heads: %d\n", params.q_num_heads); + printf("kv_num_heads: %d\n", params.kv_num_heads); + printf("max_num_chunks: %d\n", params.max_num_chunks); + printf("max_tile_q: %d\n", params.max_tile_q); + printf("batch_size: %d\n", params.batch_size); +} + +template +__global__ void decode_unified_attention_c8_kernel( + AttentionParams params) { + const uint32_t tid = threadIdx.x, wid = threadIdx.y; + + // Cache loop-invariant params fields into registers. + // Pass-by-value (no __grid_constant__) allows the compiler to cache + // struct fields, and explicit local variables guarantee no constant + // cache pressure in the grid-stride loop. + // Only cache frequently-used fields; rarely-used ones are accessed + // via params.xxx to reduce register pressure (Scheme I-A.2). + const auto qkv = params.qkv; + const auto cache_k = params.cache_k; + const auto cache_v = params.cache_v; + const auto cache_k_scale = params.cache_k_scale; + const auto cache_v_scale = params.cache_v_scale; + const auto seq_lens_q = params.seq_lens_q; + const auto seq_lens_kv = params.seq_lens_kv; + const auto block_table = params.block_table; + const auto cu_seqlens_q = params.cu_seqlens_q; + const auto block_indices = params.block_indices; + const auto mask_offset = params.mask_offset; + const auto attn_mask = params.attn_mask; + const auto tmp_o = params.tmp_o; + const auto tmp_m = params.tmp_m; + const auto tmp_d = params.tmp_d; + const float softmax_scale = params.softmax_scale; + const int q_num_heads = params.q_num_heads; + const int kv_num_heads = params.kv_num_heads; + + extern __shared__ __align__(128) uint8_t smem[]; + smem_t qo_smem(smem); + smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)), + v_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) + + NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT)); + smem_t k_scale_smem; + smem_t v_scale_smem; + T* k_smem_scale_ptr = nullptr; + T* v_smem_scale_ptr = nullptr; + + int total_block = params.num_blocks_ptr[0]; + int chunk_size = params.chunk_size_ptr[0]; + + for (int lane_idx = blockIdx.x; lane_idx < total_block; + lane_idx += gridDim.x) { + int4 indices = reinterpret_cast(block_indices)[lane_idx]; + int batch_idx = indices.x; + int kv_head_idx = indices.y; + int chunk_idx = indices.z; + int tile_idx = indices.w; + int q_head_idx = kv_head_idx * GROUP_SIZE; + + const uint32_t q_len = seq_lens_q[batch_idx]; + const int* block_table_now = + block_table + batch_idx * params.max_blocks_per_seq; + + T cache_k_scale_reg[IsDynamicC8 + ? num_frags_z * 2 + : (is_scale_channel_wise ? num_frags_y * 4 : 1)]; + T cache_v_scale_reg[IsDynamicC8 + ? num_frags_z * 4 + : (is_scale_channel_wise ? num_frags_y * 2 : 1)]; + if constexpr (!IsDynamicC8) { + if constexpr (is_scale_channel_wise) { + int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM; + const T* cache_k_scale_cur_head = cache_k_scale + scale_col_base; + for (int i = 0; i < num_frags_y; ++i) { + const int scale_idx = i * 16; + cache_k_scale_reg[i * 4] = cache_k_scale_cur_head[scale_idx]; + cache_k_scale_reg[i * 4 + 1] = cache_k_scale_cur_head[scale_idx + 1]; + cache_k_scale_reg[i * 4 + 2] = cache_k_scale_cur_head[scale_idx + 8]; + cache_k_scale_reg[i * 4 + 3] = cache_k_scale_cur_head[scale_idx + 9]; + } + scale_col_base = threadIdx.x / 4 + kv_head_idx * HEAD_DIM; + const T* cache_v_scale_cur_head = cache_v_scale + scale_col_base; + for (int i = 0; i < num_frags_y; ++i) { + const int scale_idx = i * 16; + cache_v_scale_reg[i * 2] = cache_v_scale_cur_head[scale_idx]; + cache_v_scale_reg[i * 2 + 1] = cache_v_scale_cur_head[scale_idx + 8]; + } + } else { + cache_k_scale_reg[0] = cache_k_scale[kv_head_idx]; + cache_v_scale_reg[0] = cache_v_scale[kv_head_idx]; + } + } + constexpr uint32_t num_rows_per_block = num_frags_x * 16; + const uint32_t q_end = + min(q_len, div_up((tile_idx + 1) * num_rows_per_block, GROUP_SIZE)); + const uint32_t kv_len = seq_lens_kv[batch_idx] + q_len; + const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size); + + constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); + constexpr uint32_t num_vecs_per_head_k = + HEAD_DIM / num_elems_per_128b(); + constexpr uint32_t num_vecs_per_blocksize = + BLOCK_SIZE / num_elems_per_128b(); + constexpr uint32_t inv_k_stride = 8 / num_vecs_per_head_k; + constexpr uint32_t inv_v_stride = 8 / num_vecs_per_blocksize; + + const uint32_t q_n_stride = q_num_heads * HEAD_DIM; + const uint32_t q_ori_n_stride = (q_num_heads + kv_num_heads * 2) * HEAD_DIM; + const uint32_t kv_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_h_stride = BLOCK_SIZE * HEAD_DIM; + const uint32_t kv_b_stride = HEAD_DIM; + const uint32_t kv_d_stride = BLOCK_SIZE; + + float s_frag[num_frags_x][num_frags_z][8]; + float o_frag[num_frags_x][num_frags_y][8]; + float m_frag[num_frags_x][2]; + float d_frag[num_frags_x][2]; + + T* o_base_ptr_T = nullptr; + + const uint32_t chunk_start = chunk_idx * chunk_size; + const uint32_t chunk_end = min(kv_len, chunk_start + chunk_size); + const uint32_t chunk_len = chunk_end - chunk_start; + + init_states(o_frag, m_frag, d_frag); + + const uint32_t q_start_seq_id = cu_seqlens_q[batch_idx]; + const uint32_t q_base_seq_id_this_block = tile_idx * num_frags_x * 16; + const uint32_t q_offset = q_start_seq_id * q_ori_n_stride + + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + T* q_base_ptr = qkv + q_offset; + + o_base_ptr_T = tmp_o + + batch_idx * params.max_tokens_per_batch * + params.max_num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); + const int* mask_offset_this_seq = + mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; + const bool* attn_mask_this_seq = + attn_mask ? attn_mask + + batch_idx * params.attn_mask_len * params.attn_mask_len + : nullptr; + + uint32_t q_smem_offset_r = + smem_t::get_permuted_offset(tid % 16, tid / 16); + load_q_global_smem_multi_warps(q_base_ptr, + &qo_smem, + q_base_seq_id_this_block, + q_end, + q_ori_n_stride, + HEAD_DIM); + commit_group(); + wait_group<0>(); + __syncthreads(); + + q_smem_inplace_multiply_sm_scale_multi_warps( + &qo_smem, softmax_scale); + + if constexpr (IsDynamicC8) { + k_smem_scale_ptr = reinterpret_cast( + smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) + + NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2); + v_smem_scale_ptr = k_smem_scale_ptr + NUM_WARP_KV * num_frags_z * 16; + k_scale_smem.base = reinterpret_cast(k_smem_scale_ptr); + v_scale_smem.base = reinterpret_cast(v_smem_scale_ptr); + } + + const uint32_t num_iterations = + div_up(CAUSAL ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + div_up((tile_idx + 1) * num_rows_per_block, + GROUP_SIZE), + chunk_start))) + : chunk_len, + NUM_WARP_KV * num_frags_z * 16); + const uint32_t mask_check_iteration = + (CAUSAL ? (min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + tile_idx * num_rows_per_block / GROUP_SIZE, + chunk_start))) + : mask_offset ? 0 + : chunk_len) / + (NUM_WARP_KV * num_frags_z * 16); + + uint32_t k_smem_offset_r = + smem_t::get_permuted_offset( + wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8); + + uint32_t v_smem_offset_r = + smem_t::get_permuted_offset( + (wid / 2) * num_frags_y * 16 + 8 * (tid / 16) + tid % 8, + (wid % 2) * num_frags_z + (tid % 16) / 8); + + uint32_t k_smem_offset_w = + smem_t::get_permuted_offset( + wid * 4 + tid / 8, tid % 8); + uint32_t v_smem_offset_w = + smem_t::get_permuted_offset( + wid * 8 + tid / 4, tid % 4); + + uint32_t kv_idx_base = chunk_start; + const uint32_t const_k_offset = kv_head_idx * kv_h_stride + + (wid * 4 + tid / 8) * kv_b_stride + + tid % 8 * num_elems_per_128b(); + const uint32_t const_v_offset = kv_head_idx * kv_h_stride + + (wid * 8 + tid / 4) * kv_d_stride + + tid % 4 * num_elems_per_128b(); + + produce_k_blockwise_c8(k_smem, + &k_smem_offset_w, + cache_k, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end, + const_k_offset); + + if constexpr (IsDynamicC8) { + produce_kv_dynamic_scale_gmem2smem_async(k_scale_smem, + block_table_now, + cache_k_scale, + kv_idx_base, + kv_num_heads, + kv_head_idx, + chunk_end); + } + commit_group(); + + produce_v_blockwise_c8(v_smem, + &v_smem_offset_w, + cache_v, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_d_stride, + kv_idx_base, + chunk_end, + const_v_offset); + + if constexpr (IsDynamicC8) { + produce_kv_dynamic_scale_gmem2smem_async(v_scale_smem, + block_table_now, + cache_v_scale, + kv_idx_base, + kv_num_heads, + kv_head_idx, + chunk_end); + } + commit_group(); +#pragma unroll 1 + for (uint32_t iter = 0; iter < num_iterations; ++iter) { + wait_group<1>(); + __syncthreads(); + + if constexpr (IsDynamicC8) { + produce_k_dynamic_scale_smem2reg(k_smem_scale_ptr, + cache_k_scale_reg); + } + + compute_qk_c8(&qo_smem, + &q_smem_offset_r, + &k_smem, + &k_smem_offset_r, + cache_k_scale_reg, + s_frag); + + if (iter >= mask_check_iteration || params.sliding_window > 0) { + mask_s(attn_mask_this_seq, + q_base_seq_id_this_block, + kv_idx_base + wid * num_frags_z * 16, + q_len, + kv_len, + chunk_end, + params.attn_mask_len, + s_frag, + mask_offset_this_seq, + params.sliding_window); + } + + update_mdo_states( + s_frag, o_frag, m_frag, d_frag); + __syncthreads(); + + kv_idx_base += NUM_WARP_KV * num_frags_z * 16; + produce_k_blockwise_c8(k_smem, + &k_smem_offset_w, + cache_k, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_b_stride, + kv_idx_base, + chunk_end, + const_k_offset); + + if constexpr (IsDynamicC8) { + produce_kv_dynamic_scale_gmem2smem_async(k_scale_smem, + block_table_now, + cache_k_scale, + kv_idx_base, + kv_num_heads, + kv_head_idx, + chunk_end); + } + commit_group(); + wait_group<1>(); + __syncthreads(); + + if constexpr (IsDynamicC8) { + produce_v_dynamic_scale_smem2reg(v_smem_scale_ptr, + cache_v_scale_reg); + } + + compute_sfm_v_c8_iter_sq_bvec( + &v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag, cache_v_scale_reg); + __syncthreads(); + + produce_v_blockwise_c8(v_smem, + &v_smem_offset_w, + cache_v, + block_table_now, + kv_head_idx, + kv_n_stride, + kv_h_stride, + kv_d_stride, + kv_idx_base, + chunk_end, + const_v_offset); + + if constexpr (IsDynamicC8) { + produce_kv_dynamic_scale_gmem2smem_async(v_scale_smem, + block_table_now, + cache_v_scale, + kv_idx_base, + kv_num_heads, + kv_head_idx, + chunk_end); + } + commit_group(); + } + wait_group<0>(); + __syncthreads(); + const bool do_normalize = (num_chunks_this_seq <= 1); + merge_block_res(o_frag, + reinterpret_cast(smem), + m_frag, + d_frag, + wid, + tid, + do_normalize); + + write_o_reg_gmem_multi_warps( + o_frag, + &qo_smem, + o_base_ptr_T, + q_base_seq_id_this_block, + q_head_idx, + q_len, + q_n_stride * params.max_num_chunks, + HEAD_DIM); + + if (num_chunks_this_seq > 1) { + if (wid == 0) { +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + const uint32_t qo_idx_now = + q_base_seq_id_this_block + tid / 4 + j * 8 + fx * 16; + const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; + const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; + if (qo_idx - q_start_seq_id < q_len) { + uint32_t offset; + offset = ((batch_idx * params.max_tokens_per_batch + + qo_idx_now / GROUP_SIZE) * + params.max_num_chunks + + chunk_idx) * + q_num_heads + + qo_head_idx; + tmp_m[offset] = m_frag[fx][j]; + tmp_d[offset] = d_frag[fx][j]; + } + } + } + } + } + } +} + +template +void DecodeUnifiedC8Attention(const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, + const paddle::Tensor& cache_k, + const paddle::Tensor& cache_v, + const paddle::Tensor& tmp_workspace, + const paddle::Tensor& tmp_m, + const paddle::Tensor& tmp_d, + const paddle::optional& attn_mask, + const paddle::Tensor& cache_k_scale, + const paddle::Tensor& cache_v_scale, + const paddle::optional& sinks, + const paddle::Tensor& seq_lens_q, + const paddle::Tensor& seq_lens_kv, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_table, + const paddle::Tensor& block_indices, + const paddle::Tensor& num_blocks, + const paddle::Tensor& chunk_size, + const int max_seq_len, + const int max_dec_len, + const float quant_max_bound, + const float quant_min_bound, + const int max_tokens_per_batch, + cudaStream_t& stream, + paddle::Tensor* out, + const int sliding_window) { + using NV_TYPE = typename type_traits::nv_type; + + auto num_heads = meta_data.q_num_heads; + auto kv_num_heads = meta_data.kv_num_heads; + auto token_num = meta_data.token_num; + auto bsz = meta_data.batch_size; + auto max_blocks_per_seq = meta_data.max_blocks_per_seq; + + constexpr uint32_t NUM_WARP_Q = 1; + constexpr uint32_t NUM_WARP_KV = NUM_WARPS_PER_BLOCK / NUM_WARP_Q; + constexpr uint32_t num_frags_x = Q_TILE_SIZE / (16 * NUM_WARP_Q); + constexpr uint32_t num_frags_y = HEAD_DIM / 16; + + auto* allocator = paddle::GetAllocator(qkv.place()); + + bool is_scale_channel_wise = false; + if (cache_k_scale.dims()[0] == HEAD_DIM * kv_num_heads) { + is_scale_channel_wise = true; + } + + constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV * 2; + constexpr uint32_t smem_size_0 = + num_frags_x * 16 * HEAD_DIM * sizeof(T) + + NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2 + + NUM_WARP_KV * num_frags_z * 16 * sizeof(T) * 2; + constexpr uint32_t smem_size_1 = + NUM_WARPS_PER_BLOCK * num_frags_x * num_frags_y * 33 * 8 * sizeof(float) + + NUM_WARPS_PER_BLOCK * num_frags_x * 2 * 33 * 8; + constexpr uint32_t smem_size = + smem_size_0 > smem_size_1 ? smem_size_0 : smem_size_1; + + auto split_kv_kernel = decode_unified_attention_c8_kernel; + if (is_scale_channel_wise) { + split_kv_kernel = decode_unified_attention_c8_kernel; + } + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(split_kv_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + } + const int dev_id = 0; + int sm_count; + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); + + const int max_num_chunks = div_up(max_seq_len, 512); + uint32_t attn_mask_len; + if (attn_mask) { + attn_mask_len = attn_mask.get().shape()[1]; + } else { + attn_mask_len = -1; + } + + AttentionParams params; + memset(¶ms, 0, sizeof(AttentionParams)); + + params.qkv = reinterpret_cast(const_cast(qkv.data())); + params.cache_k = const_cast(cache_k.data()); + params.cache_v = const_cast(cache_v.data()); + params.cache_k_scale = + reinterpret_cast(const_cast(cache_k_scale.data())); + params.cache_v_scale = + reinterpret_cast(const_cast(cache_v_scale.data())); + params.seq_lens_q = const_cast(seq_lens_q.data()); + params.seq_lens_kv = const_cast(seq_lens_kv.data()); + params.block_indices = const_cast(block_indices.data()); + params.num_blocks_ptr = const_cast(num_blocks.data()); + params.chunk_size_ptr = const_cast(chunk_size.data()); + params.cu_seqlens_q = const_cast(cu_seqlens_q.data()); + params.block_table = const_cast(block_table.data()); + params.mask_offset = const_cast(meta_data.mask_offset); + params.attn_mask = + attn_mask ? const_cast(attn_mask.get().data()) : nullptr; + params.max_model_len = max_dec_len; + params.max_kv_len = max_dec_len; + params.max_blocks_per_seq = max_blocks_per_seq; + params.softmax_scale = 1.f / sqrt(HEAD_DIM); + params.quant_max_bound = quant_max_bound; + params.quant_min_bound = quant_min_bound; + params.tmp_o = + reinterpret_cast(const_cast(tmp_workspace.data())); + params.tmp_m = const_cast(tmp_m.data()); + params.tmp_d = const_cast(tmp_d.data()); + params.max_tokens_per_batch = max_tokens_per_batch; + params.attn_mask_len = + attn_mask ? attn_mask_len = attn_mask.get().shape()[1] : -1; + params.sliding_window = sliding_window; + params.q_num_heads = num_heads; + params.kv_num_heads = kv_num_heads; + params.max_num_chunks = max_num_chunks; + params.batch_size = meta_data.batch_size; + + int device; + CUDA_CHECK(cudaGetDevice(&device)); + int sm_cout; + CUDA_CHECK( + cudaDeviceGetAttribute(&sm_cout, cudaDevAttrMultiProcessorCount, device)); + + dim3 grids(sm_cout * 6); + dim3 blocks(32, NUM_WARPS_PER_BLOCK); + + launchWithPdlWhenEnabled( + split_kv_kernel, grids, blocks, smem_size, stream, params); + + constexpr int vec_size = num_elems_per_128b(); + constexpr int blockx = HEAD_DIM / vec_size; + constexpr int blocky = (128 + blockx - 1) / blockx; + dim3 grids_merge(min(sm_count * 4, token_num), num_heads); + dim3 blocks_merge(blockx, blocky); + launchWithPdlWhenEnabled( + merge_chunks_kernel, + grids_merge, + blocks_merge, + 0, + stream, + params.tmp_o, + params.tmp_m, + params.tmp_d, + seq_lens_q.data(), + seq_lens_kv.data(), + seq_lens_encoder.data(), + batch_id_per_token.data(), + cu_seqlens_q.data(), + (NV_TYPE*)nullptr, + (NV_TYPE*)nullptr, + sinks ? reinterpret_cast(const_cast(sinks.get().data())) + : nullptr, + chunk_size.data(), + reinterpret_cast(out->data()), + quant_max_bound, + quant_min_bound, + -1, + max_seq_len, + max_num_chunks, + num_heads, + HEAD_DIM, + token_num, + max_tokens_per_batch); +} diff --git a/custom_ops/gpu_ops/decode_unified_attention/mem_util.cuh b/custom_ops/gpu_ops/decode_unified_attention/mem_util.cuh new file mode 100644 index 00000000000..18788858923 --- /dev/null +++ b/custom_ops/gpu_ops/decode_unified_attention/mem_util.cuh @@ -0,0 +1,389 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include + +enum class SharedMemFillMode { kFillZero, kNoFill }; + +enum class PrefetchMode { kNoPrefetch, kPrefetch }; + +template +__device__ __forceinline__ void ldmatrix_m8n8x4_impl(uint32_t* R, T* smem_ptr) { + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(R[0]), "=r"(R[1]), "=r"(R[2]), "=r"(R[3]) + : "r"(smem_int_ptr)); +} + +template +__device__ __forceinline__ void ldmatrix_m8n8x4_trans_impl(uint32_t* R, + T* smem_ptr) { + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "ldmatrix.sync.aligned.trans.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(R[0]), "=r"(R[1]), "=r"(R[2]), "=r"(R[3]) + : "r"(smem_int_ptr)); +} + +__device__ __forceinline__ void commit_group() { +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + {} +#else + asm volatile("cp.async.commit_group;\n" ::); +#endif +} + +template +__device__ __forceinline__ void wait_group() { +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + cooperative_groups::wait(cooperative_groups::this_thread_block()); +#else + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +#endif +} + +template +__device__ __forceinline__ void load_128b(T* smem_ptr, const T* gmem_ptr) { + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16); + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 16); + } else { + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16); + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 16); + } +#else + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + asm volatile( + "cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"( + smem_int_ptr), + "l"(gmem_ptr), + "n"(16), + "r"(16)); + } else { + asm volatile( + "cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(16), + "r"(16)); + } +#endif +} + +template +__device__ __forceinline__ void pred_load_128b(T* smem_ptr, + const T* gmem_ptr, + bool predicate) { + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 16 : 0; + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16); + memcpy(__cvta_shared_to_generic(smem_int_ptr), + (void*)gmem_ptr, + src_in_bytes); + } else { + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 16); + memcpy(__cvta_shared_to_generic(smem_int_ptr), + (void*)gmem_ptr, + src_in_bytes); + } + } else { + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + if (predicate) { + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 16); + } + } else { + if (predicate) { + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 16); + } + } + } +#else + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 16 : 0; + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + asm volatile( + "cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"( + smem_int_ptr), + "l"(gmem_ptr), + "n"(16), + "r"(src_in_bytes)); + } else { + asm volatile( + "cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(16), + "r"(src_in_bytes)); + } + } else { + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global.L2::128B [%1], [%2], %3;\n" + "}\n" ::"r"((int)predicate), + "r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(16)); + } else { + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)predicate), + "r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(16)); + } + } +#endif +} + +template +__device__ __forceinline__ void pred_load_64b(T* smem_ptr, + const T* gmem_ptr, + bool predicate) { + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 8 : 0; + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 8); + memcpy( + __cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, src_in_bytes); + } else { + if (predicate) { + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 8); + } + } +#else + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 8 : 0; + asm volatile( + "cp.async.ca.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(8), + "r"(src_in_bytes)); + } else { + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)predicate), + "r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(8)); + } +#endif +} + +template +__device__ __forceinline__ void pred_load_32b(T* smem_ptr, + const T* gmem_ptr, + bool predicate) { + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); +#ifdef PADDLE_WITH_CUSTOM_DEVICE_METAX_GPU + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 4 : 0; + memset(__cvta_shared_to_generic(smem_int_ptr), 0, 4); + memcpy( + __cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, src_in_bytes); + } else { + if (predicate) { + memcpy(__cvta_shared_to_generic(smem_int_ptr), (void*)gmem_ptr, 4); + } + } +#else + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 4 : 0; + asm volatile( + "cp.async.ca.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(4), + "r"(src_in_bytes)); + } else { + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)predicate), + "r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(4)); + } +#endif +} + +template +__device__ __forceinline__ void load(T* smem_ptr, const T* gmem_ptr) { + static_assert(num_bits == 128, "num_bits must be 128"); + load_128b(smem_ptr, gmem_ptr); +} + +template +__device__ __forceinline__ void pred_load(T* smem_ptr, + const T* gmem_ptr, + bool predicate) { + static_assert(num_bits == 128 || num_bits == 64 || num_bits == 32, + "num_bits must be 128, 64 or 32."); + if constexpr (num_bits == 128) { + pred_load_128b(smem_ptr, gmem_ptr, predicate); + } else if constexpr (num_bits == 64) { + pred_load_64b(smem_ptr, gmem_ptr, predicate); + } else if constexpr (num_bits == 32) { + pred_load_32b(smem_ptr, gmem_ptr, predicate); + } +} + +using b32_t = uint32_t; +using b64_t = uint2; +using b128_t = uint4; + +template +constexpr __host__ __device__ __forceinline__ uint32_t num_elems_per_128b() { + return sizeof(b128_t) / sizeof(T); +} + +struct smem_t { + // The base pointer. + b128_t* base; + __device__ __forceinline__ smem_t() : base(nullptr) {} + template + __device__ __forceinline__ smem_t(T* base) : base((b128_t*)base) {} + + template + static __device__ __forceinline__ uint32_t get_permuted_offset(uint32_t i, + uint32_t j) { + if constexpr (inv_stride <= 1) { + return i * stride + (j ^ (i % 8)); + } else { + return i / inv_stride * 8 + ((j + (i % inv_stride) * stride)) ^ + ((i / inv_stride) % 8); + } + } + + template + static __device__ __forceinline__ uint32_t + advance_offset_by_column(uint32_t offset, uint32_t step_idx) { + if constexpr (row_stride == 2) { + static_assert(step_size == 2, "Unsupported step size"); + return offset + step_size; + } else if constexpr (row_stride == 4) { + static_assert(step_size == 2 || step_size == 4, "Unsupported step size"); + if constexpr (step_size == 2) { + return (offset ^ 0x2) + (step_idx % 2 == 1) * 4; + } else { + return offset + step_size; + } + } else { + static_assert(step_size == 2 || step_size == 4 || step_size % 8 == 0, + "Unsupported step size"); + if constexpr (step_size == 2) { + return (offset ^ (0x2 + (0x4 * (step_idx % 2 == 1)))) + + (step_idx % 4 == 3) * 8; + } else if constexpr (step_size == 4) { + return (offset ^ 0x4) + (step_idx % 2 == 1) * 8; + } else { + // step_size % 8 == 0 + return offset + step_size; + } + } + } + + template + static __device__ __forceinline__ uint32_t + advance_offset_by_row(uint32_t offset) { + if constexpr (row_stride == 2) { + static_assert(step_size == 16 || step_size % 32 == 0, + "Unsupported step size"); + if constexpr (step_size == 16) { + return (offset ^ 0x4) + step_size * row_stride; + } else { + // step_size % 32 == 0 + return offset + step_size * row_stride; + } + } else if constexpr (row_stride == 4) { + static_assert(step_size == 8 || step_size % 16 == 0, + "Unsupported step size"); + if constexpr (step_size == 8) { + return (offset ^ 0x4) + step_size * row_stride; + } else { + // step_size % 16 == 0 + return offset + step_size * row_stride; + } + } else { + static_assert(step_size == 4 || step_size % 8 == 0, + "Unsupported step size"); + if constexpr (step_size == 4) { + return (offset ^ 0x4) + step_size * row_stride; + } else { + // step_size % 8 == 0 + return offset + step_size * row_stride; + } + } + } + + __device__ __forceinline__ void ldmatrix_m8n8x4(uint32_t offset, + uint32_t* R) { + b128_t* smem_ptr = base + offset; + ldmatrix_m8n8x4_impl(R, smem_ptr); + } + + __device__ __forceinline__ void ldmatrix_m8n8x4_trans(uint32_t offset, + uint32_t* R) { + b128_t* smem_ptr = base + offset; + ldmatrix_m8n8x4_trans_impl(R, smem_ptr); + } + + template + __device__ __forceinline__ void load_128b_async(uint32_t offset, + const T* gptr, + bool predicate) { + b128_t* smem_ptr = base + offset; + pred_load_128b( + smem_ptr, reinterpret_cast(gptr), predicate); + } + + template + __device__ __forceinline__ void load_128b_async(uint32_t offset, + const T* gptr) { + b128_t* smem_ptr = base + offset; + load_128b(smem_ptr, + reinterpret_cast(gptr)); + } + + template + __device__ __forceinline__ void store_128b(uint32_t offset, T* gptr) { + *reinterpret_cast(gptr) = *(base + offset); + } +}; diff --git a/custom_ops/gpu_ops/decode_unified_attention/mma_tensor_op.cuh b/custom_ops/gpu_ops/decode_unified_attention/mma_tensor_op.cuh new file mode 100644 index 00000000000..8662ee298d2 --- /dev/null +++ b/custom_ops/gpu_ops/decode_unified_attention/mma_tensor_op.cuh @@ -0,0 +1,296 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include + +enum class MMAMode { + kInit = 0U, + kInplaceUpdate = 1U, +}; + +template +__device__ __forceinline__ void mma_sync_m16n16k32_row_col_i8i8i32( + int* C, // 8 + uint32_t* A, // 4 + uint32_t* B) { // 4 + if constexpr (mma_mode == MMAMode::kInit) { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[0]), "=r"(C[1]), "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[0]), + "r"(B[1]), + "r"(0), + "r"(0), + "r"(0), + "r"(0)); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[4]), "=r"(C[5]), "=r"(C[6]), "=r"(C[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[2]), + "r"(B[3]), + "r"(0), + "r"(0), + "r"(0), + "r"(0)); + } else { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[0]), "=r"(C[1]), "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[0]), + "r"(B[1]), + "r"(C[0]), + "r"(C[1]), + "r"(C[2]), + "r"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[4]), "=r"(C[5]), "=r"(C[6]), "=r"(C[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[2]), + "r"(B[3]), + "r"(C[4]), + "r"(C[5]), + "r"(C[6]), + "r"(C[7])); + } +} + +template +__device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32( + float* C, uint32_t* A, uint32_t* B) { + if constexpr (mma_mode == MMAMode::kInit) { + if constexpr (std::is_same::value) { // fp16 + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[0]), + "r"(B[1]), + "f"(0.f), + "f"(0.f), + "f"(0.f), + "f"(0.f)); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[2]), + "r"(B[3]), + "f"(0.f), + "f"(0.f), + "f"(0.f), + "f"(0.f)); + } else { // bf16 + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[0]), + "r"(B[1]), + "f"(0.f), + "f"(0.f), + "f"(0.f), + "f"(0.f)); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[2]), + "r"(B[3]), + "f"(0.f), + "f"(0.f), + "f"(0.f), + "f"(0.f)); + } + } else { + if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[0]), + "r"(B[1]), + "f"(C[0]), + "f"(C[1]), + "f"(C[2]), + "f"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[2]), + "r"(B[3]), + "f"(C[4]), + "f"(C[5]), + "f"(C[6]), + "f"(C[7])); + } else { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[0]), + "r"(B[1]), + "f"(C[0]), + "f"(C[1]), + "f"(C[2]), + "f"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), + "r"(A[1]), + "r"(A[2]), + "r"(A[3]), + "r"(B[2]), + "r"(B[3]), + "f"(C[4]), + "f"(C[5]), + "f"(C[6]), + "f"(C[7])); + } + } +} + +template +__device__ __forceinline__ void rowsum_f16f16f32(float* d, DType* s) { + static_assert(sizeof(DType) == 2, "DType must be 16bit floating data type"); + uint32_t* s_u32 = (uint32_t*)(s); + if constexpr (std::is_same::value) { + asm volatile( + "{\n" + ".reg .f32 ph;\n" + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, ph, %1, ph}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, 0., %9, 0.};\n" + "}\n" + : "=f"(d[0]), "=f"(d[1]) + : "r"(s_u32[0]), + "r"(s_u32[1]), + "r"(s_u32[2]), + "r"(s_u32[3]), + "r"(1006648320), + "r"(1006648320), + "f"(d[0]), + "f"(d[1])); + } else { + asm volatile( + "{\n" + ".reg .f32 ph;\n" + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, ph, %1, ph}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, 0., %9, 0.};\n" + "}\n" + : "=f"(d[0]), "=f"(d[1]) + : "r"(s_u32[0]), + "r"(s_u32[1]), + "r"(s_u32[2]), + "r"(s_u32[3]), + "r"(1065369472), + "r"(1065369472), + "f"(d[0]), + "f"(d[1])); + } +} diff --git a/custom_ops/gpu_ops/decode_unified_attention/template_config.json b/custom_ops/gpu_ops/decode_unified_attention/template_config.json new file mode 100644 index 00000000000..d768c93a1ad --- /dev/null +++ b/custom_ops/gpu_ops/decode_unified_attention/template_config.json @@ -0,0 +1,78 @@ +{ + "multiquery_attention_c8": { + "name": "decode_unified_attention_c8_kernel", + "function_name": "decode_unified_attention_c8_kernel", + "impl_file": "decode_unified_attention_c8_impl.cuh", + "template_params": [ + "T", + "CacheT", + "GROUP_SIZE", + "CAUSAL", + "NUM_WARPS", + "NUM_WARP_Q", + "NUM_WARP_KV", + "HEAD_DIM", + "BLOCK_SIZE", + "num_frags_x", + "num_frags_y", + "num_frags_z", + "is_scale_channel_wise", + "IsFP8", + "IsDynamicC8" + ], + "dispatch_params": { + "T": ["half", "__nv_bfloat16"], + "CacheT": ["uint8_t"], + "GROUP_SIZE": [1, 2, 4, 5, 6, 7, 8, 12, 14, 16], + "CAUSAL": [0, 1], + "NUM_WARPS": [4], + "NUM_WARP_Q": [1], + "NUM_WARP_KV": [4], + "HEAD_DIM": [128], + "BLOCK_SIZE": [64], + "num_frags_x": [1, 2], + "num_frags_y": [8], + "num_frags_z": [1], + "is_scale_channel_wise": [0, 1], + "IsFP8": [0, 1], + "IsDynamicC8": [0, 1] + }, + "max_instances_per_file": 80, + "file_prefix": "decode_unified_attention_c8", + "function_signature": "template __global__ void {function_name}{template_args}(AttentionParams{params_template_args} params);\n\n" + }, + "multiquery_attention_c16": { + "name": "decode_unified_attention_c16_kernel", + "function_name": "decode_unified_attention_c16_kernel", + "impl_file": "decode_unified_attention_c16_impl.cuh", + "template_params": [ + "T", + "GROUP_SIZE", + "CAUSAL", + "NUM_WARPS", + "NUM_WARP_Q", + "NUM_WARP_KV", + "HEAD_DIM", + "BLOCK_SIZE", + "num_frags_x", + "num_frags_z", + "num_frags_y" + ], + "dispatch_params": { + "T": ["half", "__nv_bfloat16"], + "GROUP_SIZE": [1, 2, 4, 5, 6, 7, 8, 12, 14, 16], + "CAUSAL": [0, 1], + "NUM_WARPS": [4], + "NUM_WARP_Q": [1], + "NUM_WARP_KV": [4], + "HEAD_DIM": [128], + "BLOCK_SIZE": [64], + "num_frags_x": [1, 2], + "num_frags_z": [1], + "num_frags_y": [8] + }, + "max_instances_per_file": 80, + "file_prefix": "decode_unified_attention_c16", + "function_signature": "template __global__ void {function_name}{template_args}(AttentionParams{params_template_args} params);\n\n" + } +} diff --git a/custom_ops/gpu_ops/decode_unified_attention/utils.cuh b/custom_ops/gpu_ops/decode_unified_attention/utils.cuh new file mode 100644 index 00000000000..c8c6a06ba86 --- /dev/null +++ b/custom_ops/gpu_ops/decode_unified_attention/utils.cuh @@ -0,0 +1,690 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include +#include +#include +#include +#include "helper.h" +#include "mem_util.cuh" + +#define NUM_WARPS_PER_BLOCK 4 +#define NUM_THREADS_PER_BLOCK 128 +#define kWarpSize 32 + +#define HOSTDEVICE __host__ __device__ + +/*-------------------------------------traits-----------------------------------------*/ +template +struct type_traits { + using paddle_type = T; + using phi_type = T; + using nv_type = T; + using nv2_type = T; +}; + +// template <> +// struct type_traits { +// using paddle_type = paddle::DataType::FLOAT16; +// using phi_type = phi::dtype::float16; +// using nv_type = half; +// using nv2_type = half2; +// }; + +template <> +struct type_traits { + // using paddle_type = paddle::DataType::FLOAT16; + using phi_type = phi::dtype::float16; + using nv_type = half; + using nv2_type = half2; +}; + +template <> +struct type_traits { + // using paddle_type = paddle::DataType::FLOAT16; + using phi_type = phi::dtype::float16; + using nv_type = half; + using nv2_type = half2; +}; + +template <> +struct type_traits { + // using paddle_type = paddle::DataType::FLOAT16; + using phi_type = phi::dtype::float16; + using nv_type = half; + using nv2_type = half2; +}; + +// template <> +// struct type_traits { +// using paddle_type = paddle::DataType::FLOAT16; +// using phi_type = phi::dtype::bfloat16; +// using nv_type = __nv_bfloat16; +// using nv2_type = __nv_bfloat162; +// }; + +template <> +struct type_traits { + // using paddle_type = paddle::DataType::FLOAT16; + using phi_type = phi::dtype::bfloat16; + using nv_type = __nv_bfloat16; + using nv2_type = __nv_bfloat162; +}; + +template <> +struct type_traits<__nv_bfloat16> { + // using paddle_type = paddle::DataType::FLOAT16; + using phi_type = phi::dtype::bfloat16; + using nv_type = __nv_bfloat16; + using nv2_type = __nv_bfloat162; +}; + +template <> +struct type_traits<__nv_bfloat162> { + // using paddle_type = paddle::DataType::FLOAT16; + using phi_type = phi::dtype::bfloat16; + using nv_type = __nv_bfloat16; + using nv2_type = __nv_bfloat162; +}; + +// template <> +// struct type_traits { +// using paddle_type = paddle::DataType::FLOAT8_E4M3FN; +// using phi_type = phi::dtype::float8_e4m3fn; +// using nv_type = __nv_fp8_e4m3; +// using nv2_type = __nv_fp8x2_e4m3; +// }; + +template <> +struct type_traits { + // using paddle_type = paddle::DataType::FLOAT8_E4M3FN; + using phi_type = phi::dtype::float8_e4m3fn; + using nv_type = __nv_fp8_e4m3; + using nv2_type = __nv_fp8x2_e4m3; +}; + +template <> +struct type_traits<__nv_fp8_e4m3> { + // using paddle_type = paddle::DataType::FLOAT8_E4M3FN; + using phi_type = phi::dtype::float8_e4m3fn; + using nv_type = __nv_fp8_e4m3; + using nv2_type = __nv_fp8x2_e4m3; +}; + +template <> +struct type_traits<__nv_fp8x2_e4m3> { + // using paddle_type = paddle::DataType::FLOAT8_E4M3FN; + using phi_type = phi::dtype::float8_e4m3fn; + using nv_type = __nv_fp8_e4m3; + using nv2_type = __nv_fp8x2_e4m3; +}; +/*---------------------------------1. type + * traits--------------------------------------*/ + +/*---------------------------------2. fast + * convert--------------------------------------*/ +inline __device__ static void convert_fp8(half* result, + const uint32_t& source) { + printf("Do not support fp8 to half although it's very easy.\n"); + asm("trap;"); +} + +inline __device__ static void convert_fp8(__nv_bfloat16* result, + const uint32_t& source) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) + uint32_t dest0; + uint32_t dest1; + asm volatile( + "{\n" + ".reg .b16 lo, hi;\n" + "mov.b32 {lo, hi}, %2;\n" + "cvt.rn.f16x2.e4m3x2 %0, lo;\n" + "cvt.rn.f16x2.e4m3x2 %1, hi;\n" + "}\n" + : "=r"(dest0), "=r"(dest1) + : "r"(source)); + + ((nv_bfloat162*)(result))[0] = + __float22bfloat162_rn(__half22float2(((half2*)(&dest0))[0])); + ((nv_bfloat162*)(result))[1] = + __float22bfloat162_rn(__half22float2(((half2*)(&dest1))[0])); +#else + printf("Do not support fp8 in arch < 890\n"); + asm("trap;"); +#endif +} + +inline __device__ static void convert_int8( + half* result, const uint32_t& source) { // 4 int8 each time + uint32_t* fp16_result_ptr = reinterpret_cast(result); + uint32_t const i8s = reinterpret_cast(source); + static constexpr uint32_t mask_for_elt_01 = 0x5150; + static constexpr uint32_t mask_for_elt_23 = 0x5352; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(fp16_result_ptr[0]) + : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_01)); + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(fp16_result_ptr[1]) + : "r"(i8s), "n"(start_byte_for_fp16), "n"(mask_for_elt_23)); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(fp16_result_ptr[0]) + : "r"(fp16_result_ptr[0]), "r"(I8s_TO_F16s_MAGIC_NUM)); + asm volatile("sub.f16x2 %0, %1, %2;\n" + : "=r"(fp16_result_ptr[1]) + : "r"(fp16_result_ptr[1]), "r"(I8s_TO_F16s_MAGIC_NUM)); +} + +inline __device__ static void convert_int8( + __nv_bfloat16* result, const uint32_t& source) { // 4 int8 each time + uint32_t* bf16_result_ptr = reinterpret_cast(result); + uint32_t const i8s = reinterpret_cast(source); + + static constexpr uint32_t fp32_base = 0x4B000000; + float fp32_intermediates[4]; + + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7651); + fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7652); + fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653); + +#pragma unroll + for (int ii = 0; ii < 4; ++ii) { + fp32_intermediates[ii] -= 8388736.f; // (8388608.f + 128.f); + } + +#pragma unroll + for (int ii = 0; ii < 2; ++ii) { + bf16_result_ptr[ii] = __byte_perm(fp32_intermediates_casted[2 * ii + 0], + fp32_intermediates_casted[2 * ii + 1], + 0x7632); + } +} +/*---------------------------------2. fast + * convert--------------------------------------*/ + +/*---------------------------------3. vector + * cast--------------------------------------*/ +template +__forceinline__ HOSTDEVICE void vec_cast(dst_t* dst, const src_t* src) { +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + dst[i] = src[i]; + } +} + +template +__forceinline__ HOSTDEVICE void vec_cast(float* dst, + const half* src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((float2*)dst)[i] = __half22float2(((half2*)src)[i]); + } +} + +template +__forceinline__ HOSTDEVICE void vec_cast(half* dst, + const float* src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((half2*)dst)[i] = __float22half2_rn(((float2*)src)[i]); + } +} + +template +__forceinline__ HOSTDEVICE void vec_cast( + float* dst, const nv_bfloat16* src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((float2*)dst)[i] = __bfloat1622float2(((nv_bfloat162*)src)[i]); + } +} + +template +__forceinline__ HOSTDEVICE void vec_cast(nv_bfloat16* dst, + const float* src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((nv_bfloat162*)dst)[i] = __float22bfloat162_rn(((float2*)src)[i]); + } +} +/*---------------------------------3. vector + * cast--------------------------------------*/ + +/*-------------------------------------4. + * func-----------------------------------------*/ +__forceinline__ HOSTDEVICE int div_up(int a, int b) { + return a / b + (a % b != 0); +} + +template +__inline__ __device__ T Rsqrt(T x); + +template <> +__inline__ __device__ float Rsqrt(float x) { + return rsqrt(x); +} + +template <> +__inline__ __device__ double Rsqrt(double x) { + return rsqrt(x); +} + +__device__ __forceinline__ uint32_t sub_if_greater_or_zero(uint32_t x, + uint32_t y) { + return (x > y) ? x - y : 0U; +} + +template +inline HOSTDEVICE T roundWithTiesToEven(T x) { + T xLower = floor(x); + T xUpper = ceil(x); + // x is in interval [xl,xu]. Choose closest of two bounds, breaking ties to + // even. + T dLower = x - xLower; + T dUpper = xUpper - x; + return static_cast( + (dLower == dUpper ? fmod(xLower, 2.0F) == 0.0F : dLower < dUpper) + ? xLower + : xUpper); +} + +template +HOSTDEVICE __forceinline__ uint8_t QuantToC8(const T scale, + const T value, + const float max_bound, + const float min_bound) { + uint8_t eight_bits; + float quant_value; + if constexpr (is_need_kv_quant) { + quant_value = static_cast(scale * value); + } else { + quant_value = static_cast(value); + } + if constexpr (RoundType == 0) { + quant_value = roundWithTiesToEven(quant_value); + } else { + quant_value = round(quant_value); + } + + if constexpr (IsFP8) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) + quant_value = quant_value > 448.0f ? 448.0f : quant_value; + quant_value = quant_value < -448.0f ? -448.0f : quant_value; + auto tmp = static_cast<__nv_fp8_e4m3>(quant_value); + eight_bits = *(reinterpret_cast(&tmp)); +#else + printf("Do not support fp8 in arch < 890\n"); + asm("trap;"); +#endif + } else { + quant_value = quant_value > 127.0f ? 127.0f : quant_value; + quant_value = quant_value < -127.0f ? -127.0f : quant_value; + eight_bits = static_cast(quant_value + 128.0f); + } + return eight_bits; +} + +template +inline __device__ static void convert_c8(T* result, const uint32_t& source) { + if constexpr (IsFP8) { + convert_fp8(result, source); + } else { + convert_int8(result, source); + } +} + +template +inline __device__ void WelfordCombine1(T b_m2, T* m2) { + *m2 += b_m2; +} + +template +__inline__ __device__ void WelfordWarpReduce(T thread_m2, T* m2) { + *m2 = thread_m2; + for (int mask = thread_group_width / 2; mask > 0; mask >>= 1) { + T b_m2 = __shfl_xor_sync(0xffffffff, *m2, mask); + WelfordCombine1(b_m2, m2); + } +} + +template +__inline__ __device__ void WelfordWarpAllReduce(T thread_m2, T* m2) { + WelfordWarpReduce(thread_m2, m2); +} + +#define CHECK_CUDA_CALL(func, ...) \ + { \ + cudaError_t e = (func); \ + if (e != cudaSuccess) { \ + std::cerr << "CUDA Error: " << cudaGetErrorString(e) << " (" << e \ + << ") " << __FILE__ << ": line " << __LINE__ \ + << " at function " << STR(func) << std::endl; \ + return e; \ + } \ + } + +__device__ __forceinline__ float2 fast_float2_mul(const float2& a, + const float2& b) { + float2 res; + // 使用向量化PTX指令同时处理x/y分量 + asm volatile( + "{\n" + " fma.rn.f32 %0, %2, %4, 0.0;\n" // res.x = a.x * b.x + " fma.rn.f32 %1, %3, %5, 0.0;\n" // res.y = a.y * b.y + "}" + : "=f"(res.x), "=f"(res.y) // 输出操作数 + : "f"(a.x), "f"(a.y), "f"(b.x), "f"(b.y) // 输入操作数 + ); + return res; +} + +__device__ __forceinline__ float2 fast_float2_fma(float2& a, + const float2& b, + const float2& c) { + float2 res; + // 使用向量化PTX指令同时处理x/y分量 + asm volatile( + "{\n" + " fma.rn.f32 %0, %2, %4, %6;\n" // res.x = a.x * b.x + " fma.rn.f32 %1, %3, %5, %7;\n" // res.y = a.y * b.y + "}" + : "=f"(res.x), "=f"(res.y) // 输出操作数 + : "f"(a.x), + "f"(a.y), + "f"(b.x), + "f"(b.y), + "f"(c.x), + "f"(c.y) // 输入操作数 + ); + return res; +} + +// __device__ __forceinline__ float2 fast_bfloat162_fma(__nv_bfloat162& a_bf162, +// const __nv_bfloat162& b_bf162, const __nv_bfloat162& c_bf162) { +// // 使用向量化PTX指令同时处理x/y分量 +// asm volatile ( +// "{\n" +// " fma.rn.b16 %0, %2, %4, %0;\n" // res.x = a.x * b.x +// " fma.rn.b16 %1, %3, %5, %1;\n" // res.y = a.y * b.y +// "}" +// : "=r"(a_bf162.x), "=r"(a_bf162.y) // 输出操作数 +// : "r"(b_bf162.x), "r"(b_bf162.y), +// "r"(c_bf162.x), "r"(c_bf162.y) // 输入操作数 +// ); +// float2 res = __bfloat1622float2_rn(a_bf162); +// return res; +// } + +__device__ __forceinline__ float2 fast_float2_sub_expf(const float2& a, + const float2& b) { + float2 res; + // 使用向量化减法指令(PTX sub.rn.f32) + asm volatile( + "{\n" + " sub.f32 %0, %2, %4;\n" // res.x = a.x - b.x + " sub.f32 %1, %3, %5;\n" // res.y = a.y - b.y + "}" + : "=f"(res.x), "=f"(res.y) // 输出操作数 + : "f"(a.x), "f"(a.y), "f"(b.x), "f"(b.y) // 输入操作数 + ); + res.x = expf(res.x); + res.y = expf(res.y); + return res; +} + +template +struct StoreFunc { + __device__ __forceinline__ void operator()( + const AlignedVector& ori_out_vec, + const AlignedVector& shift_bias_vec, + const AlignedVector& smooth_weight_vec, + AlignedVector& out_vec, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int i) { + out_vec[i] = static_cast(ori_out_vec[i]); + printf("Fatal! Unimplemented StoreFunc for cascade append attention\n"); + } +}; + +template +struct StoreFunc { + __device__ __forceinline__ void operator()( + const AlignedVector& ori_out_vec, + const AlignedVector& shift_bias_vec, + const AlignedVector& smooth_weight_vec, + AlignedVector& out_vec, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int i) { + float quant_value = + 127.0f * + static_cast((ori_out_vec[i] + shift_bias_vec[i]) * + smooth_weight_vec[i]) * + in_scale; + quant_value = rintf(quant_value); + quant_value = quant_value > 127.0f ? 127.0f : quant_value; + quant_value = quant_value < -127.0f ? -127.0f : quant_value; + out_vec[i] = static_cast(quant_value); + } +}; + +template +struct StoreFunc { + __device__ __forceinline__ void operator()( + const AlignedVector& ori_out_vec, + const AlignedVector& shift_bias_vec, + const AlignedVector& smooth_weight_vec, + AlignedVector<__nv_fp8_e4m3, VEC_SIZE>& out_vec, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int i) { + float quant_value = + quant_max_bound * static_cast(ori_out_vec[i]) * in_scale; + quant_value = quant_value > quant_max_bound ? quant_max_bound : quant_value; + quant_value = quant_value < quant_min_bound ? quant_min_bound : quant_value; + out_vec[i] = static_cast<__nv_fp8_e4m3>(quant_value); + } +}; + +template +struct StoreFunc { + __device__ __forceinline__ void operator()( + const AlignedVector& ori_out_vec, + const AlignedVector& shift_bias_vec, + const AlignedVector& smooth_weight_vec, + AlignedVector& out_vec, + const float quant_max_bound, + const float quant_min_bound, + const float in_scale, + const int i) { + out_vec[i] = ori_out_vec[i]; + } +}; +/*-------------------------------------4. + * func-----------------------------------------*/ + +/*-----------------------------------5. + * dispatch---------------------------------------*/ +#define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \ + switch (head_dim) { \ + case 128: { \ + constexpr size_t HEAD_DIM = 128; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + PD_THROW("not support the head_dim"); \ + } \ + } + +#define DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \ + if (group_size == 1) { \ + constexpr size_t GROUP_SIZE = 1; \ + __VA_ARGS__ \ + } else if (group_size == 2) { \ + constexpr size_t GROUP_SIZE = 2; \ + __VA_ARGS__ \ + } else if (group_size == 3) { \ + constexpr size_t GROUP_SIZE = 3; \ + __VA_ARGS__ \ + } else if (group_size == 4) { \ + constexpr size_t GROUP_SIZE = 4; \ + __VA_ARGS__ \ + } else if (group_size == 5) { \ + constexpr size_t GROUP_SIZE = 5; \ + __VA_ARGS__ \ + } else if (group_size == 6) { \ + constexpr size_t GROUP_SIZE = 6; \ + __VA_ARGS__ \ + } else if (group_size == 7) { \ + constexpr size_t GROUP_SIZE = 7; \ + __VA_ARGS__ \ + } else if (group_size == 8) { \ + constexpr size_t GROUP_SIZE = 8; \ + __VA_ARGS__ \ + } else if (group_size == 12) { \ + constexpr size_t GROUP_SIZE = 12; \ + __VA_ARGS__ \ + } else if (group_size == 14) { \ + constexpr size_t GROUP_SIZE = 14; \ + __VA_ARGS__ \ + } else if (group_size == 16) { \ + constexpr size_t GROUP_SIZE = 16; \ + __VA_ARGS__ \ + } else { \ + PD_THROW("not support the group_size", group_size); \ + } + +#define DISPATCH_BLOCKSHAPE_Q(block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, ...) \ + if (block_shape_q <= 16) { \ + constexpr size_t BLOCK_SHAPE_Q = 16; \ + constexpr size_t NUM_WARP_Q = 1; \ + __VA_ARGS__ \ + } else if (block_shape_q <= 32) { \ + constexpr size_t BLOCK_SHAPE_Q = 32; \ + constexpr size_t NUM_WARP_Q = 1; \ + __VA_ARGS__ \ + } + +#define DISPATCH_Q_TILE_SIZE( \ + group_size, max_tokens_per_batch, Q_TILE_SIZE, ...) \ + { \ + constexpr size_t Q_TILE_SIZE = 16; \ + __VA_ARGS__ \ + } + +#define DISPATCH_CAUSAL(causal, CAUSAL, ...) \ + if (causal) { \ + constexpr bool CAUSAL = true; \ + __VA_ARGS__ \ + } else { \ + constexpr bool CAUSAL = false; \ + __VA_ARGS__ \ + } + +#define DISPATCH_BLOCKSHAPE_Q_SYSTEM( \ + block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, ...) \ + if (block_shape_q <= 16) { \ + constexpr size_t BLOCK_SHAPE_Q = 16; \ + constexpr size_t NUM_WARP_Q = 1; \ + __VA_ARGS__ \ + } else if (block_shape_q <= 32) { \ + constexpr size_t BLOCK_SHAPE_Q = 32; \ + constexpr size_t NUM_WARP_Q = 1; \ + __VA_ARGS__ \ + } + +#define DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, ...) \ + if (block_size == 64) { \ + constexpr size_t BLOCK_SIZE = 64; \ + __VA_ARGS__ \ + } + +#define DISPATCH_DyCfp8(is_dynamic_cfp8, IsDynamicC8, ...) \ + if (is_dynamic_cfp8) { \ + constexpr bool IsDynamicC8 = true; \ + __VA_ARGS__ \ + } else { \ + constexpr bool IsDynamicC8 = false; \ + __VA_ARGS__ \ + } + +#define DISPATCH_IS_FP8(is_fp8, IS_FP8, ...) \ + if (is_fp8) { \ + constexpr bool IS_FP8 = true; \ + __VA_ARGS__ \ + } else { \ + constexpr bool IS_FP8 = false; \ + __VA_ARGS__ \ + } + +struct AppendAttnMetaData { + int batch_size; + int block_size; + int q_num_heads; + int kv_num_heads; + int token_num; + int head_dims; + int head_dims_v; + int max_blocks_per_seq; + const int* mask_offset = nullptr; +}; + +template +struct AttentionParams { + T* __restrict__ qkv; + CacheT* __restrict__ cache_k; + CacheT* __restrict__ cache_v; + T* __restrict__ cache_k_scale; + T* __restrict__ cache_v_scale; + int* __restrict__ seq_lens_q; + int* __restrict__ seq_lens_kv; + int* __restrict__ block_indices; + int* __restrict__ num_blocks_ptr; + int* __restrict__ chunk_size_ptr; + int* __restrict__ cu_seqlens_q; + int* __restrict__ block_table; + int* __restrict__ mask_offset; + bool* __restrict__ attn_mask; + T* __restrict__ tmp_o; + float* __restrict__ tmp_m; + float* __restrict__ tmp_d; + int max_model_len; + int max_kv_len; + int max_blocks_per_seq; + float softmax_scale; + float quant_max_bound; + float quant_min_bound; + int num_blocks_x; + int attn_mask_len; + bool sliding_window; + int q_num_heads; + int kv_num_heads; + int max_num_chunks; + int max_tile_q; + int batch_size; + int token_num; + int head_dims; + int max_tokens_per_batch; +}; diff --git a/custom_ops/gpu_ops/decoder_write_cache_with_rope.cu b/custom_ops/gpu_ops/decoder_write_cache_with_rope.cu new file mode 100644 index 00000000000..7878e9926c5 --- /dev/null +++ b/custom_ops/gpu_ops/decoder_write_cache_with_rope.cu @@ -0,0 +1,326 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "append_attn/decoder_write_cache_with_rope_kernel.h" +#include "append_attn/speculate_write_cache_with_rope_kernel.h" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +template +class type2value; + +template <> +class type2value { + public: + static constexpr paddle::DataType value = paddle::DataType::BFLOAT16; +}; + +template <> +class type2value { + public: + static constexpr paddle::DataType value = paddle::DataType::FLOAT16; +}; + +std::vector DecoderWriteCacheWithRoPE( + const paddle::Tensor& qkv, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& batch_id_per_token, + const paddle::Tensor& cu_seqlens_q, + const paddle::Tensor& block_tables, + const paddle::Tensor& set_max_lengths, + const paddle::optional& rotary_embs, + const paddle::optional& qkv_bias, + const paddle::optional& cache_k_quant_scales, + const paddle::optional& cache_v_quant_scales, + const paddle::optional& cache_k_dequant_scales, + const paddle::optional& cache_v_dequant_scales, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& kv_signal_data, + const paddle::optional& q_norm_weight, + const paddle::optional& k_norm_weight, + const float rms_norm_eps, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const bool speculate_decoder) { + auto stream = qkv.stream(); + + AppendAttnMetaData meta_data; + + const auto& qkv_dims = qkv.dims(); + const auto& key_cache_dims = key_cache.dims(); + meta_data.token_nums = qkv_dims[0]; + meta_data.kv_num_heads = key_cache_dims[1]; + meta_data.head_dims = key_cache_dims[3]; + // TODO: trick method support c4, add attr head_dims in the future + if (cache_quant_type_str == "cache_int4_zp") { + meta_data.head_dims *= 2; + } + const int total_num_head = + qkv_dims[qkv_dims.size() - 1] / meta_data.head_dims; + meta_data.q_num_heads = total_num_head - 2 * meta_data.kv_num_heads; + + meta_data.max_blocks_per_seq = block_tables.dims()[1]; + meta_data.block_size = key_cache.dims()[2]; + meta_data.batch_size = seq_lens_this_time.dims()[0]; + + const int max_just_dec_len_this_time = set_max_lengths.data()[4]; + + if (max_just_dec_len_this_time > 0) { + if (speculate_decoder) { + switch (qkv.dtype()) { + case paddle::DataType::BFLOAT16: { + SpeculateWriteCacheWithRoPEKernel( + meta_data, + qkv, // [token_num, num_heads, head_dim] + seq_lens_decoder, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + rotary_embs, + NULL, + qkv_bias, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_zp, + cache_v_zp, + cache_quant_type_str, + use_neox_rotary_style, + rope_3d, + max_input_length, + stream, + const_cast(&qkv), + const_cast(&key_cache), + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); + break; + } + case paddle::DataType::FLOAT16: { + SpeculateWriteCacheWithRoPEKernel( + meta_data, + qkv, // [token_num, num_heads, head_dim] + seq_lens_decoder, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + rotary_embs, + NULL, + qkv_bias, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_zp, + cache_v_zp, + cache_quant_type_str, + use_neox_rotary_style, + rope_3d, + max_input_length, + stream, + const_cast(&qkv), + const_cast(&key_cache), + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); + break; + } + default: + PD_THROW( + "NOT supported data type. " + "Only bfloat16 and float16 are supported. "); + } + } else { + switch (qkv.dtype()) { + case paddle::DataType::BFLOAT16: { + DecoderWriteCacheWithRoPEKernel( + meta_data, + qkv, // [token_num, num_heads, head_dim] + seq_lens_decoder, + seq_lens_encoder, + cu_seqlens_q, + block_tables, + rotary_embs, + NULL, + qkv_bias, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_zp, + cache_v_zp, + cache_quant_type_str, + use_neox_rotary_style, + rope_3d, + max_input_length, + stream, + const_cast(&qkv), + const_cast(&key_cache), + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); + break; + } + case paddle::DataType::FLOAT16: { + DecoderWriteCacheWithRoPEKernel( + meta_data, + qkv, // [token_num, num_heads, head_dim] + seq_lens_decoder, + seq_lens_encoder, + cu_seqlens_q, + block_tables, + rotary_embs, + NULL, + qkv_bias, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_zp, + cache_v_zp, + cache_quant_type_str, + use_neox_rotary_style, + rope_3d, + max_input_length, + stream, + const_cast(&qkv), + const_cast(&key_cache), + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); + break; + } + default: + PD_THROW( + "NOT supported data type. " + "Only bfloat16 and float16 are supported. "); + } + } + } + return {qkv}; +} + +std::vector> DecoderWriteCacheWithRoPEInferShape( + const std::vector& qkv_shape, + const std::vector& key_cache_shape, + const std::vector& value_cache_shape, + const std::vector& seq_lens_encoder_shape, + const std::vector& seq_lens_decoder_shape, + const std::vector& seq_lens_this_time_shape, + const std::vector& batch_id_per_token_shape, + const std::vector& cu_seqlens_q_shape, + const std::vector& block_tables_shape, + const std::vector& set_max_lengths_shape, + const paddle::optional>& rotary_embs_shape, + const paddle::optional>& qkv_bias_shape, + const paddle::optional>& cache_k_quant_scales_shape, + const paddle::optional>& cache_v_quant_scales_shape, + const paddle::optional>& cache_k_dequant_scales_shape, + const paddle::optional>& cache_v_dequant_scales_shape, + const paddle::optional>& cache_k_zp_shape, + const paddle::optional>& cache_v_zp_shape, + const paddle::optional>& kv_signal_data_shape, + const paddle::optional>& q_norm_weight_shape, + const paddle::optional>& k_norm_weight_shape, + const float rms_norm_eps, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const bool speculate_decoder) { + return {qkv_shape}; +} + +std::vector DecoderWriteCacheWithRoPEInferDtype( + const paddle::DataType& qkv_dtype, + const paddle::DataType& key_cache_dtype, + const paddle::DataType& value_cache_dtype, + const paddle::DataType& seq_lens_encoder_dtype, + const paddle::DataType& seq_lens_decoder_dtype, + const paddle::DataType& seq_lens_this_time_dtype, + const paddle::DataType& batch_id_per_token_dtype, + const paddle::DataType& cu_seqlens_q_dtype, + const paddle::DataType& block_tables_dtype, + const paddle::DataType& set_max_lengths_dtype, + const paddle::optional& rotary_embs_dtype, + const paddle::optional& qkv_bias_dtype, + const paddle::optional& cache_k_quant_scales_dtype, + const paddle::optional& cache_v_quant_scales_dtype, + const paddle::optional& cache_k_dequant_scales_dtype, + const paddle::optional& cache_v_dequant_scales_dtype, + const paddle::optional& cache_k_zp_dtype, + const paddle::optional& cache_v_zp_dtype, + const paddle::optional& kv_signal_data_dtype, + const paddle::optional& q_norm_weight_dtype, + const paddle::optional& k_norm_weight_dtype, + const float rms_norm_eps, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const bool speculate_decoder) { + return {qkv_dtype}; +} + +PD_BUILD_STATIC_OP(decoder_write_cache_with_rope) + .Inputs({"qkv", + "key_cache", + "value_cache", + "seq_lens_encoder", + "seq_lens_decoder", + "seq_lens_this_time", + "batch_id_per_token", + "cu_seqlens_q", + "block_tables", + "set_max_lengths", + paddle::Optional("rotary_embs"), + paddle::Optional("qkv_bias"), + paddle::Optional("cache_k_quant_scales"), + paddle::Optional("cache_v_quant_scales"), + paddle::Optional("cache_k_dequant_scales"), + paddle::Optional("cache_v_dequant_scales"), + paddle::Optional("cache_k_zp"), + paddle::Optional("cache_v_zp"), + paddle::Optional("kv_signal_data"), + paddle::Optional("q_norm_weight"), + paddle::Optional("k_norm_weight")}) + .Outputs({"qkv_out"}) + .SetInplaceMap({{"qkv", "qkv_out"}}) + .Attrs({ + "rms_norm_eps: float", + "cache_quant_type: std::string", + "use_neox_rotary_style: bool", + "rope_3d: bool", + "max_input_length: int", + "quant_max_bound: float", + "quant_min_bound: float", + "speculate_decoder: bool", + }) + .SetKernelFn(PD_KERNEL(DecoderWriteCacheWithRoPE)) + .SetInferShapeFn(PD_INFER_SHAPE(DecoderWriteCacheWithRoPEInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(DecoderWriteCacheWithRoPEInferDtype)); diff --git a/custom_ops/setup_ops.py b/custom_ops/setup_ops.py index db3a0037364..7072338bff8 100644 --- a/custom_ops/setup_ops.py +++ b/custom_ops/setup_ops.py @@ -554,6 +554,13 @@ def find_end_files(directory, end_str): sources += find_end_files(fp8_auto_gen_directory, ".cu") if cc >= 90 and nvcc_version >= 12.0: + # decode unified attention + os.system( + "python utils/auto_gen_template_attention.py --config gpu_ops/decode_unified_attention/template_config.json --output gpu_ops/decode_unified_attention/template_instantiation/autogen" + ) + sources += ["gpu_ops/decode_unified_attention.cu"] + sources += ["gpu_ops/decoder_write_cache_with_rope.cu"] + sources += find_end_files("gpu_ops/decode_unified_attention", ".cu") # Hopper optimized mla sources += find_end_files("gpu_ops/mla_attn", ".cu") sources += ["gpu_ops/flash_mask_attn/flash_mask_attn.cu"] diff --git a/custom_ops/utils/auto_gen_template_attention.py b/custom_ops/utils/auto_gen_template_attention.py new file mode 100644 index 00000000000..5658f6645e7 --- /dev/null +++ b/custom_ops/utils/auto_gen_template_attention.py @@ -0,0 +1,227 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Universal template instantiation generator - fully based on configuration file template instantiation generation.""" + +import argparse +import json +import shutil +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + + +@dataclass +class TemplateConfig: + """Template configuration class.""" + + name: str # Function name + function_name: str # Actual function name + impl_file: str # Implementation file path + template_params: List[str] # Template parameter list (in order) + dispatch_params: Dict[str, List[Any]] # Dispatch parameters + data_types: Optional[List[Tuple[str, str, str]]] = None # Data type combinations (input_type, output_type, suffix) + max_instances_per_file: int = 60 # Maximum instances per file + file_prefix: str = "" # File prefix + function_signature: str = "" # Function signature template + + +class UniversalTemplateInstantiator: + """Universal template instantiator - fully based on configuration file.""" + + def __init__(self, config_file: str): + """Initialize the instantiator.""" + self.config_file = config_file + self.configs = self._load_configs() + + def _load_configs(self) -> Dict[str, TemplateConfig]: + """Load configuration file.""" + with open(self.config_file, "r", encoding="utf-8") as f: + config_data = json.load(f) + + configs = {} + for name, config_dict in config_data.items(): + config = TemplateConfig(**config_dict) + self._validate_config(config) + configs[name] = config + return configs + + def _validate_config(self, config: TemplateConfig): + """Validate configuration completeness.""" + for param_name in config.template_params: + if param_name not in config.dispatch_params: + raise ValueError(f"Template parameter '{param_name}' in '{config.name}' not found in dispatch_params") + + def _build_template_args(self, config: TemplateConfig, params: Dict[str, Any]) -> str: + """Build template arguments.""" + template_args_parts = [] + + for param_name in config.template_params: + if param_name in params: + template_args_parts.append(str(params[param_name])) + + else: + raise ValueError(f"Template parameter '{param_name}' not found in dispatch_params") + + return f"<{', '.join(template_args_parts)}>" + + def _build_params_template_args(self, params: Dict[str, Any]) -> str: + """Build template arguments for AttentionParams.""" + params_template_args = [] + if "T" in params: + params_template_args.append(str(params["T"])) + else: + raise ValueError("Template parameter 'T' not found in dispatch_params") + + if "CacheT" in params: + params_template_args.append(str(params["CacheT"])) + else: + # C16 kernels use AttentionParams - T is repeated for both args + params_template_args.append(str(params["T"])) + + return f"<{', '.join(params_template_args)}>" + + def _generate_function_signature( + self, config: TemplateConfig, template_args: str, params_template_args: str + ) -> str: + """Generate function signature.""" + if config.function_signature: + signature = config.function_signature.format( + function_name=config.function_name, + template_args=template_args, + params_template_args=params_template_args, + ) + + return signature + else: + raise ValueError(f"Function signature not found for {config.name}") + + def _generate_file_header(self, config: TemplateConfig) -> str: + """Generate file header.""" + return f"""// Generated by autogen_template_instantiation.py - Do not edit. + +#pragma once + +#include "../../{config.impl_file}" +""" + + def _generate_template_instantiation(self, config: TemplateConfig, params: Dict[str, Any]) -> str: + """Generate template instantiation.""" + template_args = self._build_template_args(config, params) + params_template_args = self._build_params_template_args(params) + return self._generate_function_signature(config, template_args, params_template_args) + + def _clean_output_directory(self, output_dir: str): + """Clean output directory before generating new files.""" + output_path = Path(output_dir) + if output_path.exists(): + shutil.rmtree(output_path) + output_path.mkdir(parents=True, exist_ok=True) + + def generate_combinations_for_type(self, config: TemplateConfig) -> List[Dict[str, Any]]: + """Generate parameter combinations for specific type.""" + combinations = [] + + def _generate_recursive( + params_dict: Dict[str, List[Any]], current_params: Dict[str, Any], param_names: List[str] + ): + if not param_names: + combinations.append(current_params.copy()) + return + + param_name = param_names[0] + for value in params_dict[param_name]: + current_params[param_name] = value + _generate_recursive(params_dict, current_params, param_names[1:]) + + _generate_recursive(config.dispatch_params, {}, list(config.dispatch_params.keys())) + + return combinations + + def split_combinations(self, combinations: List[Dict[str, Any]], max_per_file: int) -> List[List[Dict[str, Any]]]: + """Split combinations into multiple files.""" + chunks = [] + for i in range(0, len(combinations), max_per_file): + chunk = combinations[i : i + max_per_file] + chunks.append(chunk) + return chunks + + def generate_file_content( + self, + config: TemplateConfig, + file_index: int, + combinations: List[Dict[str, Any]], + ) -> str: + """Generate file content.""" + content = self._generate_file_header(config) + + for params in combinations: + content += self._generate_template_instantiation(config, params) + + return content + + def generate_for_function_type(self, function_name: str, output_dir: str): + """Generate template instantiation files for specific function type.""" + if function_name not in self.configs: + raise ValueError(f"Function type '{function_name}' not found in config") + + config = self.configs[function_name] + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + combinations = self.generate_combinations_for_type(config) + if combinations: + chunks = self.split_combinations(combinations, config.max_instances_per_file) + for i, chunk in enumerate(chunks): + filename = f"{config.file_prefix}_part_{i:02d}.cu" + filepath = output_path / filename + content = self.generate_file_content(config, i, chunk) + with open(filepath, "w", encoding="utf-8") as f: + f.write(content) + + def generate_all(self, output_dir: str): + """Generate all configured function types.""" + self._clean_output_directory(output_dir) + for function_name in self.configs.keys(): + print(f"Generating template instantiations for {function_name}...") + self.generate_for_function_type(function_name, output_dir) + print(f"Completed generating {function_name} template instantiations.") + + +def main(): + """Main function.""" + parser = argparse.ArgumentParser(description="Universal template instantiation generator") + parser.add_argument( + "--config", + "-c", + type=str, + help="Configuration file path (JSON format)", + ) + parser.add_argument( + "--output", + "-o", + type=str, + help="Output directory", + ) + + args = parser.parse_args() + + try: + instantiator = UniversalTemplateInstantiator(args.config) + instantiator.generate_all(args.output) + except Exception as e: + print(f"Error: {e}") + + +if __name__ == "__main__": + main() diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 99c5ab776f7..5efcfbf6592 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -63,8 +63,8 @@ def _validate_split_kv_size(value: int) -> int: "FD_ZMQ_SNDHWM": lambda: os.getenv("FD_ZMQ_SNDHWM", 0), # cache kv quant params directory "FD_CACHE_PARAMS": lambda: os.getenv("FD_CACHE_PARAMS", "none"), - # Set attention backend. "NATIVE_ATTN", "APPEND_ATTN" - # and "MLA_ATTN" can be set currently. + # Set attention backend. "NATIVE_ATTN", "APPEND_ATTN", "DECODE_UNIFIED_ATTN", + # "FLASH_ATTN" and "MLA_ATTN" can be set currently. "FD_ATTENTION_BACKEND": lambda: os.getenv("FD_ATTENTION_BACKEND", "APPEND_ATTN"), # Set sampling class. "base", "base_non_truncated", "air", "rejection" and "triton" can be set currently. "FD_SAMPLING_CLASS": lambda: os.getenv("FD_SAMPLING_CLASS", "base"), @@ -244,6 +244,8 @@ def _validate_split_kv_size(value: int) -> int: "FD_DETERMINISTIC_SPLIT_KV_SIZE": lambda: _validate_split_kv_size( int(os.getenv("FD_DETERMINISTIC_SPLIT_KV_SIZE", "16")) ), + # Whether to use unified attention kernel in mix + "USE_DECODE_UNIFIED_ATTENTION": lambda: bool(int(os.getenv("USE_DECODE_UNIFIED_ATTENTION", "0"))), # Enable determinism logging (print MD5 hashes and debug info) "FD_DETERMINISTIC_LOG_MODE": lambda: bool(int(os.getenv("FD_DETERMINISTIC_LOG_MODE", "0"))), # Whether to use PD REORDER, can set 0 or 1 diff --git a/fastdeploy/model_executor/forward_meta.py b/fastdeploy/model_executor/forward_meta.py index 516344a17f4..e32fb1209e1 100644 --- a/fastdeploy/model_executor/forward_meta.py +++ b/fastdeploy/model_executor/forward_meta.py @@ -124,6 +124,14 @@ class ForwardMeta: decoder_chunk_size_device: Optional[paddle.Tensor] = None + # Buffer for decode unified attention stage + decode_block_indices: Optional[paddle.Tensor] = None + decode_num_blocks: Optional[paddle.Tensor] = None + decode_chunk_size: Optional[paddle.Tensor] = None + decode_tmp_workspace: Optional[paddle.Tensor] = None + decode_tmp_m: Optional[paddle.Tensor] = None + decode_tmp_d: Optional[paddle.Tensor] = None + # Sequence length of encoder for ever batch seq_lens_encoder: Optional[paddle.Tensor] = None # Sequence length of Encoder for ever batch diff --git a/fastdeploy/model_executor/layers/attention/__init__.py b/fastdeploy/model_executor/layers/attention/__init__.py index 7efc3259fbc..dee4de4df76 100644 --- a/fastdeploy/model_executor/layers/attention/__init__.py +++ b/fastdeploy/model_executor/layers/attention/__init__.py @@ -17,6 +17,7 @@ from .attention_selecter import get_attention_backend from .base_attention_backend import AttentionBackend from .block_multihead_attn_backend import BlockAttentionBackend +from .decode_unified_attention_backend import DecodeUnifiedAttentionBackend from .dsa_attention_backend import DSAAttentionBackend from .flash_attn_backend import FlashAttentionBackend from .flash_mask_attn_backend import FlashMaskAttentionBackend @@ -36,4 +37,5 @@ "Attention", "PlasAttentionBackend", "FlashMaskAttentionBackend", + "DecodeUnifiedAttentionBackend", ] diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index eba781faae0..76de638bce6 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -73,6 +73,8 @@ def allocate_launch_related_buffer( num_heads, kv_num_heads, block_size, + head_dim=128, + dtype="bfloat16", ): # Initialize AttentionBackend buffers assert num_heads % kv_num_heads == 0 @@ -107,6 +109,28 @@ def allocate_launch_related_buffer( res["kv_batch_ids"] = paddle.full([kv_max_tile_size], 0, dtype="int32") res["kv_tile_ids_per_batch"] = paddle.full([kv_max_tile_size], 0, dtype="int32") res["kv_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu() + + # Decode unified attention split ops buffers + if envs.USE_DECODE_UNIFIED_ATTENTION: + min_chunk_size = 512 + max_num_chunk = (max_model_len + min_chunk_size - 1) // min_chunk_size + q_tile_size = 16 + q_tile_num = (decoder_step_token_num * group_size + q_tile_size - 1) // q_tile_size + res["decode_block_indices"] = paddle.full( + [max_batch_size * kv_num_heads * max_num_chunk * q_tile_num, 4], 0, dtype="int32" + ) + res["decode_num_blocks"] = paddle.full([1], 0, dtype="int32") + res["decode_chunk_size"] = paddle.full([1], 0, dtype="int32") + res["decode_tmp_workspace"] = paddle.full( + [max_batch_size * decoder_step_token_num, max_num_chunk, num_heads * head_dim], 0, dtype=dtype + ) + res["decode_tmp_m"] = paddle.full( + [max_batch_size * decoder_step_token_num, max_num_chunk, num_heads], 0, dtype="float32" + ) + res["decode_tmp_d"] = paddle.full( + [max_batch_size * decoder_step_token_num, max_num_chunk, num_heads], 0, dtype="float32" + ) + return res diff --git a/fastdeploy/model_executor/layers/attention/decode_unified_attention_backend.py b/fastdeploy/model_executor/layers/attention/decode_unified_attention_backend.py new file mode 100644 index 00000000000..71ae5ec76e2 --- /dev/null +++ b/fastdeploy/model_executor/layers/attention/decode_unified_attention_backend.py @@ -0,0 +1,363 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, List, Optional + +import paddle + +from fastdeploy.model_executor.layers.attention.ops import ( + config_for_attention, + decode_unified_attention, + decoder_write_cache_with_rope, + init_kv_signal_per_query, + init_signal_layerwise, + open_shm_and_get_meta_signal, +) +from fastdeploy.spec_decode import SpecMethod + +if TYPE_CHECKING: + from fastdeploy.model_executor.forward_meta import ForwardMeta + +from fastdeploy.config import FDConfig +from fastdeploy.model_executor.layers.attention.attention import Attention +from fastdeploy.model_executor.layers.attention.base_attention_backend import ( + AttentionBackend, + AttentionMetadata, +) +from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id + + +@dataclass +class DecodeUnifiedAttentionMetadata(AttentionMetadata): + """ + DecodeUnifiedAttentionMetadata + """ + + _dtype: paddle.dtype = paddle.bfloat16 + # pd_disaggregation + kv_signal_metadata: Optional[paddle.Tensor] = None + kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list) + + _fuse_kernel_compute_dtype: str = "bf16" + + +def allocate_decode_unified_related_buffer( + max_batch_size, + max_model_len, + encoder_block_shape_q, + decoder_block_shape_q, + decoder_step_token_num, + num_heads, + kv_num_heads, + block_size, + head_dim=128, + dtype="bfloat16", +): + # Initialize AttentionBackend buffers + assert num_heads % kv_num_heads == 0 + assert max_model_len % block_size == 0 + assert max_model_len % encoder_block_shape_q == 0 + group_size = num_heads // kv_num_heads + + res = {} + + # Decode unified attention split ops buffers + res["max_len_tensor_cpu"] = paddle.full([6], 0, dtype="int32").cpu() + min_chunk_size = 512 + max_num_chunk = (max_model_len + min_chunk_size - 1) // min_chunk_size + q_tile_size = 16 + q_tile_num = (decoder_step_token_num * group_size + q_tile_size - 1) // q_tile_size + res["decode_block_indices"] = paddle.full( + [max_batch_size * kv_num_heads * max_num_chunk * q_tile_num, 4], 0, dtype="int32" + ) + res["decode_num_blocks"] = paddle.full([1], 0, dtype="int32") + res["decode_chunk_size"] = paddle.full([1], 0, dtype="int32") + res["decode_tmp_workspace"] = paddle.full( + [max_batch_size * decoder_step_token_num, max_num_chunk, num_heads * head_dim], 0, dtype=dtype + ) + res["decode_tmp_m"] = paddle.full( + [max_batch_size * decoder_step_token_num, max_num_chunk, num_heads], 0, dtype="float32" + ) + res["decode_tmp_d"] = paddle.full( + [max_batch_size * decoder_step_token_num, max_num_chunk, num_heads], 0, dtype="float32" + ) + + return res + + +class DecodeUnifiedAttentionBackend(AttentionBackend): + """ + DecodeUnifiedAttention backend implementation. + """ + + __infer_dynamic_dims_fields__ = ["attention_metadata"] + attention_metadata: DecodeUnifiedAttentionMetadata + + def __init__( + self, + fd_config: FDConfig, + kv_num_heads: int, + num_heads: int, + head_dim: int, + encoder_block_shape_q: int = -1, + decoder_block_shape_q: int = -1, + ) -> None: + """ + AppendAttentionBackend __init__ + """ + super().__init__() + self.max_seq_len = fd_config.model_config.max_model_len + self.causal = getattr(fd_config.model_config, "causal", True) + + self.kv_num_heads = kv_num_heads + self.num_heads = num_heads + self.group_size: int = self.num_heads // self.kv_num_heads + self.head_dim = fd_config.model_config.head_dim + self.attn_outputsize_tp = self.num_heads * self.head_dim + self.block_size = fd_config.cache_config.block_size + self.num_layers: int = fd_config.model_config.num_hidden_layers + + self.speculative_method = fd_config.speculative_config.method + self.use_speculate = self.speculative_method is not None + self.speculate_max_draft_token_num = fd_config.speculative_config.num_speculative_tokens + if not self.use_speculate: + self.speculate_max_draft_token_num = 0 + self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp" + self.num_layers_draft_model: int = int(fd_config.speculative_config.method == SpecMethod.MTP) + + self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode + + self.start_layer_index: int = fd_config.model_config.start_layer_index + + self.rank, self.device_id = init_rank_and_device_id(fd_config) + + self.rope_3d: bool = fd_config.enable_rope_3d_runtime + if fd_config.speculative_config.model_type != "main": + self.rope_3d = False + # Note(ZKK): here must be consistent with append_attn_backend.py + self.max_tokens_per_batch: int = self.speculate_max_draft_token_num + 1 + + def init_attention_metadata(self, forward_meta: ForwardMeta): + """Initialize attntion metadata hence all layers in the forward pass can reuse it.""" + metadata = DecodeUnifiedAttentionMetadata() + metadata._dtype = paddle.get_default_dtype() + if metadata._dtype == "bfloat16": + metadata._fuse_kernel_compute_dtype = "bf16" + elif metadata._dtype == "float16": + metadata._fuse_kernel_compute_dtype = "fp16" + elif metadata._dtype == "float32": + metadata._fuse_kernel_compute_dtype = "fp32" + + # pd_disaggregation + metadata.kv_signal_data_list = [None] * self.num_layers + if self.pd_disaggregation_mode == "per_chunk": + if not self.keep_pd_step_flag and not forward_meta.is_dummy_or_profile_run: + init_kv_signal_per_query( + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_this_time, + forward_meta.seq_lens_decoder, + self.rank, + self.num_layers + self.num_layers_draft_model, + ) + elif self.pd_disaggregation_mode == "per_query": + metadata.kv_signal_metadata = open_shm_and_get_meta_signal( + self.rank, int(self.device_id), self.keep_pd_step_flag + ) + + self.attention_metadata = metadata + + def get_attention_meta(self) -> AttentionMetadata: + """get_attention_meta""" + return self.attention_metadata + + def _get_identity_rotary_embs(self, original_rotary_embs: paddle.Tensor) -> paddle.Tensor: + """ + Create identity rotary embeddings (cos=1, sin=0) that make RoPE a no-op. + + This is used when RoPE has already been applied externally (e.g., by PaddleFormers). + The identity transformation ensures: x * cos(0) + y * sin(0) = x, preserving the input. + + NOTE: Shape can change between prefill/decode, so we check if cached shape matches. + """ + # Check if we need to recreate (shape mismatch or not cached) + need_recreate = ( + not hasattr(self, "_identity_rotary_embs") + or self._identity_rotary_embs is None + or self._identity_rotary_embs.shape != original_rotary_embs.shape + ) + + if need_recreate: + # Create identity RoPE: cos=1, sin=0 + identity = paddle.zeros_like(original_rotary_embs) + identity[0] = 1.0 # cos = 1 + identity[1] = 0.0 # sin = 0 + self._identity_rotary_embs = identity + + return self._identity_rotary_embs + + def get_kv_cache_shape( + self, + max_num_blocks: int, + kv_cache_quant_type: str = None, + ): + """ + Calculate kv cache shape + """ + key_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim] + if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp": + key_cache_shape[-1] = self.head_dim // 2 + value_cache_shape = key_cache_shape + return key_cache_shape, value_cache_shape + + def forward_mixed( + self, + q: paddle.Tensor, + k: paddle.Tensor, + v: paddle.Tensor, + qkv: paddle.Tensor, + compressed_kv: paddle.Tensor, + k_pe: paddle.Tensor, + layer: Attention, + forward_meta: ForwardMeta, + ) -> paddle.Tensor: + """ + forward_mixed + """ + metadata = self.attention_metadata + + rope_already_applied = getattr(forward_meta, "rope_already_applied", False) + if rope_already_applied and forward_meta.rotary_embs is not None: + forward_meta.rotary_embs = self._get_identity_rotary_embs(forward_meta.rotary_embs) + + norm_after_rope_in_kernel = not getattr(layer, "qk_norm_before_rope", False) + q_norm_weight = getattr(layer, "q_norm_weight", None) if norm_after_rope_in_kernel else None + k_norm_weight = getattr(layer, "k_norm_weight", None) if norm_after_rope_in_kernel else None + + if self.rope_3d: + assert len(forward_meta.rotary_embs.shape) == 6 + else: + assert len(forward_meta.rotary_embs.shape) == 5 + if layer.use_neox_rotary_style: + assert forward_meta.rotary_embs.shape[0:4] == [2, 1, self.max_seq_len, 1] + # 128 is qwen3 + # 32 is glm + # 64 is gpt-oss + assert forward_meta.rotary_embs.shape[4] in [128, 32, 64] + + if self.pd_disaggregation_mode == "per_query": + metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise( + metadata.kv_signal_metadata, + layer.layer_id + self.start_layer_index, + ) + cache_quant_type_str = getattr(layer, "cache_quant_type_str", "none") + if cache_quant_type_str == "block_wise_fp8": + cache_k = forward_meta.caches[4 * layer.layer_id] + cache_v = forward_meta.caches[4 * layer.layer_id + 1] + cache_k_scales = forward_meta.caches[4 * layer.layer_id + 2] + cache_v_scales = forward_meta.caches[4 * layer.layer_id + 3] + else: + cache_k = forward_meta.caches[2 * layer.layer_id] + cache_v = forward_meta.caches[2 * layer.layer_id + 1] + cache_k_scales = getattr(layer, "cache_k_scale", None) + cache_v_scales = getattr(layer, "cache_v_scale", None) + + if layer.layer_id == 0: + config_for_attention( + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.decode_block_indices, + forward_meta.decode_num_blocks, + forward_meta.decode_chunk_size, + forward_meta.max_len_tensor_cpu, + getattr(layer, "cache_quant_type_str", "none"), + self.group_size, + self.kv_num_heads, + self.max_tokens_per_batch, + ) + qkv_out = decoder_write_cache_with_rope( + qkv, + cache_k, + cache_v, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.batch_id_per_token, + forward_meta.cu_seqlens_q, + forward_meta.block_tables, + forward_meta.max_len_tensor_cpu, + forward_meta.rotary_embs, + layer.qkv_bias, + cache_k_scales, + cache_v_scales, + getattr(layer, "cache_k_out_scale", None), + getattr(layer, "cache_v_out_scale", None), + getattr(layer, "cache_k_zp", None), + getattr(layer, "cache_v_zp", None), + metadata.kv_signal_data_list[layer.layer_id], + q_norm_weight, + k_norm_weight, + getattr(layer, "rms_norm_eps", 1e-6), + getattr(layer, "cache_quant_type_str", "none"), + layer.use_neox_rotary_style, + self.rope_3d, + self.max_seq_len, + getattr(layer, "quant_max_bound", 0.0), + getattr(layer, "quant_min_bound", 0.0), + self.speculative_method is not None, + ) + res_decoder = paddle.empty( + [qkv.shape[0], self.num_heads * self.head_dim], + dtype=qkv.dtype, + ) + decode_unified_attention( + qkv_out, + cache_k, + cache_v, + forward_meta.decode_tmp_workspace, + forward_meta.decode_tmp_m, + forward_meta.decode_tmp_d, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.batch_id_per_token, + forward_meta.cu_seqlens_q, + forward_meta.block_tables, + forward_meta.decode_block_indices, + forward_meta.decode_num_blocks, + forward_meta.decode_chunk_size, + forward_meta.max_len_tensor_cpu, + forward_meta.attn_mask, + cache_k_scales, + cache_v_scales, + getattr(layer, "cache_k_out_scale", None), + getattr(layer, "cache_v_out_scale", None), + getattr(layer, "cache_k_zp", None), + getattr(layer, "cache_v_zp", None), + forward_meta.attn_mask_offsets, + getattr(layer, "sinks", None), + res_decoder, + getattr(layer, "cache_quant_type_str", "none"), + self.max_seq_len, + getattr(layer, "quant_max_bound", 0.0), + getattr(layer, "quant_min_bound", 0.0), + self.speculate_max_draft_token_num + 1, + self.causal, + ) + return res_decoder diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index 727a27d0f48..fe0216745ce 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -43,6 +43,9 @@ ) from fastdeploy.model_executor.layers.attention.ops import ( append_attention, + config_for_attention, + decode_unified_attention, + decoder_write_cache_with_rope, get_attn_mask_q, get_block_shape_and_split_kv_block, gqa_rope_write_cache, @@ -274,6 +277,7 @@ def __init__( self.rope_3d = False # Note(ZKK): here must be consistent with append_attn_backend.py self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", 1024)) + self.max_tokens_per_batch: int = self.speculate_max_draft_token_num + 1 if FLASH_ATTN_VERSION is None: init_flash_attn_version() @@ -416,6 +420,20 @@ def forward_mixed( ) else: forward_meta.attn_mask_q = None + if envs.USE_DECODE_UNIFIED_ATTENTION: + config_for_attention( + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.decode_block_indices, + forward_meta.decode_num_blocks, + forward_meta.decode_chunk_size, + forward_meta.max_len_tensor_cpu, + getattr(layer, "cache_quant_type_str", "none"), + self.group_size, + self.kv_num_heads, + self.max_tokens_per_batch, + ) use_fa_do_prefill = forward_meta.max_len_tensor_cpu[1].item() > 0 @@ -470,73 +488,148 @@ def forward_mixed( head_dim=self.head_dim, )[0].reshape([-1, self.attn_outputsize_tp]) - res_decoder = append_attention( - qkv, - cache_k, - cache_v, - forward_meta.seq_lens_encoder, - forward_meta.seq_lens_decoder, - forward_meta.seq_lens_this_time, - forward_meta.batch_id_per_token, - forward_meta.cu_seqlens_q, - forward_meta.block_tables, - forward_meta.encoder_batch_ids, - forward_meta.encoder_tile_ids_per_batch, - forward_meta.encoder_num_blocks_x_cpu, - forward_meta.kv_batch_ids, - forward_meta.kv_tile_ids_per_batch, - forward_meta.kv_num_blocks_x_cpu, - forward_meta.decoder_batch_ids, - forward_meta.decoder_tile_ids_per_batch, - forward_meta.decoder_num_blocks_cpu, - forward_meta.max_len_tensor_cpu_decoder if use_fa_do_prefill else forward_meta.max_len_tensor_cpu, - forward_meta.rotary_embs, - forward_meta.attn_mask, - layer.qkv_bias, - layer.qkv_scale, - cache_k_scales, - cache_v_scales, - getattr(layer, "cache_k_out_scale", None), - getattr(layer, "cache_v_out_scale", None), - getattr(layer, "cache_k_zp", None), - getattr(layer, "cache_v_zp", None), - layer.linear_shift, - layer.linear_smooth, - forward_meta.attn_mask_offsets, - metadata.kv_signal_data_list[layer.layer_id], - q_norm_weight, - k_norm_weight, - getattr(layer, "sinks", None), - getattr(layer, "rms_norm_eps", 1e-6), - metadata._fuse_kernel_compute_dtype, - getattr(layer, "cache_quant_type_str", "none"), - layer.use_neox_rotary_style, - self.rope_3d, - self.max_seq_len, - getattr(layer, "quant_max_bound", 0.0), - getattr(layer, "quant_min_bound", 0.0), - getattr(layer, "out_scale", -1.0), - self.encoder_block_shape_q, - self.decoder_block_shape_q, - self.max_partition_size, - self.max_seq_len, - self.speculate_max_draft_token_num + 1, - self.causal, - self.speculative_method is not None, - ) - - if use_fa_do_prefill: - merge_prefill_decode_output( - res_encoder, - res_decoder, + if envs.USE_DECODE_UNIFIED_ATTENTION: + qkv_out = decoder_write_cache_with_rope( + qkv, + cache_k, + cache_v, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.batch_id_per_token, + forward_meta.cu_seqlens_q, + forward_meta.block_tables, + forward_meta.max_len_tensor_cpu, + forward_meta.rotary_embs, + layer.qkv_bias, + cache_k_scales, + cache_v_scales, + getattr(layer, "cache_k_out_scale", None), + getattr(layer, "cache_v_out_scale", None), + getattr(layer, "cache_k_zp", None), + getattr(layer, "cache_v_zp", None), + metadata.kv_signal_data_list[layer.layer_id], + q_norm_weight, + k_norm_weight, + getattr(layer, "rms_norm_eps", 1e-6), + getattr(layer, "cache_quant_type_str", "none"), + layer.use_neox_rotary_style, + self.rope_3d, + self.max_seq_len, + getattr(layer, "quant_max_bound", 0.0), + getattr(layer, "quant_min_bound", 0.0), + self.speculative_method is not None, + ) + if use_fa_do_prefill: + res_decoder = res_encoder + else: + res_decoder = paddle.empty( + [qkv.shape[0], self.num_heads * self.head_dim], + dtype=qkv.dtype, + ) + decode_unified_attention( + qkv_out, + cache_k, + cache_v, + forward_meta.decode_tmp_workspace, + forward_meta.decode_tmp_m, + forward_meta.decode_tmp_d, forward_meta.seq_lens_encoder, forward_meta.seq_lens_decoder, forward_meta.seq_lens_this_time, + forward_meta.batch_id_per_token, forward_meta.cu_seqlens_q, - self.num_heads, - self.head_dim, + forward_meta.block_tables, + forward_meta.decode_block_indices, + forward_meta.decode_num_blocks, + forward_meta.decode_chunk_size, + forward_meta.max_len_tensor_cpu, + forward_meta.attn_mask, + cache_k_scales, + cache_v_scales, + getattr(layer, "cache_k_out_scale", None), + getattr(layer, "cache_v_out_scale", None), + getattr(layer, "cache_k_zp", None), + getattr(layer, "cache_v_zp", None), + forward_meta.attn_mask_offsets, + getattr(layer, "sinks", None), + res_decoder, + getattr(layer, "cache_quant_type_str", "none"), + self.max_seq_len, + getattr(layer, "quant_max_bound", 0.0), + getattr(layer, "quant_min_bound", 0.0), self.speculate_max_draft_token_num + 1, + self.causal, ) - return res_encoder - else: return res_decoder + else: + res_decoder = append_attention( + qkv, + cache_k, + cache_v, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.batch_id_per_token, + forward_meta.cu_seqlens_q, + forward_meta.block_tables, + forward_meta.encoder_batch_ids, + forward_meta.encoder_tile_ids_per_batch, + forward_meta.encoder_num_blocks_x_cpu, + forward_meta.kv_batch_ids, + forward_meta.kv_tile_ids_per_batch, + forward_meta.kv_num_blocks_x_cpu, + forward_meta.decoder_batch_ids, + forward_meta.decoder_tile_ids_per_batch, + forward_meta.decoder_num_blocks_cpu, + forward_meta.max_len_tensor_cpu_decoder if use_fa_do_prefill else forward_meta.max_len_tensor_cpu, + forward_meta.rotary_embs, + forward_meta.attn_mask, + layer.qkv_bias, + layer.qkv_scale, + cache_k_scales, + cache_v_scales, + getattr(layer, "cache_k_out_scale", None), + getattr(layer, "cache_v_out_scale", None), + getattr(layer, "cache_k_zp", None), + getattr(layer, "cache_v_zp", None), + layer.linear_shift, + layer.linear_smooth, + forward_meta.attn_mask_offsets, + metadata.kv_signal_data_list[layer.layer_id], + q_norm_weight, + k_norm_weight, + getattr(layer, "sinks", None), + getattr(layer, "rms_norm_eps", 1e-6), + metadata._fuse_kernel_compute_dtype, + getattr(layer, "cache_quant_type_str", "none"), + layer.use_neox_rotary_style, + self.rope_3d, + self.max_seq_len, + getattr(layer, "quant_max_bound", 0.0), + getattr(layer, "quant_min_bound", 0.0), + getattr(layer, "out_scale", -1.0), + self.encoder_block_shape_q, + self.decoder_block_shape_q, + self.max_partition_size, + self.max_seq_len, + self.speculate_max_draft_token_num + 1, + self.causal, + self.speculative_method is not None, + ) + + if use_fa_do_prefill: + merge_prefill_decode_output( + res_encoder, + res_decoder, + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.cu_seqlens_q, + self.num_heads, + self.head_dim, + self.speculate_max_draft_token_num + 1, + ) + return res_encoder + else: + return res_decoder diff --git a/fastdeploy/model_executor/layers/attention/ops/__init__.py b/fastdeploy/model_executor/layers/attention/ops/__init__.py index e0175573fa3..d5d6c45afa7 100644 --- a/fastdeploy/model_executor/layers/attention/ops/__init__.py +++ b/fastdeploy/model_executor/layers/attention/ops/__init__.py @@ -15,6 +15,9 @@ """ from .append_attention import append_attention, append_attention_with_output +from .config_for_attention import config_for_attention +from .decode_unified_attention import decode_unified_attention +from .decoder_write_cache_with_rope import decoder_write_cache_with_rope from .flash_attn_v4 import flash_attn_v4 from .flash_mask_attention import flash_mask_attention from .get_attn_mask_q import get_attn_mask_q @@ -37,4 +40,7 @@ "flash_attn_v4", "flash_mask_attention", "get_attn_mask_q", + "config_for_attention", + "decoder_write_cache_with_rope", + "decode_unified_attention", ] diff --git a/fastdeploy/model_executor/layers/attention/ops/config_for_attention.py b/fastdeploy/model_executor/layers/attention/ops/config_for_attention.py new file mode 100644 index 00000000000..d8226aad4b1 --- /dev/null +++ b/fastdeploy/model_executor/layers/attention/ops/config_for_attention.py @@ -0,0 +1,58 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import paddle + +from fastdeploy.platforms import current_platform + +if current_platform.is_cuda(): + from fastdeploy.model_executor.ops.gpu import ( + config_for_attention as config_for_attention_cuda, + ) + + +def config_for_attention( + seq_lens_encoder: paddle.Tensor, + seq_lens_decoder: paddle.Tensor, + seq_lens_this_time: paddle.Tensor, + block_indices: paddle.Tensor, + num_blocks: paddle.Tensor, + chunk_size: paddle.Tensor, + max_len_tensor_cpu: paddle.Tensor, + cache_quant_type: str = "none", + group_size: int = 1, + kv_num_heads: int = 1, + max_tokens_per_batch: int = 1, +): + """ + append_attention + """ + if current_platform.is_cuda(): + config_for_attention_cuda( + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + block_indices, + num_blocks, + chunk_size, + max_len_tensor_cpu, + cache_quant_type, + group_size, + kv_num_heads, + max_tokens_per_batch, + ) + else: + raise NotImplementedError diff --git a/fastdeploy/model_executor/layers/attention/ops/decode_unified_attention.py b/fastdeploy/model_executor/layers/attention/ops/decode_unified_attention.py new file mode 100644 index 00000000000..fedfc33dc7c --- /dev/null +++ b/fastdeploy/model_executor/layers/attention/ops/decode_unified_attention.py @@ -0,0 +1,105 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from typing import Optional + +import paddle + +from fastdeploy.platforms import current_platform + +if current_platform.is_cuda(): + from fastdeploy.model_executor.ops.gpu import ( + decode_unified_attention as decode_unified_attention_cuda, + ) + + +def decode_unified_attention( + qkv: paddle.Tensor, + key_cache: paddle.Tensor, + value_cache: paddle.Tensor, + tmp_workspace: paddle.Tensor, + tmp_m: paddle.Tensor, + tmp_d: paddle.Tensor, + seq_lens_encoder: paddle.Tensor, + seq_lens_decoder: paddle.Tensor, + seq_lens_this_time: paddle.Tensor, + batch_id_per_token: paddle.Tensor, + cu_seqlens_q: paddle.Tensor, + block_tables: paddle.Tensor, + block_indices: paddle.Tensor, + num_blocks: paddle.Tensor, + chunk_size: paddle.Tensor, + set_max_lengths: paddle.Tensor, + attn_mask: Optional[paddle.Tensor] = None, + k_quant_scale: Optional[paddle.Tensor] = None, + v_quant_scale: Optional[paddle.Tensor] = None, + k_dequant_scale: Optional[paddle.Tensor] = None, + v_dequant_scale: Optional[paddle.Tensor] = None, + cache_k_zp: Optional[paddle.Tensor] = None, + cache_v_zp: Optional[paddle.Tensor] = None, + mask_offset: Optional[paddle.Tensor] = None, + sinks: Optional[paddle.Tensor] = None, + fmha_out: Optional[paddle.Tensor] = None, + cache_quant_type: str = "none", + max_input_length: int = 0, + quant_max_bound: float = 0.0, + quant_min_bound: float = 0.0, + max_tokens_per_batch: int = 1, + causal: bool = True, + sliding_window: int = 0, +) -> paddle.Tensor: + """ + append_attention + """ + if current_platform.is_cuda(): + out = decode_unified_attention_cuda( + qkv, + key_cache, + value_cache, + tmp_workspace, + tmp_m, + tmp_d, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + block_tables, + block_indices, + num_blocks, + chunk_size, + set_max_lengths, + attn_mask, + k_quant_scale, + v_quant_scale, + k_dequant_scale, + v_dequant_scale, + cache_k_zp, + cache_v_zp, + mask_offset, + sinks, + fmha_out, + cache_quant_type, + max_input_length, + quant_max_bound, + quant_min_bound, + max_tokens_per_batch, + causal, + sliding_window, + ) + return out + else: + raise NotImplementedError diff --git a/fastdeploy/model_executor/layers/attention/ops/decoder_write_cache_with_rope.py b/fastdeploy/model_executor/layers/attention/ops/decoder_write_cache_with_rope.py new file mode 100644 index 00000000000..b10f6cd1bf6 --- /dev/null +++ b/fastdeploy/model_executor/layers/attention/ops/decoder_write_cache_with_rope.py @@ -0,0 +1,97 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from typing import Optional + +import paddle + +from fastdeploy.platforms import current_platform + +if current_platform.is_cuda(): + from fastdeploy.model_executor.ops.gpu import ( + decoder_write_cache_with_rope as decoder_write_cache_with_rope_cuda, + ) + + +def decoder_write_cache_with_rope( + qkv: paddle.Tensor, + key_cache: paddle.Tensor, + value_cache: paddle.Tensor, + seq_lens_encoder: paddle.Tensor, + seq_lens_decoder: paddle.Tensor, + seq_lens_this_time: paddle.Tensor, + batch_id_per_token: paddle.Tensor, + cu_seqlens_q: paddle.Tensor, + block_tables: paddle.Tensor, + set_max_lengths: paddle.Tensor, + rotary_embs: Optional[paddle.Tensor] = None, + qkv_bias: Optional[paddle.Tensor] = None, + k_quant_scale: Optional[paddle.Tensor] = None, + v_quant_scale: Optional[paddle.Tensor] = None, + k_dequant_scale: Optional[paddle.Tensor] = None, + v_dequant_scale: Optional[paddle.Tensor] = None, + cache_k_zp: Optional[paddle.Tensor] = None, + cache_v_zp: Optional[paddle.Tensor] = None, + kv_signal_data: Optional[paddle.Tensor] = None, + q_norm_weight: Optional[paddle.Tensor] = None, + k_norm_weight: Optional[paddle.Tensor] = None, + rms_norm_eps: float = 1e-6, + cache_quant_type: str = "none", + use_neox_rotary_style: bool = False, + rope_3d: bool = False, + max_input_length: int = 0, + quant_max_bound: float = 0.0, + quant_min_bound: float = 0.0, + speculate_decoder: bool = False, +) -> paddle.Tensor: + """ + append_attention + """ + if current_platform.is_cuda(): + qkv_out = decoder_write_cache_with_rope_cuda( + qkv, + key_cache, + value_cache, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + block_tables, + set_max_lengths, + rotary_embs, + qkv_bias, + k_quant_scale, + v_quant_scale, + k_dequant_scale, + v_dequant_scale, + cache_k_zp, + cache_v_zp, + kv_signal_data, + q_norm_weight, + k_norm_weight, + rms_norm_eps, + cache_quant_type, + use_neox_rotary_style, + rope_3d, + max_input_length, + quant_max_bound, + quant_min_bound, + speculate_decoder, + ) + return qkv_out + else: + raise NotImplementedError diff --git a/fastdeploy/platforms/base.py b/fastdeploy/platforms/base.py index bb30663492a..b2eceb0aeb8 100644 --- a/fastdeploy/platforms/base.py +++ b/fastdeploy/platforms/base.py @@ -30,6 +30,7 @@ class _Backend(enum.Enum): PLAS_ATTN = enum.auto() HPU_ATTN = enum.auto() FLASH_MASK_ATTN = enum.auto() + DECODE_UNIFIED_ATTN = enum.auto() class Platform: diff --git a/fastdeploy/platforms/cuda.py b/fastdeploy/platforms/cuda.py index acdf40d8fdb..e9a3cb61574 100644 --- a/fastdeploy/platforms/cuda.py +++ b/fastdeploy/platforms/cuda.py @@ -73,8 +73,11 @@ def get_attention_backend_cls(cls, selected_backend: _Backend): elif selected_backend == _Backend.FLASH_MASK_ATTN: logger.info("Using FLASH MASK ATTN backend.") return "fastdeploy.model_executor.layers.attention.FlashMaskAttentionBackend" + elif selected_backend == _Backend.DECODE_UNIFIED_ATTN: + logger.info("Using DECODE UNIFIED ATTN backend.") + return "fastdeploy.model_executor.layers.attention.DecodeUnifiedAttentionBackend" else: raise ValueError( "Invalid attention backend you specified.\n" - "Now only support [NATIVE_ATTN, MLA_ATTN, APPEND_ATTN] in cuda place." + "Now only support [NATIVE_ATTN, MLA_ATTN, APPEND_ATTN, DECODE_UNIFIED_ATTN, FLASH_ATTN] in cuda place." ) diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index cf614ad399e..9cf9ee1785a 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -23,6 +23,7 @@ import paddle from paddleformers.utils.log import logger +from fastdeploy import envs from fastdeploy.engine.request import Request, RequestType from fastdeploy.inter_communicator import IPCSignal from fastdeploy.model_executor.forward_meta import ForwardMeta @@ -364,42 +365,68 @@ def _initialize_attn_backend( encoder_block_shape_q = 64 decoder_block_shape_q = 16 - self.model_inputs["decoder_batch_ids"] = paddle.zeros_like(self.target_model_inputs["decoder_batch_ids"]) - self.model_inputs["decoder_tile_ids_per_batch"] = paddle.zeros_like( - self.target_model_inputs["decoder_tile_ids_per_batch"] - ) - if current_platform.is_xpu() or current_platform.is_maca(): - self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like( - self.target_model_inputs["decoder_num_blocks_cpu"] - ).cpu() - else: - self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like( - self.target_model_inputs["decoder_num_blocks_cpu"] - ).pin_memory() - self.model_inputs["decoder_num_blocks_device"] = paddle.zeros_like( - self.target_model_inputs["decoder_num_blocks_device"] - ) - self.model_inputs["decoder_chunk_size_device"] = paddle.zeros_like( - self.target_model_inputs["decoder_chunk_size_device"] - ) self.model_inputs["max_len_tensor_cpu"] = paddle.zeros_like( self.target_model_inputs["max_len_tensor_cpu"] ).cpu() + if envs.FD_ATTENTION_BACKEND == "DECODE_UNIFIED_ATTN": + self.model_inputs["decoder_batch_ids"] = None + self.model_inputs["decoder_tile_ids_per_batch"] = None + self.model_inputs["decoder_num_blocks_cpu"] = None + self.model_inputs["decoder_num_blocks_device"] = None + self.model_inputs["decoder_chunk_size_device"] = None + self.model_inputs["encoder_batch_ids"] = None + self.model_inputs["encoder_tile_ids_per_batch"] = None + self.model_inputs["encoder_num_blocks_x_cpu"] = None + self.model_inputs["kv_batch_ids"] = None + self.model_inputs["kv_tile_ids_per_batch"] = None + self.model_inputs["kv_num_blocks_x_cpu"] = None + else: + self.model_inputs["decoder_batch_ids"] = paddle.zeros_like(self.target_model_inputs["decoder_batch_ids"]) + self.model_inputs["decoder_tile_ids_per_batch"] = paddle.zeros_like( + self.target_model_inputs["decoder_tile_ids_per_batch"] + ) + if current_platform.is_xpu() or current_platform.is_maca(): + self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like( + self.target_model_inputs["decoder_num_blocks_cpu"] + ).cpu() + else: + self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like( + self.target_model_inputs["decoder_num_blocks_cpu"] + ).pin_memory() + self.model_inputs["decoder_num_blocks_device"] = paddle.zeros_like( + self.target_model_inputs["decoder_num_blocks_device"] + ) + self.model_inputs["decoder_chunk_size_device"] = paddle.zeros_like( + self.target_model_inputs["decoder_chunk_size_device"] + ) + self.model_inputs["encoder_batch_ids"] = paddle.zeros_like(self.target_model_inputs["encoder_batch_ids"]) + self.model_inputs["encoder_tile_ids_per_batch"] = paddle.zeros_like( + self.target_model_inputs["encoder_tile_ids_per_batch"] + ) + self.model_inputs["encoder_num_blocks_x_cpu"] = paddle.zeros_like( + self.target_model_inputs["encoder_num_blocks_x_cpu"] + ).cpu() + self.model_inputs["kv_batch_ids"] = paddle.zeros_like(self.target_model_inputs["kv_batch_ids"]) + self.model_inputs["kv_tile_ids_per_batch"] = paddle.zeros_like( + self.target_model_inputs["kv_tile_ids_per_batch"] + ) + self.model_inputs["kv_num_blocks_x_cpu"] = paddle.zeros_like( + self.target_model_inputs["kv_num_blocks_x_cpu"] + ).cpu() + # Decode attention split ops buffers - self.model_inputs["encoder_batch_ids"] = paddle.zeros_like(self.target_model_inputs["encoder_batch_ids"]) - self.model_inputs["encoder_tile_ids_per_batch"] = paddle.zeros_like( - self.target_model_inputs["encoder_tile_ids_per_batch"] - ) - self.model_inputs["encoder_num_blocks_x_cpu"] = paddle.zeros_like( - self.target_model_inputs["encoder_num_blocks_x_cpu"] - ).cpu() - self.model_inputs["kv_batch_ids"] = paddle.zeros_like(self.target_model_inputs["kv_batch_ids"]) - self.model_inputs["kv_tile_ids_per_batch"] = paddle.zeros_like( - self.target_model_inputs["kv_tile_ids_per_batch"] - ) - self.model_inputs["kv_num_blocks_x_cpu"] = paddle.zeros_like( - self.target_model_inputs["kv_num_blocks_x_cpu"] - ).cpu() + if ( + "decode_block_indices" in self.target_model_inputs + and self.target_model_inputs["decode_block_indices"] is not None + ): + self.model_inputs["decode_block_indices"] = paddle.zeros_like( + self.target_model_inputs["decode_block_indices"] + ) + self.model_inputs["decode_num_blocks"] = paddle.zeros_like(self.target_model_inputs["decode_num_blocks"]) + self.model_inputs["decode_chunk_size"] = paddle.zeros_like(self.target_model_inputs["decode_chunk_size"]) + self.model_inputs["decode_tmp_workspace"] = self.target_model_inputs["decode_tmp_workspace"] + self.model_inputs["decode_tmp_m"] = self.target_model_inputs["decode_tmp_m"] + self.model_inputs["decode_tmp_d"] = self.target_model_inputs["decode_tmp_d"] # Get the attention backend attn_cls = get_attention_backend() diff --git a/fastdeploy/spec_decode/mtp_cuda.py b/fastdeploy/spec_decode/mtp_cuda.py index eb74b371069..d40b9f9229e 100644 --- a/fastdeploy/spec_decode/mtp_cuda.py +++ b/fastdeploy/spec_decode/mtp_cuda.py @@ -122,6 +122,12 @@ def _initialize_forward_meta(self, step_use_cudagraph: bool = False, is_dummy_ru kv_tile_ids_per_batch=self.model_inputs["kv_tile_ids_per_batch"], kv_num_blocks_x_cpu=self.model_inputs["kv_num_blocks_x_cpu"], attn_mask_offsets=self.model_inputs["attn_mask_offsets"] if self.use_attn_mask_offset else None, + decode_block_indices=self.model_inputs["decode_block_indices"], + decode_num_blocks=self.model_inputs["decode_num_blocks"], + decode_chunk_size=self.model_inputs["decode_chunk_size"], + decode_tmp_workspace=self.model_inputs["decode_tmp_workspace"], + decode_tmp_m=self.model_inputs["decode_tmp_m"], + decode_tmp_d=self.model_inputs["decode_tmp_d"], ) # Initialzie attention meta data diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 3e06c13927f..65c1e43dc13 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -45,6 +45,9 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionBackend, ) +from fastdeploy.model_executor.layers.attention.decode_unified_attention_backend import ( + allocate_decode_unified_related_buffer, +) from fastdeploy.model_executor.layers.attention.dsa_attention_backend import ( DSAAttentionBackend, ) @@ -1481,6 +1484,15 @@ def initialize_forward_meta(self, is_dummy_or_profile_run=False): routing_replay_table=routing_replay_table, ) + # Decode attention split ops buffers (assigned after construction due to ForwardMeta __getattr__) + if "decode_block_indices" in self.share_inputs: + self.forward_meta.decode_block_indices = self.share_inputs["decode_block_indices"] + self.forward_meta.decode_num_blocks = self.share_inputs["decode_num_blocks"] + self.forward_meta.decode_chunk_size = self.share_inputs["decode_chunk_size"] + self.forward_meta.decode_tmp_workspace = self.share_inputs["decode_tmp_workspace"] + self.forward_meta.decode_tmp_m = self.share_inputs["decode_tmp_m"] + self.forward_meta.decode_tmp_d = self.share_inputs["decode_tmp_d"] + dist_status = self.collect_distributed_status() if_only_decode = dist_status.if_only_decode @@ -1727,8 +1739,15 @@ def _initialize_attn_backend(self) -> None: num_heads=num_heads, kv_num_heads=self.model_config.kv_num_heads, block_size=self.fd_config.cache_config.block_size, + head_dim=head_dim, + dtype=self.model_config.dtype, ) - res_buffer = allocate_launch_related_buffer(**buffer_kwargs) + + if envs.FD_ATTENTION_BACKEND == "DECODE_UNIFIED_ATTN": + res_buffer = allocate_decode_unified_related_buffer(**buffer_kwargs) + else: + res_buffer = allocate_launch_related_buffer(**buffer_kwargs) + self.share_inputs.update(res_buffer) if int(os.getenv("USE_TBO", "0")) == 1: diff --git a/fastdeploy/worker/input_batch.py b/fastdeploy/worker/input_batch.py index 56dc4ae9aa8..5be2b229291 100644 --- a/fastdeploy/worker/input_batch.py +++ b/fastdeploy/worker/input_batch.py @@ -208,6 +208,13 @@ def init_share_inputs(self): self.kv_batch_ids = None self.kv_tile_ids_per_batch = None self.kv_num_blocks_x_cpu = None # CPU + # Decode unified attention split ops buffers (initialized by _initialize_attn_backend) + self.decode_block_indices = None + self.decode_num_blocks = None + self.decode_chunk_size = None + self.decode_tmp_workspace = None + self.decode_tmp_m = None + self.decode_tmp_d = None # Initialize thinking related buffers self.enable_thinking = paddle.full(shape=[max_num_seqs, 1], fill_value=True, dtype="bool") @@ -857,6 +864,13 @@ def init_share_inputs(self): self.kv_batch_ids = None self.kv_tile_ids_per_batch = None self.kv_num_blocks_x_cpu = None # CPU + # Decode attention split ops buffers + self.decode_block_indices = None + self.decode_num_blocks = None + self.decode_chunk_size = None + self.decode_tmp_workspace = None + self.decode_tmp_m = None + self.decode_tmp_d = None # Input tokens self.draft_tokens = paddle.full( diff --git a/tests/e2e/test_ernie_03b_pd_decode_unified_attention.py b/tests/e2e/test_ernie_03b_pd_decode_unified_attention.py new file mode 100644 index 00000000000..b689ca28b35 --- /dev/null +++ b/tests/e2e/test_ernie_03b_pd_decode_unified_attention.py @@ -0,0 +1,422 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Test splitwise deployment: use local_scheduler + router, +# set ENABLE_V1_KVCACHE_SCHEDULER is 1, use ipc to transfer cache. + +import json +import os +import shutil +import signal +import subprocess +import sys +import time + +import pytest +import requests +from utils.serving_utils import ( + FD_API_PORT, + FD_CACHE_QUEUE_PORT, + FD_ENGINE_QUEUE_PORT, + FD_METRICS_PORT, + clean, + get_registered_number, +) + +# Read ports from environment variables; use default values if not set +FD_CONNECTOR_PORT = int(os.getenv("FD_CONNECTOR_PORT", 8433)) +FD_ROUTER_PORT = int(os.getenv("FD_ROUTER_PORT", 8533)) + +# List of ports to clean before and after tests +PORTS_TO_CLEAN = [ + FD_API_PORT, + FD_ENGINE_QUEUE_PORT, + FD_METRICS_PORT, + FD_CACHE_QUEUE_PORT, + FD_CONNECTOR_PORT, + FD_API_PORT + 1, + FD_ENGINE_QUEUE_PORT + 1, + FD_METRICS_PORT + 1, + FD_CACHE_QUEUE_PORT + 1, + FD_CONNECTOR_PORT + 1, + FD_ROUTER_PORT, +] + + +@pytest.fixture(scope="session", autouse=True) +def setup_and_run_server(): + """ + Pytest fixture that runs once per test session: + - Cleans ports before tests + - Starts the API server as a subprocess + - Waits for server port to open (up to 30 seconds) + - Tears down server after all tests finish + """ + print("Pre-test port cleanup...") + clean(PORTS_TO_CLEAN) + + print("log dir clean ") + if os.path.exists("log_router") and os.path.isdir("log_router"): + shutil.rmtree("log_router") + if os.path.exists("log_prefill") and os.path.isdir("log_prefill"): + shutil.rmtree("log_prefill") + if os.path.exists("log_decode") and os.path.isdir("log_decode"): + shutil.rmtree("log_decode") + + base_path = os.getenv("MODEL_PATH") + if base_path: + model_path = os.path.join(base_path, "ERNIE-4.5-0.3B-Paddle") + else: + model_path = "baidu/ERNIE-4.5-0.3B-Paddle" + print(f"model_path: {model_path}") + + base_log_dir = os.getenv("FD_LOG_DIR", "log") + + # router + print("start router...") + env_router = os.environ.copy() + env_router["FD_LOG_DIR"] = os.path.join(base_log_dir, "log_router") + router_log_path = "router.log" + + router_cmd = [ + sys.executable, + "-m", + "fastdeploy.router.launch", + "--port", + str(FD_ROUTER_PORT), + "--splitwise", + ] + + with open(router_log_path, "w") as logfile: + process_router = subprocess.Popen( + router_cmd, + stdout=logfile, + stderr=subprocess.STDOUT, + start_new_session=True, # Enables killing full group via os.killpg + env=env_router, + ) + + # prefill实例 + print("start prefill...") + env_prefill = os.environ.copy() + env_prefill["CUDA_VISIBLE_DEVICES"] = "0" + env_prefill["FD_ATTENTION_BACKEND"] = "FLASH_ATTN" + env_prefill["FLAGS_flash_attn_version"] = "3" + env_prefill["FD_LOG_DIR"] = os.path.join(base_log_dir, "log_prefill") + prefill_log_path = "prefill.log" + prefill_cmd = [ + sys.executable, + "-m", + "fastdeploy.entrypoints.openai.api_server", + "--model", + model_path, + "--port", + str(FD_API_PORT), + "--tensor-parallel-size", + "1", + "--engine-worker-queue-port", + str(FD_ENGINE_QUEUE_PORT), + "--metrics-port", + str(FD_METRICS_PORT), + "--cache-queue-port", + str(FD_CACHE_QUEUE_PORT), + "--max-model-len", + "8192", + "--max-num-seqs", + "20", + "--quantization", + "wint8", + "--splitwise-role", + "prefill", + "--cache-transfer-protocol", + "ipc", + "--pd-comm-port", + str(FD_CONNECTOR_PORT), + "--router", + f"0.0.0.0:{FD_ROUTER_PORT}", + ] + + # Start subprocess in new process group + with open(prefill_log_path, "w") as logfile: + process_prefill = subprocess.Popen( + prefill_cmd, + stdout=logfile, + stderr=subprocess.STDOUT, + start_new_session=True, # Enables killing full group via os.killpg + env=env_prefill, + ) + time.sleep(1) + + # decode实例 + print("start decode...") + env_decode = os.environ.copy() + env_decode["CUDA_VISIBLE_DEVICES"] = "1" + env_decode["FD_ATTENTION_BACKEND"] = "DECODE_UNIFIED_ATTN" + env_decode["FD_LOG_DIR"] = os.path.join(base_log_dir, "log_decode") + decode_log_path = "decode.log" + decode_cmd = [ + sys.executable, + "-m", + "fastdeploy.entrypoints.openai.api_server", + "--model", + model_path, + "--port", + str(FD_API_PORT + 1), + "--tensor-parallel-size", + "1", + "--engine-worker-queue-port", + str(FD_ENGINE_QUEUE_PORT + 1), + "--metrics-port", + str(FD_METRICS_PORT + 1), + "--cache-queue-port", + str(FD_CACHE_QUEUE_PORT + 1), + "--max-model-len", + "8192", + "--max-num-seqs", + "20", + "--quantization", + "wint8", + "--splitwise-role", + "decode", + "--cache-transfer-protocol", + "ipc", + "--pd-comm-port", + str(FD_CONNECTOR_PORT + 1), + "--router", + f"0.0.0.0:{FD_ROUTER_PORT}", + ] + + # Start subprocess in new process group + with open(decode_log_path, "w") as logfile: + process_decode = subprocess.Popen( + decode_cmd, + stdout=logfile, + stderr=subprocess.STDOUT, + start_new_session=True, # Enables killing full group via os.killpg + env=env_decode, + ) + + # Wait up to 300 seconds for API server to be ready + for _ in range(60): + registered_numbers = get_registered_number(f"0.0.0.0:{FD_ROUTER_PORT}") + if registered_numbers["prefill"] >= 1 and registered_numbers["decode"] >= 1: + print("Prefill and decode servers are both online") + break + time.sleep(5) + else: + print("[TIMEOUT] API server failed to start in 5 minutes. Cleaning up...") + try: + os.killpg(process_router.pid, signal.SIGTERM) + os.killpg(process_prefill.pid, signal.SIGTERM) + os.killpg(process_decode.pid, signal.SIGTERM) + clean(PORTS_TO_CLEAN) + except Exception as e: + print(f"Failed to kill process group: {e}") + raise RuntimeError(f"API server did not start on port {FD_API_PORT}") + + yield # Run tests + + print("\n===== Post-test server cleanup... =====") + try: + os.killpg(process_router.pid, signal.SIGTERM) + os.killpg(process_prefill.pid, signal.SIGTERM) + os.killpg(process_decode.pid, signal.SIGTERM) + clean(PORTS_TO_CLEAN) + print(f"Prefill server (pid={process_prefill.pid}) terminated") + print(f"Decode server (pid={process_decode.pid}) terminated") + except Exception as e: + print(f"Failed to terminate API server: {e}") + + +@pytest.fixture(scope="session") +def api_url(request): + """ + Returns the API endpoint URL for chat completions. + """ + return f"http://0.0.0.0:{FD_ROUTER_PORT}/v1/chat/completions" + + +@pytest.fixture(scope="session") +def metrics_url(request): + """ + Returns the metrics endpoint URL. + """ + return f"http://0.0.0.0:{FD_METRICS_PORT}/metrics" + + +@pytest.fixture +def headers(): + """ + Returns common HTTP request headers. + """ + return {"Content-Type": "application/json"} + + +def test_metrics_config(metrics_url): + timeout = 600 + url = metrics_url.replace("metrics", "config-info") + res = requests.get(url, timeout=timeout) + assert res.status_code == 200 + + +def send_request(url, payload, timeout=60): + """ + 发送请求到指定的URL,并返回响应结果。 + """ + headers = { + "Content-Type": "application/json", + } + + try: + res = requests.post(url, headers=headers, json=payload, timeout=timeout) + print("🟢 接收响应中...\n") + return res + except requests.exceptions.Timeout: + print(f"❌ 请求超时(超过 {timeout} 秒)") + return None + except requests.exceptions.RequestException as e: + print(f"❌ 请求失败:{e}") + return None + + +def get_stream_chunks(response): + """解析流式返回,生成chunk List[dict]""" + chunks = [] + + if response.status_code == 200: + for line in response.iter_lines(decode_unicode=True): + if line: + if line.startswith("data: "): + line = line[len("data: ") :] + + if line.strip() == "[DONE]": + break + + try: + chunk = json.loads(line) + chunks.append(chunk) + except Exception as e: + print(f"解析失败: {e}, 行内容: {line}") + else: + print(f"请求失败,状态码: {response.status_code}") + print("返回内容:", response.text) + + return chunks + + +def test_chat_usage_stream(api_url): + """测试流式chat usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "牛顿的三大运动定律是什么?"}, + ], + "max_tokens": 50, + "stream": True, + "stream_options": {"include_usage": True, "continuous_usage_stats": True}, + "metadata": {"min_tokens": 10}, + } + + response = send_request(url=api_url, payload=payload) + chunks = get_stream_chunks(response) + result = "".join([x["choices"][0]["delta"]["content"] for x in chunks[:-1]]) + print("Decode Response:", result) + assert result != "", "结果为空" + usage = chunks[-1]["usage"] + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" + + +def test_chat_usage_non_stream(api_url): + """测试非流式chat usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "牛顿的三大运动定律是什么?"}, + ], + "max_tokens": 50, + "stream": False, + "metadata": {"min_tokens": 10}, + } + + response = send_request(url=api_url, payload=payload).json() + usage = response["usage"] + result = response["choices"][0]["message"]["content"] + assert result != "", "结果为空" + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" + + +def test_non_chat_usage_stream(api_url): + """测试流式非chat usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "prompt": "牛顿的三大运动定律是什么?", + "max_tokens": 50, + "stream": True, + "stream_options": {"include_usage": True, "continuous_usage_stats": True}, + "metadata": {"min_tokens": 10}, + } + api_url = api_url.replace("chat/completions", "completions") + + response = send_request(url=api_url, payload=payload) + chunks = get_stream_chunks(response) + result = "".join([x["choices"][0]["text"] for x in chunks[:-1]]) + print("Decode Response:", result) + assert result != "", "结果为空" + usage = chunks[-1]["usage"] + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" + + +def test_non_chat_usage_non_stream(api_url): + """测试非流式非chat usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "prompt": "牛顿的三大运动定律是什么?", + "max_tokens": 50, + "stream": False, + "metadata": {"min_tokens": 10}, + } + api_url = api_url.replace("chat/completions", "completions") + + response = send_request(url=api_url, payload=payload).json() + usage = response["usage"] + result = response["choices"][0]["text"] + print("Decode Response:", result) + assert result != "", "结果为空" + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" diff --git a/tests/e2e/test_ernie_21b_mtp_decode_unified_attention.py b/tests/e2e/test_ernie_21b_mtp_decode_unified_attention.py new file mode 100644 index 00000000000..0083d70e769 --- /dev/null +++ b/tests/e2e/test_ernie_21b_mtp_decode_unified_attention.py @@ -0,0 +1,381 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import shutil +import signal +import subprocess +import sys +import time + +import pytest +import requests +from utils.serving_utils import ( + FD_API_PORT, + FD_CACHE_QUEUE_PORT, + FD_ENGINE_QUEUE_PORT, + FD_METRICS_PORT, + clean, + is_port_open, +) + +os.environ["FD_ATTENTION_BACKEND"] = "FLASH_ATTN" +os.environ["FLAGS_flash_attn_version"] = "3" +os.environ["USE_DECODE_UNIFIED_ATTENTION"] = "1" + + +@pytest.fixture(scope="session", autouse=True) +def setup_and_run_server(): + """ + Pytest fixture that runs once per test session: + - Cleans ports before tests + - Starts the API server as a subprocess + - Waits for server port to open (up to 30 seconds) + - Tears down server after all tests finish + """ + print("Pre-test port cleanup...") + clean() + + print("log dir clean ") + if os.path.exists("log") and os.path.isdir("log"): + shutil.rmtree("log") + + base_path = os.getenv("MODEL_PATH") + if base_path: + model_path = os.path.join(base_path, "ernie-4_5-21b-a3b-bf16-paddle") + else: + model_path = "./ernie-4_5-21b-a3b-bf16-paddle" + mtp_model_path = os.path.join(model_path, "mtp") + speculative_config = {"method": "mtp", "num_speculative_tokens": 1, "model": mtp_model_path} + + log_path = "server.log" + cmd = [ + sys.executable, + "-m", + "fastdeploy.entrypoints.openai.api_server", + "--model", + model_path, + "--port", + str(FD_API_PORT), + "--tensor-parallel-size", + "2", + "--engine-worker-queue-port", + str(FD_ENGINE_QUEUE_PORT), + "--metrics-port", + str(FD_METRICS_PORT), + "--cache-queue-port", + str(FD_CACHE_QUEUE_PORT), + "--max-model-len", + "32768", + "--max-num-seqs", + "128", + "--quantization", + "wint4", + "--speculative-config", + json.dumps(speculative_config), + "--graph-optimization-config", + '{"use_cudagraph":true, "use_unique_memory_pool":true, "draft_model_use_cudagraph":true}', + ] + + # Start subprocess in new process group + # 清除log目录 + if os.path.exists("log"): + shutil.rmtree("log") + with open(log_path, "w") as logfile: + process = subprocess.Popen( + cmd, + stdout=logfile, + stderr=subprocess.STDOUT, + start_new_session=True, # Enables killing full group via os.killpg + ) + + # Wait up to 300 seconds for API server to be ready + for _ in range(300): + if is_port_open("127.0.0.1", FD_API_PORT): + print(f"Server is up on port {FD_API_PORT}") + break + time.sleep(1) + else: + print("[TIMEOUT] API server failed to start in 5 minutes. Cleaning up...") + try: + os.killpg(process.pid, signal.SIGTERM) + clean() + except Exception as e: + print(f"Failed to kill process group: {e}") + raise RuntimeError(f"API server did not start on port {FD_API_PORT}") + + yield # Run tests + + print("\n===== Post-test server cleanup... =====") + try: + os.killpg(process.pid, signal.SIGTERM) + clean() + print(f"server (pid={process.pid}) terminated") + except Exception as e: + print(f"Failed to terminate API server: {e}") + + +@pytest.fixture(scope="session") +def api_url(request): + """ + Returns the API endpoint URL for chat completions. + """ + return f"http://0.0.0.0:{FD_API_PORT}/v1/chat/completions" + + +@pytest.fixture(scope="session") +def metrics_url(request): + """ + Returns the metrics endpoint URL. + """ + return f"http://0.0.0.0:{FD_METRICS_PORT}/metrics" + + +@pytest.fixture +def headers(): + """ + Returns common HTTP request headers. + """ + return {"Content-Type": "application/json"} + + +def send_request(url, payload, timeout=60): + """ + 发送请求到指定的URL,并返回响应结果。 + """ + headers = { + "Content-Type": "application/json", + } + + try: + res = requests.post(url, headers=headers, json=payload, timeout=timeout) + print("🟢 接收响应中...\n") + return res + except requests.exceptions.Timeout: + print(f"❌ 请求超时(超过 {timeout} 秒)") + return None + except requests.exceptions.RequestException as e: + print(f"❌ 请求失败:{e}") + return None + + +def get_stream_chunks(response): + """解析流式返回,生成chunk List[dict]""" + chunks = [] + + if response.status_code == 200: + for line in response.iter_lines(decode_unicode=True): + if line: + if line.startswith("data: "): + line = line[len("data: ") :] + + if line.strip() == "[DONE]": + break + + try: + chunk = json.loads(line) + chunks.append(chunk) + except Exception as e: + print(f"解析失败: {e}, 行内容: {line}") + else: + print(f"请求失败,状态码: {response.status_code}") + print("返回内容:", response.text) + + return chunks + + +def test_chat_usage_stream(api_url): + """测试流式chat usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "牛顿的三大运动定律是什么?"}, + ], + "max_tokens": 50, + "stream": True, + "stream_options": {"include_usage": True, "continuous_usage_stats": True}, + "metadata": {"min_tokens": 10}, + } + + response = send_request(url=api_url, payload=payload) + chunks = get_stream_chunks(response) + result = "".join([x["choices"][0]["delta"]["content"] for x in chunks[:-1]]) + print("Prefill Response:", result) + assert result != "", "结果为空" + usage = chunks[-1]["usage"] + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" + + +def test_chat_usage_non_stream(api_url): + """测试非流式chat usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "牛顿的三大运动定律是什么?"}, + ], + "max_tokens": 50, + "stream": False, + "metadata": {"min_tokens": 10}, + } + + response = send_request(url=api_url, payload=payload).json() + usage = response["usage"] + result = response["choices"][0]["message"]["content"] + assert result != "", "结果为空" + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" + + +def test_non_chat_usage_stream(api_url): + """测试流式非chat usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "prompt": "牛顿的三大运动定律是什么?", + "max_tokens": 50, + "stream": True, + "stream_options": {"include_usage": True, "continuous_usage_stats": True}, + "metadata": {"min_tokens": 10}, + } + api_url = api_url.replace("chat/completions", "completions") + + response = send_request(url=api_url, payload=payload) + chunks = get_stream_chunks(response) + result = "".join([x["choices"][0]["text"] for x in chunks[:-1]]) + # print("Prefill Response:", result) + assert result != "", "结果为空" + usage = chunks[-1]["usage"] + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" + + +def test_non_chat_usage_non_stream(api_url): + """测试非流式非chat usage""" + payload = { + "model": "default", + "temperature": 0, + "top_p": 0, + "seed": 33, + "prompt": "牛顿的三大运动定律是什么?", + "max_tokens": 50, + "stream": False, + "metadata": {"min_tokens": 10}, + } + api_url = api_url.replace("chat/completions", "completions") + + response = send_request(url=api_url, payload=payload).json() + usage = response["usage"] + result = response["choices"][0]["text"] + # print("Prefill Response:", result) + assert result != "", "结果为空" + total_tokens = usage["completion_tokens"] + usage["prompt_tokens"] + assert payload["max_tokens"] >= usage["completion_tokens"], "completion_tokens大于max_tokens" + assert payload["metadata"]["min_tokens"] <= usage["completion_tokens"], "completion_tokens小于min_tokens" + assert usage["total_tokens"] == total_tokens, "total_tokens不等于prompt_tokens + completion_tokens" + + +def test_mtp_accept_ratio(api_url): + """测试mtp接受率""" + payload = { + "model": "default", + "messages": [ + { + "role": "user", + "content": "国外项目风险管理研究起步较早,理论体系成熟。早期研究集中于保险与金融领域,后逐步扩展至工程项目、" + "公共管理等多领域。在理论层面,COSO《企业风险管理——整合框架》和ISO31000标准为风险管理提供了系统性" + "指导,强调风险识别、评估、应对与监控的全流程管理。风险识别方法包括故障树分析、事件树分析等;风险评估" + "则广泛应用VaR模型、蒙特卡洛模拟等量化工具。应对策略涵盖规避、转移、减轻和接受等,并衍生出风险共享、" + "升级等复杂策略。此外,组织文化、管理层支持等因素对风险管理有效性影响显著。近年来,随着科技发展," + "人工智能、大数据等技术被引入风险管理,推动其向智能化、自动化方向发展。请介绍一下国外关于项目风险管理" + "的文献研究综述,300字以内", + }, + ], + "stream": True, + "stream_options": {"include_usage": True, "continuous_usage_stats": True}, + "temperature": 0, + "seed": 23, + "top_p": 0, + } + + print("fastdeploy answer is :") + + try: + # TODO: 第一次和第二次存在diff,后面正常,暂时多请求一次 + response = send_request(url=api_url, payload=payload) + chunks = get_stream_chunks(response) + response = send_request(url=api_url, payload=payload) + chunks = get_stream_chunks(response) + for idx, chunk in enumerate(chunks): + print(f"\nchunk[{idx}]:\n{json.dumps(chunk, ensure_ascii=False)}") + result = "".join([x["choices"][0]["delta"]["content"] for x in chunks[:-1]]) + speculate_metrics = chunks[-2]["choices"][0]["speculate_metrics"] + except Exception as e: + print(f"解析失败: {e}") + print("\nresult:\n", result) + + baseline = ( + "国外项目风险管理研究起步早、体系成熟。" + "早期聚焦保险与金融领域,后拓展至多领域。" + "理论层面,COSO《企业风险管理——整合框架》及ISO31000标准提供系统性指导," + "强调全流程管理。" + "风险识别方法多样,如故障树、事件树分析;" + "评估常用VaR模型、蒙特卡洛模拟等量化工具。" + "应对策略丰富,涵盖规避、转移等基本策略及风险共享、升级等复杂策略。" + "组织文化与管理层支持对风险管理有效性影响大。" + "近年来,科技发展促使人工智能、大数据等融入," + "推动风险管理向智能化、自动化迈进 。" + ) + + baseline_ratio = { + "accepted_tokens": 130, + "rejected_tokens": 20, + "accept_ratio": 0.42307692307692313, + "average_accept_length": 1.7333333333333334, + "accepted_tokens_per_head": [75, 55], + "accept_ratio_per_head": [0.7333333333333333], + } + + response = send_request(url=api_url, payload=payload) + chunks = get_stream_chunks(response) + result_2 = "".join([x["choices"][0]["delta"]["content"] for x in chunks[:-1]]) + speculate_metrics_2 = chunks[-2]["choices"][0]["speculate_metrics"] + print("chunks:", chunks[-2]) + print("baseline", speculate_metrics) + print("speculate_metrics_2", speculate_metrics_2) + assert result_2 == baseline, f"与baseline存在diff,result_2: {result}\n baseline: {baseline}" + assert speculate_metrics_2 == baseline_ratio, ( + f"speculate_metrics存在diff," f"speculate_metrics_2: {speculate_metrics_2}\n " f"baseline: {baseline_ratio}" + ) + assert speculate_metrics_2["accept_ratio"] > 0, "accept_ratio异常" + prompt_tokens = chunks[-1]["usage"]["prompt_tokens"] + cached_tokens = chunks[-1]["usage"]["prompt_tokens_details"]["cached_tokens"] + assert cached_tokens == prompt_tokens // 64 * 64, "cached_tokens数量有问题" diff --git a/tests/operators/attention/test_decode_unified_attention_c16.py b/tests/operators/attention/test_decode_unified_attention_c16.py new file mode 100644 index 00000000000..7582439f7cf --- /dev/null +++ b/tests/operators/attention/test_decode_unified_attention_c16.py @@ -0,0 +1,814 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import random +import unittest + +import numpy as np +import paddle + +from fastdeploy.model_executor.layers.attention.ops import ( + append_attention as append_attention_op, +) +from fastdeploy.model_executor.layers.attention.ops import ( + config_for_attention, + decode_unified_attention, + decoder_write_cache_with_rope, + get_block_shape_and_split_kv_block, +) + +seed = 1000 + +random.seed(seed) +np.random.seed(seed) +paddle.seed(seed) + + +class RopeEmbedding: + def __init__(self, use_neox_rotary_style=False): + self.use_neox_rotary_style = use_neox_rotary_style + self.base = 10000 + + def get_rotary_position_embedding(self, position_ids, head_dim): + bsz, max_seq_len = position_ids.shape[:2] + rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, head_dim // 2), dtype="float32") + inv_freq = self.base ** (-paddle.arange(0, head_dim, 2, dtype="float32") / head_dim) + freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq) + emb = paddle.stack([freqs], axis=-1).reshape((bsz, max_seq_len, head_dim // 2)) + emb = paddle.unsqueeze(emb, 2) + rot_emb[0] = paddle.cos(emb) + rot_emb[1] = paddle.sin(emb) + return rot_emb + + def _apply_rope(self, rotary_emb, q, k, start_pos=0): + seq, head_dim = q.shape[2], q.shape[3] + cos, sin = paddle.chunk(rotary_emb, 2, axis=0) + cos = cos[:, :, start_pos : start_pos + seq, ...] + sin = sin[:, :, start_pos : start_pos + seq, ...] + cos = paddle.squeeze(cos, axis=0).transpose([0, 2, 1, 3])[:, :, :seq, :] + sin = paddle.squeeze(sin, axis=0).transpose([0, 2, 1, 3])[:, :, :seq, :] + + sin_pos = paddle.reshape(paddle.stack([sin, sin], axis=-1), [1, 1, seq, head_dim]) + cos_pos = paddle.reshape(paddle.stack([cos, cos], axis=-1), [1, 1, seq, head_dim]) + rotate_half_q = paddle.reshape( + paddle.stack([-q[:, :, :, 1::2], q[:, :, :, 0::2]], axis=-1), + paddle.shape(q), + ) + rotate_half_k = paddle.reshape( + paddle.stack([-k[:, :, :, 1::2], k[:, :, :, 0::2]], axis=-1), + paddle.shape(k), + ) + + query = paddle.add(paddle.multiply(q, cos_pos), paddle.multiply(rotate_half_q, sin_pos)) + key = paddle.add(paddle.multiply(k, cos_pos), paddle.multiply(rotate_half_k, sin_pos)) + return paddle.cast(query, q.dtype), paddle.cast(key, k.dtype) + + +def naive_attention_impl(query, key, value, cache_k=None, cache_v=None, mask=None, scale=1.0): + batch = query.shape[0] + heads = query.shape[1] + seq_len = query.shape[2] + head_dim = query.shape[3] + kv_head = key.shape[1] + + key = key.reshape([batch, kv_head, 1, seq_len, head_dim]) + key = paddle.tile(key, [1, 1, heads // kv_head, 1, 1]) + key = key.reshape([batch, heads, seq_len, head_dim]) + + if cache_k is not None: + cache_k = cache_k.reshape([batch, kv_head, 1, -1, head_dim]) + cache_k = paddle.tile(cache_k, [1, 1, heads // kv_head, 1, 1]) + cache_k = cache_k.reshape([batch, heads, -1, head_dim]) + key = paddle.concat([cache_k, key], axis=2) + + value = value.reshape([batch, kv_head, 1, seq_len, head_dim]) + value = paddle.tile(value, [1, 1, heads // kv_head, 1, 1]) + value = value.reshape([batch, heads, seq_len, head_dim]) + + if cache_v is not None: + cache_v = cache_v.reshape([batch, kv_head, 1, -1, head_dim]) + cache_v = paddle.tile(cache_v, [1, 1, heads // kv_head, 1, 1]) + cache_v = cache_v.reshape([batch, heads, -1, head_dim]) + value = paddle.concat([cache_v, value], axis=2) + + qk_res = paddle.matmul(query, key, transpose_y=True) + attention = qk_res * scale + if mask is not None: + attention = attention + mask + softmax_result = paddle.nn.functional.softmax(attention, -1) + result = paddle.matmul(paddle.cast(softmax_result, dtype=value.dtype), value) + return result + + +def block_cache_to_naive_cache(cache_k, cache_v, bsz, block_tables, cache_seq_len): + """Read K/V from paged cache and return as [batch, num_head, seq_len, dim_head].""" + _, num_head, blocksize, dim_head = cache_k.shape + out_cache_k = paddle.zeros(shape=[bsz, num_head, cache_seq_len, dim_head], dtype=cache_k.dtype) + out_cache_v = paddle.zeros(shape=[bsz, num_head, cache_seq_len, dim_head], dtype=cache_v.dtype) + for i in range(bsz): + for j in range(cache_seq_len): + out_cache_k[i, :, j, :] = cache_k[block_tables[i, j // blocksize], :, j % blocksize, :] + out_cache_v[i, :, j, :] = cache_v[block_tables[i, j // blocksize], :, j % blocksize, :] + return out_cache_k, out_cache_v + + +def get_padding_offset(bsz, seq_lens_this_time): + token_num = paddle.sum(seq_lens_this_time) + batch_id_per_token = paddle.zeros(shape=(token_num), dtype="int32") + cu_seqlens_q = paddle.zeros(shape=(bsz + 1), dtype="int32") + cu_seqlens_k = paddle.zeros(shape=(bsz + 1), dtype="int32") + index = 0 + for i in range(bsz): + seq_len_now = seq_lens_this_time[i].item() + for j in range(seq_len_now): + batch_id_per_token[index] = i + index += 1 + cu_seqlens_q[i + 1] = index + cu_seqlens_k[i + 1] = index + return batch_id_per_token, cu_seqlens_q, cu_seqlens_k + + +def remove_padding(seq_lens, cu_seq_lens, inputs, token_num): + bsz, num_head, seq_len, head_dim = inputs.shape + output = paddle.zeros(shape=[token_num, num_head * head_dim], dtype=inputs.dtype) + inputs = inputs.transpose([0, 2, 1, 3]).reshape([bsz, seq_len, -1]) + for i in range(bsz): + seq_len_now = seq_lens[i] + start_idx = cu_seq_lens[i] + end_idx = cu_seq_lens[i + 1] + output[start_idx:end_idx, :] = inputs[i, :seq_len_now, :] + return output + + +def get_qkv_and_qkv_concat_tensor(bs, q_num_head, kv_num_head, seq_len, head_dim, place, dtype): + query = np.random.random([bs, q_num_head, seq_len, head_dim]) + q = paddle.to_tensor(query, place=place, dtype=dtype, stop_gradient=False) - 0.5 + key = np.random.random([bs, kv_num_head, seq_len, head_dim]) + k = paddle.to_tensor(key, place=place, dtype=dtype, stop_gradient=False) - 0.5 + value = np.random.random([bs, kv_num_head, seq_len, head_dim]) + v = paddle.to_tensor(value, place=place, dtype=dtype, stop_gradient=False) - 0.5 + token_num = bs * seq_len + + qkv = paddle.concat( + [ + q.transpose([0, 2, 1, 3]).reshape([token_num, q_num_head * head_dim]), + k.transpose([0, 2, 1, 3]).reshape([token_num, kv_num_head * head_dim]), + v.transpose([0, 2, 1, 3]).reshape([token_num, kv_num_head * head_dim]), + ], + axis=1, + ).reshape([token_num, -1]) + return q, k, v, qkv + + +class TestDecodeUnifiedAttentionC16(unittest.TestCase): + """Base test class for decode append attention with cache_quant_type='none' (fp16/bf16 KV cache). + + Uses append_attention for prefill (verified correct by test_append_attention_c16.py) + and then tests decode_unified_attention (new split ops) against the same naive reference. + + Subclasses override setUp to vary batch_size, max_tokens_per_batch, dtype, etc. + """ + + def setUp(self): + paddle.disable_static() + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 1 + self.max_tokens_per_batch = 1 + self.head_dim = 128 + self.block_size = 64 + self.dtype = "bfloat16" + self.cache_quant_type = "none" + self.use_neox_rotary_style = False + self.rope_3d = False + self.softmax_scale = self.head_dim**-0.5 + self.rms_norm_eps = 1e-6 + self.causal = True + self.group_size = self.q_num_head // self.kv_num_head + + # Use small seq_len for fast testing; can increase later + self.seq_len = 6400 + self.max_model_len = self.seq_len + 128 + self.init_tensor() + + def init_tensor(self): + self.rope = RopeEmbedding(self.use_neox_rotary_style) + tmp_position_ids = paddle.arange(self.max_model_len).reshape((1, -1)) + self.rotary_embs = self.rope.get_rotary_position_embedding(tmp_position_ids, self.head_dim) + + # block_table + self.block_num_per_seq = (self.max_model_len + self.block_size - 1) // self.block_size + self.max_block_num = self.block_num_per_seq * self.batch_size + self.free_list = list(range(self.max_block_num - 1, -1, -1)) + self.block_tables = paddle.zeros(shape=(self.batch_size, self.block_num_per_seq), dtype="int32") + for i in range(self.batch_size): + need_block_num = (self.max_model_len + self.block_size - 1) // self.block_size + for j in range(need_block_num): + self.block_tables[i, j] = self.free_list.pop() + + # cache + self.cache_shape = ( + self.max_block_num, + self.kv_num_head, + self.block_size, + self.head_dim, + ) + self.cache_k = paddle.zeros(shape=self.cache_shape, dtype=self.dtype) + self.cache_v = paddle.zeros(shape=self.cache_shape, dtype=self.dtype) + + # Encoder phase: prefill with seq_len tokens + self.enc_q, self.enc_k, self.enc_v, self.enc_qkv = get_qkv_and_qkv_concat_tensor( + self.batch_size, + self.q_num_head, + self.kv_num_head, + self.seq_len, + self.head_dim, + self.place, + self.dtype, + ) + + # Decoder phase: max_tokens_per_batch decode tokens + self.dec_q, self.dec_k, self.dec_v, self.dec_qkv = get_qkv_and_qkv_concat_tensor( + self.batch_size, + self.q_num_head, + self.kv_num_head, + self.max_tokens_per_batch, + self.head_dim, + self.place, + self.dtype, + ) + + def _get_block_shape_buffers(self, seq_lens_encoder, seq_lens_decoder, seq_lens_this_time): + max_num_block_dec = self.batch_size * (self.max_model_len * self.group_size + 16 - 1) // 16 + decoder_batch_ids = paddle.full([max_num_block_dec], 0, dtype="int32") + decoder_tile_ids_per_batch = paddle.full([max_num_block_dec], 0, dtype="int32") + decoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").cpu() + decoder_num_blocks_device = paddle.full([1], 0, dtype="int32") + decoder_chunk_size_device = paddle.full([1], 64, dtype="int32") + + max_num_block = self.batch_size * (self.max_model_len * self.group_size + 64 - 1) // 64 + encoder_batch_ids = paddle.full([max_num_block], 0, dtype="int32") + encoder_tile_ids_per_batch = paddle.full([max_num_block], 0, dtype="int32") + encoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").cpu() + + kv_batch_ids = paddle.full([max_num_block], 0, dtype="int32") + kv_tile_ids_per_batch = paddle.full([max_num_block], 0, dtype="int32") + kv_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu() + max_len_tensor_cpu = paddle.full([6], 0, dtype="int32").cpu() + + get_block_shape_and_split_kv_block( + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks_cpu, + decoder_num_blocks_device, + decoder_chunk_size_device, + max_len_tensor_cpu, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks_cpu, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks_x_cpu, + 64, + 16, + self.group_size, + self.block_size, + ) + return { + "decoder_batch_ids": decoder_batch_ids, + "decoder_tile_ids_per_batch": decoder_tile_ids_per_batch, + "decoder_num_blocks_cpu": decoder_num_blocks_cpu, + "encoder_batch_ids": encoder_batch_ids, + "encoder_tile_ids_per_batch": encoder_tile_ids_per_batch, + "encoder_num_blocks_cpu": encoder_num_blocks_cpu, + "kv_batch_ids": kv_batch_ids, + "kv_tile_ids_per_batch": kv_tile_ids_per_batch, + "kv_num_blocks_x_cpu": kv_num_blocks_x_cpu, + "max_len_tensor_cpu": max_len_tensor_cpu, + } + + def run_append_attention( + self, + qkv, + cache_k, + cache_v, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + ): + """Run append_attention op.""" + buffers = self._get_block_shape_buffers(seq_lens_encoder, seq_lens_decoder, seq_lens_this_time) + + qkv_copy = copy.deepcopy(qkv) + cache_k_copy = copy.deepcopy(cache_k) + cache_v_copy = copy.deepcopy(cache_v) + + out = append_attention_op( + qkv_copy, + cache_k_copy, + cache_v_copy, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + self.block_tables, + buffers["encoder_batch_ids"], + buffers["encoder_tile_ids_per_batch"], + buffers["encoder_num_blocks_cpu"], + buffers["kv_batch_ids"], + buffers["kv_tile_ids_per_batch"], + buffers["kv_num_blocks_x_cpu"], + buffers["decoder_batch_ids"], + buffers["decoder_tile_ids_per_batch"], + buffers["decoder_num_blocks_cpu"], + buffers["max_len_tensor_cpu"], + self.rotary_embs, + None, # attn_mask + None, # qkv_bias + None, # qkv_out_scales + None, # cache_k_quant_scales + None, # cache_v_quant_scales + None, # cache_k_dequant_scales + None, # cache_v_dequant_scales + None, # cache_k_zp + None, # cache_v_zp + None, # linear_shift + None, # linear_smooth + None, # mask_offset + None, # kv_signal_data + None, # q_norm_weight + None, # k_norm_weight + None, # sinks + self.rms_norm_eps, + "bf16", + self.cache_quant_type, + self.use_neox_rotary_style, + self.rope_3d, + self.max_model_len, + 0.0, # quant_max_bound + 0.0, # quant_min_bound + -1, + 64, + 16, + 1024, + self.max_model_len, + self.max_tokens_per_batch, # speculate_max_draft_token_num + self.causal, + self.max_tokens_per_batch > 1, # speculate_decoder + ) + return out, cache_k_copy, cache_v_copy + + def _build_decode_buffer(self): + """Build buffer for new split decode ops.""" + buffer = {} + min_chunk_size = 512 + max_num_chunk = (self.max_model_len + min_chunk_size - 1) // min_chunk_size + q_tile_size = 16 + q_tile_num = (self.max_tokens_per_batch * self.group_size + q_tile_size - 1) // q_tile_size + buffer["max_len_tensor_cpu"] = paddle.full([6], 0, dtype="int32").cpu() + buffer["block_indices"] = paddle.full( + [self.batch_size * self.kv_num_head * max_num_chunk * q_tile_num, 4], 0, dtype="int32" + ) + buffer["num_blocks"] = paddle.full([1], 0, dtype="int32") + buffer["chunk_size"] = paddle.full([1], 0, dtype="int32") + buffer["tmp_workspace"] = paddle.full( + [self.batch_size * self.max_tokens_per_batch, max_num_chunk, self.q_num_head * self.head_dim], + 0, + dtype=self.dtype, + ) + buffer["tmp_m"] = paddle.full( + [self.batch_size * self.max_tokens_per_batch, max_num_chunk, self.q_num_head], 0, dtype="float32" + ) + buffer["tmp_d"] = paddle.full( + [self.batch_size * self.max_tokens_per_batch, max_num_chunk, self.q_num_head], 0, dtype="float32" + ) + return buffer + + def _run_decode_unified_attention( + self, + cache_k, + cache_v, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + ): + """Run config_for_attention + decoder_write_cache_with_rope + decode_unified_attention.""" + buffer = self._build_decode_buffer() + + config_for_attention( + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + buffer["block_indices"], + buffer["num_blocks"], + buffer["chunk_size"], + buffer["max_len_tensor_cpu"], + self.cache_quant_type, + self.group_size, + self.kv_num_head, + self.max_tokens_per_batch, + ) + + dec_cache_k = copy.deepcopy(cache_k) + dec_cache_v = copy.deepcopy(cache_v) + dec_qkv = copy.deepcopy(self.dec_qkv) + + decoder_write_cache_with_rope( + dec_qkv, + dec_cache_k, + dec_cache_v, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + self.block_tables, + buffer["max_len_tensor_cpu"], + self.rotary_embs, + None, # qkv_bias + None, # cache_k_quant_scales + None, # cache_v_quant_scales + None, # cache_k_dequant_scales + None, # cache_v_dequant_scales + None, # cache_k_zp + None, # cache_v_zp + None, # kv_signal_data + None, # q_norm_weight + None, # k_norm_weight + self.rms_norm_eps, + self.cache_quant_type, + self.use_neox_rotary_style, + self.rope_3d, + self.max_model_len, + 0.0, # quant_max_bound + 0.0, # quant_min_bound + self.max_tokens_per_batch > 1, # speculate_decoder + ) + + out = decode_unified_attention( + dec_qkv, + dec_cache_k, + dec_cache_v, + buffer["tmp_workspace"], + buffer["tmp_m"], + buffer["tmp_d"], + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + self.block_tables, + buffer["block_indices"], + buffer["num_blocks"], + buffer["chunk_size"], + buffer["max_len_tensor_cpu"], + None, # attn_mask + None, # cache_k_quant_scales + None, # cache_v_quant_scales + None, # cache_k_dequant_scales + None, # cache_v_dequant_scales + None, # cache_k_zp + None, # cache_v_zp + None, # mask_offset + None, # sinks + paddle.empty([dec_qkv.shape[0], self.q_num_head * self.head_dim], dtype=dec_qkv.dtype), # fmha_out + self.cache_quant_type, + self.max_model_len, + 0.0, # quant_max_bound + 0.0, # quant_min_bound + self.max_tokens_per_batch, # speculate_max_draft_token_num + self.causal, # causal + ) + return out, dec_cache_k, dec_cache_v + + def do_prefill_with_append_attention(self): + """Prefill using append_attention. Returns cache_k, cache_v after prefill.""" + seq_lens_encoder = paddle.to_tensor([self.seq_len] * self.batch_size, "int32") + seq_lens_decoder = paddle.to_tensor([0] * self.batch_size, "int32") + seq_lens_this_time = copy.deepcopy(seq_lens_encoder) + + batch_id_per_token, cu_seqlens_q, _ = get_padding_offset(self.batch_size, seq_lens_this_time) + + _, cache_k, cache_v = self.run_append_attention( + self.enc_qkv, + self.cache_k, + self.cache_v, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + ) + return cache_k, cache_v + + def compute_naive_decode_ref(self, cache_k, cache_v): + """Compute naive reference for decode step using cache from paged cache.""" + # Read K/V from paged cache + naive_cache_k, naive_cache_v = block_cache_to_naive_cache( + cache_k, cache_v, self.batch_size, self.block_tables, self.seq_len + ) + + # Only use the first decode token (seq_lens_this_time=1 per batch) + dec_q = self.dec_q[:, :, :1, :] + dec_k = self.dec_k[:, :, :1, :] + dec_v = self.dec_v[:, :, :1, :] + + # Apply RoPE to decode Q/K at position seq_len + dec_q_rope, dec_k_rope = self.rope._apply_rope(self.rotary_embs, dec_q, dec_k, start_pos=self.seq_len) + + # Compute naive attention + out_ref = naive_attention_impl( + dec_q_rope, + dec_k_rope, + dec_v, + cache_k=naive_cache_k, + cache_v=naive_cache_v, + scale=self.softmax_scale, + ) + + dec_seq_lens_this_time = paddle.to_tensor([1] * self.batch_size, "int32") + dec_token_num = self.batch_size + _, dec_cu_seqlens_q, _ = get_padding_offset(self.batch_size, dec_seq_lens_this_time) + out_ref = remove_padding(dec_seq_lens_this_time, dec_cu_seqlens_q, out_ref, dec_token_num) + return out_ref + + def test_naive_vs_decode_unified_attention(self): + """Test: prefill with append_attention, then decode with new split decode ops.""" + # Step 1: Prefill + cache_k, cache_v = self.do_prefill_with_append_attention() + + # Step 2: Naive reference for decode + out_ref = self.compute_naive_decode_ref(cache_k, cache_v) + + # Step 3: Decode with new split ops + # seq_lens_this_time must match qkv rows: batch_size * max_tokens_per_batch + dec_seq_lens_encoder = paddle.to_tensor([0] * self.batch_size, "int32") + dec_seq_lens_decoder = paddle.to_tensor([self.seq_len] * self.batch_size, "int32") + dec_seq_lens_this_time = paddle.to_tensor([self.max_tokens_per_batch] * self.batch_size, "int32") + + dec_batch_id_per_token, dec_cu_seqlens_q, _ = get_padding_offset(self.batch_size, dec_seq_lens_this_time) + + out, _, _ = self._run_decode_unified_attention( + cache_k, + cache_v, + dec_seq_lens_encoder, + dec_seq_lens_decoder, + dec_seq_lens_this_time, + dec_batch_id_per_token, + dec_cu_seqlens_q, + ) + + out_ref_f = out_ref.astype("float32").numpy() + out_decode_f = out.astype("float32").numpy() + + # Truncate to actual token count (output may be padded to max_tokens_per_batch) + dec_token_num = self.batch_size + out_decode_f = out_decode_f[:dec_token_num] + + np.testing.assert_allclose( + out_decode_f, + out_ref_f, + rtol=1e-02, + atol=1e-02, + err_msg="decode_unified_attention output doesn't match naive reference", + ) + + def test_append_vs_decode_unified_attention(self): + """Test: append_attention decode vs new split decode ops should produce same result.""" + # Step 1: Prefill + cache_k, cache_v = self.do_prefill_with_append_attention() + + # Step 2: Decode with append_attention + # seq_lens_this_time must match qkv rows: batch_size * max_tokens_per_batch + dec_seq_lens_encoder = paddle.to_tensor([0] * self.batch_size, "int32") + dec_seq_lens_decoder = paddle.to_tensor([self.seq_len] * self.batch_size, "int32") + dec_seq_lens_this_time = paddle.to_tensor([self.max_tokens_per_batch] * self.batch_size, "int32") + dec_batch_id_per_token, dec_cu_seqlens_q, _ = get_padding_offset(self.batch_size, dec_seq_lens_this_time) + + out_append, _, _ = self.run_append_attention( + self.dec_qkv, + copy.deepcopy(cache_k), + copy.deepcopy(cache_v), + dec_seq_lens_encoder, + dec_seq_lens_decoder, + dec_seq_lens_this_time, + dec_batch_id_per_token, + dec_cu_seqlens_q, + ) + + # Step 3: Decode with new split ops + out_decode, _, _ = self._run_decode_unified_attention( + cache_k, + cache_v, + dec_seq_lens_encoder, + dec_seq_lens_decoder, + dec_seq_lens_this_time, + dec_batch_id_per_token, + dec_cu_seqlens_q, + ) + + out_append_f = out_append.astype("float32").numpy() + out_decode_f = out_decode.astype("float32").numpy() + + # Truncate to actual token count (output may be padded to max_tokens_per_batch) + dec_token_num = self.batch_size + out_append_f = out_append_f[:dec_token_num] + out_decode_f = out_decode_f[:dec_token_num] + + np.testing.assert_allclose( + out_decode_f, + out_append_f, + rtol=1e-02, + atol=1e-02, + err_msg="decode_unified_attention doesn't match append_attention decode", + ) + + +class TestDecodeUnifiedAttentionC16Speculate(TestDecodeUnifiedAttentionC16): + """Test with speculate decode: max_tokens_per_batch=2. + + When max_tokens_per_batch > 1, naive ref only computes 1 token while ops + compute multiple tokens. So naive comparison tests are skipped; only + append_attention vs decode_unified_attention comparison is kept. + """ + + def setUp(self): + paddle.disable_static() + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 1 + self.max_tokens_per_batch = 2 + self.head_dim = 128 + self.block_size = 64 + self.dtype = "bfloat16" + self.cache_quant_type = "none" + self.use_neox_rotary_style = False + self.rope_3d = False + self.softmax_scale = self.head_dim**-0.5 + self.rms_norm_eps = 1e-6 + self.causal = True + self.group_size = self.q_num_head // self.kv_num_head + self.seq_len = 6400 + self.max_model_len = self.seq_len + 128 + self.init_tensor() + + +class TestDecodeUnifiedAttentionC16MultiBatch(TestDecodeUnifiedAttentionC16): + """Test with multiple batches.""" + + def setUp(self): + paddle.disable_static() + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 4 + self.max_tokens_per_batch = 1 + self.head_dim = 128 + self.block_size = 64 + self.dtype = "bfloat16" + self.cache_quant_type = "none" + self.use_neox_rotary_style = False + self.rope_3d = False + self.softmax_scale = self.head_dim**-0.5 + self.rms_norm_eps = 1e-6 + self.causal = True + self.group_size = self.q_num_head // self.kv_num_head + self.seq_len = 6400 + self.max_model_len = self.seq_len + 128 + self.init_tensor() + + +class TestDecodeUnifiedAttentionC16MultiHead(TestDecodeUnifiedAttentionC16): + """Test with multiple KV heads (GQA).""" + + def setUp(self): + paddle.disable_static() + self.place = paddle.CUDAPlace(0) + self.q_num_head = 16 + self.kv_num_head = 2 + self.batch_size = 2 + self.max_tokens_per_batch = 1 + self.head_dim = 128 + self.block_size = 64 + self.dtype = "bfloat16" + self.cache_quant_type = "none" + self.use_neox_rotary_style = False + self.rope_3d = False + self.softmax_scale = self.head_dim**-0.5 + self.rms_norm_eps = 1e-6 + self.causal = True + self.group_size = self.q_num_head // self.kv_num_head + self.seq_len = 6400 + self.max_model_len = self.seq_len + 128 + self.init_tensor() + + +class TestDecodeUnifiedAttentionC16FP16(TestDecodeUnifiedAttentionC16): + """Test with float16 dtype.""" + + def setUp(self): + paddle.disable_static() + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 1 + self.max_tokens_per_batch = 1 + self.head_dim = 128 + self.block_size = 64 + self.dtype = "float16" + self.cache_quant_type = "none" + self.use_neox_rotary_style = False + self.rope_3d = False + self.softmax_scale = self.head_dim**-0.5 + self.rms_norm_eps = 1e-6 + self.causal = True + self.group_size = self.q_num_head // self.kv_num_head + self.seq_len = 6400 + self.max_model_len = self.seq_len + 128 + self.init_tensor() + + +class TestDecodeUnifiedAttentionC16NoCausal(TestDecodeUnifiedAttentionC16): + """Test with causal=False.""" + + def setUp(self): + paddle.disable_static() + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 1 + self.max_tokens_per_batch = 1 + self.head_dim = 128 + self.block_size = 64 + self.dtype = "bfloat16" + self.cache_quant_type = "none" + self.use_neox_rotary_style = False + self.rope_3d = False + self.softmax_scale = self.head_dim**-0.5 + self.rms_norm_eps = 1e-6 + self.causal = False + self.group_size = self.q_num_head // self.kv_num_head + self.seq_len = 6400 + self.max_model_len = self.seq_len + 128 + self.init_tensor() + + +class TestDecodeUnifiedAttentionC16MultiBatchSpeculate(TestDecodeUnifiedAttentionC16): + """Test with multi-batch + speculate decode. + + When max_tokens_per_batch > 1, the naive reference only computes 1 token + while ops compute multiple tokens. So we only compare append_attention vs + decode_unified_attention (both should produce same result), and skip the + naive comparison tests. + """ + + def setUp(self): + paddle.disable_static() + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 4 + self.max_tokens_per_batch = 2 + self.head_dim = 128 + self.block_size = 64 + self.dtype = "bfloat16" + self.cache_quant_type = "none" + self.use_neox_rotary_style = False + self.rope_3d = False + self.softmax_scale = self.head_dim**-0.5 + self.rms_norm_eps = 1e-6 + self.causal = True + self.group_size = self.q_num_head // self.kv_num_head + self.seq_len = 6400 + self.max_model_len = self.seq_len + 128 + self.init_tensor() + + def test_naive_vs_decode_unified_attention(self): + """Skip: naive ref only computes 1 token, but ops compute max_tokens_per_batch tokens.""" + pass + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/operators/attention/test_decode_unified_attention_c8.py b/tests/operators/attention/test_decode_unified_attention_c8.py new file mode 100644 index 00000000000..d5ec0e5354c --- /dev/null +++ b/tests/operators/attention/test_decode_unified_attention_c8.py @@ -0,0 +1,921 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import random +import unittest + +import numpy as np +import paddle + +from fastdeploy.model_executor.layers.attention.ops import ( + append_attention, + config_for_attention, + decode_unified_attention, + decoder_write_cache_with_rope, + get_block_shape_and_split_kv_block, + gqa_rope_write_cache, + pre_cache_len_concat, +) + +seed = 1000 + +random.seed(seed) +np.random.seed(seed) +paddle.seed(seed) + + +class RopeEmbedding: + def __init__(self, use_neox_rotary_style=False): + self.use_neox_rotary_style = use_neox_rotary_style + self.base = 10000 + + def get_rotary_position_embedding(self, position_ids, head_dim): + bsz, max_seq_len = position_ids.shape[:2] + rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, head_dim // 2), dtype="float32") + inv_freq = self.base ** (-paddle.arange(0, head_dim, 2, dtype="float32") / head_dim) + + # shape: [B, S, D/2] + freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq) + # shape: [B, S, D/2] + emb = paddle.stack([freqs], axis=-1).reshape((bsz, max_seq_len, head_dim // 2)) + # shape: [B, S, 1, D/2] + emb = paddle.unsqueeze(emb, 2) + + rot_emb[0] = paddle.cos(emb) + rot_emb[1] = paddle.sin(emb) + return rot_emb + + +def get_padding_offset(bsz, seq_lens_this_time): + token_num = paddle.sum(seq_lens_this_time) + batch_id_per_token = paddle.zeros(shape=(token_num), dtype="int32") + cu_seqlens_q = paddle.zeros(shape=(bsz + 1), dtype="int32") + cu_seqlens_k = paddle.zeros(shape=(bsz + 1), dtype="int32") + index = 0 + for i in range(bsz): + seq_len_now = seq_lens_this_time[i].item() + for j in range(seq_len_now): + batch_id_per_token[index] = i + index += 1 + cu_seqlens_q[i + 1] = index + cu_seqlens_k[i + 1] = index + return batch_id_per_token, cu_seqlens_q, cu_seqlens_k + + +def get_qkv_and_qkv_concat_tensor(bs, q_num_head, kv_num_head, seq_len, head_dim, place, dtype): + query = np.random.random([bs, q_num_head, seq_len, head_dim]) + q = paddle.to_tensor(query, place=place, dtype=dtype, stop_gradient=False) - 0.5 + key = np.random.random([bs, kv_num_head, seq_len, head_dim]) + k = paddle.to_tensor(key, place=place, dtype=dtype, stop_gradient=False) - 0.5 + value = np.random.random([bs, kv_num_head, seq_len, head_dim]) + v = paddle.to_tensor(value, place=place, dtype=dtype, stop_gradient=False) - 0.5 + token_num = bs * seq_len + + qkv = paddle.concat( + [ + q.transpose([0, 2, 1, 3]).reshape([token_num, q_num_head * head_dim]), + k.transpose([0, 2, 1, 3]).reshape([token_num, kv_num_head * head_dim]), + v.transpose([0, 2, 1, 3]).reshape([token_num, kv_num_head * head_dim]), + ], + axis=1, + ).reshape([token_num, -1]) + return q, k, v, qkv + + +class TestDecodeUnifiedAttention(unittest.TestCase): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeUnifiedAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 1 + self.max_tokens_per_batch = 1 + self.cache_len = 500 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 131072 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "cache_fp8" + self.use_qk_norm = False + self.use_mask_offset = False + self.causal = True + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + def init_tensor(self): + # seq_lens + if self.seq_len_dec is None: + self.seq_lens_dec = [ + self.cache_len, + ] * self.batch_size + else: + self.batch_size = len(self.seq_lens_dec) + self.seq_lens_decoder = paddle.to_tensor( + self.seq_lens_dec, + "int32", + ) + if self.seq_lens_this_time is None: + self.seq_lens_this_time = [ + self.max_tokens_per_batch, + ] * self.batch_size + self.token_num = sum(self.seq_lens_this_time) + self.seq_lens_this_time = paddle.to_tensor(self.seq_lens_this_time, "int32") + + self.seq_lens_enc = [0] * self.batch_size + + self.seq_lens_encoder = paddle.to_tensor( + self.seq_lens_enc, + "int32", + ) + + # self.qkv = paddle.rand([self.token_num, (self.q_num_head + 2 * self.kv_num_head) * self.head_dim], dtype=self.dtype) + self.q, self.k, self.v, self.qkv = get_qkv_and_qkv_concat_tensor( + self.batch_size, + self.q_num_head, + self.kv_num_head, + self.max_tokens_per_batch, + self.head_dim, + self.place, + self.dtype, + ) + self.qkv = paddle.to_tensor(self.qkv, dtype=self.dtype) + + # qk_norm + self.q_norm_weight = None + self.k_norm_weight = None + if self.use_qk_norm: + q_norm_weight_np = np.random.random([self.head_dim]) / 10 + k_norm_weight_np = np.random.random([self.head_dim]) / 10 + self.q_norm_weight = paddle.to_tensor(q_norm_weight_np, dtype="float32") + self.k_norm_weight = paddle.to_tensor(k_norm_weight_np, dtype="float32") + + # rotary embedding + self.rope = RopeEmbedding(False) + tmp_position_ids = paddle.arange(self.max_model_len).reshape((1, -1)) + self.rotary_embs = self.rope.get_rotary_position_embedding(tmp_position_ids, self.head_dim) + + # block_table + self.block_num_per_seq = (self.max_model_len + self.block_size - 1) // self.block_size + self.max_block_num = self.block_num_per_seq * self.batch_size + self.free_list = list(range(self.max_block_num - 1, -1, -1)) + self.block_tables = paddle.zeros(shape=(self.batch_size, self.block_num_per_seq), dtype="int32") + for i in range(self.batch_size): + need_block_num = (self.max_model_len + self.block_size - 1) // self.block_size + for j in range(need_block_num): + self.block_tables[i, j] = self.free_list.pop() + + # cache_kv && scale + self.cache_shape = ( + self.max_block_num, + self.kv_num_head, + self.block_size, + self.head_dim, + ) + + if self.cache_quant_type == "block_wise_fp8": + self.cache_scale_shape = ( + self.max_block_num, + self.kv_num_head, + self.block_size, + ) + self.cache_k = paddle.zeros(shape=self.cache_shape, dtype="uint8") + self.cache_v = paddle.zeros(shape=self.cache_shape, dtype="uint8") + self.cache_k_scale = paddle.zeros(shape=self.cache_scale_shape, dtype=self.dtype) + self.cache_v_scale = paddle.zeros(shape=self.cache_scale_shape, dtype=self.dtype) + self.cache_k_out_scale = None + self.cache_v_out_scale = None + else: + self.cache_k_scale = ( + self.quant_max_bound / self.k.transpose([1, 0, 2, 3]).reshape([self.kv_num_head, -1]).abs().max(axis=1) + ).astype(self.dtype) + self.cache_v_scale = ( + self.quant_max_bound / self.v.transpose([1, 0, 2, 3]).reshape([self.kv_num_head, -1]).abs().max(axis=1) + ).astype(self.dtype) + + self.cache_k_out_scale = ( + self.k.transpose([1, 0, 2, 3]).reshape([self.kv_num_head, -1]).abs().max(axis=1) / self.quant_max_bound + ).astype(self.dtype) + self.cache_v_out_scale = ( + self.v.transpose([1, 0, 2, 3]).reshape([self.kv_num_head, -1]).abs().max(axis=1) / self.quant_max_bound + ).astype(self.dtype) + + self.cache_k = paddle.zeros(shape=self.cache_shape, dtype="uint8") + self.cache_v = paddle.zeros(shape=self.cache_shape, dtype="uint8") + + ( + self.batch_id_per_token, + self.cu_seqlens_q, + self.cu_seqlens_k, + ) = get_padding_offset(self.batch_size, self.seq_lens_this_time) + + # mask offset + self.mask_offset = None + if self.use_mask_offset: + self.mask_offset = paddle.full(self.batch_size * 2, 0, "int32") + for i in range(self.batch_size): + self.mask_offset[i * 2] = 0 + self.mask_offset[i * 2 + 1] = self.seq_lens_dec[i] + 1 + + # buffer + self.buffer = {} + min_chunk_size = 512 + max_num_chunk = (self.max_model_len + min_chunk_size - 1) // min_chunk_size + self.group_size = self.q_num_head // self.kv_num_head + q_tile_size = 16 + q_tile_num = (self.max_tokens_per_batch * self.group_size + q_tile_size - 1) // q_tile_size + self.buffer["max_len_tensor_cpu"] = paddle.full([6], 0, dtype="int32").cpu() + # block_indices: Launched block's indices with 4 dimensions [batch_idx, kv_head_idx, chunk_idx, q_tile_idx] in decode append attention backend + self.buffer["block_indices"] = paddle.full( + [self.batch_size * self.kv_num_head * max_num_chunk * q_tile_num, 4], 0, dtype="int32" + ) + # num_blocks: Number of Launched blocks in decode append attention backend, researched by config_for_attention op + self.buffer["num_blocks"] = paddle.full([1], 0, dtype="int32") + # chunk_size: Chunk size for split kv cache in decode append attention backend, researched by config_for_attention op + self.buffer["chunk_size"] = paddle.full([1], 0, dtype="int32") + # tmp_workspace: Workspace tensor for temporary store the result before merging in decode append attention backend + self.buffer["tmp_workspace"] = paddle.full( + [self.batch_size * self.max_tokens_per_batch, max_num_chunk, self.q_num_head * self.head_dim], + 0, + dtype=self.dtype, + ) + # tmp_m: Tmp_m tensor for temporary store the max value before merging in decode append attention backend + self.buffer["tmp_m"] = paddle.full( + [self.batch_size * self.max_tokens_per_batch, max_num_chunk, self.q_num_head], 0, dtype="float32" + ) + # tmp_d: Tmp_d tensor for temporary store the exponential sum before merging in decode append attention backend + self.buffer["tmp_d"] = paddle.full( + [self.batch_size * self.max_tokens_per_batch, max_num_chunk, self.q_num_head], 0, dtype="float32" + ) + + def append_attention_with_args( + self, + qkv, + cache_k, + cache_v, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + ): + """Run append_attention with explicit arguments.""" + # buffer + max_num_block_dec = self.batch_size * (self.max_model_len * self.group_size + 16 - 1) // 16 + decoder_batch_ids = paddle.full([max_num_block_dec], 0, dtype="int32") + decoder_tile_ids_per_batch = paddle.full([max_num_block_dec], 0, dtype="int32") + decoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").cpu() + decoder_num_blocks_device = paddle.full([1], 0, dtype="int32") + decoder_chunk_size_device = paddle.full([1], 64, dtype="int32") + + max_num_block = self.batch_size * (self.max_model_len * self.group_size + 64 - 1) // 64 + encoder_batch_ids = paddle.full([max_num_block], 0, dtype="int32") + encoder_tile_ids_per_batch = paddle.full([max_num_block], 0, dtype="int32") + encoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").cpu() + + kv_batch_ids = paddle.full([max_num_block], 0, dtype="int32") + kv_tile_ids_per_batch = paddle.full([max_num_block], 0, dtype="int32") + kv_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu() + max_len_tensor_cpu = paddle.full([6], 0, dtype="int32").cpu() + + get_block_shape_and_split_kv_block( + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks_cpu, + decoder_num_blocks_device, + decoder_chunk_size_device, + max_len_tensor_cpu, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks_cpu, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks_x_cpu, + 64, + 16, + self.group_size, + self.block_size, + ) + out = append_attention( + qkv, + cache_k, + cache_v, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + batch_id_per_token, + cu_seqlens_q, + self.block_tables, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks_cpu, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks_x_cpu, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks_cpu, + max_len_tensor_cpu, + self.rotary_embs, + None, # attn_mask + None, # qkv_bias + None, # qkv_out_scales + self.cache_k_scale, # cache_k_quant_scales + self.cache_v_scale, # cache_v_quant_scales + self.cache_k_out_scale, # cache_k_dequant_scales + self.cache_v_out_scale, # cache_v_dequant_scales + None, # cache_k_zp + None, # cache_v_zp + None, # linear_shift + None, # linear_smooth + None, # mask_offset + None, # kv_signal_data + self.q_norm_weight, + self.k_norm_weight, + None, # sinks + self.rms_norm_eps, + "bf16", + self.cache_quant_type, + False, # use_neox_rotary_style + self.rope_3d, + self.max_model_len, + self.quant_max_bound, # quant_max_bound + self.quant_min_bound, # quant_min_bound + -1, + 64, + 16, + self.max_model_len, + 1024, + self.max_tokens_per_batch, + self.causal, + self.max_tokens_per_batch > 1, + self.sliding_window, + ) + return out, cache_k, cache_v + + def append_attention(self): + """Convenience wrapper using default self members.""" + return self.append_attention_with_args( + copy.deepcopy(self.qkv), + copy.deepcopy(self.cache_k), + copy.deepcopy(self.cache_v), + self.seq_lens_encoder, + self.seq_lens_decoder, + self.seq_lens_this_time, + self.batch_id_per_token, + self.cu_seqlens_q, + ) + + def decode_unified_attention(self): + paddle.disable_static() + + config_for_attention( + self.seq_lens_encoder, + self.seq_lens_decoder, + self.seq_lens_this_time, + self.buffer["block_indices"], + self.buffer["num_blocks"], + self.buffer["chunk_size"], + self.buffer["max_len_tensor_cpu"], + self.cache_quant_type, + self.group_size, + self.kv_num_head, + self.max_tokens_per_batch, + ) + # print(f"num_blocks: {self.buffer['num_blocks']}") + decoder_write_cache_with_rope( + self.qkv, + self.cache_k, + self.cache_v, + self.seq_lens_encoder, + self.seq_lens_decoder, + self.seq_lens_this_time, + self.batch_id_per_token, + self.cu_seqlens_q, + self.block_tables, + self.buffer["max_len_tensor_cpu"], + self.rotary_embs, # rotary_embs + None, # qkv_bias + self.cache_k_scale, # cache_k_quant_scales + self.cache_v_scale, # cache_v_quant_scales + self.cache_k_out_scale, # cache_k_dequant_scales + self.cache_v_out_scale, # cache_v_dequant_scales + None, # cache_k_zp + None, # cache_v_zp + None, # kv_signal_data + self.q_norm_weight, # q_norm_weight + self.k_norm_weight, # k_norm_weight + self.rms_norm_eps, + self.cache_quant_type, + False, # use_neox_rotary_style + self.rope_3d, + self.max_model_len, + self.quant_max_bound, # quant_max_bound + self.quant_min_bound, # quant_min_bound + self.max_tokens_per_batch > 1, # speculate_decoder + ) + + out = decode_unified_attention( + self.qkv, + self.cache_k, + self.cache_v, + self.buffer["tmp_workspace"], + self.buffer["tmp_m"], + self.buffer["tmp_d"], + self.seq_lens_encoder, + self.seq_lens_decoder, + self.seq_lens_this_time, + self.batch_id_per_token, + self.cu_seqlens_q, + self.block_tables, + self.buffer["block_indices"], + self.buffer["num_blocks"], + self.buffer["chunk_size"], + self.buffer["max_len_tensor_cpu"], # set_max_lengths + None, # attn_mask + self.cache_k_scale, # cache_k_quant_scales + self.cache_v_scale, # cache_v_quant_scales + self.cache_k_out_scale, # cache_k_dequant_scales + self.cache_v_out_scale, # cache_v_dequant_scales + None, # cache_k_zp + None, # cache_v_zp + None, # mask_offset + None, # sinks # sinks + paddle.empty([self.qkv.shape[0], self.q_num_head * self.head_dim], dtype=self.qkv.dtype), # fmha_out + self.cache_quant_type, + self.max_model_len, + self.quant_max_bound, # quant_max_bound + self.quant_min_bound, # quant_min_bound + self.max_tokens_per_batch, # speculate_max_draft_token_num + self.causal, # causal + self.sliding_window, + ) + return self.qkv, out + + def prefill(self): + # init seq_len + seq_lens_encoder = copy.deepcopy(self.seq_lens_decoder) + seq_lens_decoder = paddle.zeros([self.batch_size], dtype="int32") + seq_lens_this_time = seq_lens_encoder + token_num = seq_lens_this_time.sum().item() + qkv_np = np.random.random([token_num, (self.q_num_head + 2 * self.kv_num_head) * self.head_dim]) - 0.5 + qkv = paddle.to_tensor(qkv_np, dtype=self.dtype) + + ( + batch_id_per_token, + cu_seqlens_q, + cu_seqlens_k, + ) = get_padding_offset(self.batch_size, seq_lens_this_time) + # buffer + decode_max_tile_size = self.batch_size * (self.max_model_len * self.group_size + 16 - 1) // 16 + decoder_batch_ids = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") + decoder_tile_ids_per_batch = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") + decoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").cpu() + decoder_num_blocks_device = paddle.full([1], 0, dtype="int32") + decoder_chunk_size_device = paddle.full([1], 64, dtype="int32") + max_num_block = self.batch_size * (self.max_model_len * self.group_size + 64 - 1) // 64 + encoder_batch_ids = paddle.full([max_num_block], 0, dtype="int32") + encoder_tile_ids_per_batch = paddle.full([max_num_block], 0, dtype="int32") + encoder_num_blocks_cpu = paddle.full([1], 0, dtype="int32").cpu() + + kv_batch_ids = paddle.full([max_num_block], 0, dtype="int32") + kv_tile_ids_per_batch = paddle.full([max_num_block], 0, dtype="int32") + kv_num_blocks_x_cpu = paddle.full([1], 0, dtype="int32").cpu() + max_len_tensor_cpu = paddle.full([6], 0, dtype="int32").cpu() + get_block_shape_and_split_kv_block( + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks_cpu, + decoder_num_blocks_device, + decoder_chunk_size_device, + max_len_tensor_cpu, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks_cpu, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks_x_cpu, + 64, + 16, + self.group_size, + self.block_size, + ) + ( + cu_seqlens_k, + pre_cache_batch_ids, + pre_cache_tile_ids_per_batch, + pre_cache_num_blocks_cpu, + kv_token_num_cpu, + ) = pre_cache_len_concat( + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + max_len_tensor_cpu[2], + self.block_size, + ) + q, k, v, _ = gqa_rope_write_cache( + qkv, + self.cache_k, + self.cache_v, + cu_seqlens_q, + cu_seqlens_k, + self.rotary_embs, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + batch_id_per_token, + self.block_tables, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks_x_cpu, + pre_cache_batch_ids, + pre_cache_tile_ids_per_batch, + pre_cache_num_blocks_cpu, + self.q_norm_weight, + self.k_norm_weight, + self.cache_k_scale, # cache_k_quant_scales + self.cache_v_scale, # cache_v_quant_scales + self.cache_k_out_scale, # cache_k_dequant_scales + self.cache_v_out_scale, # cache_v_dequant_scales + None, # cache_k_zp + None, # cache_v_zp + None, # kv_signal_data + kv_token_num_cpu[0].item(), + self.max_model_len, + self.rms_norm_eps, + False, # use_neox_rotary_style + self.cache_quant_type, + self.rope_3d, + ) + + k = k.reshape([self.batch_size, -1, self.kv_num_head, self.head_dim]).transpose([0, 2, 1, 3]) + v = v.reshape([self.batch_size, -1, self.kv_num_head, self.head_dim]).transpose([0, 2, 1, 3]) + return k, v + + def test_all(self): + """Compare append_attention vs decode_unified_attention output for consistency.""" + # Step 1: Prefill - just write K/V to cache via gqa_rope_write_cache + self.prefill() + + # Step 2: Decode with append_attention (copy cache so it's not modified) + dec_seq_lens_encoder = paddle.zeros([self.batch_size], dtype="int32") + dec_seq_lens_decoder = copy.deepcopy(self.seq_lens_decoder) + + dec_seq_lens_this_time = paddle.to_tensor([self.max_tokens_per_batch] * self.batch_size, dtype="int32") + dec_batch_id_per_token, dec_cu_seqlens_q, _ = get_padding_offset(self.batch_size, dec_seq_lens_this_time) + + out_append_dec, _, _ = self.append_attention_with_args( + copy.deepcopy(self.qkv), + copy.deepcopy(self.cache_k), + copy.deepcopy(self.cache_v), + dec_seq_lens_encoder, + dec_seq_lens_decoder, + dec_seq_lens_this_time, + dec_batch_id_per_token, + dec_cu_seqlens_q, + ) + + # Step 3: Decode with decode_unified_attention (uses self.cache_k/v directly) + _, out_decode = self.decode_unified_attention() + + # Step 4: Compare + out_append_f = out_append_dec.astype("float32").numpy() + out_decode_f = out_decode.astype("float32").numpy() + + np.testing.assert_allclose( + out_decode_f, + out_append_f, + rtol=1e-02, + atol=1e-02, + err_msg="decode_unified_attention output doesn't match append_attention output", + ) + + +class TestDecodeUnifiedAttentionMultiBatch(TestDecodeUnifiedAttention): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeUnifiedAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 60 + self.max_tokens_per_batch = 2 + self.cache_len = 500 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 131072 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "cache_fp8" + self.use_qk_norm = False + self.use_mask_offset = False + self.causal = True + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + +class TestDecodeUnifiedAttentionSpeculate(TestDecodeUnifiedAttention): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeUnifiedAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 6 + self.max_tokens_per_batch = 2 + self.cache_len = 500 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 131072 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "cache_fp8" + self.use_qk_norm = False + self.use_mask_offset = False + self.causal = True + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + +class TestDecodeUnifiedAttentionMultiHead(TestDecodeUnifiedAttention): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeUnifiedAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 16 + self.kv_num_head = 2 + self.batch_size = 6 + self.max_tokens_per_batch = 2 + self.cache_len = 500 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 131072 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "cache_fp8" + self.use_qk_norm = False + self.use_mask_offset = False + self.causal = True + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + +class TestDecodeUnifiedAttentionMultiSpeculate(TestDecodeUnifiedAttention): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeUnifiedAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 6 + self.max_tokens_per_batch = 4 + self.cache_len = 500 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 131072 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "cache_fp8" + self.use_qk_norm = False + self.use_mask_offset = False + self.causal = True + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + +class TestDecodeUnifiedAttentionSpeculateBs128Mtp4(TestDecodeUnifiedAttention): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeUnifiedAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 128 + self.max_tokens_per_batch = 4 + self.cache_len = 508 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 2048 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "cache_fp8" + self.use_qk_norm = False + self.use_mask_offset = False + self.causal = True + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + +class TestDecodeUnifiedAttentionDynamicC8(TestDecodeUnifiedAttention): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeUnifiedAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 6 + self.max_tokens_per_batch = 2 + self.cache_len = 500 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 131072 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "block_wise_fp8" + self.use_qk_norm = False + self.use_mask_offset = False + self.causal = True + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + +class TestDecodeUnifiedAttentionDynamicC8MultiBatch(TestDecodeUnifiedAttention): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeUnifiedAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 60 + self.max_tokens_per_batch = 2 + self.cache_len = 500 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 131072 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "block_wise_fp8" + self.use_qk_norm = False + self.use_mask_offset = False + self.causal = True + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + +class TestDecodeUnifiedAttentionDynamicC8Speculate(TestDecodeUnifiedAttention): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeUnifiedAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 6 + self.max_tokens_per_batch = 4 + self.cache_len = 500 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 131072 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "block_wise_fp8" + self.use_qk_norm = False + self.use_mask_offset = False + self.causal = True + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + +class TestDecodeUnifiedAttentionQKNorm(TestDecodeUnifiedAttention): + def setUp(self): + paddle.disable_static() + self.name = "TestDecodeUnifiedAttention" + self.place = paddle.CUDAPlace(0) + self.q_num_head = 14 + self.kv_num_head = 1 + self.batch_size = 6 + self.max_tokens_per_batch = 2 + self.cache_len = 500 + self.seq_len_dec = None + self.seq_lens_this_time = None + self.max_model_len = 131072 + self.head_dim = 128 + self.rms_norm_eps = 1e-6 + self.rope_3d = False + self.q_hid_dim = self.q_num_head * self.head_dim + self.kv_hid_dim = self.kv_num_head * self.head_dim + self.block_size = 64 + self.use_neox_rotary_style = False + self.softmax_scale = self.head_dim**-0.5 + self.rope_theta = 10000 + self.sliding_window = 0 + self.dtype = "bfloat16" + self.cache_quant_type = "cache_fp8" + self.use_qk_norm = True + self.use_mask_offset = False + self.causal = True + self.quant_min_bound = -448.0 + self.quant_max_bound = 448.0 + self.init_tensor() + + +if __name__ == "__main__": + unittest.main()