Skip to content

Commit b0e2e01

Browse files
remove unused code from mask_s (#7961)
1 parent 92fdcf7 commit b0e2e01

4 files changed

Lines changed: 93 additions & 154 deletions

File tree

custom_ops/gpu_ops/append_attn/append_attention_func.cuh

Lines changed: 45 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1001,14 +1001,11 @@ __device__ __forceinline__ void compute_qk_c8(smem_t* q_smem,
10011001
}
10021002

10031003
template <typename T,
1004-
bool partition_kv,
10051004
bool causal,
10061005
uint32_t group_size,
10071006
uint32_t num_warps,
10081007
uint32_t num_frags_x,
1009-
uint32_t num_frags_y,
1010-
uint32_t num_frags_z,
1011-
bool IS_SYSTEM = false>
1008+
uint32_t num_frags_z>
10121009
__device__ __forceinline__ void mask_s(const bool* attn_mask,
10131010
const uint32_t qo_idx_base,
10141011
const uint32_t kv_idx_base,
@@ -1027,74 +1024,55 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
10271024
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
10281025
#pragma unroll
10291026
for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) {
1030-
if constexpr (!IS_SYSTEM) {
1031-
const uint32_t q_idx = (qo_idx_base + fx * 16 + tx / 4 +
1032-
8 * ((reg_id % 4) / 2)) /
1033-
group_size,
1034-
kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) +
1035-
8 * (reg_id / 4) + reg_id % 2;
1036-
bool out_of_boundary;
1037-
if (mask_offset) {
1038-
if (sliding_window > 0) {
1039-
int swa_part = mask_offset[q_idx * 2 + 1] - sliding_window;
1040-
if (swa_part < 0) swa_part = 0;
1041-
int sink_part =
1042-
mask_offset[q_idx * 2] + sink_size; // sink_size = 128
1043-
out_of_boundary =
1044-
q_idx < qo_len ? (kv_idx >= mask_offset[q_idx * 2 + 1] ||
1045-
kv_idx < mask_offset[q_idx * 2] ||
1046-
(kv_idx >= sink_part && kv_idx < swa_part))
1047-
: true;
1048-
} else {
1049-
out_of_boundary = q_idx < qo_len
1050-
? (kv_idx >= mask_offset[q_idx * 2 + 1] ||
1051-
kv_idx < mask_offset[q_idx * 2])
1052-
: true;
1053-
}
1054-
} else if (sliding_window > 0) {
1055-
bool out_of_window = int(kv_idx) <= (int)kv_len + (int)q_idx -
1056-
(int)qo_len -
1057-
sliding_window;
1058-
out_of_boundary = (causal ? (kv_idx > kv_len + q_idx - qo_len ||
1059-
out_of_window || (kv_idx >= chunk_end))
1060-
: kv_idx >= chunk_end);
1027+
const uint32_t q_idx = (qo_idx_base + fx * 16 + tx / 4 +
1028+
8 * ((reg_id % 4) / 2)) /
1029+
group_size,
1030+
kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) +
1031+
8 * (reg_id / 4) + reg_id % 2;
1032+
bool out_of_boundary;
1033+
if (mask_offset) {
1034+
if (sliding_window > 0) {
1035+
int swa_part = mask_offset[q_idx * 2 + 1] - sliding_window;
1036+
if (swa_part < 0) swa_part = 0;
1037+
int sink_part =
1038+
mask_offset[q_idx * 2] + sink_size; // sink_size = 128
1039+
out_of_boundary = q_idx < qo_len
1040+
? (kv_idx >= mask_offset[q_idx * 2 + 1] ||
1041+
kv_idx < mask_offset[q_idx * 2] ||
1042+
(kv_idx >= sink_part && kv_idx < swa_part))
1043+
: true;
10611044
} else {
1062-
out_of_boundary = (causal ? (kv_idx > kv_len + q_idx - qo_len ||
1063-
(kv_idx >= chunk_end))
1064-
: kv_idx >= chunk_end);
1065-
if (attn_mask != nullptr && kv_idx > kv_len - qo_len &&
1066-
kv_idx < chunk_end && q_idx < attn_mask_len) {
1067-
const int32_t mask_idx =
1068-
q_idx * attn_mask_len + kv_idx - kv_len + qo_len;
1069-
bool mask = attn_mask[mask_idx];
1070-
out_of_boundary |= mask;
1071-
}
1045+
out_of_boundary = q_idx < qo_len
1046+
? (kv_idx >= mask_offset[q_idx * 2 + 1] ||
1047+
kv_idx < mask_offset[q_idx * 2])
1048+
: true;
10721049
}
1073-
1074-
if constexpr (std::is_same<T, half>::value) {
1075-
s_frag[fx][fz][reg_id] =
1076-
out_of_boundary ? -5e4f : s_frag[fx][fz][reg_id];
1077-
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
1078-
s_frag[fx][fz][reg_id] =
1079-
out_of_boundary ? -3.0e+30f : s_frag[fx][fz][reg_id];
1080-
}
1081-
1050+
} else if (sliding_window > 0) {
1051+
bool out_of_window = int(kv_idx) <= (int)kv_len + (int)q_idx -
1052+
(int)qo_len - sliding_window;
1053+
out_of_boundary = (causal ? (kv_idx > kv_len + q_idx - qo_len ||
1054+
out_of_window || (kv_idx >= chunk_end))
1055+
: kv_idx >= chunk_end);
10821056
} else {
1083-
const uint32_t q_idx = qo_idx_base,
1084-
kv_idx = kv_idx_base + fz * 16 + 2 * (tx % 4) +
1085-
8 * (reg_id / 4) + reg_id % 2;
1086-
const bool out_of_boundary =
1087-
(causal
1088-
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
1089-
: kv_idx >= chunk_end);
1090-
if constexpr (std::is_same<T, half>::value) {
1091-
s_frag[fx][fz][reg_id] =
1092-
out_of_boundary ? -5e4f : s_frag[fx][fz][reg_id];
1093-
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
1094-
s_frag[fx][fz][reg_id] =
1095-
out_of_boundary ? -3.0e+30f : s_frag[fx][fz][reg_id];
1057+
out_of_boundary = (causal ? (kv_idx > kv_len + q_idx - qo_len ||
1058+
(kv_idx >= chunk_end))
1059+
: kv_idx >= chunk_end);
1060+
if (attn_mask != nullptr && kv_idx > kv_len - qo_len &&
1061+
kv_idx < chunk_end && q_idx < attn_mask_len) {
1062+
const int32_t mask_idx =
1063+
q_idx * attn_mask_len + kv_idx - kv_len + qo_len;
1064+
bool mask = attn_mask[mask_idx];
1065+
out_of_boundary |= mask;
10961066
}
10971067
}
1068+
1069+
if constexpr (std::is_same<T, half>::value) {
1070+
s_frag[fx][fz][reg_id] =
1071+
out_of_boundary ? -5e4f : s_frag[fx][fz][reg_id];
1072+
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
1073+
s_frag[fx][fz][reg_id] =
1074+
out_of_boundary ? -3.0e+30f : s_frag[fx][fz][reg_id];
1075+
}
10981076
}
10991077
}
11001078
}

custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh

Lines changed: 22 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -253,24 +253,18 @@ __global__ void multi_query_append_attention_kernel(
253253
&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag);
254254
// mask according to kv_idx and q_idx
255255
if (iter >= mask_check_iteration || sliding_window > 0) {
256-
mask_s<T,
257-
partition_kv,
258-
CAUSAL,
259-
GROUP_SIZE,
260-
NUM_WARPS,
261-
num_frags_x,
262-
num_frags_y,
263-
num_frags_z>(nullptr,
264-
q_base_seq_id_this_block,
265-
kv_idx_base,
266-
q_len,
267-
kv_len,
268-
chunk_end,
269-
-1,
270-
s_frag,
271-
mask_offset_this_seq,
272-
sliding_window,
273-
sink_size);
256+
mask_s<T, CAUSAL, GROUP_SIZE, NUM_WARPS, num_frags_x, num_frags_z>(
257+
nullptr,
258+
q_base_seq_id_this_block,
259+
kv_idx_base,
260+
q_len,
261+
kv_len,
262+
chunk_end,
263+
-1,
264+
s_frag,
265+
mask_offset_this_seq,
266+
sliding_window,
267+
sink_size);
274268
}
275269

276270
// update m,d
@@ -565,22 +559,22 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
565559

566560
const uint32_t num_iterations = div_up(
567561
CAUSAL
568-
? (min(chunk_len,
569-
sub_if_greater_or_zero(
570-
kv_len - q_len +
571-
div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE),
572-
chunk_start)))
562+
? min(chunk_len,
563+
sub_if_greater_or_zero(
564+
kv_len - q_len +
565+
div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE),
566+
chunk_start))
573567
: chunk_len,
574568
BLOCK_SIZE);
575569
const uint32_t mask_check_iteration =
576-
(CAUSAL ? (min(chunk_len,
577-
sub_if_greater_or_zero(kv_len - q_len, chunk_start)))
570+
(CAUSAL
571+
? min(chunk_len, sub_if_greater_or_zero(kv_len - q_len, chunk_start))
578572
: mask_offset ? 0
579573
: chunk_len) /
580-
(BLOCK_SIZE);
574+
BLOCK_SIZE;
581575

582576
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
583-
wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8);
577+
wid * num_frags_z * 16 + tid / 16 * 8 + tid % 8, tid % 16 / 8);
584578

