Skip to content

Commit 1c3030b

Browse files
Refactor prefill instantiation to match decode/splitdecode pattern
- Add FmhaPrefillRunner<HEAD_DIM> struct functor to xe_fmha_fwd_prefill_runner.hpp - Create xe_fmha_fwd_prefill_kernel.cpp.in template for CMake-generated instantiation - Create xe_fmha_fwd_prefill_dispatch.hpp with extern template declarations - Create FMHAPrefillXe20.cmake to generate per-HEAD_DIM .cpp files - Move prefill::mha_fwd() from header to flash_attention.cpp with DISPATCH_PREFILL_KERNEL macro - Update CMakeLists.txt to include FMHAPrefillXe20.cmake Agent-Logs-Url: https://github.com/sgl-project/sgl-kernel-xpu/sessions/33c51854-a2bc-41c9-af5c-cda629558e85 Co-authored-by: sunjiweiswift <16934286+sunjiweiswift@users.noreply.github.com>
1 parent 63165db commit 1c3030b

6 files changed

Lines changed: 448 additions & 285 deletions

File tree

src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ foreach(file ${device_cpp})
2020
endforeach()
2121

2222
include(FMHADecodeXe20.cmake)
23+
include(FMHAPrefillXe20.cmake)
2324

2425
message(STATUS "BMG files: ${device_cpp_xe20}")
2526
message(STATUS "Common files: ${device_cpp_common}")

src/FMHAPrefillXe20.cmake

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Generate FMHA prefill kernel instantiation files.
2+
# Each HEAD_DIM is compiled as a separate translation unit to parallelize
3+
# and speed up compilation.
4+
#
5+
# Tile shape mapping (HEAD_DIM -> TILED_Q, TILED_KV, NUM_SG):
6+
# 64 -> 128, 64, 8
7+
# 96 -> 128, 64, 8
8+
# 128 -> 256, 32, 16
9+
# 192 -> 256, 64, 32
10+
# 256 -> 256, 64, 32
11+
# 512 -> 256, 64, 32
12+
13+
set(FMHA_PREFILL_TEMPLATE
14+
"${CMAKE_CURRENT_SOURCE_DIR}/sycl/xe_fmha_fwd_prefill_kernel.cpp.in")
15+
16+
# Define the per-HEAD_DIM tile configurations
17+
# Format: HEAD_DIM;TILED_Q;TILED_KV;NUM_SG
18+
set(FMHA_PREFILL_CONFIGS
19+
"64;128;64;8"
20+
"96;128;64;8"
21+
"128;256;32;16"
22+
"192;256;64;32"
23+
"256;256;64;32"
24+
"512;256;64;32"
25+
)
26+
27+
foreach(CONFIG ${FMHA_PREFILL_CONFIGS})
28+
list(GET CONFIG 0 HEAD_DIM)
29+
list(GET CONFIG 1 TILED_Q)
30+
list(GET CONFIG 2 TILED_KV)
31+
list(GET CONFIG 3 NUM_SG)
32+
33+
set(GENERATED_FILE
34+
"${CMAKE_CURRENT_BINARY_DIR}/sycl/xe_fmha_fwd_prefill_kernel_${HEAD_DIM}.cpp")
35+
configure_file(${FMHA_PREFILL_TEMPLATE} ${GENERATED_FILE} @ONLY)
36+
list(APPEND device_cpp_common ${GENERATED_FILE})
37+
endforeach()

src/sycl/flash_attention.cpp

Lines changed: 289 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737

3838
#include "kernels/chunk_prefill/chunk_prefill_runner.hpp"
3939
#include "kernels/flash_attention_v2/xe_fmha_fwd_decode_dispatch.hpp"
40-
#include "kernels/flash_attention_v2/xe_fmha_fwd_prefill_runner.hpp"
40+
#include "kernels/flash_attention_v2/xe_fmha_fwd_prefill_dispatch.hpp"
4141

