Skip to content

Commit 18f0124

Browse files
[OP][Optimization] Remove ENABLE_PREFILL template parameter in multi_query_append_attention_warp1_4_kernel (#7201)
1 parent 8cb417e commit 18f0124

4 files changed

Lines changed: 32 additions & 66 deletions

File tree

custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh

Lines changed: 26 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -430,8 +430,7 @@ template <typename T,
430430
uint32_t num_frags_x,
431431
uint32_t num_frags_z,
432432
uint32_t num_frags_y,
433-
typename OutT = T,
434-
bool ENABLE_PREFILL = true>
433+
typename OutT = T>
435434
__global__ void multi_query_append_attention_warp1_4_kernel(
436435
T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
437436
T *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
@@ -525,17 +524,11 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
525524
if (!partition_kv || num_chunks_this_seq <= 1) {
526525
o_base_ptr_int8 = out + o_offset;
527526
} else {
528-
if (ENABLE_PREFILL) {
529-
o_base_ptr_T = tmp_workspace + batch_id * num_chunks * q_n_stride +
530-
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
531-
tid % 8 * num_elems_per_128b<T>();
532-
} else {
533-
o_base_ptr_T =
534-
tmp_workspace +
535-
batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride +
536-
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
537-
tid % 8 * num_elems_per_128b<T>();
538-
}
527+
o_base_ptr_T =
528+
tmp_workspace +
529+
batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride +
530+
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
531+
tid % 8 * num_elems_per_128b<T>();
539532
}
540533
const int *mask_offset_this_seq =
541534
mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
@@ -799,18 +792,12 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
799792
const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE;
800793

801794
if (qo_idx - q_start_seq_id < q_len) {
802-
uint32_t offset;
803-
if (ENABLE_PREFILL) {
804-
offset = (batch_id * num_chunks + chunk_idx) * q_num_heads +
805-
qo_head_idx;
806-
} else {
807-
offset = ((batch_id * speculate_max_draft_token_num +
808-
qo_idx_now / GROUP_SIZE) *
809-
num_chunks +
810-
chunk_idx) *
811-
q_num_heads +
812-
qo_head_idx;
813-
}
795+
const uint32_t offset = ((batch_id * speculate_max_draft_token_num +
796+
qo_idx_now / GROUP_SIZE) *
797+
num_chunks +
798+
chunk_idx) *
799+
q_num_heads +
800+
qo_head_idx;
814801
tmp_m[offset] = m_frag[fx][j];
815802
tmp_d[offset] = d_frag[fx][j];
816803
}
@@ -1123,8 +1110,7 @@ void MultiQueryAppendAttention(
11231110
num_frags_x,
11241111
num_frags_z,
11251112
num_frags_y,
1126-
OUT_NV_TYPE,
1127-
ENABLE_PREFILL>;
1113+
OUT_NV_TYPE>;
11281114
if (smem_size >= 48 * 1024) {
11291115
cudaFuncSetAttribute(split_kv_kernel,
11301116
cudaFuncAttributeMaxDynamicSharedMemorySize,
@@ -1169,8 +1155,7 @@ void MultiQueryAppendAttention(
11691155
num_frags_x,
11701156
num_frags_z,
11711157
num_frags_y,
1172-
OUT_NV_TYPE,
1173-
ENABLE_PREFILL>;
1158+
OUT_NV_TYPE>;
11741159
if (smem_size >= 48 * 1024) {
11751160
cudaFuncSetAttribute(nosplit_kv_kernel,
11761161
cudaFuncAttributeMaxDynamicSharedMemorySize,
@@ -1222,43 +1207,18 @@ void MultiQueryAppendAttention(
12221207
sink_size);
12231208
} else {
12241209
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
1225-
if (is_decoder) {
1226-
tmp_workspace = allocator->Allocate(
1227-
phi::SizeOf(qkv.dtype()) *
1228-
static_cast<size_t>(bsz * num_chunks * num_heads * HEAD_DIM));
1229-
tmp_m = allocator->Allocate(
1230-
phi::SizeOf(paddle::DataType::FLOAT32) *
1231-
static_cast<size_t>(bsz * num_chunks * num_heads));
1232-
tmp_d = allocator->Allocate(
1233-
phi::SizeOf(paddle::DataType::FLOAT32) *
1234-
static_cast<size_t>(bsz * num_chunks * num_heads));
1235-
} else {
1236-
if (ENABLE_PREFILL) {
1237-
tmp_workspace =
1238-
allocator->Allocate(phi::SizeOf(qkv.dtype()) *
1239-
static_cast<size_t>(token_num * num_chunks *
1240-
num_heads * HEAD_DIM));
1241-
tmp_m = allocator->Allocate(
1242-
phi::SizeOf(paddle::DataType::FLOAT32) *
1243-
static_cast<size_t>(token_num * num_chunks * num_heads));
1244-
tmp_d = allocator->Allocate(
1245-
phi::SizeOf(paddle::DataType::FLOAT32) *
1246-
static_cast<size_t>(token_num * num_chunks * num_heads));
1247-
} else {
1248-
tmp_workspace = allocator->Allocate(
1249-
phi::SizeOf(qkv.dtype()) *
1250-
static_cast<size_t>(speculate_max_draft_token_num * bsz *
1251-
num_chunks * num_heads * HEAD_DIM));
1252-
tmp_m = allocator->Allocate(
1253-
phi::SizeOf(paddle::DataType::FLOAT32) *
1254-
static_cast<size_t>(speculate_max_draft_token_num * bsz *
1255-
num_chunks * num_heads));
1256-
tmp_d = allocator->Allocate(
1257-
phi::SizeOf(paddle::DataType::FLOAT32) *
1258-
static_cast<size_t>(speculate_max_draft_token_num * bsz *
1259-
num_chunks * num_heads));
1260-
}
1261-
}
1210+
tmp_workspace = allocator->Allocate(
1211+
phi::SizeOf(qkv.dtype()) *
1212+
static_cast<size_t>(speculate_max_draft_token_num * bsz * num_chunks *
1213+
num_heads * HEAD_DIM));
1214+
tmp_m = allocator->Allocate(
1215+
phi::SizeOf(paddle::DataType::FLOAT32) *
1216+
static_cast<size_t>(speculate_max_draft_token_num * bsz * num_chunks *
1217+
num_heads));
1218+
tmp_d = allocator->Allocate(
1219+
phi::SizeOf(paddle::DataType::FLOAT32) *
1220+
static_cast<size_t>(speculate_max_draft_token_num * bsz * num_chunks *
1221+
num_heads));
12621222
launchWithPdlWhenEnabled(
12631223
split_kv_kernel,
12641224
grids,

fastdeploy/model_executor/layers/attention/append_attn_backend.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,8 @@ def __init__(
146146
self.causal: bool = getattr(fd_config.model_config, "causal", True)
147147
self.speculative_method = fd_config.speculative_config.method
148148
self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens
149+
if self.speculative_method is None:
150+
self.speculate_max_draft_token_num = 0
149151
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
150152
self.num_layers_draft_model: int = int(fd_config.speculative_config.method == SpecMethod.MTP)
151153

fastdeploy/model_executor/layers/attention/flash_attn_backend.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,8 @@ def __init__(
258258
self.speculative_method = fd_config.speculative_config.method
259259
self.use_speculate = self.speculative_method is not None
260260
self.speculate_max_draft_token_num = fd_config.speculative_config.num_speculative_tokens
261+
if not self.use_speculate:
262+
self.speculate_max_draft_token_num = 0
261263
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
262264
self.num_layers_draft_model: int = int(fd_config.speculative_config.method == SpecMethod.MTP)
263265

fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ def __init__(
109109
self.speculative_method = fd_config.speculative_config.method
110110
self.use_speculate = self.speculative_method is not None
111111
self.speculate_max_draft_token_num = fd_config.speculative_config.num_speculative_tokens
112+
if not self.use_speculate:
113+
self.speculate_max_draft_token_num = 0
112114
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
113115
self.num_layers_draft_model: int = int(fd_config.speculative_config.method == SpecMethod.MTP)
114116

0 commit comments

Comments
 (0)