585579
uint32_t v_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
586580
wid * num_frags_z * 16 + tid % 16, tid / 16);
@@ -637,14 +631,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
637631
&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag);
638632
// mask according to kv_idx and q_idx
639633
if (iter >= mask_check_iteration || sliding_window > 0) {
640-
mask_s<T,
641-
partition_kv,
642-
CAUSAL,
643-
GROUP_SIZE,
644-
NUM_WARPS,
645-
num_frags_x,
646-
num_frags_y,
647-
num_frags_z>(
634+
mask_s<T, CAUSAL, GROUP_SIZE, NUM_WARPS, num_frags_x, num_frags_z>(
648635
attn_mask ? attn_mask + batch_id * attn_mask_len * attn_mask_len
649636
: nullptr,
650637
q_base_seq_id_this_block,

custom_ops/gpu_ops/append_attn/multiquery_attention_c4_impl.cuh

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -338,24 +338,18 @@ __global__ void multi_query_append_attention_c4_kernel(
338338
cache_k_zp_frag);
339339

340340
if (iter >= mask_check_iteration || sliding_window > 0) {
341-
mask_s<T,
342-
partition_kv,
343-
CAUSAL,
344-
GROUP_SIZE,
345-
NUM_WARPS,
346-
num_frags_x,
347-
num_frags_y,
348-
num_frags_z>(nullptr,
349-
q_base_seq_id_this_block,
350-
kv_idx_base,
351-
q_len,
352-
kv_len,
353-
chunk_end,
354-
-1,
355-
s_frag,
356-
mask_offset_this_seq,
357-
sliding_window,
358-
sink_size);
341+
mask_s<T, CAUSAL, GROUP_SIZE, NUM_WARPS, num_frags_x, num_frags_z>(
342+
nullptr,
343+
q_base_seq_id_this_block,
344+
kv_idx_base,
345+
q_len,
346+
kv_len,
347+
chunk_end,
348+
-1,
349+
s_frag,
350+
mask_offset_this_seq,
351+
sliding_window,
352+
sink_size);
359353
}
360354

361355
update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(
@@ -837,14 +831,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
837831
cache_k_scale_frag,
838832
cache_k_zp_frag);
839833
if (iter >= mask_check_iteration || sliding_window > 0) {
840-
mask_s<T,
841-
partition_kv,
842-
CAUSAL,
843-
GROUP_SIZE,
844-
NUM_WARPS,
845-
num_frags_x,
846-
num_frags_y,
847-
num_frags_z>(
834+
mask_s<T, CAUSAL, GROUP_SIZE, NUM_WARPS, num_frags_x, num_frags_z>(
848835
attn_mask ? attn_mask + batch_id * attn_mask_len * attn_mask_len
849836
: nullptr,
850837
q_base_seq_id_this_block,

custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -354,24 +354,18 @@ __global__ void multi_query_append_attention_c8_kernel(
354354

355355
// mask according to kv_idx and q_idx
356356
if (iter >= mask_check_iteration || sliding_window > 0) {
357-
mask_s<T,
358-
partition_kv,
359-
CAUSAL,
360-
GROUP_SIZE,
361-
NUM_WARPS,
362-
num_frags_x,
363-
num_frags_y,
364-
num_frags_z>(nullptr,
365-
q_base_seq_id_this_block,
366-
kv_idx_base,
367-
q_len,
368-
kv_len,
369-
chunk_end,
370-
-1,
371-
s_frag,
372-
mask_offset_this_seq,
373-
sliding_window,
374-
sink_size);
357+
mask_s<T, CAUSAL, GROUP_SIZE, NUM_WARPS, num_frags_x, num_frags_z>(
358+
nullptr,
359+
q_base_seq_id_this_block,
360+
kv_idx_base,
361+
q_len,
362+
kv_len,
363+
chunk_end,
364+
-1,
365+
s_frag,
366+
mask_offset_this_seq,
367+
sliding_window,
368+
sink_size);
375369
}
376370

377371
// update m,d
@@ -903,14 +897,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
903897
s_frag);
904898
// mask according to kv_idx and q_idx
905899
if (iter >= mask_check_iteration || sliding_window > 0) {
906-
mask_s<T,
907-
partition_kv,
908-
CAUSAL,
909-
GROUP_SIZE,
910-
NUM_WARPS,
911-
num_frags_x,
912-
num_frags_y,
913-
num_frags_z>(
900+
mask_s<T, CAUSAL, GROUP_SIZE, NUM_WARPS, num_frags_x, num_frags_z>(
914901
attn_mask ? attn_mask + batch_id * attn_mask_len * attn_mask_len
915902
: nullptr,
916903
q_base_seq_id_this_block,

0 commit comments

Comments
 (0)