@@ -430,8 +430,7 @@ template <typename T,
430430 uint32_t num_frags_x,
431431 uint32_t num_frags_z,
432432 uint32_t num_frags_y,
433- typename OutT = T,
434- bool ENABLE_PREFILL = true >
433+ typename OutT = T>
435434__global__ void multi_query_append_attention_warp1_4_kernel (
436435 T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
437436 T *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
@@ -525,17 +524,11 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
525524 if (!partition_kv || num_chunks_this_seq <= 1 ) {
526525 o_base_ptr_int8 = out + o_offset;
527526 } else {
528- if (ENABLE_PREFILL) {
529- o_base_ptr_T = tmp_workspace + batch_id * num_chunks * q_n_stride +
530- chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
531- tid % 8 * num_elems_per_128b<T>();
532- } else {
533- o_base_ptr_T =
534- tmp_workspace +
535- batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride +
536- chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
537- tid % 8 * num_elems_per_128b<T>();
538- }
527+ o_base_ptr_T =
528+ tmp_workspace +
529+ batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride +
530+ chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
531+ tid % 8 * num_elems_per_128b<T>();
539532 }
540533 const int *mask_offset_this_seq =
541534 mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr ;
@@ -799,18 +792,12 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
799792 const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE;
800793
801794 if (qo_idx - q_start_seq_id < q_len) {
802- uint32_t offset;
803- if (ENABLE_PREFILL) {
804- offset = (batch_id * num_chunks + chunk_idx) * q_num_heads +
805- qo_head_idx;
806- } else {
807- offset = ((batch_id * speculate_max_draft_token_num +
808- qo_idx_now / GROUP_SIZE) *
809- num_chunks +
810- chunk_idx) *
811- q_num_heads +
812- qo_head_idx;
813- }
795+ const uint32_t offset = ((batch_id * speculate_max_draft_token_num +
796+ qo_idx_now / GROUP_SIZE) *
797+ num_chunks +
798+ chunk_idx) *
799+ q_num_heads +
800+ qo_head_idx;
814801 tmp_m[offset] = m_frag[fx][j];
815802 tmp_d[offset] = d_frag[fx][j];
816803 }
@@ -1123,8 +1110,7 @@ void MultiQueryAppendAttention(
11231110 num_frags_x,
11241111 num_frags_z,
11251112 num_frags_y,
1126- OUT_NV_TYPE,
1127- ENABLE_PREFILL>;
1113+ OUT_NV_TYPE>;
11281114 if (smem_size >= 48 * 1024 ) {
11291115 cudaFuncSetAttribute (split_kv_kernel,
11301116 cudaFuncAttributeMaxDynamicSharedMemorySize,
@@ -1169,8 +1155,7 @@ void MultiQueryAppendAttention(
11691155 num_frags_x,
11701156 num_frags_z,
11711157 num_frags_y,
1172- OUT_NV_TYPE,
1173- ENABLE_PREFILL>;
1158+ OUT_NV_TYPE>;
11741159 if (smem_size >= 48 * 1024 ) {
11751160 cudaFuncSetAttribute (nosplit_kv_kernel,
11761161 cudaFuncAttributeMaxDynamicSharedMemorySize,
@@ -1222,43 +1207,18 @@ void MultiQueryAppendAttention(
12221207 sink_size);
12231208 } else {
12241209 phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
1225- if (is_decoder) {
1226- tmp_workspace = allocator->Allocate (
1227- phi::SizeOf (qkv.dtype ()) *
1228- static_cast <size_t >(bsz * num_chunks * num_heads * HEAD_DIM));
1229- tmp_m = allocator->Allocate (
1230- phi::SizeOf (paddle::DataType::FLOAT32) *
1231- static_cast <size_t >(bsz * num_chunks * num_heads));
1232- tmp_d = allocator->Allocate (
1233- phi::SizeOf (paddle::DataType::FLOAT32) *
1234- static_cast <size_t >(bsz * num_chunks * num_heads));
1235- } else {
1236- if (ENABLE_PREFILL) {
1237- tmp_workspace =
1238- allocator->Allocate (phi::SizeOf (qkv.dtype ()) *
1239- static_cast <size_t >(token_num * num_chunks *
1240- num_heads * HEAD_DIM));
1241- tmp_m = allocator->Allocate (
1242- phi::SizeOf (paddle::DataType::FLOAT32) *
1243- static_cast <size_t >(token_num * num_chunks * num_heads));
1244- tmp_d = allocator->Allocate (
1245- phi::SizeOf (paddle::DataType::FLOAT32) *
1246- static_cast <size_t >(token_num * num_chunks * num_heads));
1247- } else {
1248- tmp_workspace = allocator->Allocate (
1249- phi::SizeOf (qkv.dtype ()) *
1250- static_cast <size_t >(speculate_max_draft_token_num * bsz *
1251- num_chunks * num_heads * HEAD_DIM));
1252- tmp_m = allocator->Allocate (
1253- phi::SizeOf (paddle::DataType::FLOAT32) *
1254- static_cast <size_t >(speculate_max_draft_token_num * bsz *
1255- num_chunks * num_heads));
1256- tmp_d = allocator->Allocate (
1257- phi::SizeOf (paddle::DataType::FLOAT32) *
1258- static_cast <size_t >(speculate_max_draft_token_num * bsz *
1259- num_chunks * num_heads));
1260- }
1261- }
1210+ tmp_workspace = allocator->Allocate (
1211+ phi::SizeOf (qkv.dtype ()) *
1212+ static_cast <size_t >(speculate_max_draft_token_num * bsz * num_chunks *
1213+ num_heads * HEAD_DIM));
1214+ tmp_m = allocator->Allocate (
1215+ phi::SizeOf (paddle::DataType::FLOAT32) *
1216+ static_cast <size_t >(speculate_max_draft_token_num * bsz * num_chunks *
1217+ num_heads));
1218+ tmp_d = allocator->Allocate (
1219+ phi::SizeOf (paddle::DataType::FLOAT32) *
1220+ static_cast <size_t >(speculate_max_draft_token_num * bsz * num_chunks *
1221+ num_heads));
12621222 launchWithPdlWhenEnabled (
12631223 split_kv_kernel,
12641224 grids,
0 commit comments