4242
namespace decode {
4343

@@ -423,6 +423,294 @@ std::vector<at::Tensor> mha_fwd(
423423

424424
} // namespace decode
425425

426+
namespace prefill {
427+
428+
// Dispatch macro following the same pattern as decode.
429+
// Directly call struct operator() - no function pointers.
430+
431+
#define DISPATCH_PREFILL_KERNEL(HD) FmhaPrefillRunner<HD>{}(params)
432+
433+
std::vector<at::Tensor> mha_fwd(
434+
const at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
435+
const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size,
436+
// h_k, d) if there is page_table.
437+
const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages,
438+
// page_size, h_k, dv) if there is page_table.
439+
std::optional<const at::Tensor>& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q
440+
const at::Tensor& cu_seqlens_q, // b+1
441+
const at::Tensor& cu_seqlens_k, // b+1
442+
int max_seqlen_q,
443+
int max_seqlen_k,
444+
std::optional<const at::Tensor>& page_table, // (b_k, max_num_pages_per_seq)
445+
std::optional<const at::Tensor>& kv_batch_idx_, // b. indices to index into the KV cache
446+
std::optional<const at::Tensor>& leftpad_k_, // b
447+
std::optional<const at::Tensor>& rotary_cos_, // seqlen_ro x (rotary_dim / 2)
448+
std::optional<const at::Tensor>& rotary_sin_, // seqlen_ro x (rotary_dim / 2)
449+
std::optional<const at::Tensor>& seqlens_rotary_, // b
450+
std::optional<at::Tensor>& q_descale_, // (b, h_k), not (b, h)
451+
std::optional<at::Tensor>& k_descale_, // (b, h_k)
452+
std::optional<at::Tensor>& v_descale_, // (b, h_k)
453+
const float softmax_scale_,
454+
std::optional<const at::Tensor>& sinks_,
455+
bool is_causal,
456+
int window_size_left,
457+
int window_size_right,
458+
float const softcap,
459+
bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
460+
std::optional<at::Tensor>& scheduler_metadata_, // (b + 1)
461+
int num_splits,
462+
std::optional<bool> pack_gqa_,
463+
int const sm_margin) {
464+
auto q_type = q.scalar_type();
465+
TORCH_CHECK(
466+
q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16,
467+
"mha_fwd only supports Half and BFloat16, got",
468+
q_type);
469+
470+
TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype");
471+
TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype");
472+
CHECK_LAST_DIM_CONTIGUOUS_INPUT(q);
473+
CHECK_LAST_DIM_CONTIGUOUS_INPUT(k);
474+
CHECK_LAST_DIM_CONTIGUOUS_INPUT(v);
475+
476+
TORCH_CHECK(page_table.value().dtype() == torch::kInt32, "page_table must have dtype torch.int32");
477+
TORCH_CHECK(page_table.value().stride(-1) == 1, "page_table must have contiguous last dimension");
478+
479+
TORCH_CHECK(q.dim() == 3, "query must be in ragged format");
480+
CHECK_INPUT(cu_seqlens_q);
481+
TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32");
482+
483+
CHECK_INPUT(cu_seqlens_k);
484+
TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32");
485+
486+
auto const sizes = q.sizes();
487+
const int batch_size = cu_seqlens_q.size(0) - 1;
488+
int seqlen_q = max_seqlen_q;
489+
int total_q = q.size(0);
490+
int num_heads = q.size(-2);
491+
int const head_size = q.size(-1);
492+
int const head_size_v = v.size(-1);
493+
int const max_num_pages_per_seq = page_table.value().size(1);
494+
int const num_pages = k.size(0);
495+
int const page_size = k.size(1);
496+
int const seqlen_k = max_num_pages_per_seq * page_size;
497+
int const total_k = num_pages * page_size;
498+
int const num_heads_k = k.size(-2);
499+
500+
int const batch_size_k = page_table.value().size(0);
501+
float softmax_scale = softmax_scale_;
502+
503+
if (!kv_batch_idx_.has_value()) {
504+
TORCH_CHECK(batch_size == batch_size_k, "batch_size must be equal to batch_size_k");
505+
}
506+
507+
// Currently only support head dims <= 512
508+
static constexpr int max_headdim = 512;
509+
TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most ", max_headdim);
510+
TORCH_CHECK(num_heads == num_heads_k, "Only support number of heads in key/value equals to number of heads in query");
511+
512+
// This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM
513+
// TODO: check this
514+
515+
if (window_size_left >= seqlen_k - 1) {
516+
window_size_left = -1;
517+
}
518+
window_size_right = min(window_size_right, seqlen_q);
519+
// causal=true is the same as causal=false in this case
520+
if (is_causal) {
521+
window_size_right = 0;
522+
}
523+
524+
CHECK_SHAPE(k, num_pages, page_size, num_heads_k, head_size);
525+
CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size_v);
526+
CHECK_SHAPE(page_table.value(), batch_size_k, max_num_pages_per_seq);
527+
528+
if (leftpad_k_.has_value()) {
529+
auto leftpad_k = leftpad_k_.value();
530+
TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32");
531+
CHECK_INPUT(leftpad_k);
532+
CHECK_SHAPE(leftpad_k, batch_size);
533+
}
534+
535+
static constexpr int alignment = 8;
536+
TORCH_CHECK(head_size % alignment == 0, "head_size should be a multiple of " + std::to_string(alignment));
537+
TORCH_CHECK(head_size_v % alignment == 0, "head_size_v should be a multiple of " + std::to_string(alignment));
538+
539+
auto opts = q.options();
540+
at::Tensor out;
541+
out = torch::empty({total_q, num_heads, head_size_v}, opts);
542+
543+
int const head_size_rounded = round_up_headdim(head_size);
544+
int const head_size_v_rounded = head_size_v == head_size ? head_size_rounded : round_up_headdim(head_size_v);
545+
546+
// Otherwise the kernel will be launched from cuda:0 device
547+
// Cast to char to avoid compiler warning about narrowing
548+
c10::DeviceGuard device_guard(q.device());
549+
550+
at::Tensor softmax_lse;
551+
softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
552+
553+
// align with FA3
554+
Arguments params;
555+
params.is_bf16 = q.dtype() == torch::kBFloat16;
556+
557+
// Set the pointers and strides.
558+
params.q_ptr = q.data_ptr();
559+
params.k_ptr = k.data_ptr();
560+
params.v_ptr = v.data_ptr();
561+
// All stride are in elements, not bytes.
562+
params.q_row_stride = q.stride(-3);
563+
params.k_row_stride = k.stride(-3);
564+
params.v_row_stride = v.stride(-3);
565+
params.q_head_stride = q.stride(-2);
566+
params.k_head_stride = k.stride(-2);
567+
params.v_head_stride = v.stride(-2);
568+
params.v_dim_stride = v.stride(-1);
569+
params.o_ptr = out.data_ptr();
570+
params.o_row_stride = out.stride(-3);
571+
params.o_head_stride = out.stride(-2);
572+
573+
params.cu_seqlens_q = cu_seqlens_q.data_ptr<int>();
574+
params.cu_seqlens_k = cu_seqlens_k.data_ptr<int>();
575+
576+
// Softmax sum
577+
params.softmax_lse_ptr = softmax_lse.data_ptr();
578+
579+
// Set the dimensions.
580+
params.b = batch_size;
581+
params.h = num_heads;
582+
params.h_k = num_heads_k;
583+
params.q_group_size = 1;
584+
params.seqlen_q = seqlen_q;
585+
params.seqlen_k = seqlen_k;
586+
params.d = head_size;
587+
params.d_rounded = head_size_rounded;
588+
589+
// Set the different scale values.
590+
params.softmax_scale = softmax_scale;
591+
params.softmax_sink_ptr = sinks_.has_value() ? sinks_.value().data_ptr() : nullptr;
592+
593+
params.softcap = softcap;
594+
595+
// Set this to probability of keeping an element to simplify things.
596+
params.p_dropout = 1.f;
597+
598+
// Causal is the special case where window_size_right == 0 and window_size_left < 0.
599+
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
600+
params.is_causal = window_size_left < 0 && window_size_right == 0;
601+
params.is_local = (window_size_left >= 0 || window_size_right >= 0) && !params.is_causal;
602+
603+
// TODO: check this
604+
if (window_size_left < 0) {
605+
window_size_left = seqlen_k - 1;
606+
}
607+
if (window_size_right < 0) {
608+
window_size_right = seqlen_q - 1;
609+
}
610+
params.window_size_left = window_size_left;
611+
params.window_size_right = window_size_right;
612+
params.total_q = total_q;
613+
params.total_k = total_k;
614+
params.b_k = batch_size_k;
615+
params.dv = head_size_v;
616+
params.page_table = page_table.value().data_ptr<int>();
617+
params.page_table_batch_stride = page_table.value().stride(0);
618+
params.max_num_pages_per_seq = max_num_pages_per_seq;
619+
params.page_size = page_size;
620+
params.num_pages = num_pages;
621+
622+
if (q_v_.has_value()) {
623+
TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64");
624+
TORCH_CHECK(
625+
q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16,
626+
"q_v is only supported for fp16 and bf16 data type");
627+
TORCH_CHECK(false, "q_v is not supported yet");
628+
at::Tensor q_v = q_v_.value();
629+
TORCH_CHECK(q_v.dtype() == q_type, "q_v must have the same dtype as query");
630+
TORCH_CHECK(q_v.stride(-1) == 1, "q_v tensor must have contiguous last dimension");
631+
CHECK_SHAPE(q_v, total_q, num_heads, head_size_v);
632+
params.qv_ptr = q_v.data_ptr();
633+
// All stride are in elements, not bytes.
634+
params.qv_row_stride = q_v.stride(-3);
635+
params.qv_head_stride = q_v.stride(-2);
636+
}
637+
638+
if (rotary_cos_.has_value()) {
639+
auto rotary_cos = rotary_cos_.value();
640+
CHECK_INPUT(rotary_cos);
641+
params.rotary_dim = rotary_cos.size(1) * 2;
642+
TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim");
643+
TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported");
644+
const int seqlen_ro = rotary_cos.size(0);
645+
TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache");
646+
CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2);
647+
TORCH_CHECK(rotary_cos.scalar_type() == q_type, "rotary_cos must have the same dtype as query");
648+
649+
TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided");
650+
auto rotary_sin = rotary_sin_.value();
651+
CHECK_INPUT(rotary_sin);
652+
CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2);
653+
TORCH_CHECK(rotary_sin.scalar_type() == q_type, "rotary_cos must have the same dtype as query");
654+
params.rotary_cos_ptr = rotary_cos.data_ptr();
655+
params.rotary_sin_ptr = rotary_sin.data_ptr();
656+
params.is_rotary_interleaved = is_rotary_interleaved;
657+
if (seqlens_rotary_.has_value()) {
658+
at::Tensor seqlens_rotary = seqlens_rotary_.value();
659+
CHECK_INPUT(seqlens_rotary);
660+
TORCH_CHECK(seqlens_rotary.dtype() == torch::kInt32, "seqlens_rotary must have dtype torch.int32");
661+
CHECK_SHAPE(seqlens_rotary, batch_size);
662+
params.seqlens_rotary = seqlens_rotary.data_ptr<int>();
663+
}
664+
} else {
665+
params.rotary_dim = 0;
666+
}
667+
668+
if (kv_batch_idx_.has_value()) {
669+
auto kv_batch_idx = kv_batch_idx_.value();
670+
CHECK_INPUT(kv_batch_idx);
671+
TORCH_CHECK(kv_batch_idx.scalar_type() == torch::kInt32, "kv_batch_idx must have dtype int32");
672+
params.kv_batch_idx = reinterpret_cast<int*>(kv_batch_idx.data_ptr());
673+
}
674+
675+
params.tensor_opts = torch::TensorOptions().dtype(torch::kUInt8).device(q.device());
676+
677+
at::Tensor out_accum, softmax_lse_accum;
678+
679+
TORCH_CHECK(
680+
params.d == 64 || params.d == 96 || params.d == 128 || params.d == 192 || params.d == 256 || params.d == 512,
681+
"Unsupported head size for prefill attention: ",
682+
params.d);
683+
684+
switch (params.d) {
685+
case 64:
686+
DISPATCH_PREFILL_KERNEL(64);
687+
break;
688+
case 96:
689+
DISPATCH_PREFILL_KERNEL(96);
690+
break;
691+
case 128:
692+
DISPATCH_PREFILL_KERNEL(128);
693+
break;
694+
case 192:
695+
DISPATCH_PREFILL_KERNEL(192);
696+
break;
697+
case 256:
698+
DISPATCH_PREFILL_KERNEL(256);
699+
break;
700+
case 512:
701+
DISPATCH_PREFILL_KERNEL(512);
702+
break;
703+
default:
704+
TORCH_CHECK(false, "Unsupported head size for prefill attention: ", params.d);
705+
}
706+
707+
return {out, softmax_lse, out_accum, softmax_lse_accum};
708+
}
709+
710+
#undef DISPATCH_PREFILL_KERNEL
711+
712+
} // namespace prefill
713+
426714
std::vector<at::Tensor> mha_fwd(
427715
const at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
428716
const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size,

0 commit comments

Comments
 (0)