From 4421989a6a9f1aef7293715079102ec551ee4ae6 Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Wed, 18 Mar 2026 14:13:06 +0800 Subject: [PATCH 01/23] Split FMHA decode and GroupGemm template instantiations into per-kernel compilation units (#140) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Split the monolithic template instantiation of xe_fmha_fwd_decode_runner.hpp into 72 separate .cpp files (one per QG_SZ × HEAD_DIM × PAGE_SIZE combination), each compiled as its own library. This enables parallel compilation and significantly speeds up build times. Changes: - Create xe_fmha_fwd_decode_kernel.cpp.in template for per-combination compilation - Create xe_fmha_fwd_decode_dispatch.hpp with function declarations for all 72 kernels - Move decode::mha_fwd() from header to flash_attention.cpp with dispatch table - Update src/CMakeLists.txt to generate .cpp files via configure_file() - Remove mha_fwd() definition from xe_fmha_fwd_decode_runner.hpp header Co-authored-by: airMeng <39229107+airMeng@users.noreply.github.com> Co-authored-by: jiwei1.sun --- CMakeLists.txt | 2 +- Dockerfile.xpu_kernel | 5 +- src/CMakeLists.txt | 4 + src/FMHADecodeXe20.cmake | 23 ++ src/GroupGemmXe20.cmake | 32 ++ src/sycl/GroupGemmXe20.cpp | 123 +++--- src/sycl/GroupGemmXe20LauncherInstance.cpp.in | 109 ++++++ src/sycl/flash_attention.cpp | 368 ++++++++++++++++++ .../xe_fmha_fwd_decode_dispatch.hpp | 73 ++++ .../xe_fmha_fwd_decode_runner.hpp | 325 ---------------- src/sycl/xe_fmha_fwd_decode_kernel.cpp.in | 56 +++ 11 files changed, 725 insertions(+), 395 deletions(-) create mode 100644 src/FMHADecodeXe20.cmake create mode 100644 src/GroupGemmXe20.cmake create mode 100644 src/sycl/GroupGemmXe20LauncherInstance.cpp.in create mode 100644 src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_dispatch.hpp create mode 100644 src/sycl/xe_fmha_fwd_decode_kernel.cpp.in diff --git a/CMakeLists.txt b/CMakeLists.txt index 1ba0e2ef..8147a33e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -42,7 +42,7 @@ set(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable headers only mode in cutla FetchContent_Declare( repo-cutlass-sycl GIT_REPOSITORY https://github.com/intel/sycl-tla.git - GIT_TAG 482b40e8bed0e9204311d1569c876b4573dfb952 + GIT_TAG 64584484b4279b1b4184b508af445698a4a1b603 GIT_SHALLOW OFF ) FetchContent_MakeAvailable(repo-cutlass-sycl) diff --git a/Dockerfile.xpu_kernel b/Dockerfile.xpu_kernel index 3c34ba22..8e2a25b7 100644 --- a/Dockerfile.xpu_kernel +++ b/Dockerfile.xpu_kernel @@ -22,10 +22,10 @@ ARG SG_LANG_KERNEL_BRANCH=main # Install the latest UMD driver for SYCL-TLA RUN apt-get install -y software-properties-common && \ add-apt-repository -y ppa:kobuk-team/intel-graphics && \ + apt-get update && \ apt-get install -y libze-intel-gpu1 libze1 intel-metrics-discovery intel-opencl-icd clinfo intel-gsc && \ apt-get install -y intel-media-va-driver-non-free libmfx-gen1 libvpl2 libvpl-tools libva-glx2 va-driver-all vainfo && \ - apt-get install -y libze-dev intel-ocloc && \ - apt-get update + apt-get install -y libze-dev intel-ocloc # Install Miniforge & PyTorch/Triton RUN curl -fsSL -v -o miniforge.sh -O https://github.com/conda-forge/miniforge/releases/download/25.1.1-0/Miniforge3-Linux-x86_64.sh && \ @@ -66,3 +66,4 @@ RUN --mount=type=secret,id=github_token \ # Set the default shell to bash SHELL ["bash", "-c"] CMD ["bash", "-c", "source /root/.bashrc && exec bash"] +USER root diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 115dcbd4..8f59214c 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -19,6 +19,8 @@ foreach(file ${device_cpp}) endforeach() +include(FMHADecodeXe20.cmake) + message(STATUS "BMG files: ${device_cpp_xe20}") message(STATUS "Common files: ${device_cpp_common}") @@ -26,6 +28,8 @@ list(APPEND ATen_XPU_CPP_SRCS ${host_cpp}) list(APPEND ATen_XPU_SYCL_COMMON ${device_cpp_common}) list(APPEND ATen_XPU_SYCL_XE20 ${device_cpp_xe20}) +include(${CMAKE_CURRENT_SOURCE_DIR}/GroupGemmXe20.cmake) + set(ATen_XPU_CPP_SRCS ${ATen_XPU_CPP_SRCS} PARENT_SCOPE) set(ATen_XPU_SYCL_COMMON ${ATen_XPU_SYCL_COMMON} PARENT_SCOPE) set(ATen_XPU_SYCL_XE20 ${ATen_XPU_SYCL_XE20} PARENT_SCOPE) diff --git a/src/FMHADecodeXe20.cmake b/src/FMHADecodeXe20.cmake new file mode 100644 index 00000000..fe996ee5 --- /dev/null +++ b/src/FMHADecodeXe20.cmake @@ -0,0 +1,23 @@ +# Generate FMHA decode kernel instantiation files. +# Each (QG_SZ, HEAD_DIM, PAGE_SIZE) combination is compiled as a separate +# library to parallelize and speed up compilation. + +set(FMHA_DECODE_QG_SIZES 1 2 4 8 16 32) +set(FMHA_DECODE_HEAD_DIMS 64 96 128 192) +set(FMHA_DECODE_PAGE_SIZES 32 64 128) + +set(FMHA_DECODE_TEMPLATE + "${CMAKE_CURRENT_SOURCE_DIR}/sycl/xe_fmha_fwd_decode_kernel.cpp.in") + +foreach(QG_SZ ${FMHA_DECODE_QG_SIZES}) + foreach(HEAD_DIM ${FMHA_DECODE_HEAD_DIMS}) + foreach(PAGE_SIZE ${FMHA_DECODE_PAGE_SIZES}) + math(EXPR NUM_SG "${PAGE_SIZE} / 16") + + set(GENERATED_FILE + "${CMAKE_CURRENT_BINARY_DIR}/sycl/xe_fmha_fwd_decode_kernel_${QG_SZ}_${HEAD_DIM}_${PAGE_SIZE}.cpp") + configure_file(${FMHA_DECODE_TEMPLATE} ${GENERATED_FILE} @ONLY) + list(APPEND device_cpp_common ${GENERATED_FILE}) + endforeach() + endforeach() +endforeach() diff --git a/src/GroupGemmXe20.cmake b/src/GroupGemmXe20.cmake new file mode 100644 index 00000000..267d7233 --- /dev/null +++ b/src/GroupGemmXe20.cmake @@ -0,0 +1,32 @@ +set(GROUP_GEMM_XE20_TEMPLATE "${CMAKE_CURRENT_SOURCE_DIR}/sycl/GroupGemmXe20LauncherInstance.cpp.in") +set(GROUP_GEMM_XE20_GEN_DIR "${CMAKE_CURRENT_BINARY_DIR}/generated/group_gemm_xe20") +set(GROUP_GEMM_XE20_INST_SRCS) +file(MAKE_DIRECTORY ${GROUP_GEMM_XE20_GEN_DIR}) + +function(add_group_gemm_xe20_inst TILE_M TILE_N TILE_K SG_SHAPE SG_STRIDE ACT_TYPE FUSE_ACT WITH_BIAS) + set(TILE "Shape<${TILE_M}, ${TILE_N}, ${TILE_K}>") + set(SGLAYOUT "Layout, Stride<${SG_STRIDE}>>") + set(GEN_SRC + "${GROUP_GEMM_XE20_GEN_DIR}/GroupGemmXe20_inst_${TILE_M}_${TILE_N}_${TILE_K}_a${ACT_TYPE}_f${FUSE_ACT}_b${WITH_BIAS}.cpp") + + configure_file(${GROUP_GEMM_XE20_TEMPLATE} ${GEN_SRC} @ONLY) + list(APPEND GROUP_GEMM_XE20_INST_SRCS ${GEN_SRC}) + set(GROUP_GEMM_XE20_INST_SRCS ${GROUP_GEMM_XE20_INST_SRCS} PARENT_SCOPE) +endfunction() + +foreach(act_type 0 1) + foreach(with_bias true false) + foreach(fuse_act true false) + add_group_gemm_xe20_inst("_8" "_64" "_32" "_1, _4, _1" "_4, _1, _0" ${act_type} ${fuse_act} ${with_bias}) + add_group_gemm_xe20_inst("_16" "_64" "_32" "_1, _4, _1" "_4, _1, _0" ${act_type} ${fuse_act} ${with_bias}) + add_group_gemm_xe20_inst("_32" "_64" "_32" "_1, _4, _1" "_4, _1, _0" ${act_type} ${fuse_act} ${with_bias}) + endforeach() + + add_group_gemm_xe20_inst("_128" "_64" "_32" "_4, _2, _1" "_2, _1, _0" ${act_type} true ${with_bias}) + add_group_gemm_xe20_inst("_128" "_128" "_32" "_4, _2, _1" "_2, _1, _0" ${act_type} false ${with_bias}) + add_group_gemm_xe20_inst("_256" "_64" "_32" "_8, _2, _1" "_2, _1, _0" ${act_type} true ${with_bias}) + add_group_gemm_xe20_inst("_256" "_256" "_32" "_8, _4, _1" "_4, _1, _0" ${act_type} false ${with_bias}) + endforeach() +endforeach() + +list(APPEND ATen_XPU_SYCL_XE20 ${GROUP_GEMM_XE20_INST_SRCS}) diff --git a/src/sycl/GroupGemmXe20.cpp b/src/sycl/GroupGemmXe20.cpp index f2e8ab2d..e0c7a364 100644 --- a/src/sycl/GroupGemmXe20.cpp +++ b/src/sycl/GroupGemmXe20.cpp @@ -15,10 +15,6 @@ using namespace cute; using ElementAccumulator = float; // <- data type of accumulator -template -class GemmXe20Name; - -// ActType: 0=silu, 1=gelu template void Xe20MoEGEMMLauncher( sycl::queue q, @@ -31,69 +27,62 @@ void Xe20MoEGEMMLauncher( const int gemm_k, const int* num_rows_per_expert_device, const int num_experts, - int* workspace) { - using Element = cutlass::bfloat16_t; - - auto make_dummy_tensor = [&](auto val, auto stride) { - return make_tensor(make_gmem_ptr(&val), make_layout(repeat>(1), stride)); - }; - auto make_dummy_bias = [&](auto val) { - return make_tensor(make_gmem_ptr(&val), make_layout(Shape{}, Stride<_1>{})); - }; - using StrideA = Stride; - using StrideB = Stride; - using StrideD = Stride; - using TensorA = decltype(make_dummy_tensor(Element{}, StrideA{})); - using TensorB = decltype(make_dummy_tensor(Element{}, StrideB{})); - using TensorD = decltype(make_dummy_tensor(Element{}, StrideD{})); - using TensorBias = decltype(make_dummy_bias(Element{})); - - using ElementA_non_CV = cutlass::platform::remove_cv_t; - using MMA = - typename TiledMMAHelper>, Layout, SGLayout>::TiledMMA; - auto mma = MMA{}; - - int sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0); - auto MaxThreadsPerWorkgroup = size(mma); - - static constexpr int MaxThreadsPerSM = 512; - - TORCH_CHECK( - MaxThreadsPerSM % MaxThreadsPerWorkgroup == 0, "MaxThreadsPerSM must be divisible by MaxThreadsPerWorkgroup") - - sycl::range<3> local(1, 1, MaxThreadsPerWorkgroup); - sycl::range<3> global(1, sm_count * MaxThreadsPerSM / MaxThreadsPerWorkgroup, 1); - - namespace syclex = sycl::ext::oneapi::experimental; - namespace intelex = sycl::ext::intel::experimental; - - syclex::properties kernel_props{syclex::sub_group_size<16>, intelex::grf_size<256>}; - - using Kernel = - MoE::MoEGEMM; - typename Kernel::Params params{ - static_cast(activations), - static_cast(weights), - static_cast(bias), - static_cast(outputs), - num_rows_per_expert_device, - gemm_n, - gemm_k, - num_experts, - workspace, - mma, - }; - - auto event = q.submit([&](sycl::handler& h) { - sycl::local_accessor local_mem(sycl::range<1>(1), h); - h.parallel_for>( - sycl::nd_range<3>(global * local, local), kernel_props, [=](sycl::nd_item<3> item) { - int32_t* slm_mem = - static_cast(local_mem.template get_multi_ptr().get()); - Kernel{}(params, item, slm_mem); - }); - }); -} + int* workspace); + +using Tile_8_64_32 = Shape<_8, _64, _32>; +using Tile_16_64_32 = Shape<_16, _64, _32>; +using Tile_32_64_32 = Shape<_32, _64, _32>; +using Tile_128_64_32 = Shape<_128, _64, _32>; +using Tile_128_128_32 = Shape<_128, _128, _32>; +using Tile_256_64_32 = Shape<_256, _64, _32>; +using Tile_256_256_32 = Shape<_256, _256, _32>; + +using SG_1_4_1 = Layout, Stride<_4, _1, _0>>; +using SG_4_2_1 = Layout, Stride<_2, _1, _0>>; +using SG_8_2_1 = Layout, Stride<_2, _1, _0>>; +using SG_8_4_1 = Layout, Stride<_4, _1, _0>>; + +#define DECLARE_XE20_MOE_EXTERN(Tile, SGLayout, ActType, FuseAct, WithBias) \ + extern template void Xe20MoEGEMMLauncher( \ + sycl::queue, \ + const void*, \ + const void*, \ + const void*, \ + const void*, \ + void*, \ + const int, \ + const int, \ + const int*, \ + const int, \ + int*); + +#define DECLARE_XE20_MOE_TILE_ALL_FUSES(Tile, SGLayout) \ + DECLARE_XE20_MOE_EXTERN(Tile, SGLayout, 0, true, true) \ + DECLARE_XE20_MOE_EXTERN(Tile, SGLayout, 0, true, false) \ + DECLARE_XE20_MOE_EXTERN(Tile, SGLayout, 0, false, true) \ + DECLARE_XE20_MOE_EXTERN(Tile, SGLayout, 0, false, false) \ + DECLARE_XE20_MOE_EXTERN(Tile, SGLayout, 1, true, true) \ + DECLARE_XE20_MOE_EXTERN(Tile, SGLayout, 1, true, false) \ + DECLARE_XE20_MOE_EXTERN(Tile, SGLayout, 1, false, true) \ + DECLARE_XE20_MOE_EXTERN(Tile, SGLayout, 1, false, false) + +#define DECLARE_XE20_MOE_TILE_FUSE(Tile, SGLayout, FuseAct) \ + DECLARE_XE20_MOE_EXTERN(Tile, SGLayout, 0, FuseAct, true) \ + DECLARE_XE20_MOE_EXTERN(Tile, SGLayout, 0, FuseAct, false) \ + DECLARE_XE20_MOE_EXTERN(Tile, SGLayout, 1, FuseAct, true) \ + DECLARE_XE20_MOE_EXTERN(Tile, SGLayout, 1, FuseAct, false) + +DECLARE_XE20_MOE_TILE_ALL_FUSES(Tile_8_64_32, SG_1_4_1) +DECLARE_XE20_MOE_TILE_ALL_FUSES(Tile_16_64_32, SG_1_4_1) +DECLARE_XE20_MOE_TILE_ALL_FUSES(Tile_32_64_32, SG_1_4_1) +DECLARE_XE20_MOE_TILE_FUSE(Tile_128_64_32, SG_4_2_1, true) +DECLARE_XE20_MOE_TILE_FUSE(Tile_128_128_32, SG_4_2_1, false) +DECLARE_XE20_MOE_TILE_FUSE(Tile_256_64_32, SG_8_2_1, true) +DECLARE_XE20_MOE_TILE_FUSE(Tile_256_256_32, SG_8_4_1, false) + +#undef DECLARE_XE20_MOE_TILE_FUSE +#undef DECLARE_XE20_MOE_TILE_ALL_FUSES +#undef DECLARE_XE20_MOE_EXTERN #define LAUNCH_MOE(...) \ Xe20MoEGEMMLauncher<__VA_ARGS__>( \ diff --git a/src/sycl/GroupGemmXe20LauncherInstance.cpp.in b/src/sycl/GroupGemmXe20LauncherInstance.cpp.in new file mode 100644 index 00000000..e9a0b0f1 --- /dev/null +++ b/src/sycl/GroupGemmXe20LauncherInstance.cpp.in @@ -0,0 +1,109 @@ +#define SYCL_INTEL_TARGET 20 + +#include +#include +#include + +#include + +#include "sycl/Utils.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "sycl/kernels/moe/xe20/moe_kernel.hpp" + +using namespace cute; + +template +class GemmXe20Name; + +template +void Xe20MoEGEMMLauncher( + sycl::queue q, + const void* activations, + const void* weights, + const void* scales, + const void* bias, + void* outputs, + const int gemm_n, + const int gemm_k, + const int* num_rows_per_expert_device, + const int num_experts, + int* workspace) { + (void)scales; + using Element = cutlass::bfloat16_t; + + auto make_dummy_tensor = [&](auto val, auto stride) { + return make_tensor(make_gmem_ptr(&val), make_layout(repeat>(1), stride)); + }; + auto make_dummy_bias = [&](auto val) { + return make_tensor(make_gmem_ptr(&val), make_layout(Shape{}, Stride<_1>{})); + }; + using StrideA = Stride; + using StrideB = Stride; + using StrideD = Stride; + using TensorA = decltype(make_dummy_tensor(Element{}, StrideA{})); + using TensorB = decltype(make_dummy_tensor(Element{}, StrideB{})); + using TensorD = decltype(make_dummy_tensor(Element{}, StrideD{})); + using TensorBias = decltype(make_dummy_bias(Element{})); + + using ElementA_non_CV = cutlass::platform::remove_cv_t; + using MMA = + typename TiledMMAHelper>, Layout, SGLayout>::TiledMMA; + auto mma = MMA{}; + + int sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(0); + auto MaxThreadsPerWorkgroup = size(mma); + + static constexpr int MaxThreadsPerSM = 512; + + TORCH_CHECK( + MaxThreadsPerSM % MaxThreadsPerWorkgroup == 0, "MaxThreadsPerSM must be divisible by MaxThreadsPerWorkgroup"); + + sycl::range<3> local(1, 1, MaxThreadsPerWorkgroup); + sycl::range<3> global(1, sm_count * MaxThreadsPerSM / MaxThreadsPerWorkgroup, 1); + + namespace syclex = sycl::ext::oneapi::experimental; + namespace intelex = sycl::ext::intel::experimental; + + syclex::properties kernel_props{syclex::sub_group_size<16>, intelex::grf_size<256>}; + + using Kernel = + MoE::MoEGEMM; + typename Kernel::Params params{ + static_cast(activations), + static_cast(weights), + static_cast(bias), + static_cast(outputs), + num_rows_per_expert_device, + gemm_n, + gemm_k, + num_experts, + workspace, + mma, + }; + + q.submit([&](sycl::handler& h) { + sycl::local_accessor local_mem(sycl::range<1>(1), h); + h.parallel_for>( + sycl::nd_range<3>(global * local, local), kernel_props, [=](sycl::nd_item<3> item) { + int32_t* slm_mem = + static_cast(local_mem.template get_multi_ptr().get()); + Kernel{}(params, item, slm_mem); + }); + }); +} + +template void Xe20MoEGEMMLauncher<@TILE@, @SGLAYOUT@, @ACT_TYPE@, @FUSE_ACT@, @WITH_BIAS@>( + sycl::queue, + const void*, + const void*, + const void*, + const void*, + void*, + const int, + const int, + const int*, + const int, + int*); + +#undef SYCL_INTEL_TARGET diff --git a/src/sycl/flash_attention.cpp b/src/sycl/flash_attention.cpp index 4e2a7ad9..729f0d65 100644 --- a/src/sycl/flash_attention.cpp +++ b/src/sycl/flash_attention.cpp @@ -38,8 +38,376 @@ #include #include "kernels/chunk_prefill/chunk_prefill_runner.hpp" +#include "kernels/flash_attention_v2/xe_fmha_fwd_decode_dispatch.hpp" #include "kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp" +namespace decode { + +namespace { + +using launch_fn_t = void (*)(bool use_sink, const Arguments& params); + +#define LAUNCH_FN_ENTRY(QG, HD, PS) &launch_fmha_decode_##QG##_##HD##_##PS + +launch_fn_t get_launch_fn(int qg_sz, int head_dim, int page_size) { + // Dispatch table indexed by (qg_sz, head_dim, page_size). + // qg_sz index: {1->0, 2->1, 4->2, 8->3, 16->4, 32->5} + // head_dim index: {64->0, 96->1, 128->2, 192->3} + // page_size index: {32->0, 64->1, 128->2} + +#define PAGE_ENTRIES(QG, HD) \ + { LAUNCH_FN_ENTRY(QG, HD, 32), LAUNCH_FN_ENTRY(QG, HD, 64), LAUNCH_FN_ENTRY(QG, HD, 128) } + +#define HD_ENTRIES(QG) \ + { PAGE_ENTRIES(QG, 64), PAGE_ENTRIES(QG, 96), PAGE_ENTRIES(QG, 128), PAGE_ENTRIES(QG, 192) } + + static const launch_fn_t table[6][4][3] = { + HD_ENTRIES(1), + HD_ENTRIES(2), + HD_ENTRIES(4), + HD_ENTRIES(8), + HD_ENTRIES(16), + HD_ENTRIES(32), + }; + +#undef HD_ENTRIES +#undef PAGE_ENTRIES + + int qg_idx = -1; + switch (qg_sz) { + case 1: + qg_idx = 0; + break; + case 2: + qg_idx = 1; + break; + case 4: + qg_idx = 2; + break; + case 8: + qg_idx = 3; + break; + case 16: + qg_idx = 4; + break; + case 32: + qg_idx = 5; + break; + default: + return nullptr; + } + + int hd_idx = -1; + switch (head_dim) { + case 64: + hd_idx = 0; + break; + case 96: + hd_idx = 1; + break; + case 128: + hd_idx = 2; + break; + case 192: + hd_idx = 3; + break; + default: + return nullptr; + } + + int ps_idx = -1; + switch (page_size) { + case 32: + ps_idx = 0; + break; + case 64: + ps_idx = 1; + break; + case 128: + ps_idx = 2; + break; + default: + return nullptr; + } + + return table[qg_idx][hd_idx][ps_idx]; +} + +#undef LAUNCH_FN_ENTRY + +} // namespace + +std::vector mha_fwd( + at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, + // h_k, d) if there is page_table. + const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, + // page_size, h_k, dv) if there is page_table. + std::optional& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q + const at::Tensor& cu_seqlens_q, // b+1 + const at::Tensor& cu_seqlens_k, // b+1 + int max_seqlen_q, + int max_seqlen_k, + std::optional& page_table, // (b_k, max_num_pages_per_seq) + std::optional& kv_batch_idx_, // b. indices to index into the KV cache + std::optional& leftpad_k_, // b + std::optional& rotary_cos_, // seqlen_ro x (rotary_dim / 2) + std::optional& rotary_sin_, // seqlen_ro x (rotary_dim / 2) + std::optional& seqlens_rotary_, // b + std::optional& q_descale_, // (b, h_k), not (b, h) + std::optional& k_descale_, // (b, h_k) + std::optional& v_descale_, // (b, h_k) + const float softmax_scale_, + std::optional& sinks_, + bool is_causal, + int window_size_left, + int window_size_right, + float const softcap, + bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 + std::optional& scheduler_metadata_, // (b + 1) + int num_splits, + std::optional pack_gqa_, + int const sm_margin) { + auto q_type = q.scalar_type(); + TORCH_CHECK( + q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, + "mha_fwd only supports Half and BFloat16, got", + q_type); + + TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype"); + TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype"); + CHECK_INPUT(q); + CHECK_INPUT(k); + CHECK_INPUT(v); + TORCH_CHECK( + q.stride(-1) == 1 && k.stride(-1) == 1 && v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + + TORCH_CHECK(page_table.value().dtype() == torch::kInt32, "page_table must have dtype torch.int32"); + TORCH_CHECK(page_table.value().stride(-1) == 1, "page_table must have contiguous last dimension"); + + TORCH_CHECK(q.dim() == 3, "query must be in ragged format"); + CHECK_INPUT(cu_seqlens_q); + TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32"); + + CHECK_INPUT(cu_seqlens_k); + TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32"); + + auto const sizes = q.sizes(); + const int batch_size = cu_seqlens_q.size(0) - 1; + int seqlen_q = max_seqlen_q; + int total_q = q.size(0); + int num_heads = q.size(-2); + int const head_size = q.size(-1); + int const head_size_v = v.size(-1); + int const max_num_pages_per_seq = page_table.value().size(1); + int const num_pages = k.size(0); + int const page_size = k.size(1); + int const seqlen_k = max_num_pages_per_seq * page_size; + int const total_k = num_pages * page_size; + int const num_heads_k = k.size(-2); + int q_group_size = num_heads / num_heads_k; + + int const batch_size_k = page_table.value().size(0); + float softmax_scale = softmax_scale_; + + if (!kv_batch_idx_.has_value()) { + TORCH_CHECK(batch_size == batch_size_k, "batch_size must be equal to batch_size_k"); + } + + // Currently only support head dims <= 256 + static constexpr int max_headdim = 256; + TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most ", max_headdim); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM + // TODO: check this + + if (window_size_left >= seqlen_k - 1) { + window_size_left = -1; + } + window_size_right = min(window_size_right, seqlen_q); + // causal=true is the same as causal=false in this case + if (is_causal) { + window_size_right = 0; + } + + CHECK_SHAPE(k, num_pages, page_size, num_heads_k, head_size); + CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size_v); + CHECK_SHAPE(page_table.value(), batch_size_k, max_num_pages_per_seq); + + if (leftpad_k_.has_value()) { + auto leftpad_k = leftpad_k_.value(); + TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); + CHECK_INPUT(leftpad_k); + CHECK_SHAPE(leftpad_k, batch_size); + } + + static constexpr int alignment = 8; + TORCH_CHECK(head_size % alignment == 0, "head_size should be a multiple of " + std::to_string(alignment)); + TORCH_CHECK(head_size_v % alignment == 0, "head_size_v should be a multiple of " + std::to_string(alignment)); + + auto opts = q.options(); + at::Tensor out; + out = torch::empty({total_q, num_heads, head_size_v}, opts); + + int const head_size_rounded = round_up_headdim(head_size); + int const head_size_v_rounded = head_size_v == head_size ? head_size_rounded : round_up_headdim(head_size_v); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + c10::DeviceGuard device_guard(q.device()); + + at::Tensor softmax_lse; + softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); + + // align with FA3 + Arguments params; + params.is_bf16 = q.dtype() == torch::kBFloat16; + + // Set the pointers and strides. + params.q_ptr = q.data_ptr(); + params.k_ptr = k.data_ptr(); + params.v_ptr = v.data_ptr(); + // All stride are in elements, not bytes. + params.q_row_stride = q.stride(-3); + params.k_row_stride = k.stride(-3); + params.v_row_stride = v.stride(-3); + params.q_head_stride = q.stride(-2); + params.k_head_stride = k.stride(-2); + params.v_head_stride = v.stride(-2); + params.v_dim_stride = v.stride(-1); + params.o_ptr = out.data_ptr(); + params.o_row_stride = out.stride(-3); + params.o_head_stride = out.stride(-2); + + params.cu_seqlens_q = cu_seqlens_q.data_ptr(); + params.cu_seqlens_k = cu_seqlens_k.data_ptr(); + + // Softmax sum + params.softmax_lse_ptr = softmax_lse.data_ptr(); + + // Set the dimensions. + params.b = batch_size; + params.h = num_heads; + params.h_k = num_heads_k; + params.q_group_size = num_heads / num_heads_k; + params.seqlen_q = seqlen_q * q_group_size; + params.seqlen_k = seqlen_k; + params.d = head_size; + params.d_rounded = head_size_rounded; + + // Set the different scale values. + params.softmax_scale = softmax_scale; + bool use_sink = sinks_.has_value(); + params.softmax_sink_ptr = use_sink ? sinks_.value().data_ptr() : nullptr; + + params.softcap = softcap; + + // Set this to probability of keeping an element to simplify things. + params.p_dropout = 1.f; + + // Causal is the special case where window_size_right == 0 and window_size_left < 0. + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + params.is_causal = window_size_left < 0 && window_size_right == 0; + params.is_local = (window_size_left >= 0 || window_size_right >= 0) && !params.is_causal; + + // TODO: check this + if (window_size_left < 0) { + window_size_left = seqlen_k - 1; + } + if (window_size_right < 0) { + window_size_right = seqlen_q - 1; + } + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + params.total_q = total_q; + params.total_k = total_k; + params.b_k = batch_size_k; + params.dv = head_size_v; + params.page_table = page_table.value().data_ptr(); + params.page_table_batch_stride = page_table.value().stride(0); + params.max_num_pages_per_seq = max_num_pages_per_seq; + params.page_size = page_size; + params.num_pages = num_pages; + + if (q_v_.has_value()) { + TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64"); + TORCH_CHECK( + q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, + "q_v is only supported for fp16 and bf16 data type"); + TORCH_CHECK(false, "q_v is not supported yet"); + at::Tensor q_v = q_v_.value(); + TORCH_CHECK(q_v.dtype() == q_type, "q_v must have the same dtype as query"); + TORCH_CHECK(q_v.stride(-1) == 1, "q_v tensor must have contiguous last dimension"); + CHECK_SHAPE(q_v, total_q, num_heads, head_size_v); + params.qv_ptr = q_v.data_ptr(); + // All stride are in elements, not bytes. + params.qv_row_stride = q_v.stride(-3); + params.qv_head_stride = q_v.stride(-2); + } + + if (rotary_cos_.has_value()) { + auto rotary_cos = rotary_cos_.value(); + CHECK_INPUT(rotary_cos); + params.rotary_dim = rotary_cos.size(1) * 2; + TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim"); + TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported"); + const int seqlen_ro = rotary_cos.size(0); + TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache"); + CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2); + TORCH_CHECK(rotary_cos.scalar_type() == q_type, "rotary_cos must have the same dtype as query"); + + TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided"); + auto rotary_sin = rotary_sin_.value(); + CHECK_INPUT(rotary_sin); + CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2); + TORCH_CHECK(rotary_sin.scalar_type() == q_type, "rotary_cos must have the same dtype as query"); + params.rotary_cos_ptr = rotary_cos.data_ptr(); + params.rotary_sin_ptr = rotary_sin.data_ptr(); + params.is_rotary_interleaved = is_rotary_interleaved; + if (seqlens_rotary_.has_value()) { + at::Tensor seqlens_rotary = seqlens_rotary_.value(); + CHECK_INPUT(seqlens_rotary); + TORCH_CHECK(seqlens_rotary.dtype() == torch::kInt32, "seqlens_rotary must have dtype torch.int32"); + CHECK_SHAPE(seqlens_rotary, batch_size); + params.seqlens_rotary = seqlens_rotary.data_ptr(); + } + } else { + params.rotary_dim = 0; + } + + if (kv_batch_idx_.has_value()) { + auto kv_batch_idx = kv_batch_idx_.value(); + CHECK_INPUT(kv_batch_idx); + TORCH_CHECK(kv_batch_idx.scalar_type() == torch::kInt32, "kv_batch_idx must have dtype int32"); + params.kv_batch_idx = reinterpret_cast(kv_batch_idx.data_ptr()); + } + + params.tensor_opts = torch::TensorOptions().dtype(torch::kUInt8).device(q.device()); + + at::Tensor out_accum, softmax_lse_accum; + auto outaccum_type = at::ScalarType::Float; + + int qg_sz = nextPowerOf2(max_seqlen_q); + TORCH_CHECK(qg_sz >= 1 && qg_sz <= 32, "Unsupported qgroup_size for decode attention: ", max_seqlen_q); + TORCH_CHECK( + params.d == 64 || params.d == 96 || params.d == 128 || params.d == 192, + "Unsupported head size for decode attention: ", + params.d); + TORCH_CHECK( + params.page_size == 32 || params.page_size == 64 || params.page_size == 128, + "Unsupported page size for decode attention: ", + params.page_size); + + auto fn = get_launch_fn(qg_sz, params.d, params.page_size); + TORCH_CHECK(fn != nullptr, "No FMHA decode kernel for qg=", qg_sz, " hd=", params.d, " ps=", params.page_size); + fn(use_sink, params); + + return {out, softmax_lse, out_accum, softmax_lse_accum}; +} + +} // namespace decode + std::vector mha_fwd( at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, diff --git a/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_dispatch.hpp b/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_dispatch.hpp new file mode 100644 index 00000000..113b287b --- /dev/null +++ b/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_dispatch.hpp @@ -0,0 +1,73 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +namespace decode { + +struct Arguments; + +// Declarations for generated FMHA decode kernel launch functions. +// Each function is defined in a separate generated .cpp file from +// xe_fmha_fwd_decode_kernel.cpp.in, compiled as its own library. +// +// Naming: launch_fmha_decode___ +// Parameters: +// QG_SZ in {1, 2, 4, 8, 16, 32} +// HEAD_DIM in {64, 96, 128, 192} +// PAGE_SIZE in {32, 64, 128} (with NUM_SG = PAGE_SIZE / 16) + +#define DECLARE_LAUNCH_FMHA_DECODE(QG, HD, PS) \ + void launch_fmha_decode_##QG##_##HD##_##PS(bool use_sink, const Arguments& params); + +#define DECLARE_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES(QG, HD) \ + DECLARE_LAUNCH_FMHA_DECODE(QG, HD, 32) \ + DECLARE_LAUNCH_FMHA_DECODE(QG, HD, 64) \ + DECLARE_LAUNCH_FMHA_DECODE(QG, HD, 128) + +#define DECLARE_LAUNCH_FMHA_DECODE_ALL_QG(HD) \ + DECLARE_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES(1, HD) \ + DECLARE_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES(2, HD) \ + DECLARE_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES(4, HD) \ + DECLARE_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES(8, HD) \ + DECLARE_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES(16, HD) \ + DECLARE_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES(32, HD) + +DECLARE_LAUNCH_FMHA_DECODE_ALL_QG(64) +DECLARE_LAUNCH_FMHA_DECODE_ALL_QG(96) +DECLARE_LAUNCH_FMHA_DECODE_ALL_QG(128) +DECLARE_LAUNCH_FMHA_DECODE_ALL_QG(192) + +#undef DECLARE_LAUNCH_FMHA_DECODE +#undef DECLARE_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES +#undef DECLARE_LAUNCH_FMHA_DECODE_ALL_QG + +} // namespace decode diff --git a/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp b/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp index bd9a3f57..4c9cae96 100644 --- a/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp +++ b/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp @@ -441,329 +441,4 @@ struct FMHAConfig { } } }; -std::vector mha_fwd( - at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, - // h_k, d) if there is page_table. - const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, - // page_size, h_k, dv) if there is page_table. - std::optional& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q - const at::Tensor& cu_seqlens_q, // b+1 - const at::Tensor& cu_seqlens_k, // b+1 - int max_seqlen_q, - int max_seqlen_k, - std::optional& page_table, // (b_k, max_num_pages_per_seq) - std::optional& kv_batch_idx_, // b. indices to index into the KV cache - std::optional& leftpad_k_, // b - std::optional& rotary_cos_, // seqlen_ro x (rotary_dim / 2) - std::optional& rotary_sin_, // seqlen_ro x (rotary_dim / 2) - std::optional& seqlens_rotary_, // b - std::optional& q_descale_, // (b, h_k), not (b, h) - std::optional& k_descale_, // (b, h_k) - std::optional& v_descale_, // (b, h_k) - const float softmax_scale_, - std::optional& sinks_, - bool is_causal, - int window_size_left, - int window_size_right, - float const softcap, - bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 - std::optional& scheduler_metadata_, // (b + 1) - int num_splits, - std::optional pack_gqa_, - int const sm_margin) { - auto q_type = q.scalar_type(); - TORCH_CHECK( - q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, - "mha_fwd only supports Half and BFloat16, got", - q_type); - - TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype"); - TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype"); - CHECK_INPUT(q); - CHECK_INPUT(k); - CHECK_INPUT(v); - TORCH_CHECK( - q.stride(-1) == 1 && k.stride(-1) == 1 && v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - - TORCH_CHECK(page_table.value().dtype() == torch::kInt32, "page_table must have dtype torch.int32"); - TORCH_CHECK(page_table.value().stride(-1) == 1, "page_table must have contiguous last dimension"); - - TORCH_CHECK(q.dim() == 3, "query must be in ragged format"); - CHECK_INPUT(cu_seqlens_q); - TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32"); - - CHECK_INPUT(cu_seqlens_k); - TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32"); - - auto const sizes = q.sizes(); - const int batch_size = cu_seqlens_q.size(0) - 1; - int seqlen_q = max_seqlen_q; - int total_q = q.size(0); - int num_heads = q.size(-2); - int const head_size = q.size(-1); - int const head_size_v = v.size(-1); - int const max_num_pages_per_seq = page_table.value().size(1); - int const num_pages = k.size(0); - int const page_size = k.size(1); - int const seqlen_k = max_num_pages_per_seq * page_size; - int const total_k = num_pages * page_size; - int const num_heads_k = k.size(-2); - int q_group_size = num_heads / num_heads_k; - - int const batch_size_k = page_table.value().size(0); - float softmax_scale = softmax_scale_; - - if (!kv_batch_idx_.has_value()) { - TORCH_CHECK(batch_size == batch_size_k, "batch_size must be equal to batch_size_k"); - } - - // Currently only support head dims <= 256 - static constexpr int max_headdim = 256; - TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most ", max_headdim); - TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - - // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM - // TODO: check this - - if (window_size_left >= seqlen_k - 1) { - window_size_left = -1; - } - window_size_right = min(window_size_right, seqlen_q); - // causal=true is the same as causal=false in this case - if (is_causal) { - window_size_right = 0; - } - - CHECK_SHAPE(k, num_pages, page_size, num_heads_k, head_size); - CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size_v); - CHECK_SHAPE(page_table.value(), batch_size_k, max_num_pages_per_seq); - - if (leftpad_k_.has_value()) { - auto leftpad_k = leftpad_k_.value(); - TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); - CHECK_INPUT(leftpad_k); - CHECK_SHAPE(leftpad_k, batch_size); - } - - static constexpr int alignment = 8; - TORCH_CHECK(head_size % alignment == 0, "head_size should be a multiple of " + std::to_string(alignment)); - TORCH_CHECK(head_size_v % alignment == 0, "head_size_v should be a multiple of " + std::to_string(alignment)); - - auto opts = q.options(); - at::Tensor out; - out = torch::empty({total_q, num_heads, head_size_v}, opts); - - int const head_size_rounded = round_up_headdim(head_size); - int const head_size_v_rounded = head_size_v == head_size ? head_size_rounded : round_up_headdim(head_size_v); - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - c10::DeviceGuard device_guard(q.device()); - - at::Tensor softmax_lse; - softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); - - // align with FA3 - Arguments params; - params.is_bf16 = q.dtype() == torch::kBFloat16; - - // Set the pointers and strides. - params.q_ptr = q.data_ptr(); - params.k_ptr = k.data_ptr(); - params.v_ptr = v.data_ptr(); - // All stride are in elements, not bytes. - params.q_row_stride = q.stride(-3); - params.k_row_stride = k.stride(-3); - params.v_row_stride = v.stride(-3); - params.q_head_stride = q.stride(-2); - params.k_head_stride = k.stride(-2); - params.v_head_stride = v.stride(-2); - params.v_dim_stride = v.stride(-1); - params.o_ptr = out.data_ptr(); - params.o_row_stride = out.stride(-3); - params.o_head_stride = out.stride(-2); - - params.cu_seqlens_q = cu_seqlens_q.data_ptr(); - params.cu_seqlens_k = cu_seqlens_k.data_ptr(); - - // Softmax sum - params.softmax_lse_ptr = softmax_lse.data_ptr(); - - // Set the dimensions. - params.b = batch_size; - params.h = num_heads; - params.h_k = num_heads_k; - params.q_group_size = num_heads / num_heads_k; - params.seqlen_q = seqlen_q * q_group_size; - params.seqlen_k = seqlen_k; - params.d = head_size; - params.d_rounded = head_size_rounded; - - // Set the different scale values. - params.softmax_scale = softmax_scale; - bool use_sink = sinks_.has_value(); - params.softmax_sink_ptr = use_sink ? sinks_.value().data_ptr() : nullptr; - - params.softcap = softcap; - - // Set this to probability of keeping an element to simplify things. - params.p_dropout = 1.f; - - // Causal is the special case where window_size_right == 0 and window_size_left < 0. - // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. - params.is_causal = window_size_left < 0 && window_size_right == 0; - params.is_local = (window_size_left >= 0 || window_size_right >= 0) && !params.is_causal; - - // TODO: check this - if (window_size_left < 0) { - window_size_left = seqlen_k - 1; - } - if (window_size_right < 0) { - window_size_right = seqlen_q - 1; - } - params.window_size_left = window_size_left; - params.window_size_right = window_size_right; - params.total_q = total_q; - params.total_k = total_k; - params.b_k = batch_size_k; - params.dv = head_size_v; - params.page_table = page_table.value().data_ptr(); - params.page_table_batch_stride = page_table.value().stride(0); - params.max_num_pages_per_seq = max_num_pages_per_seq; - params.page_size = page_size; - params.num_pages = num_pages; - - if (q_v_.has_value()) { - TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64"); - TORCH_CHECK( - q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, - "q_v is only supported for fp16 and bf16 data type"); - TORCH_CHECK(false, "q_v is not supported yet"); - at::Tensor q_v = q_v_.value(); - TORCH_CHECK(q_v.dtype() == q_type, "q_v must have the same dtype as query"); - TORCH_CHECK(q_v.stride(-1) == 1, "q_v tensor must have contiguous last dimension"); - CHECK_SHAPE(q_v, total_q, num_heads, head_size_v); - params.qv_ptr = q_v.data_ptr(); - // All stride are in elements, not bytes. - params.qv_row_stride = q_v.stride(-3); - params.qv_head_stride = q_v.stride(-2); - } - - if (rotary_cos_.has_value()) { - auto rotary_cos = rotary_cos_.value(); - CHECK_INPUT(rotary_cos); - params.rotary_dim = rotary_cos.size(1) * 2; - TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim"); - TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported"); - const int seqlen_ro = rotary_cos.size(0); - TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache"); - CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2); - TORCH_CHECK(rotary_cos.scalar_type() == q_type, "rotary_cos must have the same dtype as query"); - - TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided"); - auto rotary_sin = rotary_sin_.value(); - CHECK_INPUT(rotary_sin); - CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2); - TORCH_CHECK(rotary_sin.scalar_type() == q_type, "rotary_cos must have the same dtype as query"); - params.rotary_cos_ptr = rotary_cos.data_ptr(); - params.rotary_sin_ptr = rotary_sin.data_ptr(); - params.is_rotary_interleaved = is_rotary_interleaved; - if (seqlens_rotary_.has_value()) { - at::Tensor seqlens_rotary = seqlens_rotary_.value(); - CHECK_INPUT(seqlens_rotary); - TORCH_CHECK(seqlens_rotary.dtype() == torch::kInt32, "seqlens_rotary must have dtype torch.int32"); - CHECK_SHAPE(seqlens_rotary, batch_size); - params.seqlens_rotary = seqlens_rotary.data_ptr(); - } - } else { - params.rotary_dim = 0; - } - - if (kv_batch_idx_.has_value()) { - auto kv_batch_idx = kv_batch_idx_.value(); - CHECK_INPUT(kv_batch_idx); - TORCH_CHECK(kv_batch_idx.scalar_type() == torch::kInt32, "kv_batch_idx must have dtype int32"); - params.kv_batch_idx = reinterpret_cast(kv_batch_idx.data_ptr()); - } - - params.tensor_opts = torch::TensorOptions().dtype(torch::kUInt8).device(q.device()); - - at::Tensor out_accum, softmax_lse_accum; - auto outaccum_type = at::ScalarType::Float; - - constexpr bool Causal = false; // The decode kernel does not support causal mode. It must be set to false. - - auto launch_kernel = [&](auto _QG_SZ, auto _HEAD_DIM, auto _PAGE_SIZE, auto _NUM_SG) { - using TileShapeQK = cute::Shape; - using TileShapePV = cute::Shape; - using TileShapeOutput = cute::Shape; - using SubgroupLayoutQK = cute::Layout>; - - AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { - AT_DISPATCH_BOOL_NO_RETURN(params.is_local, LocalMask, { - FMHAConfig::run(params); - }); - }); - }; - - auto dispatch_page_size = [&](auto _QG_SZ, auto _HEAD_DIM) { - switch (params.page_size) { - case 32: - launch_kernel(_QG_SZ, _HEAD_DIM, _32{}, _2{}); - break; - case 64: - launch_kernel(_QG_SZ, _HEAD_DIM, _64{}, _4{}); - break; - case 128: - launch_kernel(_QG_SZ, _HEAD_DIM, _128{}, _8{}); - break; - default: - TORCH_CHECK(false, "Unsupported page size for decode attention: ", params.page_size); - } - }; - - auto dispatch_q_group = [&](auto _HEAD_DIM) { - switch (nextPowerOf2(max_seqlen_q)) { - case 1: - dispatch_page_size(_1{}, _HEAD_DIM); - break; - case 2: - dispatch_page_size(_2{}, _HEAD_DIM); - break; - case 4: - dispatch_page_size(_4{}, _HEAD_DIM); - break; - case 8: - dispatch_page_size(_8{}, _HEAD_DIM); - break; - case 16: - dispatch_page_size(_16{}, _HEAD_DIM); - break; - case 32: - dispatch_page_size(_32{}, _HEAD_DIM); - break; - default: - TORCH_CHECK(false, "Unsupported qgroup_size for decode attention: ", max_seqlen_q); - } - }; - - switch (params.d) { - case 64: - dispatch_q_group(_64{}); - break; - case 96: - dispatch_q_group(_96{}); - break; - case 128: - dispatch_q_group(_128{}); - break; - case 192: - dispatch_q_group(_192{}); - break; - default: - TORCH_CHECK(false, "Unsupported head size for decode attention: ", params.d); - } - return {out, softmax_lse, out_accum, softmax_lse_accum}; -} } // namespace decode diff --git a/src/sycl/xe_fmha_fwd_decode_kernel.cpp.in b/src/sycl/xe_fmha_fwd_decode_kernel.cpp.in new file mode 100644 index 00000000..c2ac6ba2 --- /dev/null +++ b/src/sycl/xe_fmha_fwd_decode_kernel.cpp.in @@ -0,0 +1,56 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +// Auto-generated from xe_fmha_fwd_decode_kernel.cpp.in +// Template parameters: QG_SZ=@QG_SZ@, HEAD_DIM=@HEAD_DIM@, PAGE_SIZE=@PAGE_SIZE@, NUM_SG=@NUM_SG@ +#define SYCL_INTEL_TARGET 20 + +#include "sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp" + +namespace decode { + +void launch_fmha_decode_@QG_SZ@_@HEAD_DIM@_@PAGE_SIZE@(bool use_sink, const Arguments& params) { + using namespace cute; + + constexpr bool Causal = false; + using TileShapeQK = cute::Shape, cute::Int<@PAGE_SIZE@>, cute::_64>; + using TileShapePV = cute::Shape, cute::_32, cute::Int<@PAGE_SIZE@>>; + using TileShapeOutput = cute::Shape, cute::Int<@HEAD_DIM@>>; + using SubgroupLayoutQK = cute::Layout, cute::_1>>; + + AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { + AT_DISPATCH_BOOL_NO_RETURN(params.is_local, LocalMask, { + FMHAConfig::run(params); + }); + }); +} + +} // namespace decode From b6f0d63dd7990c9a8f11d55eb98e532f221fba54 Mon Sep 17 00:00:00 2001 From: Yang Wang Date: Thu, 19 Mar 2026 19:15:34 -0700 Subject: [PATCH 02/23] Fix noncontiguous input for rmsnorm (#117) * fix norm with noncontiguous input * remove comment out test * support in kernel --- src/sycl/Norm.h | 17 ++++++-- src/sycl/RMSNorm.cpp | 98 +++++++++++++++++++++++++++++++------------- tests/test_norm.py | 41 ++++++++++++++++++ 3 files changed, 124 insertions(+), 32 deletions(-) diff --git a/src/sycl/Norm.h b/src/sycl/Norm.h index b8e39a15..e432f5db 100644 --- a/src/sycl/Norm.h +++ b/src/sycl/Norm.h @@ -13,7 +13,7 @@ constexpr int NUM_REDUCE_STAGES = 16; #define DECLARE_SYCL_GLOBAL_FENCE sycl::access::fence_space::global_space #define DECLARE_SYCL_GLOBAL_AND_LOCAL_FENCE dpcpp_global_and_local_fence = sycl::access::fence_space::global_and_local -inline std::pair _check_layer_norm_inputs( +inline std::tuple _check_layer_norm_inputs( const torch::Tensor& input, IntArrayRef normalized_shape, std::optional& weight /* optional */, @@ -34,8 +34,9 @@ inline std::pair _check_layer_norm_inputs( unsigned int batch_size = input.size(0); unsigned int hidden_size = input.size(1); + unsigned int batch_stride = input.stride(0); - return std::make_pair(batch_size, hidden_size); + return std::make_tuple(batch_size, hidden_size, batch_stride); } template @@ -98,8 +99,14 @@ static inline void norm_group_reduce( class NormConfig { public: - NormConfig(int Batch, int Plane, int problem_dim, int element_size_bytes) - : Batch(Batch), Plane(Plane), problem_dim(problem_dim), element_size_bytes(element_size_bytes) { + NormConfig( + int Batch, int Plane, int problem_dim, int element_size_bytes, int input_batch_stride, int output_batch_stride) + : Batch(Batch), + Plane(Plane), + problem_dim(problem_dim), + element_size_bytes(element_size_bytes), + input_batch_stride(input_batch_stride), + output_batch_stride(output_batch_stride) { semaphores_ptr = nullptr; scratchpad_ptr = nullptr; sub_group_num_global = 1; @@ -126,6 +133,8 @@ class NormConfig { int workgroup_size; int sub_group_num; + int input_batch_stride; + int output_batch_stride; int* semaphores_ptr; void* scratchpad_ptr; int sub_group_num_global; diff --git a/src/sycl/RMSNorm.cpp b/src/sycl/RMSNorm.cpp index 4b0e06ee..32b5ce5f 100644 --- a/src/sycl/RMSNorm.cpp +++ b/src/sycl/RMSNorm.cpp @@ -36,7 +36,7 @@ class RMSNormForward : public NormForward { auto group_id = item_id.get_group(0); auto group_id_foreach = item_id.get_group(1); auto local_id = item_id.get_local_id(2); - index_t group_offset = group_id * cfg.Plane; + index_t group_offset = group_id * cfg.input_batch_stride; for (index_t j = local_id * vec_size; j < cfg.WGPlane; j += cfg.workgroup_size * vec_size) { index_t plane_offset = group_id_foreach * cfg.WGPlane + j; @@ -63,7 +63,8 @@ class RMSNormForward : public NormForward { auto group_id_foreach = item_id.get_group(1); auto local_id = item_id.get_local_id(2); - index_t group_offset = group_id * cfg.Plane; + index_t x_group_offset = group_id * cfg.input_batch_stride; + index_t y_group_offset = group_id * cfg.output_batch_stride; if (cfg.workgroup_num_foreach == 1) { if (local_id == 0) { reduce_project(item_id, sum_value, sum_tmp, cfg); @@ -75,14 +76,14 @@ class RMSNormForward : public NormForward { for (index_t j = local_id * vec_size; j < cfg.WGPlane; j += cfg.workgroup_size * vec_size) { index_t plane_offset = group_id_foreach * cfg.WGPlane + j; if (plane_offset < cfg.Plane) { - vec_t X_val = *(reinterpret_cast(NF::X_data + group_offset + plane_offset)); + vec_t X_val = *(reinterpret_cast(NF::X_data + x_group_offset + plane_offset)); vec_t Y_val; weight_vec_t gamma_val = *(reinterpret_cast(NF::gamma_data + plane_offset)); for (int v = 0; v < vec_size; ++v) { Y_val[v] = static_cast(gamma_val[v] * var_val * X_val[v]); } - *(reinterpret_cast(NF::Y_data + group_offset + plane_offset)) = Y_val; + *(reinterpret_cast(NF::Y_data + y_group_offset + plane_offset)) = Y_val; } } } @@ -113,7 +114,7 @@ class AddRMSNormForward : public RMSNormForward { auto group_id = item_id.get_group(0); auto group_id_foreach = item_id.get_group(1); auto local_id = item_id.get_local_id(2); - index_t group_offset = group_id * cfg.Plane; + index_t group_offset = group_id * cfg.input_batch_stride; for (index_t j = local_id * vec_size; j < cfg.WGPlane; j += cfg.workgroup_size * vec_size) { index_t plane_offset = group_id_foreach * cfg.WGPlane + j; @@ -148,7 +149,8 @@ class GemmaRMSNormForward : public RMSNormForward { auto group_id_foreach = item_id.get_group(1); auto local_id = item_id.get_local_id(2); - index_t group_offset = group_id * cfg.Plane; + index_t x_group_offset = group_id * cfg.input_batch_stride; + index_t y_group_offset = group_id * cfg.output_batch_stride; if (cfg.workgroup_num_foreach == 1) { if (local_id == 0) { RNF::reduce_project(item_id, sum_value, sum_tmp, cfg); @@ -160,14 +162,14 @@ class GemmaRMSNormForward : public RMSNormForward { for (index_t j = local_id * vec_size; j < cfg.WGPlane; j += cfg.workgroup_size * vec_size) { index_t plane_offset = group_id_foreach * cfg.WGPlane + j; if (plane_offset < cfg.Plane) { - vec_t X_val = *(reinterpret_cast(NF::X_data + group_offset + plane_offset)); + vec_t X_val = *(reinterpret_cast(NF::X_data + x_group_offset + plane_offset)); vec_t Y_val; weight_vec_t gamma_val = *(reinterpret_cast(NF::gamma_data + plane_offset)); for (int v = 0; v < vec_size; ++v) { Y_val[v] = static_cast((accscalar_t(1.0) + gamma_val[v]) * var_val * X_val[v]); } - *(reinterpret_cast(NF::Y_data + group_offset + plane_offset)) = Y_val; + *(reinterpret_cast(NF::Y_data + y_group_offset + plane_offset)) = Y_val; } } } @@ -310,13 +312,21 @@ void launch_vectorized_fused_norm_kernel(Norm& norm, template void RMSNormKernelImplInternal( - const Tensor& X, const Tensor& gemma, int64_t M, int64_t N, acc_type eps, Tensor& Y, Tensor& rstd) { + const Tensor& X, + const Tensor& gemma, + int64_t M, + int64_t N, + acc_type eps, + Tensor& Y, + Tensor& rstd, + int64_t input_batch_stride, + int64_t output_batch_stride) { scalar_t* X_data = X.data_ptr(); scalar_t* Y_data = Y.data_ptr(); mean_t* var_data = rstd.data_ptr(); weight_t* gemma_data = gemma.defined() ? gemma.data_ptr() : nullptr; - auto config = NormConfig(M, N, 1, sizeof(scalar_t)); + auto config = NormConfig(M, N, 1, sizeof(scalar_t), input_batch_stride, output_batch_stride); RMSNormForward rms_norm_forward(X_data, Y_data, var_data, gemma_data, eps, M, N); config.workgroup_num_foreach = 1; config.WGPlane = config.Plane; @@ -338,7 +348,7 @@ void FusedAddRMSNormKernelImplInternal( weight_t* gemma_data = gemma.defined() ? gemma.data_ptr() : nullptr; scalar_t* residual_data = residual.data_ptr(); - auto config = NormConfig(M, N, 1, sizeof(scalar_t)); + auto config = NormConfig(M, N, 1, sizeof(scalar_t), N, N); AddRMSNormForward add_rms_norm_forward( X_data, X_data, var_data, gemma_data, eps, residual_data, M, N); config.workgroup_num_foreach = 1; @@ -349,13 +359,21 @@ void FusedAddRMSNormKernelImplInternal( template void GemmaRMSNormKernelImplInternal( - const Tensor& X, const Tensor& gemma, int64_t M, int64_t N, acc_type eps, Tensor& Y, Tensor& rstd) { + const Tensor& X, + const Tensor& gemma, + int64_t M, + int64_t N, + acc_type eps, + Tensor& Y, + Tensor& rstd, + int64_t input_batch_stride, + int64_t output_batch_stride) { scalar_t* X_data = X.data_ptr(); scalar_t* Y_data = Y.data_ptr(); mean_t* var_data = rstd.data_ptr(); weight_t* gemma_data = gemma.defined() ? gemma.data_ptr() : nullptr; - auto config = NormConfig(M, N, 1, sizeof(scalar_t)); + auto config = NormConfig(M, N, 1, sizeof(scalar_t), input_batch_stride, output_batch_stride); GemmaRMSNormForward gemma_rms_norm_forward(X_data, Y_data, var_data, gemma_data, eps, M, N); config.workgroup_num_foreach = 1; config.WGPlane = config.Plane; @@ -377,7 +395,7 @@ void GemmaFusedAddRMSNormKernelImplInternal( weight_t* gemma_data = gemma.defined() ? gemma.data_ptr() : nullptr; scalar_t* residual_data = residual.data_ptr(); - auto config = NormConfig(M, N, 1, sizeof(scalar_t)); + auto config = NormConfig(M, N, 1, sizeof(scalar_t), N, N); GemmaAddRMSNormForward gemma_add_rms_norm_forward( X_data, X_data, var_data, gemma_data, eps, residual_data, M, N); config.workgroup_num_foreach = 1; @@ -390,28 +408,40 @@ void GemmaFusedAddRMSNormKernelImplInternal( void rmsnorm(torch::Tensor& output, torch::Tensor& input, torch::Tensor& weight, double eps) { std::optional opt_weight = weight; std::optional opt_bias; - auto M_N = _check_layer_norm_inputs(input, c10::IntArrayRef({input.size(-1)}), opt_weight, opt_bias); - auto M = M_N.first; - auto N = M_N.second; + auto M_N_S = _check_layer_norm_inputs(input, c10::IntArrayRef({input.size(-1)}), opt_weight, opt_bias); + auto M = std::get<0>(M_N_S); + auto N = std::get<1>(M_N_S); + auto input_batch_stride = std::get<2>(M_N_S); Tensor input_ = (input.dim() == 1) ? input.reshape({M, N}) : input; Tensor output_ = (output.dim() == 1) ? output.reshape({M, N}) : output; Tensor weight_ = (weight.dim() == 1) ? weight.reshape({N}) : weight; Tensor rstd = at::empty({M}, input_.options().dtype(kFloat)); + int64_t output_batch_stride = (output.dim() >= 2) ? output.stride(0) : N; SYCL_DISPATCH_FLOATING_TYPES( at::ScalarType::Half, at::ScalarType::BFloat16, input_.scalar_type(), "RMSNormKernelImpl", [&]() { RMSNormKernelImplInternal( - input_, weight_, M, N, static_cast>(eps), output_, rstd); + input_, + weight_, + M, + N, + static_cast>(eps), + output_, + rstd, + input_batch_stride, + output_batch_stride); }); } void fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps) { + TORCH_CHECK(input.is_contiguous(), "fused_add_rmsnorm: input must be contiguous"); + TORCH_CHECK(residual.is_contiguous(), "fused_add_rmsnorm: residual must be contiguous"); std::optional opt_weight = weight; std::optional opt_bias; - auto M_N = _check_layer_norm_inputs(input, c10::IntArrayRef({input.size(-1)}), opt_weight, opt_bias); - auto M = M_N.first; - auto N = M_N.second; + auto M_N_S = _check_layer_norm_inputs(input, c10::IntArrayRef({input.size(-1)}), opt_weight, opt_bias); + auto M = std::get<0>(M_N_S); + auto N = std::get<1>(M_N_S); Tensor rstd = at::empty({M}, input.options().dtype(kFloat)); @@ -425,28 +455,40 @@ void fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tenso void gemma_rmsnorm(torch::Tensor& output, torch::Tensor& input, torch::Tensor& weight, double eps) { std::optional opt_weight = weight; std::optional opt_bias; - auto M_N = _check_layer_norm_inputs(input, c10::IntArrayRef({input.size(-1)}), opt_weight, opt_bias); - auto M = M_N.first; - auto N = M_N.second; + auto M_N_S = _check_layer_norm_inputs(input, c10::IntArrayRef({input.size(-1)}), opt_weight, opt_bias); + auto M = std::get<0>(M_N_S); + auto N = std::get<1>(M_N_S); + auto input_batch_stride = std::get<2>(M_N_S); Tensor input_ = (input.dim() == 1) ? input.reshape({M, N}) : input; Tensor output_ = (output.dim() == 1) ? output.reshape({M, N}) : output; Tensor weight_ = (weight.dim() == 1) ? weight.reshape({N}) : weight; Tensor rstd = at::empty({M}, input_.options().dtype(kFloat)); + int64_t output_batch_stride = (output.dim() >= 2) ? output.stride(0) : N; SYCL_DISPATCH_FLOATING_TYPES( at::ScalarType::Half, at::ScalarType::BFloat16, input_.scalar_type(), "GemmaRMSNormKernelImpl", [&]() { GemmaRMSNormKernelImplInternal( - input_, weight_, M, N, static_cast>(eps), output_, rstd); + input_, + weight_, + M, + N, + static_cast>(eps), + output_, + rstd, + input_batch_stride, + output_batch_stride); }); } void gemma_fused_add_rmsnorm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, double eps) { + TORCH_CHECK(input.is_contiguous(), "gemma_fused_add_rmsnorm: input must be contiguous"); + TORCH_CHECK(residual.is_contiguous(), "gemma_fused_add_rmsnorm: residual must be contiguous"); std::optional opt_weight = weight; std::optional opt_bias; - auto M_N = _check_layer_norm_inputs(input, c10::IntArrayRef({input.size(-1)}), opt_weight, opt_bias); - auto M = M_N.first; - auto N = M_N.second; + auto M_N_S = _check_layer_norm_inputs(input, c10::IntArrayRef({input.size(-1)}), opt_weight, opt_bias); + auto M = std::get<0>(M_N_S); + auto N = std::get<1>(M_N_S); Tensor input_ = (input.dim() == 1) ? input.reshape({M, N}) : input; Tensor residual_ = (residual.dim() == 1) ? residual.reshape({M, N}) : residual; diff --git a/tests/test_norm.py b/tests/test_norm.py index 6b0623db..a49ed003 100644 --- a/tests/test_norm.py +++ b/tests/test_norm.py @@ -134,5 +134,46 @@ def test_gemma_fused_add_rmsnorm(batch_size, hidden_size, dtype): torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3) +############################################################################### +# Non-contiguous input tests (DeepSeek split pattern: stride[0] != hidden_size) +############################################################################### + + +def _make_non_contiguous(batch_size, hidden_size, dtype, extra=64): + """Create a non-contiguous tensor by slicing a larger tensor, + mimicking latent_cache.split([hidden_size, extra], dim=-1)[0].""" + full = torch.randn(batch_size, hidden_size + extra, device=device, dtype=dtype) + view = full[:, :hidden_size] # stride = (hidden_size + extra, 1) + # assert not view.is_contiguous() + assert view.stride(0) == hidden_size + extra + return view + + +@pytest.mark.parametrize("batch_size", [1, 19, 99]) +@pytest.mark.parametrize("hidden_size", [512, 1024, 3072]) +@pytest.mark.parametrize("dtype", [torch.float16]) +def test_norm_non_contiguous(batch_size, hidden_size, dtype): + x_nc = _make_non_contiguous(batch_size, hidden_size, dtype) + w = torch.randn(hidden_size, device=device, dtype=dtype) + + y_ref = llama_rms_norm(x_nc.clone(), w) + y = sgl_kernel.rmsnorm(x_nc, w) + + torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("batch_size", [1, 19, 99]) +@pytest.mark.parametrize("hidden_size", [512, 1024, 3072]) +@pytest.mark.parametrize("dtype", [torch.float16]) +def test_gemma_norm_non_contiguous(batch_size, hidden_size, dtype): + x_nc = _make_non_contiguous(batch_size, hidden_size, dtype) + w = torch.randn(hidden_size, device=device, dtype=dtype) + + y_ref = gemma_rms_norm(x_nc.clone(), w) + y = sgl_kernel.gemma_rmsnorm(x_nc, w) + + torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3) + + if __name__ == "__main__": sys.exit(pytest.main([__file__])) From 3796408db93a1b05f15a8e6a80e0336a4c18179a Mon Sep 17 00:00:00 2001 From: Suryaprakash Shanmugam Date: Mon, 23 Mar 2026 11:41:52 +0530 Subject: [PATCH 03/23] Add MXFP4 Per Token Group Quant kernel and tests (#106) * Add MXFP4 Per Token Group Quant kernel and tests Remove commented out fp8 blockwise group gemm registration * Add benchmarking for per token group quant mxfp4 * Add test to run_suite.py * Fix group size constraint for mxfp4; Add benchmark test to CI flow * Remove reference provider from the benchmark script - Add check for quantized and scale values separately - Include eps value in ref quant function call * Fix MXFP4 quantization to match OCP MX spec - Replace ceil(log2(max/6.0)) scale computation with floor(log2(max)) - E2M1_EMAX per OCP MX spec - Fix roundTiesToEven at midpoints in SYCL kernel (change <= to < at odd-mantissa boundaries) - Replace naive argmin-based quantize_to_e2m1 reference with microxcaling _quantize_elemwise_core algorithm - Normalize signed zeros (+0.0 vs -0.0) before packed byte comparison in tests and benchmark * Fix lint issues * Remove unsupported group sizes * Apply formatting check * Add TODO for quantize_to_e2m1 * Update tests/test_per_token_group_quant_mxfp4.py Co-authored-by: Meng, Hengyu * trigger CI --------- Co-authored-by: Meng, Hengyu --- .github/workflows/pr-test-xpu.yml | 2 + .../bench_per_token_group_quant_mxfp4.py | 547 ++++++++++++++++++ include/sgl_kernel_ops.h | 2 + python/sgl_kernel/__init__.py | 1 + python/sgl_kernel/gemm.py | 61 ++ src/sycl/per_token_group_quant_fp4.cpp | 334 +++++++++++ src/torch_extension_sycl.cc | 4 + tests/run_suite.py | 1 + tests/test_per_token_group_quant_mxfp4.py | 546 +++++++++++++++++ 9 files changed, 1498 insertions(+) create mode 100644 benchmark/bench_per_token_group_quant_mxfp4.py create mode 100644 src/sycl/per_token_group_quant_fp4.cpp create mode 100644 tests/test_per_token_group_quant_mxfp4.py diff --git a/.github/workflows/pr-test-xpu.yml b/.github/workflows/pr-test-xpu.yml index 015766d2..c777a1ec 100644 --- a/.github/workflows/pr-test-xpu.yml +++ b/.github/workflows/pr-test-xpu.yml @@ -67,6 +67,8 @@ jobs: python3 bench_merge_states_v2.py 2>&1 | tee merge_states.py.log \ python3 bench_swiglu_alpha_limit.py 2>&1 | tee swiglu_alpha_limit.py.log \ python3 bench_fused_qk_norm_rope.py 2>&1 | tee fused_qk_norm_rope.py.log \ + python3 bench_per_token_group_quant_8bit.py 2>&1 | tee per_token_group_quant_8bit.py.log \ + python3 bench_per_token_group_quant_mxfp4.py 2>&1 | tee per_token_group_quant_mxfp4.py.log \ " - name: Copy logs from container diff --git a/benchmark/bench_per_token_group_quant_mxfp4.py b/benchmark/bench_per_token_group_quant_mxfp4.py new file mode 100644 index 00000000..ae1e4671 --- /dev/null +++ b/benchmark/bench_per_token_group_quant_mxfp4.py @@ -0,0 +1,547 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Benchmark script for MXFP4 (E2M1) per-token group quantization on Intel XPU. + +Rounding: Per OCP MX spec (section 5.3.3), FP4 conversion uses +roundTiesToEven — at midpoints between representable values, the +value with even mantissa (mantissa bit = 0) is chosen. +""" + +import itertools +import os + +import pandas as pd +import torch +import triton + +MXFP4_BLOCK_SIZE = 32 +FLOAT4_E2M1_MAX = 6.0 + +# E2M1 format parameters (from Microsoft microxcaling formats.py) +E2M1_EBITS = 2 +E2M1_MBITS = 3 # includes sign bit and implicit one +E2M1_EMAX = 2 ** (E2M1_EBITS - 1) # = 2 +E2M1_MAX_NORM = ( + 2**E2M1_EMAX * float(2 ** (E2M1_MBITS - 1) - 1) / 2 ** (E2M1_MBITS - 2) +) # = 6.0 + +FP32_EXPONENT_BIAS = 127 +FP32_MIN_NORMAL = 2 ** (-FP32_EXPONENT_BIAS + 1) # 2^(-126) + +IS_CI = ( + os.getenv("CI", "false").lower() == "true" + or os.getenv("GITHUB_ACTIONS", "false").lower() == "true" +) + +kE2M1ToFloat = torch.tensor( + [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32 +) + + +def is_xpu_available() -> bool: + return hasattr(torch, "xpu") and torch.xpu.is_available() + + +def _round_mantissa_even(A: torch.Tensor) -> torch.Tensor: + """Round mantissa using roundTiesToEven (from Microsoft microxcaling). + + At exact 0.5 midpoints (i.e., values like 0.5, 2.5, 4.5, ...), + round to the nearest even integer (the one whose LSB is 0). + + Ref: https://github.com/microsoft/microxcaling/blob/main/mx/elemwise_ops.py + """ + absA = torch.abs(A) + # Identify exact midpoints: 0.5, 2.5, 4.5, ... i.e. (absA - 0.5) % 2 == 0 + maskA = ((absA - 0.5) % 2 == torch.zeros_like(A)).type(A.dtype) + # round half up, then subtract 1 at midpoints to get even + return torch.sign(A) * (torch.floor(absA + 0.5) - maskA) + + +def _quantize_elemwise_core_e2m1( + A: torch.Tensor, saturate_normals: bool = True +) -> torch.Tensor: + """Element-wise quantization to E2M1 using Microsoft microxcaling's + _quantize_elemwise_core algorithm with round='even'. + + E2M1 format: ebits=2, mbits=3, emax=2, max_norm=6.0 + min_exp = -(2^(ebits-1)) + 2 = 0 + + Algorithm (from Microsoft microxcaling elemwise_ops.py): + 1. Compute per-element private exponent = floor(log2(|A|)), + clamped to min_exp. + 2. Left-shift: out = A / 2^private_exp * 2^(mbits-2) + 3. Round mantissa with roundTiesToEven + 4. Right-shift: out = out / 2^(mbits-2) * 2^private_exp + 5. Clamp to [-max_norm, max_norm] if saturate_normals + + Ref: https://github.com/microsoft/microxcaling/blob/main/mx/elemwise_ops.py + """ + ebits = E2M1_EBITS # 2 + mbits = E2M1_MBITS # 3 + max_norm = E2M1_MAX_NORM # 6.0 + + # min representable exponent: -(2^(ebits-1)) + 2 = 0 + min_exp = -(2 ** (ebits - 1)) + 2 # 0 + + out = A.clone() + + # Per-element private exponent: floor(log2(|A|)) + # Add guard for zeros: log2(0) is -inf, we use (A==0) to avoid that + private_exp = torch.floor(torch.log2(torch.abs(A) + (A == 0).type(A.dtype))) + private_exp = private_exp.clip(min=min_exp) + + # Left-shift: scale up so mantissa bits land in integer portion + # out = A / 2^private_exp * 2^(mbits-2) + shift = mbits - 2 # = 1 + out = out / (2**private_exp) * (2**shift) + + # Round mantissa with roundTiesToEven + out = _round_mantissa_even(out) + + # Right-shift: undo scaling + # out = out / 2^(mbits-2) * 2^private_exp + out = out / (2**shift) * (2**private_exp) + + # Saturate to [-max_norm, max_norm] + if saturate_normals: + out = torch.clamp(out, min=-max_norm, max=max_norm) + + return out + + +def _float_to_e2m1_code(val: torch.Tensor) -> torch.Tensor: + """Convert quantized float values back to E2M1 4-bit codes. + + After _quantize_elemwise_core_e2m1, values are one of the 8 representable + E2M1 magnitudes: {0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0}. + This maps them to 4-bit codes (sign in bit 3, magnitude in bits 0-2). + """ + sign = (val < 0).to(torch.uint8) + abs_val = val.abs() + + # Map representable magnitudes to 3-bit indices via the kE2M1ToFloat LUT. + indices = torch.zeros_like(abs_val, dtype=torch.uint8) + lut = kE2M1ToFloat.to(device=val.device) + for i in range(8): + indices = torch.where( + torch.isclose(abs_val, lut[i].expand_as(abs_val), atol=1e-6, rtol=0), + torch.tensor(i, dtype=torch.uint8, device=val.device), + indices, + ) + + return (sign << 3) | indices + + +def quantize_to_e2m1(tensor: torch.Tensor) -> torch.Tensor: + """Quantize tensor values to E2M1 format (4-bit indices). + + Uses the Microsoft microxcaling _quantize_elemwise_core algorithm + with roundTiesToEven, then maps the resulting float values to 4-bit codes. + + Ref: https://github.com/microsoft/microxcaling/blob/main/mx/elemwise_ops.py + """ + quantized_float = _quantize_elemwise_core_e2m1( + tensor.float(), saturate_normals=True + ) + return _float_to_e2m1_code(quantized_float) + + +def pack_fp4(tensor: torch.Tensor) -> torch.Tensor: + """Pack two 4-bit values into one uint8.""" + assert tensor.shape[-1] % 2 == 0 + shape = tensor.shape[:-1] + (tensor.shape[-1] // 2, 2) + paired = tensor.reshape(shape) + packed = (paired[..., 0] & 0x0F) | ((paired[..., 1] & 0x0F) << 4) + return packed.to(torch.uint8) + + +def _normalize_packed_fp4_signed_zero(packed: torch.Tensor) -> torch.Tensor: + """Canonicalize signed zeros in packed FP4 bytes. + + In E2M1, code 0x0 is +0.0 and code 0x8 is -0.0. Both represent + the same value, but different implementations may emit either form. + This helper rewrites every -0.0 nibble (0x8) to +0.0 (0x0) so that + byte-level comparisons are not tripped up by this harmless difference. + """ + lo = packed & 0x0F + hi = (packed >> 4) & 0x0F + lo = torch.where(lo == 0x08, torch.zeros_like(lo), lo) + hi = torch.where(hi == 0x08, torch.zeros_like(hi), hi) + return (lo | (hi << 4)).to(torch.uint8) + + +def unpack_fp4(packed: torch.Tensor) -> torch.Tensor: + """Unpack uint8 into two 4-bit values.""" + low = packed & 0x0F + high = (packed >> 4) & 0x0F + return torch.stack([low, high], dim=-1).reshape(*packed.shape[:-1], -1) + + +def dequantize_e2m1( + quantized: torch.Tensor, dtype: torch.dtype = torch.float32 +) -> torch.Tensor: + """Dequantize E2M1 values back to float.""" + sign = ((quantized >> 3) & 1).to(torch.bool) + magnitude_idx = (quantized & 0x07).to(torch.long) + kE2M1 = kE2M1ToFloat.to(device=quantized.device) + magnitude = kE2M1[magnitude_idx] + result = torch.where(sign, -magnitude, magnitude) + return result.to(dtype) + + +def _shared_exponents(A: torch.Tensor, axis: int) -> torch.Tensor: + """Compute shared exponents per block using Microsoft microxcaling's + _shared_exponents algorithm with method="max". + + Algorithm: + 1. shared_exp = max(|A|) along axis (per block) + 2. shared_exp = floor(log2(shared_exp + FP32_MIN_NORMAL * (shared_exp == 0))) + The FP32_MIN_NORMAL guard ensures log2(0) doesn't produce -inf. + 3. Offset by emax: shared_exp = shared_exp - emax + + Ref: https://github.com/microsoft/microxcaling/blob/main/mx/mx_ops.py + """ + shared_exp = torch.max(torch.abs(A), dim=axis, keepdim=True).values + + # floor(log2(...)) with zero-guard from microxcaling + shared_exp = torch.floor( + torch.log2( + shared_exp + FP32_MIN_NORMAL * (shared_exp == 0).type(shared_exp.dtype) + ) + ) + + # Offset by the largest representable exponent in E2M1 + shared_exp = shared_exp - E2M1_EMAX + + return shared_exp + + +def quantize_to_mxfp4_ref( + tensor: torch.Tensor, block_size: int = MXFP4_BLOCK_SIZE, eps: float = 1e-10 +) -> tuple: + """Reference implementation for MXFP4 quantization using Microsoft + microxcaling's _quantize_mx algorithm. + + Algorithm (from mx_ops.py _quantize_mx): + 1. Reshape into blocks + 2. Compute shared exponent per block via _shared_exponents + 3. Clamp shared_exp to scale_emax range [-127, 127] + 4. Scale elements: A = A / 2^shared_exp + 5. Quantize element-wise with _quantize_elemwise_core (saturate_normals=True) + 6. Rescale: A = A * 2^shared_exp (implicitly stored in UE8M0 scale) + + Ref: https://github.com/microsoft/microxcaling/blob/main/mx/mx_ops.py + """ + assert tensor.dim() == 2 + m, k = tensor.shape + assert k % block_size == 0 + assert k % 2 == 0 + + tensor_fp32 = tensor.float() + num_blocks = k // block_size + tensor_blocks = tensor_fp32.reshape(m, num_blocks, block_size) + + # Compute shared exponents (microxcaling _shared_exponents + offset by emax) + shared_exp = _shared_exponents(tensor_blocks, axis=-1) + + # Clamp to UE8M0 scale range: scale_bits=8, scale_emax = 2^(8-1)-1 = 127 + scale_emax = 127 + shared_exp = shared_exp.clamp(min=-scale_emax, max=scale_emax) + + # Encode as UE8M0: stored_scale = shared_exp + 127 + scales_ue8m0 = (shared_exp.to(torch.int32) + 127).to(torch.uint8).squeeze(-1) + + # Scale elements by shared exponent: A = A / 2^shared_exp + scaled_tensor = tensor_blocks / (2.0**shared_exp) + + # Quantize element-wise with microxcaling core (roundTiesToEven, saturate) + quantized_float = _quantize_elemwise_core_e2m1(scaled_tensor, saturate_normals=True) + + # Convert quantized float values to 4-bit E2M1 codes + quantized_blocks = _float_to_e2m1_code(quantized_float) + + quantized = quantized_blocks.reshape(m, k) + packed = pack_fp4(quantized) + + return packed, scales_ue8m0 + + +def dequantize_mxfp4( + packed: torch.Tensor, + scales: torch.Tensor, + dtype: torch.dtype = torch.float32, + block_size: int = MXFP4_BLOCK_SIZE, +) -> torch.Tensor: + """Dequantize MXFP4 packed values back to float.""" + m, packed_k = packed.shape + k = packed_k * 2 + + unpacked = unpack_fp4(packed) + dequantized = dequantize_e2m1(unpacked, dtype) + + num_blocks = k // block_size + dequantized_blocks = dequantized.reshape(m, num_blocks, block_size) + + scale_exp = scales.to(torch.int32) - 127 + scale_values = torch.pow(2.0, scale_exp.float()).unsqueeze(-1) + scaled = dequantized_blocks * scale_values + + return scaled.reshape(m, k).to(dtype) + + +def reference_per_token_group_quant_mxfp4( + x: torch.Tensor, group_size: int, eps: float = 1e-10 +) -> tuple: + """Reference implementation using PyTorch operations.""" + assert x.shape[-1] % group_size == 0 + assert x.is_contiguous() + + x_cpu = x.cpu().float() + x_q, x_s = quantize_to_mxfp4_ref(x_cpu, group_size, eps) + return x_q.to(x.device), x_s.to(x.device) + + +def sglang_per_token_group_quant_mxfp4( + x: torch.Tensor, group_size: int, eps: float = 1e-10 +) -> tuple: + """SGL kernel wrapper for MXFP4 quantization.""" + from sgl_kernel import sgl_per_token_group_quant_fp4 + + assert x.shape[-1] % group_size == 0 + assert x.is_contiguous() + + x_q, x_s = sgl_per_token_group_quant_fp4(x=x, group_size=group_size, eps=eps) + return x_q, x_s + + +def calculate_diff( + batch_size: int, + seq_len: int, + hidden_dim: int, + group_size: int, + src_dtype: torch.dtype, + eps: float = 1e-10, +): + """Verify correctness by comparing reference and kernel implementations.""" + device = torch.device("xpu") + + x = torch.randn(batch_size * seq_len, hidden_dim, device=device, dtype=src_dtype) + + x_q_ref, x_s_ref = reference_per_token_group_quant_mxfp4(x.clone(), group_size, eps) + x_q_sgl, x_s_sgl = sglang_per_token_group_quant_mxfp4(x.clone(), group_size, eps) + + # Compare quantized outputs directly (packed uint8 and scales). + # Normalise signed zeros first: in E2M1 code 0x0 (+0.0) and 0x8 + # (-0.0) are semantically identical. The kernel may preserve the + # sign of the original float while the reference always emits +0.0, + # so we canonicalise before comparing. + x_q_ref_norm = _normalize_packed_fp4_signed_zero(x_q_ref.cpu()) + x_q_sgl_norm = _normalize_packed_fp4_signed_zero(x_q_sgl.cpu()) + q_match = torch.equal(x_q_ref_norm, x_q_sgl_norm) + s_match = torch.equal(x_s_ref.cpu(), x_s_sgl.cpu()) + + if q_match and s_match: + print( + f" \u2705 Quantized values match (batch={batch_size}, seq={seq_len}, hidden={hidden_dim}, group={group_size}, dtype={src_dtype})" + ) + else: + q_mismatches = (x_q_ref_norm != x_q_sgl_norm).sum().item() if not q_match else 0 + s_mismatches = ( + (x_s_ref.cpu() != x_s_sgl.cpu()).sum().item() if not s_match else 0 + ) + print( + f" \u274c Quantized values differ: " + f"packed_q({q_mismatches} mismatches) " + f"scales({s_mismatches} mismatches)" + ) + + # Compare dequantized outputs + x_dq_ref = dequantize_mxfp4(x_q_ref.cpu(), x_s_ref.cpu(), torch.float32, group_size) + x_dq_sgl = dequantize_mxfp4(x_q_sgl.cpu(), x_s_sgl.cpu(), torch.float32, group_size) + + if torch.allclose(x_dq_ref, x_dq_sgl, rtol=0.2, atol=0.5): + print( + f" \u2705 Dequantized values match (batch={batch_size}, seq={seq_len}, hidden={hidden_dim}, group={group_size}, dtype={src_dtype})" + ) + else: + max_diff = (x_dq_ref - x_dq_sgl).abs().max().item() + print(f" \u274c Dequantized values differ (max_diff={max_diff:.4f})") + + +def calculate_flops(num_elements: int, num_groups: int) -> int: + """Calculate FLOPs for MXFP4 per-token-group quantization.""" + flops_per_element = 5 + flops_per_group = 8 + return (num_elements * flops_per_element) + (num_groups * flops_per_group) + + +def calculate_effective_bandwidth( + batch_size: int, + seq_len: int, + hidden_dim: int, + group_size: int, + src_dtype: torch.dtype, + time_ms: float, +) -> dict: + """Calculate effective bandwidth and FLOPs for MXFP4 quantization kernel.""" + num_tokens = batch_size * seq_len + num_elements = num_tokens * hidden_dim + num_groups = num_elements // group_size + + dtype_size = 2 if src_dtype in (torch.float16, torch.bfloat16) else 4 + input_bytes = num_elements * dtype_size + output_bytes = num_elements // 2 + scale_bytes = num_groups + total_bytes = input_bytes + output_bytes + scale_bytes + + time_s = time_ms / 1000.0 + bandwidth_gbs = (total_bytes / 1e9) / time_s + + total_flops = calculate_flops(num_elements, num_groups) + gflops = (total_flops / 1e9) / time_s + + return { + "num_tokens": num_tokens, + "num_elements": num_elements, + "num_groups": num_groups, + "total_bytes": total_bytes, + "bandwidth_gbs": bandwidth_gbs, + "total_flops": total_flops, + "gflops": gflops, + } + + +batch_size_range = [1, 2, 4, 8, 16, 32, 64] if not IS_CI else [1, 4, 16] +seq_len_range = [64, 128, 256, 512, 1024, 2048] if not IS_CI else [64, 256] +# Only group_size=32 is supported for MXFP4 (per OCP MX spec block size) +group_size_range = [32] +src_dtype_range = [torch.bfloat16] + +configs = list( + itertools.product( + batch_size_range, seq_len_range, group_size_range, src_dtype_range + ) +) + +all_results = [] + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size", "seq_len", "group_size", "src_dtype"], + x_vals=configs, + line_arg="provider", + line_vals=["sglang"], + line_names=["SGL Kernel"], + styles=[("green", "-")], + ylabel="us", + plot_name="per-token-group-quant-mxfp4-performance", + args={}, + ) +) +def benchmark(batch_size, seq_len, group_size, src_dtype, provider): + device = torch.device("xpu") + hidden_dim = 7168 + + x = torch.randn(batch_size * seq_len, hidden_dim, device=device, dtype=src_dtype) + + quantiles = [0.5, 0.2, 0.8] + + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: sglang_per_token_group_quant_mxfp4(x, group_size), + quantiles=quantiles, + ) + + bw_metrics = calculate_effective_bandwidth( + batch_size, seq_len, hidden_dim, group_size, src_dtype, ms + ) + + all_results.append( + { + "batch_size": batch_size, + "seq_len": seq_len, + "num_tokens": bw_metrics["num_tokens"], + "hidden_dim": hidden_dim, + "group_size": group_size, + "src_dtype": str(src_dtype), + "provider": provider, + "time_us": 1000 * ms, + "bandwidth_gbs": bw_metrics["bandwidth_gbs"], + "total_bytes_mb": bw_metrics["total_bytes"] / 1e6, + "total_flops_m": bw_metrics["total_flops"] / 1e6, + "gflops": bw_metrics["gflops"], + } + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +def print_summary(results: list): + """Print summary statistics from benchmark results.""" + print("\n" + "=" * 100) + print("MXFP4 Per-Token Group Quantization Benchmark Results") + print("=" * 100) + + df = pd.DataFrame(results) + df["bandwidth_gbs"] = df["bandwidth_gbs"].round(2) + df["total_bytes_mb"] = df["total_bytes_mb"].round(2) + df["time_us"] = df["time_us"].round(2) + df["total_flops_m"] = df["total_flops_m"].round(2) + df["gflops"] = df["gflops"].round(2) + + print("\nDetailed Results:") + print(df.to_markdown(index=False)) + + print("\n" + "=" * 100) + print("Summary Statistics by Provider") + print("=" * 100) + summary = df.groupby("provider").agg( + { + "bandwidth_gbs": ["mean", "min", "max"], + "time_us": ["mean", "min", "max"], + "gflops": ["mean", "min", "max"], + } + ) + print(summary.to_markdown()) + + +def main(): + if not is_xpu_available(): + print("Error: Intel XPU not available") + return + + try: + from sgl_kernel import sgl_per_token_group_quant_fp4 + + assert callable(sgl_per_token_group_quant_fp4) + except ImportError: + print("Error: sgl_per_token_group_quant_fp4 kernel not available") + return + + print("Running MXFP4 Per-Token Group Quantization Benchmark") + print(" Device: Intel XPU") + print(f" MXFP4 block size: {MXFP4_BLOCK_SIZE}") + + print("\n" + "=" * 80) + print("Correctness Verification") + print("=" * 80) + calculate_diff( + batch_size=2, + seq_len=64, + hidden_dim=128, + group_size=32, + src_dtype=torch.bfloat16, + ) + calculate_diff( + batch_size=1, seq_len=32, hidden_dim=128, group_size=32, src_dtype=torch.float32 + ) + + print("\n" + "=" * 80) + print("Performance Benchmark") + print("=" * 80) + benchmark.run(print_data=True) + + print_summary(all_results) + + +if __name__ == "__main__": + main() diff --git a/include/sgl_kernel_ops.h b/include/sgl_kernel_ops.h index 6dec3842..212518fb 100644 --- a/include/sgl_kernel_ops.h +++ b/include/sgl_kernel_ops.h @@ -157,6 +157,8 @@ void fused_qk_norm_rope( double high, double attention_factor, int64_t rotary_dim); +void sgl_per_token_group_quant_fp4( + at::Tensor input, at::Tensor output_q, at::Tensor output_s, int64_t group_size, double eps); } // namespace at::native::xpu void silu_and_mul(torch::Tensor& out, torch::Tensor& input); void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); diff --git a/python/sgl_kernel/__init__.py b/python/sgl_kernel/__init__.py index a243f4df..9f78df88 100755 --- a/python/sgl_kernel/__init__.py +++ b/python/sgl_kernel/__init__.py @@ -43,6 +43,7 @@ scaled_fp4_quant, sgl_per_tensor_quant_fp8, sgl_per_token_group_quant_8bit, + sgl_per_token_group_quant_fp4, sgl_per_token_group_quant_fp8, sgl_per_token_group_quant_int8, sgl_per_token_quant_fp8, diff --git a/python/sgl_kernel/gemm.py b/python/sgl_kernel/gemm.py index 30316544..284d9eac 100644 --- a/python/sgl_kernel/gemm.py +++ b/python/sgl_kernel/gemm.py @@ -124,6 +124,67 @@ def sgl_per_tensor_quant_fp8( ) +def sgl_per_token_group_quant_fp4( + x: torch.Tensor, + group_size: int = 32, + eps: float = 1e-10, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize input tensor to MXFP4 (E2M1) format with per-token group scaling. + + MXFP4 follows the OpenCompute MX (Microscaling) format specification: + - Data type: E2M1 (4-bit float with 2-bit exponent, 1-bit mantissa) + - Block size: 32 elements per scale factor (default) + - Scale format: UE8M0 (unsigned 8-bit exponent-only, no mantissa) + + Args: + x: Input tensor with shape (..., K) where K is divisible by group_size. + Must be contiguous and dtype float16, bfloat16, or float32. + group_size: Number of elements per quantization group. Must be 32 for MXFP4. + eps: Small epsilon to avoid division by zero. Default is 1e-10. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - output_q: Packed FP4 tensor with shape (..., K // 2) and dtype uint8. + Two E2M1 values are packed into each byte. + - output_s: Scale tensor with shape (..., K // group_size) and dtype uint8. + Scales are stored in UE8M0 format (exponent + 127 bias). + """ + assert ( + x.shape[-1] % group_size == 0 + ), f"the last dimension of `x` ({x.shape[-1]}) must be divisible by `group_size` ({group_size})" + assert x.is_contiguous(), "`x` is not contiguous" + assert group_size == 32, f"group_size must be 32 for MXFP4, got {group_size}" + + # Ensure input is 2D for the kernel + original_shape = x.shape + if x.dim() == 1: + x = x.unsqueeze(0) + elif x.dim() > 2: + x = x.view(-1, x.shape[-1]) + + m, k = x.shape + num_groups_per_row = k // group_size + + # Output is packed FP4 (2 values per byte) + output_q = torch.empty((m, k // 2), device=x.device, dtype=torch.uint8) + + # Scales in row-major layout: (m, num_groups_per_row) + # Each row has the scales for that token's groups + output_s = torch.empty((m, num_groups_per_row), device=x.device, dtype=torch.uint8) + + if x.shape[0] > 0: + torch.ops.sgl_kernel.sgl_per_token_group_quant_fp4.default( + x, output_q, output_s, group_size, eps + ) + + # Reshape output to match input shape + output_shape_q = original_shape[:-1] + (original_shape[-1] // 2,) + output_shape_s = original_shape[:-1] + (original_shape[-1] // group_size,) + + return output_q.view(output_shape_q), output_s.view(output_shape_s) + + def sgl_per_token_quant_fp8( input: torch.Tensor, output_q: torch.Tensor, diff --git a/src/sycl/per_token_group_quant_fp4.cpp b/src/sycl/per_token_group_quant_fp4.cpp new file mode 100644 index 00000000..4c276f96 --- /dev/null +++ b/src/sycl/per_token_group_quant_fp4.cpp @@ -0,0 +1,334 @@ +// SPDX-License-Identifier: Apache-2.0 +/* + * SYCL kernel for per-token group quantization to MXFP4 (E2M1) format. + * + * MXFP4 follows the OpenCompute MX (Microscaling) format specification: + * - Data type: E2M1 (4-bit float with 2-bit exponent, 1-bit mantissa) + * - Block size: 32 elements per scale factor + * - Scale format: UE8M0 (unsigned 8-bit exponent-only, no mantissa) + * + * E2M1 representable values (magnitude): 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0 + * With sign bit, we have 16 total values. + * + * Bit layout of E2M1: + * Bit 3: Sign (0 = positive, 1 = negative) + * Bits 0-2: Magnitude index (0-7) + * + * Two FP4 values are packed into a single uint8_t: + * - Lower nibble (bits 0-3): First value + * - Upper nibble (bits 4-7): Second value + * + * Rounding: Per OCP MX spec (section 5.3.3), FP4 conversion uses + * roundTiesToEven — at midpoints between representable values, the + * value with even mantissa (mantissa bit = 0) is chosen. + */ + +#include +#include +#include + +#include +#include + +#include "SYCLHelpers.h" +#include "Utils.h" + +namespace at::native::xpu { + +constexpr float FLOAT4_E2M1_MAX = 6.0f; + +template +inline T QuantGroupReduceMaxFP4(T val, sycl::nd_item<1> item) { + auto sg = item.get_sub_group(); + + val = sycl::fmax(val, sycl::permute_group_by_xor(sg, val, 8)); + val = sycl::fmax(val, sycl::permute_group_by_xor(sg, val, 4)); + val = sycl::fmax(val, sycl::permute_group_by_xor(sg, val, 2)); + val = sycl::fmax(val, sycl::permute_group_by_xor(sg, val, 1)); + + return val; +} + +// E2M1 format (4-bit float): 1 sign bit, 2 exponent bits, 1 mantissa bit +// Encoding: exp=00 (subnormal), exp=01/10/11 (normal with bias=1) +// Result: bits[3]=sign, bits[2:1]=exponent, bits[0]=mantissa +// +// Representable values and their codes: +// 0.0 -> 0b000 (subnormal, m=0, even) +// 0.5 -> 0b001 (subnormal, m=1, odd) +// 1.0 -> 0b010 (e=01, m=0, even) +// 1.5 -> 0b011 (e=01, m=1, odd) +// 2.0 -> 0b100 (e=10, m=0, even) +// 3.0 -> 0b101 (e=10, m=1, odd) +// 4.0 -> 0b110 (e=11, m=0, even) +// 6.0 -> 0b111 (e=11, m=1, odd) +// +// RoundTiesToEven: At exact midpoints between two representable values, +// we round to the one whose mantissa bit is 0 (even). +// +// Midpoints and their rounding targets: +// 0.25 -> midpoint of (0.0, 0.5) -> round to 0.0 (m=0, even) +// 0.75 -> midpoint of (0.5, 1.0) -> round to 1.0 (m=0, even) +// 1.25 -> midpoint of (1.0, 1.5) -> round to 1.0 (m=0, even) +// 1.75 -> midpoint of (1.5, 2.0) -> round to 2.0 (m=0, even) +// 2.5 -> midpoint of (2.0, 3.0) -> round to 2.0 (m=0, even) +// 3.5 -> midpoint of (3.0, 4.0) -> round to 4.0 (m=0, even) +// 5.0 -> midpoint of (4.0, 6.0) -> round to 4.0 (m=0, even) +inline uint8_t quantize_to_e2m1(float val) { + uint8_t sign = (val < 0.0f) ? 1 : 0; + float abs_val = sycl::fabs(val); + + uint8_t code; + // RoundTiesToEven: at midpoints, round to the value with even mantissa (m=0). + // Midpoints use strict < for the upper bound so ties go to the even value. + // TODO(sspintel): Optimize this logic under a LUT to avoid branch divergence. + if (abs_val <= 0.25f) { + code = 0b000; // 0.0 (subnormal: exp=00, m=0) + } else if (abs_val < 0.75f) { + code = 0b001; // 0.5 (subnormal: exp=00, m=1) + } else if (abs_val <= 1.25f) { + code = 0b010; // 1.0 (exp=01, m=0) + } else if (abs_val < 1.75f) { + code = 0b011; // 1.5 (exp=01, m=1) + } else if (abs_val <= 2.5f) { + code = 0b100; // 2.0 (exp=10, m=0) + } else if (abs_val < 3.5f) { + code = 0b101; // 3.0 (exp=10, m=1) + } else if (abs_val <= 5.0f) { + code = 0b110; // 4.0 (exp=11, m=0) + } else { + code = 0b111; // 6.0 (exp=11, m=1) + } + + return (sign << 3) | code; +} + +// Use SYCL native vector type for efficient loading +template +using vec_t = sycl::vec; + +// Compile-time constants for group sizes +template +struct FP4GroupSizeTraits { + static constexpr int THREADS_PER_GROUP = 16; + static constexpr int SUB_GROUP_SIZE = 32; +}; + +template +struct PerTokenGroupQuantFP4Kernel : public __SYCL_KER_CONFIG_CONVENTION__ { + static constexpr uint32_t VEC_SIZE = 16 / sizeof(T); + static constexpr int32_t NUM_VEC_ELEMS = GROUP_SIZE / VEC_SIZE; + static constexpr int32_t THREADS_PER_GROUP = FP4GroupSizeTraits::THREADS_PER_GROUP; + static constexpr int32_t VECS_PER_THREAD = (NUM_VEC_ELEMS + THREADS_PER_GROUP - 1) / THREADS_PER_GROUP; + + PerTokenGroupQuantFP4Kernel( + const T* input, uint8_t* output_q, uint8_t* output_s, int num_groups, int groups_per_block, float eps) + : input(input), + output_q(output_q), + output_s(output_s), + num_groups(num_groups), + groups_per_block(groups_per_block), + eps(eps) {} + + void sycl_ker_config_convention(sycl::handler& cgh) {} + + [[sycl::reqd_sub_group_size(32)]] void operator()(sycl::nd_item<1> item) const { + const int64_t local_group_id = item.get_local_id(0) / THREADS_PER_GROUP; + const int lane_id = item.get_local_id(0) % THREADS_PER_GROUP; + + const int64_t block_group_id = item.get_group(0) * groups_per_block; + const int64_t global_group_id = block_group_id + local_group_id; + + if (global_group_id >= num_groups) return; + + const int64_t block_group_offset = global_group_id * GROUP_SIZE; + + float local_absmax = eps; + + const T* group_input = input + block_group_offset; + // Output is packed FP4 (2 values per byte), so offset is halved + uint8_t* group_output = output_q + (block_group_offset / 2); + + // Calculate scale output position (row-major layout) + // Each row has num_groups_per_row scales, stored contiguously + uint8_t* scale_output = output_s + global_group_id; + + using vec_type = vec_t; + using float_vec_type = vec_t; + + vec_type input_vecs[VECS_PER_THREAD]; + float_vec_type input_vals[VECS_PER_THREAD]; + +#pragma unroll + for (int32_t v = 0; v < VECS_PER_THREAD; ++v) { + const int32_t i = lane_id + v * THREADS_PER_GROUP; + if (i < NUM_VEC_ELEMS) { + input_vecs[v].load( + 0, sycl::multi_ptr(group_input + i * VEC_SIZE)); + +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; ++j) { + float val = static_cast(input_vecs[v][j]); + input_vals[v][j] = val; + local_absmax = sycl::fmax(local_absmax, sycl::fabs(val)); + } + } + } + + // Reduce across the threads in the quantization group to find the maximum + local_absmax = QuantGroupReduceMaxFP4(local_absmax, item); + + // Shared exponent per OCP MX spec / Microsoft micro-scaling: + // shared_exp = floor(log2(absmax)) - E2M1_EMAX + // where E2M1_EMAX = 2. eps already lower-limits local_absmax so + // log2 is well-defined. + float log2_scale = sycl::floor(sycl::log2(local_absmax)) - 2.0f; + int clamped_exponent = sycl::clamp(static_cast(log2_scale), -127, 127); + float scale_value = sycl::exp2(static_cast(clamped_exponent)); + + if (lane_id == 0) { + // Store scale as UE8M0: exponent + 127 bias + uint8_t scale_ue8m0 = static_cast(clamped_exponent + 127); + *scale_output = scale_ue8m0; + } + + const float inv_scale = 1.0f / scale_value; + + // Second pass: quantize and pack values + // Each thread processes VEC_SIZE elements at a time + // Two FP4 values are packed into one byte +#pragma unroll + for (int32_t v = 0; v < VECS_PER_THREAD; ++v) { + const int32_t i = lane_id + v * THREADS_PER_GROUP; + if (i < NUM_VEC_ELEMS) { + // Process VEC_SIZE elements, packing pairs into bytes + uint8_t packed_output[VEC_SIZE / 2]; + +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE; j += 2) { + float val0 = input_vals[v][j] * inv_scale; + float val1 = input_vals[v][j + 1] * inv_scale; + + uint8_t q0 = quantize_to_e2m1(val0); + uint8_t q1 = quantize_to_e2m1(val1); + + // Pack: first value in lower nibble, second in upper nibble + // No masking needed — quantize_to_e2m1 returns values in [0, 15] + packed_output[j / 2] = q0 | (q1 << 4); + } + + // Store packed output + // Each vec of VEC_SIZE elements becomes VEC_SIZE/2 packed bytes + uint8_t* out_ptr = group_output + i * (VEC_SIZE / 2); +#pragma unroll + for (uint32_t j = 0; j < VEC_SIZE / 2; ++j) { + out_ptr[j] = packed_output[j]; + } + } + } + } + + private: + const T* input; + uint8_t* output_q; + uint8_t* output_s; + int num_groups; + int groups_per_block; + float eps; +}; + +void sgl_per_token_group_quant_fp4( + torch::Tensor input, torch::Tensor output_q, torch::Tensor output_s, int64_t group_size, double eps) { + CHECK_CONTIGUOUS(input); + CHECK_CONTIGUOUS(output_q); + + TORCH_CHECK(group_size == 32, "sgl_per_token_group_quant_fp4: group_size must be 32 for MXFP4, got ", group_size); + + TORCH_CHECK( + input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16 || + input.scalar_type() == at::ScalarType::Float, + "sgl_per_token_group_quant_fp4: input dtype must be Float16, BFloat16, or Float32, got ", + input.scalar_type()); + + TORCH_CHECK( + output_q.scalar_type() == at::ScalarType::Byte, + "output_q must be uint8 (packed FP4), got ", + output_q.scalar_type()); + TORCH_CHECK( + output_s.scalar_type() == at::ScalarType::Byte, + "output_s must be uint8 (UE8M0 scales), got ", + output_s.scalar_type()); + + TORCH_CHECK(input.dim() >= 1, "input must have at least 1 dimension"); + TORCH_CHECK( + input.size(-1) % group_size == 0, + "sgl_per_token_group_quant_fp4: last dimension of input (", + input.size(-1), + ") must be divisible by group_size (", + group_size, + ")"); + + const int num_groups = input.numel() / group_size; + + // Output should be half the size (2 FP4 values per byte) + CHECK_EQ(output_q.numel(), input.numel() / 2); + + // Ensure eps is positive to prevent NaN from log2(0) + float eps_f = static_cast(eps); + if (eps_f <= 0.0f) { + eps_f = 1e-10f; + } + + auto queue = dpcppGetCurrentQueue(); + + constexpr int THREADS_PER_GROUP = 16; + + int groups_per_block = 1; + + if (num_groups % 16 == 0) { + groups_per_block = 16; + } else if (num_groups % 8 == 0) { + groups_per_block = 8; + } else if (num_groups % 4 == 0) { + groups_per_block = 4; + } else if (num_groups % 2 == 0) { + groups_per_block = 2; + } + + const int num_blocks = num_groups / groups_per_block; + const int num_threads = groups_per_block * THREADS_PER_GROUP; + + sycl::range<1> global_range(num_blocks * num_threads); + sycl::range<1> local_range(num_threads); + +#define LAUNCH_FP4_KERNEL_WITH_GROUP_SIZE(T, GS) \ + do { \ + auto kernel = PerTokenGroupQuantFP4Kernel( \ + static_cast(input.data_ptr()), \ + static_cast(output_q.data_ptr()), \ + static_cast(output_s.data_ptr()), \ + num_groups, \ + groups_per_block, \ + eps_f); \ + sycl_kernel_submit(global_range, local_range, queue, kernel); \ + } while (0) + +#define LAUNCH_FP4_KERNEL(T) \ + do { \ + LAUNCH_FP4_KERNEL_WITH_GROUP_SIZE(T, 32); \ + } while (0) + + // Dispatch based on input type + if (input.scalar_type() == at::ScalarType::Half) { + LAUNCH_FP4_KERNEL(sycl::half); + } else if (input.scalar_type() == at::ScalarType::BFloat16) { + LAUNCH_FP4_KERNEL(sycl::ext::oneapi::bfloat16); + } else if (input.scalar_type() == at::ScalarType::Float) { + LAUNCH_FP4_KERNEL(float); + } + +#undef LAUNCH_FP4_KERNEL +#undef LAUNCH_FP4_KERNEL_WITH_GROUP_SIZE +} + +} // namespace at::native::xpu diff --git a/src/torch_extension_sycl.cc b/src/torch_extension_sycl.cc index 163e53aa..4e561fef 100644 --- a/src/torch_extension_sycl.cc +++ b/src/torch_extension_sycl.cc @@ -130,6 +130,10 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "sgl_per_token_group_quant_8bit(Tensor input, Tensor output_q, Tensor output_s, int group_size," " float eps, float fp8_min, float fp8_max, bool scale_ue8m0) -> ()"); m.impl("sgl_per_token_group_quant_8bit", torch::kXPU, &at::native::xpu::sgl_per_token_group_quant_8bit); + m.def( + "sgl_per_token_group_quant_fp4(Tensor input, Tensor output_q, Tensor output_s, int group_size," + " float eps) -> ()"); + m.impl("sgl_per_token_group_quant_fp4", torch::kXPU, &at::native::xpu::sgl_per_token_group_quant_fp4); m.def("sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()"); m.impl("sgl_per_tensor_quant_fp8", torch::kXPU, &sgl_per_tensor_quant_fp8); diff --git a/tests/run_suite.py b/tests/run_suite.py index 0b3923b4..58539deb 100644 --- a/tests/run_suite.py +++ b/tests/run_suite.py @@ -23,6 +23,7 @@ class TestFile: TestFile("test_moe_prepare_input.py"), TestFile("test_swiglu_with_alpha_limit.py"), TestFile("test_per_token_group_quant_8bit.py"), + TestFile("test_per_token_group_quant_mxfp4.py"), TestFile("test_moe_fused_gate.py"), TestFile("test_per_tensor_quant_fp8.py"), TestFile("test_fused_qk_norm_rope.py"), diff --git a/tests/test_per_token_group_quant_mxfp4.py b/tests/test_per_token_group_quant_mxfp4.py new file mode 100644 index 00000000..2e129719 --- /dev/null +++ b/tests/test_per_token_group_quant_mxfp4.py @@ -0,0 +1,546 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Tests for MXFP4 (E2M1) Per-Token Group Quantization on Intel XPU + +MXFP4 follows the OpenCompute MX (Microscaling) format specification: +- Data type: E2M1 (4-bit float with 2-bit exponent, 1-bit mantissa) +- Block size: 32 elements per scale factor +- Scale format: UE8M0 (unsigned 8-bit exponent-only, no mantissa) + +Rounding: Per OCP MX spec (section 5.3.3), FP4 conversion uses +roundTiesToEven — at midpoints between representable values, the +value with even mantissa (mantissa bit = 0) is chosen. + +Usage: + pytest test_per_token_group_quant_mxfp4.py -v +""" + +import pytest +import torch + +MXFP4_BLOCK_SIZE = 32 +FLOAT4_E2M1_MAX = 6.0 + +# E2M1 format parameters (from Microsoft microxcaling formats.py) +E2M1_EBITS = 2 +E2M1_MBITS = 3 # includes sign bit and implicit one +E2M1_EMAX = 2 ** (E2M1_EBITS - 1) # = 2 +E2M1_MAX_NORM = ( + 2**E2M1_EMAX * float(2 ** (E2M1_MBITS - 1) - 1) / 2 ** (E2M1_MBITS - 2) +) # = 6.0 + +FP32_EXPONENT_BIAS = 127 +FP32_MIN_NORMAL = 2 ** (-FP32_EXPONENT_BIAS + 1) # 2^(-126) + +kE2M1ToFloat = torch.tensor( + [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32 +) + + +def is_xpu_available() -> bool: + return hasattr(torch, "xpu") and torch.xpu.is_available() + + +def _round_mantissa_even(A: torch.Tensor) -> torch.Tensor: + """Round mantissa using roundTiesToEven (from Microsoft microxcaling). + + At exact 0.5 midpoints (i.e., values like 0.5, 2.5, 4.5, ...), + round to the nearest even integer (the one whose LSB is 0). + + Ref: https://github.com/microsoft/microxcaling/blob/main/mx/elemwise_ops.py + """ + absA = torch.abs(A) + # Identify exact midpoints: 0.5, 2.5, 4.5, ... i.e. (absA - 0.5) % 2 == 0 + maskA = ((absA - 0.5) % 2 == torch.zeros_like(A)).type(A.dtype) + # round half up, then subtract 1 at midpoints to get even + return torch.sign(A) * (torch.floor(absA + 0.5) - maskA) + + +def _quantize_elemwise_core_e2m1( + A: torch.Tensor, saturate_normals: bool = True +) -> torch.Tensor: + """Element-wise quantization to E2M1 using Microsoft microxcaling's + _quantize_elemwise_core algorithm with round='even'. + + E2M1 format: ebits=2, mbits=3, emax=2, max_norm=6.0 + min_exp = -(2^(ebits-1)) + 2 = 0 + + Algorithm (from Microsoft microxcaling elemwise_ops.py): + 1. Compute per-element private exponent = floor(log2(|A|)), + clamped to min_exp. + 2. Left-shift: out = A / 2^private_exp * 2^(mbits-2) + 3. Round mantissa with roundTiesToEven + 4. Right-shift: out = out / 2^(mbits-2) * 2^private_exp + 5. Clamp to [-max_norm, max_norm] if saturate_normals + + Ref: https://github.com/microsoft/microxcaling/blob/main/mx/elemwise_ops.py + """ + ebits = E2M1_EBITS # 2 + mbits = E2M1_MBITS # 3 + max_norm = E2M1_MAX_NORM # 6.0 + + # min representable exponent: -(2^(ebits-1)) + 2 = 0 + min_exp = -(2 ** (ebits - 1)) + 2 # 0 + + out = A.clone() + + # Per-element private exponent: floor(log2(|A|)) + # Add guard for zeros: log2(0) is -inf, we use (A==0) to avoid that + private_exp = torch.floor(torch.log2(torch.abs(A) + (A == 0).type(A.dtype))) + private_exp = private_exp.clip(min=min_exp) + + # Left-shift: scale up so mantissa bits land in integer portion + # out = A / 2^private_exp * 2^(mbits-2) + shift = mbits - 2 # = 1 + out = out / (2**private_exp) * (2**shift) + + # Round mantissa with roundTiesToEven + out = _round_mantissa_even(out) + + # Right-shift: undo scaling + # out = out / 2^(mbits-2) * 2^private_exp + out = out / (2**shift) * (2**private_exp) + + # Saturate to [-max_norm, max_norm] + if saturate_normals: + out = torch.clamp(out, min=-max_norm, max=max_norm) + + return out + + +def _float_to_e2m1_code(val: torch.Tensor) -> torch.Tensor: + """Convert quantized float values back to E2M1 4-bit codes. + + After _quantize_elemwise_core_e2m1, values are one of the 8 representable + E2M1 magnitudes: {0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0}. + This maps them to 4-bit codes (sign in bit 3, magnitude in bits 0-2). + """ + sign = (val < 0).to(torch.uint8) + abs_val = val.abs() + + # Map representable magnitudes to 3-bit indices via the kE2M1ToFloat LUT. + # Use a tolerance-based comparison since values are exact after quantization. + indices = torch.zeros_like(abs_val, dtype=torch.uint8) + lut = kE2M1ToFloat.to(device=val.device) + for i in range(8): + indices = torch.where( + torch.isclose(abs_val, lut[i].expand_as(abs_val), atol=1e-6, rtol=0), + torch.tensor(i, dtype=torch.uint8, device=val.device), + indices, + ) + + return (sign << 3) | indices + + +def quantize_to_e2m1(tensor: torch.Tensor) -> torch.Tensor: + """Quantize tensor values to E2M1 format (4-bit indices). + + Uses the Microsoft microxcaling _quantize_elemwise_core algorithm + with roundTiesToEven, then maps the resulting float values to 4-bit codes. + + Ref: https://github.com/microsoft/microxcaling/blob/main/mx/elemwise_ops.py + """ + quantized_float = _quantize_elemwise_core_e2m1( + tensor.float(), saturate_normals=True + ) + return _float_to_e2m1_code(quantized_float) + + +def pack_fp4(tensor: torch.Tensor) -> torch.Tensor: + assert tensor.shape[-1] % 2 == 0 + shape = tensor.shape[:-1] + (tensor.shape[-1] // 2, 2) + paired = tensor.reshape(shape) + packed = (paired[..., 0] & 0x0F) | ((paired[..., 1] & 0x0F) << 4) + return packed.to(torch.uint8) + + +def _normalize_packed_fp4_signed_zero(packed: torch.Tensor) -> torch.Tensor: + """Canonicalize signed zeros in packed FP4 bytes. + + In E2M1, code 0x0 is +0.0 and code 0x8 is -0.0. Both represent + the same value, but different implementations may emit either form. + This helper rewrites every -0.0 nibble (0x8) to +0.0 (0x0) so that + byte-level comparisons are not tripped up by this harmless difference. + """ + # For each nibble, 0x8 is the only code that equals -0.0 + # (sign=1, exponent=0, mantissa=0). Clear bit 3 whenever the + # lower 3 bits (magnitude) are zero — i.e. the nibble is 0x0 or 0x8. + lo = packed & 0x0F + hi = (packed >> 4) & 0x0F + lo = torch.where(lo == 0x08, torch.zeros_like(lo), lo) + hi = torch.where(hi == 0x08, torch.zeros_like(hi), hi) + return (lo | (hi << 4)).to(torch.uint8) + + +def unpack_fp4(packed: torch.Tensor) -> torch.Tensor: + low = packed & 0x0F + high = (packed >> 4) & 0x0F + unpacked = torch.stack([low, high], dim=-1).reshape(*packed.shape[:-1], -1) + return unpacked + + +def dequantize_e2m1( + quantized: torch.Tensor, dtype: torch.dtype = torch.float32 +) -> torch.Tensor: + sign = ((quantized >> 3) & 1).to(torch.bool) + magnitude_idx = (quantized & 0x07).to(torch.long) + kE2M1 = kE2M1ToFloat.to(device=quantized.device) + magnitude = kE2M1[magnitude_idx] + result = torch.where(sign, -magnitude, magnitude) + return result.to(dtype) + + +def _shared_exponents(A: torch.Tensor, axis: int) -> torch.Tensor: + """Compute shared exponents per block using Microsoft microxcaling's + _shared_exponents algorithm with method="max". + + Algorithm: + 1. shared_exp = max(|A|) along axis (per block) + 2. shared_exp = floor(log2(shared_exp + FP32_MIN_NORMAL * (shared_exp == 0))) + The FP32_MIN_NORMAL guard ensures log2(0) doesn't produce -inf. + 3. Offset by emax: shared_exp = shared_exp - emax + + Ref: https://github.com/microsoft/microxcaling/blob/main/mx/mx_ops.py + """ + shared_exp = torch.max(torch.abs(A), dim=axis, keepdim=True).values + + # floor(log2(...)) with zero-guard from microxcaling + shared_exp = torch.floor( + torch.log2( + shared_exp + FP32_MIN_NORMAL * (shared_exp == 0).type(shared_exp.dtype) + ) + ) + + # Offset by the largest representable exponent in E2M1 + shared_exp = shared_exp - E2M1_EMAX + + return shared_exp + + +def quantize_to_mxfp4( + tensor: torch.Tensor, block_size: int = MXFP4_BLOCK_SIZE, eps: float = 1e-10 +) -> tuple: + """Quantize to MXFP4 using Microsoft microxcaling's _quantize_mx algorithm. + + Algorithm (from mx_ops.py _quantize_mx): + 1. Reshape into blocks + 2. Compute shared exponent per block via _shared_exponents + 3. Clamp shared_exp to scale_emax range [-127, 127] + 4. Scale elements: A = A / 2^shared_exp + 5. Quantize element-wise with _quantize_elemwise_core (saturate_normals=True) + 6. Rescale: A = A * 2^shared_exp (implicitly stored in UE8M0 scale) + + Ref: https://github.com/microsoft/microxcaling/blob/main/mx/mx_ops.py + """ + assert tensor.dim() == 2 + m, k = tensor.shape + assert k % block_size == 0 + assert k % 2 == 0 + + tensor_fp32 = tensor.float() + num_blocks = k // block_size + tensor_blocks = tensor_fp32.reshape(m, num_blocks, block_size) + + # Compute shared exponents (microxcaling _shared_exponents + offset by emax) + shared_exp = _shared_exponents(tensor_blocks, axis=-1) + + # Clamp to UE8M0 scale range: scale_bits=8, scale_emax = 2^(8-1)-1 = 127 + scale_emax = 127 + shared_exp = shared_exp.clamp(min=-scale_emax, max=scale_emax) + + # Encode as UE8M0: stored_scale = shared_exp + 127 + scales_ue8m0 = (shared_exp.to(torch.int32) + 127).to(torch.uint8).squeeze(-1) + + # Scale elements by shared exponent: A = A / 2^shared_exp + scaled_tensor = tensor_blocks / (2.0**shared_exp) + + # Quantize element-wise with microxcaling core (roundTiesToEven, saturate) + quantized_float = _quantize_elemwise_core_e2m1(scaled_tensor, saturate_normals=True) + + # Convert quantized float values to 4-bit E2M1 codes + quantized_blocks = _float_to_e2m1_code(quantized_float) + + quantized = quantized_blocks.reshape(m, k) + packed = pack_fp4(quantized) + + return packed, scales_ue8m0 + + +def dequantize_mxfp4( + packed: torch.Tensor, + scales: torch.Tensor, + dtype: torch.dtype = torch.float32, + block_size: int = MXFP4_BLOCK_SIZE, +) -> torch.Tensor: + m, packed_k = packed.shape + k = packed_k * 2 + + unpacked = unpack_fp4(packed) + dequantized = dequantize_e2m1(unpacked, dtype) + + num_blocks = k // block_size + dequantized_blocks = dequantized.reshape(m, num_blocks, block_size) + + scale_exp = scales.to(torch.int32) - 127 + scale_values = torch.pow(2.0, scale_exp.float()).unsqueeze(-1) + scaled = dequantized_blocks * scale_values + + return scaled.reshape(m, k).to(dtype) + + +class TestMXFP4ReferenceQuantization: + def test_e2m1_roundtrip(self): + device = torch.device("cpu") + test_values = torch.tensor( + [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, + ], + dtype=torch.float32, + device=device, + ) + quantized = quantize_to_e2m1(test_values) + dequantized = dequantize_e2m1(quantized) + torch.testing.assert_close(dequantized, test_values, atol=0.0, rtol=0.0) + + def test_e2m1_round_ties_to_even(self): + """Test that midpoints between representable values round to even (m=0). + + Per OCP MX spec section 5.3.3, FP4 must use roundTiesToEven. + At midpoints, the value with even mantissa (m=0) is chosen. + """ + device = torch.device("cpu") + # Midpoint values and their expected quantized results + # (midpoint_value, expected_dequantized_value) + midpoint_tests = [ + (0.25, 0.0), # midpoint of (0.0, 0.5) -> 0.0 (m=0, even) + (0.75, 1.0), # midpoint of (0.5, 1.0) -> 1.0 (m=0, even) + (1.25, 1.0), # midpoint of (1.0, 1.5) -> 1.0 (m=0, even) + (1.75, 2.0), # midpoint of (1.5, 2.0) -> 2.0 (m=0, even) + (2.5, 2.0), # midpoint of (2.0, 3.0) -> 2.0 (m=0, even) + (3.5, 4.0), # midpoint of (3.0, 4.0) -> 4.0 (m=0, even) + (5.0, 4.0), # midpoint of (4.0, 6.0) -> 4.0 (m=0, even) + # Negative midpoints + (-0.25, 0.0), # -> -0.0 = 0.0 + (-0.75, -1.0), + (-1.25, -1.0), + (-1.75, -2.0), + (-2.5, -2.0), + (-3.5, -4.0), + (-5.0, -4.0), + ] + for midpoint, expected in midpoint_tests: + tensor = torch.tensor([midpoint], dtype=torch.float32, device=device) + quantized = quantize_to_e2m1(tensor) + dequantized = dequantize_e2m1(quantized) + # For -0.25, dequantized is -0.0 which equals 0.0 + assert dequantized.item() == expected or ( + expected == 0.0 and abs(dequantized.item()) == 0.0 + ), f"Midpoint {midpoint}: expected {expected}, got {dequantized.item()}" + + def test_pack_unpack_roundtrip(self): + device = torch.device("cpu") + m, k = 16, 64 + original = torch.randint(0, 16, (m, k), dtype=torch.uint8, device=device) + packed = pack_fp4(original) + unpacked = unpack_fp4(packed) + torch.testing.assert_close(unpacked, original) + + def test_mxfp4_quantization_shape(self): + device = torch.device("cpu") + m, k = 32, 128 + original = torch.randn(m, k, dtype=torch.float32, device=device) + packed, scales = quantize_to_mxfp4(original) + assert packed.shape == (m, k // 2) + assert scales.shape == (m, k // MXFP4_BLOCK_SIZE) + assert packed.dtype == torch.uint8 + assert scales.dtype == torch.uint8 + + def test_mxfp4_dequantization_accuracy(self): + device = torch.device("cpu") + m, k = 32, 128 + original = torch.randn(m, k, dtype=torch.float32, device=device) * 3.0 + packed, scales = quantize_to_mxfp4(original) + dequantized = dequantize_mxfp4(packed, scales, torch.float32) + assert dequantized.shape == original.shape + relative_error = (dequantized - original).abs() / (original.abs() + 1e-6) + mean_error = relative_error.mean().item() + assert mean_error < 0.5 + + +@pytest.mark.skipif(not is_xpu_available(), reason="XPU not available") +class TestPerTokenGroupQuantFP4XPU: + @pytest.fixture(autouse=True) + def setup(self): + import utils + + self.device = utils.get_device() + self.eps = 1e-10 + + def _import_kernel(self): + try: + from sgl_kernel import sgl_per_token_group_quant_fp4 + + return sgl_per_token_group_quant_fp4 + except ImportError: + pytest.skip("sgl_per_token_group_quant_fp4 kernel not available") + + def _test_against_reference( + self, + num_tokens: int, + hidden_dim: int, + src_dtype: torch.dtype = torch.bfloat16, + seed: int = 42, + ): + sgl_per_token_group_quant_fp4 = self._import_kernel() + group_size = MXFP4_BLOCK_SIZE + + torch.manual_seed(seed) + + x_cpu = torch.randn(num_tokens, hidden_dim, dtype=src_dtype, device="cpu") + x_q_ref, scales_ref = quantize_to_mxfp4(x_cpu.float(), group_size, eps=self.eps) + + x_xpu = x_cpu.to(self.device) + x_q_xpu, scales_xpu = sgl_per_token_group_quant_fp4( + x=x_xpu, + group_size=group_size, + eps=self.eps, + ) + + x_q_xpu_cpu = x_q_xpu.cpu() + scales_xpu_cpu = scales_xpu.cpu() + + assert ( + x_q_xpu_cpu.shape == x_q_ref.shape + ), f"Quantized shape mismatch: {x_q_xpu_cpu.shape} vs {x_q_ref.shape}" + assert ( + scales_xpu_cpu.shape == scales_ref.shape + ), f"Scales shape mismatch: {scales_xpu_cpu.shape} vs {scales_ref.shape}" + assert x_q_xpu_cpu.dtype == torch.uint8 + assert scales_xpu_cpu.dtype == torch.uint8 + + # Compare quantized values directly (packed uint8). + # Normalise signed zeros first: in E2M1 code 0x0 (+0.0) and 0x8 + # (-0.0) are semantically identical. The kernel may preserve the + # sign of the original float while the reference always emits +0.0, + # so we canonicalise before comparing. + x_q_xpu_norm = _normalize_packed_fp4_signed_zero(x_q_xpu_cpu) + x_q_ref_norm = _normalize_packed_fp4_signed_zero(x_q_ref) + q_match = torch.equal(x_q_xpu_norm, x_q_ref_norm) + if not q_match: + q_mismatches = (x_q_xpu_norm != x_q_ref_norm).sum().item() + total = x_q_ref_norm.numel() + assert ( + q_mismatches / total < 0.05 + ), f"Too many quantized value mismatches: {q_mismatches}/{total}" + + # Compare scale exponents (allow ±1 difference due to rounding) + scale_exp_ref = scales_ref.to(torch.int32) - 127 + scale_exp_xpu = scales_xpu_cpu.to(torch.int32) - 127 + exp_diff = (scale_exp_ref - scale_exp_xpu).abs() + assert exp_diff.max() == 0, f"Scale exponent difference: {exp_diff.max()}" + + # Compare dequantized outputs + x_dq_ref = dequantize_mxfp4(x_q_ref, scales_ref, torch.float32, group_size) + x_dq_xpu = dequantize_mxfp4( + x_q_xpu_cpu, scales_xpu_cpu, torch.float32, group_size + ) + torch.testing.assert_close(x_dq_xpu, x_dq_ref, rtol=0.0, atol=0.0) + + @pytest.mark.parametrize( + "num_tokens,hidden_dim,src_dtype", + [ + (128, 256, torch.bfloat16), + (64, 128, torch.float16), + (64, 128, torch.float32), + (256, 2048, torch.bfloat16), + ], + ) + def test_quantization_vs_reference(self, num_tokens, hidden_dim, src_dtype): + self._test_against_reference(num_tokens, hidden_dim, src_dtype) + + def test_quantize_dequantize_roundtrip(self): + sgl_per_token_group_quant_fp4 = self._import_kernel() + + torch.manual_seed(42) + m, k = 128, 256 + + x_cpu = torch.randn(m, k, dtype=torch.bfloat16, device="cpu") + x_xpu = x_cpu.to(self.device) + + x_q, scales = sgl_per_token_group_quant_fp4( + x=x_xpu, group_size=MXFP4_BLOCK_SIZE + ) + + x_dq = dequantize_mxfp4( + x_q.cpu(), scales.cpu(), torch.float32, MXFP4_BLOCK_SIZE + ) + + correlation = torch.corrcoef( + torch.stack([x_dq.flatten(), x_cpu.float().flatten()]) + )[0, 1] + assert correlation > 0.9, f"Correlation too low: {correlation}" + + def test_round_ties_to_even_on_xpu(self): + """Test that the kernel implements roundTiesToEven at midpoints.""" + sgl_per_token_group_quant_fp4 = self._import_kernel() + + # Create a tensor of exactly 32 elements (one group) containing + # midpoint values. Scale will be 1.0 (exponent=0) since max abs is 5.0 + # which maps to scale = 2^(floor(log2(5.0)) - 2) = 2^(2 - 2) = 2^0 = 1.0 + midpoints = [ + 0.25, + 0.75, + 1.25, + 1.75, + 2.5, + 3.5, + 5.0, + -0.25, + -0.75, + -1.25, + -1.75, + -2.5, + -3.5, + -5.0, + ] + # Pad to 32 elements with zeros + padded = midpoints + [0.0] * (32 - len(midpoints)) + x = torch.tensor([padded], dtype=torch.float32, device=self.device) + + x_q, scales = sgl_per_token_group_quant_fp4( + x=x, group_size=MXFP4_BLOCK_SIZE, eps=self.eps + ) + + # Reference + x_q_ref, scales_ref = quantize_to_mxfp4( + x.cpu().float(), MXFP4_BLOCK_SIZE, eps=self.eps + ) + + x_dq_xpu = dequantize_mxfp4( + x_q.cpu(), scales.cpu(), torch.float32, MXFP4_BLOCK_SIZE + ) + x_dq_ref = dequantize_mxfp4( + x_q_ref, scales_ref, torch.float32, MXFP4_BLOCK_SIZE + ) + + torch.testing.assert_close(x_dq_xpu, x_dq_ref, atol=0.0, rtol=0.0) + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main([__file__, "-v"])) From 9bdad0b6f784a1d3bc980f80911fa687ca57fcf8 Mon Sep 17 00:00:00 2001 From: "jiwei1.sun" Date: Mon, 23 Mar 2026 14:58:09 +0800 Subject: [PATCH 04/23] add page 64 --- .../xe_fmha_fwd_decode_runner.hpp | 36 +++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp b/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp index 3e3546f0..38dff69e 100644 --- a/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp +++ b/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp @@ -1102,12 +1102,12 @@ std::vector mha_fwd( auto dispatch_page_size = [&](auto _QG_SZ, auto _HEAD_DIM) { switch (params.page_size) { - // case 32: - // launch_kernel(_QG_SZ, _HEAD_DIM, _32{}, _2{}); - // break; - // case 64: - // launch_kernel(_QG_SZ, _HEAD_DIM, _64{}, _4{}); - // break; + case 32: + launch_kernel(_QG_SZ, _HEAD_DIM, _32{}, _2{}); + break; + case 64: + launch_kernel(_QG_SZ, _HEAD_DIM, _64{}, _4{}); + break; case 128: launch_kernel(_QG_SZ, _HEAD_DIM, _128{}, _8{}); break; @@ -1142,21 +1142,21 @@ std::vector mha_fwd( }; switch (params.d) { - // case 64: - // dispatch_q_group(_64{}); - // break; - // case 96: - // dispatch_q_group(_96{}); - // break; + case 64: + dispatch_q_group(_64{}); + break; + case 96: + dispatch_q_group(_96{}); + break; case 128: dispatch_q_group(_128{}); break; - // case 192: - // dispatch_q_group(_192{}); - // break; - // case 256: - // dispatch_q_group(_256{}); - // break; + case 192: + dispatch_q_group(_192{}); + break; + case 256: + dispatch_q_group(_256{}); + break; default: TORCH_CHECK(false, "Unsupported head size for decode attention: ", params.d); } From 3c1476861127761c39ab0a97c806a0666cc134c9 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 23 Mar 2026 07:17:34 +0000 Subject: [PATCH 05/23] Initial plan From 644ff837e40946c6a775e14557e16dbaf7657da6 Mon Sep 17 00:00:00 2001 From: "jiwei1.sun" Date: Fri, 13 Mar 2026 10:46:38 +0800 Subject: [PATCH 06/23] add reduce.h --- .../kernel/xe_fhma_fwd_kernel.hpp | 331 ++++++++++++++++++ .../kernel/xe_reduce_split_k.h | 304 ++++++++++++++++ 2 files changed, 635 insertions(+) create mode 100644 src/sycl/kernels/flash_attention_v2/kernel/xe_reduce_split_k.h diff --git a/src/sycl/kernels/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp b/src/sycl/kernels/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp index 711d4337..755a3b5c 100644 --- a/src/sycl/kernels/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp +++ b/src/sycl/kernels/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp @@ -766,4 +766,335 @@ class XeFMHAFwdDynamicSplitKernel { } }; +template +class XeFMHAFwdSplitKVKernel { + public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + using VariableLength = cutlass::fmha::collective::VariableLength; + static constexpr bool is_var_len = cutlass::fmha::collective::is_variable_length_v; + using CollectiveMainloop = CollectiveMainloop_; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + + using TiledMMAQK = typename CollectiveMainloop::TiledMMAQK; + using TiledMMAPV = typename CollectiveMainloop::TiledMMAPV; + using TileShapeQK = typename CollectiveMainloop::TileShapeQK; + using TileShapePV = typename CollectiveMainloop::TileShapePV; + using SubgroupLayoutQK = typename CollectiveMainloop::SubgroupLayoutQK; + using ElementQ = typename CollectiveMainloop::TensorQ::element_type; + using ElementK = typename CollectiveMainloop::TensorK::element_type; + using ElementV = typename CollectiveMainloop::TensorV::element_type; + + using StrideQ = decltype(stride(typename CollectiveMainloop::TensorQ{})); + using StrideK = decltype(stride(typename CollectiveMainloop::TensorK{})); + using StrideV = decltype(stride(typename CollectiveMainloop::TensorV{})); + + using SGPerWG = typename CollectiveMainloop::SGPerWG; + + using FragA = typename CollectiveMainloop::FragA; + using FragARow = typename CollectiveMainloop::FragARow; + + // Tile scheduler derived types + using TileScheduler = TileScheduler_; + using TileSchedulerParams = typename TileScheduler::Params; + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + using TileShapeO = typename CollectiveEpilogue::TileShapeO; + using ElementO = typename CollectiveEpilogue::TensorO::element_type; + using ElementLSE = typename CollectiveEpilogue::ElementLSE; + using StrideO = decltype(stride(typename CollectiveEpilogue::TensorO{})); + + // Kernel level shared memory storage + using MainloopSharedStorage = typename CollectiveMainloop::SharedStorage; + using EpilogueSharedStorage = typename CollectiveEpilogue::SharedStorage; + union SharedStorage { + MainloopSharedStorage mainloop; + EpilogueSharedStorage epilogue; + }; + + static constexpr int SharedStorageSize = is_empty_v ? size_t(0) : sizeof(SharedStorage); + + static constexpr int max_num_kv_splits = SGPerWG::value * intel::sg_size; + static constexpr int dpas_max_repeat_count = 8; + static constexpr bool Sink = CollectiveEpilogue::Sink; + using ElementSink = typename CollectiveEpilogue::ElementSink; + + // Device side arguments + struct KernelArguments { + ProblemShape shape; + const ElementQ* Q; + StrideQ dQ; + const ElementK* K; + StrideK dK; + const ElementV* V; + StrideV dV; + ElementO* Oaccum; + StrideO dOaccum; + ElementLSE* exp_sums; + StrideO dExp_sums; + ElementLSE* max_logits; + StrideO dMax_logits; + + const ElementSink* sm_sink; + }; + using KernelParams = KernelArguments; + + struct Arguments { + KernelArguments kernel{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + int num_kv_splits = -1; // no split by default + }; + + // Kernel entry point API + struct Params { + KernelParams kernel; + MainloopParams mainloop; + EpilogueParams epilogue; + TileSchedulerParams scheduler; + }; + + // + // Methods + // + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return { + args.kernel, + CollectiveMainloop::to_underlying_arguments(args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.epilogue, workspace), + TileScheduler::to_underlying_arguments(args.kernel.shape, args.hw_info, TileShapeO{}, args.num_kv_splits)}; + } + + static bool can_implement(Arguments const& args) { + if (!is_var_len && args.kernel.shape.seq_len_qo != 1) { + // decode only + return false; + } + + if (args.num_kv_splits > max_num_kv_splits) { + return false; + } + + return CollectiveMainloop::can_implement(args.mainloop) && CollectiveEpilogue::can_implement(args.epilogue); + } + + static int get_workspace_size(Arguments const& args) { + return 0; + } + + static cutlass::Status initialize_workspace( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + return Status::kSuccess; + } + + static dim3 get_grid_shape(Params const& params) { + return TileScheduler::template get_grid_shape(params.scheduler); + } + + static dim3 get_block_shape() { + return dim3(SGPerWG::value * intel::sg_size, 1, 1); + } + + CUTLASS_DEVICE + Shape get_sequence_length_shape(ProblemShape const& problem_shape, int const& batch) { + if constexpr (is_var_len) { + auto q_len = + cutlass::fmha::collective::apply_variable_length(Shape{problem_shape.seq_len_qo}, batch); + return Shape{get<0>(q_len), problem_shape.seq_len_kv.cumulative_length[batch]}; + } else { + return Shape{problem_shape.seq_len_qo, problem_shape.seq_len_kv}; + } + } + + CUTLASS_DEVICE + void operator()(Params const& params, char* smem_buf) { + using namespace sycl::ext::oneapi::this_work_item; + + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + auto& p = params.kernel; + ProblemShape const& s = p.shape; + int head_group_q = s.num_heads_q / s.num_heads_kv; + + int thr_id = int(ThreadIdxX()); + int sub_group_id = thr_id / intel::sg_size; + int q_sg_tile = get<0>(shape_div(TileShapeQK{}, shape(SubgroupLayoutQK{}))); + + auto cS = make_identity_tensor(take<0, 2>(TiledMMAQK{}.tile_mnk())); + auto tScS = TiledMMAQK{}.get_slice(thr_id).partition_C(cS); + auto q_offset_wi = get<0>(tScS(0)); + auto q_offset_sg = group_broadcast(sycl::ext::oneapi::this_work_item::get_sub_group(), q_offset_wi, 0); + + TileScheduler tile_scheduler{params.scheduler}; + auto num_kv_splits = params.scheduler.num_kv_splits_; + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto [blk_q, blk_v, head, idx_b, idx_kv_split] = tile_scheduler.get_block_coord(); // (Q,V,h,b,id_split) + auto blk_qv = make_coord(blk_q, blk_v); + int head_q_start = head * head_group_q; + + auto sequence_length_shape = get_sequence_length_shape(s, idx_b); + auto [seq_len_qo, seq_len_kv] = sequence_length_shape; + if (blk_q * get<0>(TileShapeQK{}) >= seq_len_qo) continue; + + auto offset = cute::min(seq_len_qo, seq_len_kv); + auto discard_seq_coord = seq_len_qo - offset; + auto full_tile_offset = seq_len_kv - offset; + int seq_coord = cute::min(seq_len_qo, (blk_q * get<0>(TileShapeQK{}) + q_offset_sg)); + + if (CollectiveMainloop::CausalMask && seq_coord < discard_seq_coord) continue; + // For decode window_size_right doesn't have effect + const int seq_len = seq_len_kv; + // For decode, all packed GQA heads are at position seq_len_kv - 1. + // Use seq_len - 1 (= seq_len_kv - 1) as the decode position for + // k_block0 to match ReduceSplitK's computation. + const int k_block0 = CollectiveMainloop::LocalMask + ? cute::max(seq_len - 1 - params.mainloop.window_size_left, 0) / get<1>(TileShapeQK{}) + : 0; + const int k_blocks = cute::ceil_div(seq_len, get<1>(TileShapeQK{})); + const int windowed_k_blocks = k_blocks - k_block0; + + int offset_q = 0, offset_k = 0, offset_v = 0, offset_o = 0; + int offset_exp_sums = 0, offset_max_logits = 0; + if constexpr (is_var_len) { + auto qo_cumulative = s.seq_len_qo.cumulative_length; + + offset_q = s.num_heads_q * s.head_size_qk * qo_cumulative[idx_b]; + offset_o = s.num_heads_q * s.head_size_vo * num_kv_splits * qo_cumulative[idx_b]; + offset_exp_sums = s.num_heads_q * num_kv_splits * qo_cumulative[idx_b]; + offset_max_logits = s.num_heads_q * num_kv_splits * qo_cumulative[idx_b]; + + // for gqa packing, seq_len_qo must be 1 + seq_len_qo = 1; + } + + // neglect seq_len_qo since it's always 1 for decode + auto batch_dim = is_var_len ? 1 : s.batch; + auto shape_Q = make_shape(head_group_q, s.head_size_qk, s.num_heads_kv, batch_dim); + // shape + auto total_seqlen_kv = params.mainloop.total_seqlen_kv; + auto shape_K = make_shape(total_seqlen_kv, s.head_size_qk, s.num_heads_kv, batch_dim); + auto shape_V = make_shape(s.head_size_vo, total_seqlen_kv, s.num_heads_kv, batch_dim); + + auto shape_O = make_shape(head_group_q, s.head_size_vo, s.num_heads_kv, num_kv_splits, batch_dim); + auto shape_exp_sums = make_shape(head_group_q, num_kv_splits, s.num_heads_kv, batch_dim); + auto shape_max_logits = make_shape(head_group_q, num_kv_splits, s.num_heads_kv, batch_dim); + auto shape_sink = make_shape(s.num_heads_kv, head_group_q); + + int num_blocks_per_split = cute::ceil_div(windowed_k_blocks, num_kv_splits); + int kv_split_offset = k_block0 + idx_kv_split * num_blocks_per_split; + int num_effective_kv_blocks = + cute::min(windowed_k_blocks - idx_kv_split * num_blocks_per_split, num_blocks_per_split); + + if (num_effective_kv_blocks <= 0) { + // no need computation + continue; + } + + auto dcQ = const_cast(p.Q + offset_q); + auto dcK = const_cast(p.K); + auto dcV = const_cast(p.V); + auto ptrO = p.Oaccum + offset_o; + auto ptrExp_sums = p.exp_sums + offset_exp_sums; + auto ptrMax_logits = p.max_logits + offset_max_logits; + + auto layout_q = make_ordered_layout(shape_Q, Step<_1, _0, _2, _3>{}); + auto layout_k = make_ordered_layout(shape_K, Step<_2, _0, _1, _3>{}); + auto layout_v = make_ordered_layout(shape_V, Step<_0, _2, _1, _3>{}); + + auto layout_o = make_ordered_layout(shape_O, Step<_1, _0, _2, _3, _4>{}); + auto layout_exp_sums = make_ordered_layout(shape_exp_sums, Step<_1, _0, _2, _3>{}); + auto layout_max_logits = make_ordered_layout(shape_max_logits, Step<_1, _0, _2, _3>{}); + auto layout_sink = make_ordered_layout(shape_sink, Step<_1, _0>{}); + + Tensor Q = make_tensor(make_gmem_ptr(dcQ), layout_q); + Tensor K = make_tensor(make_gmem_ptr(dcK), layout_k); + Tensor V = make_tensor(make_gmem_ptr(dcV), layout_v); + Tensor O = make_tensor(make_gmem_ptr(ptrO), layout_o); + Tensor exp_sums = make_tensor(make_gmem_ptr(ptrExp_sums), layout_exp_sums); + Tensor max_logits = make_tensor(make_gmem_ptr(ptrMax_logits), layout_max_logits); + Tensor sinks = make_tensor(make_gmem_ptr(const_cast(p.sm_sink)), layout_sink); + + // O accumulator types + FragA tArA; + FragARow tA_max, tA_sum; + + // Main loop + int l_coord = is_var_len ? 0 : idx_b; + + int start_blk = kv_split_offset; + int end_blk = kv_split_offset + num_effective_kv_blocks; + + CollectiveMainloop mainloop(params.mainloop, shared_storage.mainloop); + + mainloop( + Q(_, _, head, l_coord), + K(_, _, head, l_coord), + V(_, _, head, l_coord), + tArA, + tA_max, + tA_sum, + blk_qv, + idx_b, + start_blk, + end_blk, + k_blocks, + thr_id, + seq_len, + full_tile_offset, + discard_seq_coord); + + if constexpr (!is_empty_v && !is_empty_v) { + sycl::group_barrier(get_work_group<3>()); + } + + // Epilogue + CollectiveEpilogue epilogue{params.epilogue, shared_storage.epilogue}; + if constexpr (Sink) { + auto sinks_per_kv = sinks(head, _); + epilogue( + O(_, _, head, idx_kv_split, l_coord), + tArA, + tA_max, + tA_sum, + blk_qv, + thr_id, + exp_sums(_, _, head, l_coord), + max_logits(_, _, head, l_coord), + idx_kv_split, + head_group_q, + sinks_per_kv, + num_kv_splits); + } else { + epilogue( + O(_, _, head, idx_kv_split, l_coord), + tArA, + tA_max, + tA_sum, + blk_qv, + thr_id, + exp_sums(_, _, head, l_coord), + max_logits(_, _, head, l_coord), + idx_kv_split, + head_group_q, + sinks, + num_kv_splits); + } + } + } +}; + } // namespace cutlass::fmha::kernel diff --git a/src/sycl/kernels/flash_attention_v2/kernel/xe_reduce_split_k.h b/src/sycl/kernels/flash_attention_v2/kernel/xe_reduce_split_k.h new file mode 100644 index 00000000..94452a2e --- /dev/null +++ b/src/sycl/kernels/flash_attention_v2/kernel/xe_reduce_split_k.h @@ -0,0 +1,304 @@ +/*************************************************************************************************** + * Copyright (C) 2025-2026 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Kernel performing a reduction over densely packed tensors in global memory +*/ + +#pragma once + +#include "cute/util/type_traits.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "sycl/kernels/flash_attention_v2/collective/fmha_fusion.hpp" +#include "sycl/kernels/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp" +#include "sycl/kernels/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp" +#include "sycl/kernels/flash_attention_v2/kernel/xe_tile_scheduler.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reduction { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class ReduceSplitK { + public: + using ProblemShape = ProblemShape_; + using VariableLength = cutlass::fmha::collective::VariableLength; + static constexpr bool is_var_len = cutlass::fmha::collective::is_variable_length_v; + using TileScheduler = TileScheduler_; + static_assert( + is_same_v, + "ReduceSplitK kernel requires XeReduceSplitKTileScheduler"); + using TileSchedulerParams = typename TileScheduler::Params; + + using ElementO = typename FMHAKernel_::ElementO; + using StrideO = typename FMHAKernel_::StrideO; + using TileShapeO = typename FMHAKernel_::TileShapeO; + using TileShapeQK = typename FMHAKernel_::TileShapeQK; + + using ElementLSE = typename FMHAKernel_::ElementLSE; + + using SGPerWG = typename FMHAKernel_::SGPerWG; + + // num values (head_dim) processed by each thread + constexpr static int num_vals_per_thread = int(get<1>(TileShapeO{}) / (SGPerWG::value * intel::sg_size)); + + // + // Types + // + + struct KernelArguments { + ProblemShape shape; + // outputs: + ElementO* O; + StrideO dO; + // below are inputs + // TODO: whether same dtype as output or accum? + const ElementO* Oaccum; + StrideO dOaccum; + const ElementLSE* exp_sums; + StrideO dExp_sums; + const ElementLSE* max_logits; + StrideO dMax_logits; + int window_size_left = -1; + }; + using KernelParams = KernelArguments; + + struct Arguments { + KernelArguments kernel{}; + KernelHardwareInfo hw_info{}; + int num_kv_splits = -1; // no split by default + }; + + /// Params structure + struct Params { + KernelParams kernel; + TileSchedulerParams scheduler; + }; + + struct SharedStorage { + cutlass::Array max_logits_slm_array; + cutlass::Array exp_sums_slm_array; + }; + + static constexpr int SharedStorageSize = is_empty_v ? size_t(0) : sizeof(SharedStorage); + + public: + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return { + args.kernel, + TileScheduler::to_underlying_arguments(args.kernel.shape, args.hw_info, TileShapeO{}, args.num_kv_splits)}; + } + + static bool can_implement(Arguments const& args) { + // only support decode + if (!is_var_len && args.kernel.shape.seq_len_qo > 1) { + return false; + } + + if (args.num_kv_splits > FMHAKernel_::max_num_kv_splits) { + return false; + } + return true; + } + + static int get_workspace_size(Arguments const& args) { + return 0; + } + + static cutlass::Status initialize_workspace( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + return Status::kSuccess; + } + + static dim3 get_grid_shape(Params const& params) { + return TileScheduler::template get_grid_shape(params.scheduler); + } + + static dim3 get_block_shape() { + return dim3(SGPerWG::value * intel::sg_size, 1, 1); + } + + CUTLASS_DEVICE + Shape get_sequence_length_shape(ProblemShape const& problem_shape, int const& batch) { + if constexpr (is_var_len) { + auto q_len = + cutlass::fmha::collective::apply_variable_length(Shape{problem_shape.seq_len_qo}, batch); + return Shape{get<0>(q_len), problem_shape.seq_len_kv.cumulative_length[batch]}; + } else { + return Shape{problem_shape.seq_len_qo, problem_shape.seq_len_kv}; + } + } + + /// Perform a reduction + CUTLASS_DEVICE + void operator()(Params const& params, char* smem_buf) { + using namespace sycl::ext::oneapi::this_work_item; + + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + auto& p = params.kernel; + ProblemShape const& s = p.shape; + + int thr_id = int(ThreadIdxX()); + int sub_group_id = thr_id / intel::sg_size; + int tid_in_sg = thr_id % intel::sg_size; + + TileScheduler tile_scheduler{params.scheduler}; + auto num_kv_splits = params.scheduler.num_kv_splits; + + auto batch_dim = is_var_len ? 1 : s.batch; + auto num_heads_q = s.num_heads_q; + auto head_size_vo = s.head_size_vo; + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto [seq_idx, head_q, idx_b] = tile_scheduler.get_block_coord(); + + auto sequence_length_shape = get_sequence_length_shape(s, idx_b); + auto [seq_len_qo, seq_len_kv] = sequence_length_shape; + + // when varlen enabled, use largest seq_len_qo to decide work group num + if (seq_idx >= seq_len_qo) continue; + + const int k_blocks = cute::ceil_div(seq_len_kv, get<1>(TileShapeQK{})); + // Sliding window: skip blocks before the window + constexpr bool LocalMask = FMHAKernel_::CollectiveMainloop::LocalMask; + const int k_block0 = LocalMask ? cute::max(seq_len_kv - 1 - p.window_size_left, 0) / get<1>(TileShapeQK{}) : 0; + const int windowed_k_blocks = k_blocks - k_block0; + int num_blocks_per_split = cute::ceil_div(windowed_k_blocks, num_kv_splits); + + int offset_o = 0, offset_o_accum = 0; + int offset_exp_sums = 0, offset_max_logits = 0; + + if constexpr (is_var_len) { + auto qo_cumulative = s.seq_len_qo.cumulative_length; + + offset_o_accum = s.num_heads_q * s.head_size_vo * num_kv_splits * qo_cumulative[idx_b]; + offset_exp_sums = s.num_heads_q * num_kv_splits * qo_cumulative[idx_b]; + offset_max_logits = s.num_heads_q * num_kv_splits * qo_cumulative[idx_b]; + + offset_o = s.num_heads_q * s.head_size_vo * qo_cumulative[idx_b]; + } + + auto shape_O = make_shape(seq_len_qo, head_size_vo, num_heads_q, batch_dim); + auto shape_Oaccum = is_var_len ? make_shape(seq_len_qo, head_size_vo, num_heads_q * num_kv_splits, batch_dim) + : make_shape(seq_len_qo, head_size_vo, num_heads_q * num_kv_splits, batch_dim); + + auto shape_exp_sums = make_shape(seq_len_qo, num_kv_splits, num_heads_q, batch_dim); + auto shape_max_logits = make_shape(seq_len_qo, num_kv_splits, num_heads_q, batch_dim); + + auto dcOaccum = const_cast(p.Oaccum + offset_o_accum); + auto ptrExp_sums = const_cast(p.exp_sums + offset_exp_sums); + auto ptrMax_logits = const_cast(p.max_logits + offset_max_logits); + auto ptrO = p.O + offset_o; + + auto stride_o = is_var_len ? cutlass::make_cute_packed_stride(StrideO{}, shape_O) : p.dO; + auto stride_o_accum = is_var_len ? cutlass::make_cute_packed_stride(StrideO{}, shape_Oaccum) : p.dOaccum; + auto stride_exp_sums = is_var_len ? cutlass::make_cute_packed_stride(StrideO{}, shape_exp_sums) : p.dExp_sums; + auto stride_max_logits = + is_var_len ? cutlass::make_cute_packed_stride(StrideO{}, shape_max_logits) : p.dMax_logits; + + Tensor Oaccum = make_tensor(make_gmem_ptr(dcOaccum), make_layout(shape_Oaccum, stride_o_accum)); + Tensor O = make_tensor(make_gmem_ptr(ptrO), make_layout(shape_O, stride_o)); + + Tensor exp_sums = make_tensor(make_gmem_ptr(ptrExp_sums), make_layout(shape_exp_sums, stride_exp_sums)); + Tensor max_logits = make_tensor(make_gmem_ptr(ptrMax_logits), make_layout(shape_max_logits, stride_max_logits)); + + int l_coord = is_var_len ? 0 : idx_b; + + // Step 1: reduce max logits across different partitions + // store into SLM for later use + + ElementLSE global_max_logits{cutlass::platform::numeric_limits::lowest()}; + ElementLSE global_exp_sums{0}; + // only first subgroup participates + if (thr_id < num_kv_splits && thr_id * num_blocks_per_split < windowed_k_blocks) { + ElementLSE cur_max_logit = max_logits(seq_idx, thr_id, head_q, l_coord); + global_max_logits = sycl::max(global_max_logits, cur_max_logit); + shared_storage.max_logits_slm_array[thr_id] = cur_max_logit; + + ElementLSE cur_exp_sum = exp_sums(seq_idx, thr_id, head_q, l_coord); + shared_storage.exp_sums_slm_array[thr_id] = cur_exp_sum; + } + + // barrier for SLM writes finished + sycl::group_barrier(get_work_group<3>()); + + // reduce across wg + global_max_logits = reduce_over_group(get_work_group<1>(), global_max_logits, sycl::maximum<>()); + + // broadcast to all other threads + global_max_logits = sycl::group_broadcast(get_work_group<1>(), global_max_logits, 0); + + for (int idx = thr_id; idx < s.head_size_vo; idx += SGPerWG::value * intel::sg_size) { + ElementLSE acc = 0; + global_exp_sums = 0; + for (int i = 0; i < num_kv_splits; ++i) { + if (i * num_blocks_per_split >= windowed_k_blocks) { + break; + } + ElementLSE local_max_logit = shared_storage.max_logits_slm_array[i]; + ElementLSE local_exp_sum = shared_storage.exp_sums_slm_array[i]; + + ElementLSE rescale = sycl::native::exp2(local_max_logit - global_max_logits); + + // in FMHA epilogue, it's divided by local_exp_sum, here we multiply + // back + ElementLSE adjusted_o_accum = + static_cast(Oaccum(seq_idx, idx, i * num_heads_q + head_q, l_coord)) * local_exp_sum; + acc += adjusted_o_accum * rescale; + + // update global exp sum + global_exp_sums += local_exp_sum * rescale; + } + + ElementLSE inv_global_exp_sums = 1. / global_exp_sums; + acc *= inv_global_exp_sums; + O(seq_idx, idx, head_q, l_coord) = static_cast(acc); + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace reduction +} // namespace cutlass From ab884e95e2f50181b13c6a81128bfe3c19ded154 Mon Sep 17 00:00:00 2001 From: "jiwei1.sun" Date: Fri, 13 Mar 2026 20:02:33 +0800 Subject: [PATCH 07/23] add XeFMHAFwdSplitKVKernel --- .../kernel/xe_fhma_fwd_kernel.hpp | 331 ++++++++++++++++++ 1 file changed, 331 insertions(+) diff --git a/src/sycl/kernels/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp b/src/sycl/kernels/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp index 755a3b5c..c83d3d19 100644 --- a/src/sycl/kernels/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp +++ b/src/sycl/kernels/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp @@ -1097,4 +1097,335 @@ class XeFMHAFwdSplitKVKernel { } }; +template +class XeFMHAFwdSplitKVKernel { + public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + using VariableLength = cutlass::fmha::collective::VariableLength; + static constexpr bool is_var_len = cutlass::fmha::collective::is_variable_length_v; + using CollectiveMainloop = CollectiveMainloop_; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + + using TiledMMAQK = typename CollectiveMainloop::TiledMMAQK; + using TiledMMAPV = typename CollectiveMainloop::TiledMMAPV; + using TileShapeQK = typename CollectiveMainloop::TileShapeQK; + using TileShapePV = typename CollectiveMainloop::TileShapePV; + using SubgroupLayoutQK = typename CollectiveMainloop::SubgroupLayoutQK; + using ElementQ = typename CollectiveMainloop::TensorQ::element_type; + using ElementK = typename CollectiveMainloop::TensorK::element_type; + using ElementV = typename CollectiveMainloop::TensorV::element_type; + + using StrideQ = decltype(stride(typename CollectiveMainloop::TensorQ{})); + using StrideK = decltype(stride(typename CollectiveMainloop::TensorK{})); + using StrideV = decltype(stride(typename CollectiveMainloop::TensorV{})); + + using SGPerWG = typename CollectiveMainloop::SGPerWG; + + using FragA = typename CollectiveMainloop::FragA; + using FragARow = typename CollectiveMainloop::FragARow; + + // Tile scheduler derived types + using TileScheduler = TileScheduler_; + using TileSchedulerParams = typename TileScheduler::Params; + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + using TileShapeO = typename CollectiveEpilogue::TileShapeO; + using ElementO = typename CollectiveEpilogue::TensorO::element_type; + using ElementLSE = typename CollectiveEpilogue::ElementLSE; + using StrideO = decltype(stride(typename CollectiveEpilogue::TensorO{})); + + // Kernel level shared memory storage + using MainloopSharedStorage = typename CollectiveMainloop::SharedStorage; + using EpilogueSharedStorage = typename CollectiveEpilogue::SharedStorage; + union SharedStorage { + MainloopSharedStorage mainloop; + EpilogueSharedStorage epilogue; + }; + + static constexpr int SharedStorageSize = is_empty_v ? size_t(0) : sizeof(SharedStorage); + + static constexpr int max_num_kv_splits = SGPerWG::value * intel::sg_size; + static constexpr int dpas_max_repeat_count = 8; + static constexpr bool Sink = CollectiveEpilogue::Sink; + using ElementSink = typename CollectiveEpilogue::ElementSink; + + // Device side arguments + struct KernelArguments { + ProblemShape shape; + const ElementQ* Q; + StrideQ dQ; + const ElementK* K; + StrideK dK; + const ElementV* V; + StrideV dV; + ElementO* Oaccum; + StrideO dOaccum; + ElementLSE* exp_sums; + StrideO dExp_sums; + ElementLSE* max_logits; + StrideO dMax_logits; + + const ElementSink* sm_sink; + }; + using KernelParams = KernelArguments; + + struct Arguments { + KernelArguments kernel{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + int num_kv_splits = -1; // no split by default + }; + + // Kernel entry point API + struct Params { + KernelParams kernel; + MainloopParams mainloop; + EpilogueParams epilogue; + TileSchedulerParams scheduler; + }; + + // + // Methods + // + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return { + args.kernel, + CollectiveMainloop::to_underlying_arguments(args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.epilogue, workspace), + TileScheduler::to_underlying_arguments(args.kernel.shape, args.hw_info, TileShapeO{}, args.num_kv_splits)}; + } + + static bool can_implement(Arguments const& args) { + if (!is_var_len && args.kernel.shape.seq_len_qo != 1) { + // decode only + return false; + } + + if (args.num_kv_splits > max_num_kv_splits) { + return false; + } + + return CollectiveMainloop::can_implement(args.mainloop) && CollectiveEpilogue::can_implement(args.epilogue); + } + + static int get_workspace_size(Arguments const& args) { + return 0; + } + + static cutlass::Status initialize_workspace( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + return Status::kSuccess; + } + + static dim3 get_grid_shape(Params const& params) { + return TileScheduler::template get_grid_shape(params.scheduler); + } + + static dim3 get_block_shape() { + return dim3(SGPerWG::value * intel::sg_size, 1, 1); + } + + CUTLASS_DEVICE + Shape get_sequence_length_shape(ProblemShape const& problem_shape, int const& batch) { + if constexpr (is_var_len) { + auto q_len = + cutlass::fmha::collective::apply_variable_length(Shape{problem_shape.seq_len_qo}, batch); + return Shape{get<0>(q_len), problem_shape.seq_len_kv.cumulative_length[batch]}; + } else { + return Shape{problem_shape.seq_len_qo, problem_shape.seq_len_kv}; + } + } + + CUTLASS_DEVICE + void operator()(Params const& params, char* smem_buf) { + using namespace sycl::ext::oneapi::this_work_item; + + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + auto& p = params.kernel; + ProblemShape const& s = p.shape; + int head_group_q = s.num_heads_q / s.num_heads_kv; + + int thr_id = int(ThreadIdxX()); + int sub_group_id = thr_id / intel::sg_size; + int q_sg_tile = get<0>(shape_div(TileShapeQK{}, shape(SubgroupLayoutQK{}))); + + auto cS = make_identity_tensor(take<0, 2>(TiledMMAQK{}.tile_mnk())); + auto tScS = TiledMMAQK{}.get_slice(thr_id).partition_C(cS); + auto q_offset_wi = get<0>(tScS(0)); + auto q_offset_sg = group_broadcast(sycl::ext::oneapi::this_work_item::get_sub_group(), q_offset_wi, 0); + + TileScheduler tile_scheduler{params.scheduler}; + auto num_kv_splits = params.scheduler.num_kv_splits_; + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto [blk_q, blk_v, head, idx_b, idx_kv_split] = tile_scheduler.get_block_coord(); // (Q,V,h,b,id_split) + auto blk_qv = make_coord(blk_q, blk_v); + int head_q_start = head * head_group_q; + + auto sequence_length_shape = get_sequence_length_shape(s, idx_b); + auto [seq_len_qo, seq_len_kv] = sequence_length_shape; + if (blk_q * get<0>(TileShapeQK{}) >= seq_len_qo) continue; + + auto offset = cute::min(seq_len_qo, seq_len_kv); + auto discard_seq_coord = seq_len_qo - offset; + auto full_tile_offset = seq_len_kv - offset; + int seq_coord = cute::min(seq_len_qo, (blk_q * get<0>(TileShapeQK{}) + q_offset_sg)); + + if (CollectiveMainloop::CausalMask && seq_coord < discard_seq_coord) continue; + // For decode window_size_right doesn't have effect + const int seq_len = seq_len_kv; + // For decode, all packed GQA heads are at position seq_len_kv - 1. + // Use seq_len - 1 (= seq_len_kv - 1) as the decode position for + // k_block0 to match ReduceSplitK's computation. + const int k_block0 = CollectiveMainloop::LocalMask + ? cute::max(seq_len - 1 - params.mainloop.window_size_left, 0) / get<1>(TileShapeQK{}) + : 0; + const int k_blocks = cute::ceil_div(seq_len, get<1>(TileShapeQK{})); + const int windowed_k_blocks = k_blocks - k_block0; + + int offset_q = 0, offset_k = 0, offset_v = 0, offset_o = 0; + int offset_exp_sums = 0, offset_max_logits = 0; + if constexpr (is_var_len) { + auto qo_cumulative = s.seq_len_qo.cumulative_length; + + offset_q = s.num_heads_q * s.head_size_qk * qo_cumulative[idx_b]; + offset_o = s.num_heads_q * s.head_size_vo * num_kv_splits * qo_cumulative[idx_b]; + offset_exp_sums = s.num_heads_q * num_kv_splits * qo_cumulative[idx_b]; + offset_max_logits = s.num_heads_q * num_kv_splits * qo_cumulative[idx_b]; + + // for gqa packing, seq_len_qo must be 1 + seq_len_qo = 1; + } + + // neglect seq_len_qo since it's always 1 for decode + auto batch_dim = is_var_len ? 1 : s.batch; + auto shape_Q = make_shape(head_group_q, s.head_size_qk, s.num_heads_kv, batch_dim); + // shape + auto total_seqlen_kv = params.mainloop.total_seqlen_kv; + auto shape_K = make_shape(total_seqlen_kv, s.head_size_qk, s.num_heads_kv, batch_dim); + auto shape_V = make_shape(s.head_size_vo, total_seqlen_kv, s.num_heads_kv, batch_dim); + + auto shape_O = make_shape(head_group_q, s.head_size_vo, s.num_heads_kv, num_kv_splits, batch_dim); + auto shape_exp_sums = make_shape(head_group_q, num_kv_splits, s.num_heads_kv, batch_dim); + auto shape_max_logits = make_shape(head_group_q, num_kv_splits, s.num_heads_kv, batch_dim); + auto shape_sink = make_shape(s.num_heads_kv, head_group_q); + + int num_blocks_per_split = cute::ceil_div(windowed_k_blocks, num_kv_splits); + int kv_split_offset = k_block0 + idx_kv_split * num_blocks_per_split; + int num_effective_kv_blocks = + cute::min(windowed_k_blocks - idx_kv_split * num_blocks_per_split, num_blocks_per_split); + + if (num_effective_kv_blocks <= 0) { + // no need computation + continue; + } + + auto dcQ = const_cast(p.Q + offset_q); + auto dcK = const_cast(p.K); + auto dcV = const_cast(p.V); + auto ptrO = p.Oaccum + offset_o; + auto ptrExp_sums = p.exp_sums + offset_exp_sums; + auto ptrMax_logits = p.max_logits + offset_max_logits; + + auto layout_q = make_ordered_layout(shape_Q, Step<_1, _0, _2, _3>{}); + auto layout_k = make_ordered_layout(shape_K, Step<_2, _0, _1, _3>{}); + auto layout_v = make_ordered_layout(shape_V, Step<_0, _2, _1, _3>{}); + + auto layout_o = make_ordered_layout(shape_O, Step<_1, _0, _2, _3, _4>{}); + auto layout_exp_sums = make_ordered_layout(shape_exp_sums, Step<_1, _0, _2, _3>{}); + auto layout_max_logits = make_ordered_layout(shape_max_logits, Step<_1, _0, _2, _3>{}); + auto layout_sink = make_ordered_layout(shape_sink, Step<_1, _0>{}); + + Tensor Q = make_tensor(make_gmem_ptr(dcQ), layout_q); + Tensor K = make_tensor(make_gmem_ptr(dcK), layout_k); + Tensor V = make_tensor(make_gmem_ptr(dcV), layout_v); + Tensor O = make_tensor(make_gmem_ptr(ptrO), layout_o); + Tensor exp_sums = make_tensor(make_gmem_ptr(ptrExp_sums), layout_exp_sums); + Tensor max_logits = make_tensor(make_gmem_ptr(ptrMax_logits), layout_max_logits); + Tensor sinks = make_tensor(make_gmem_ptr(const_cast(p.sm_sink)), layout_sink); + + // O accumulator types + FragA tArA; + FragARow tA_max, tA_sum; + + // Main loop + int l_coord = is_var_len ? 0 : idx_b; + + int start_blk = kv_split_offset; + int end_blk = kv_split_offset + num_effective_kv_blocks; + + CollectiveMainloop mainloop(params.mainloop, shared_storage.mainloop); + + mainloop( + Q(_, _, head, l_coord), + K(_, _, head, l_coord), + V(_, _, head, l_coord), + tArA, + tA_max, + tA_sum, + blk_qv, + idx_b, + start_blk, + end_blk, + k_blocks, + thr_id, + seq_len, + full_tile_offset, + discard_seq_coord); + + if constexpr (!is_empty_v && !is_empty_v) { + sycl::group_barrier(get_work_group<3>()); + } + + // Epilogue + CollectiveEpilogue epilogue{params.epilogue, shared_storage.epilogue}; + if constexpr (Sink) { + auto sinks_per_kv = sinks(head, _); + epilogue( + O(_, _, head, idx_kv_split, l_coord), + tArA, + tA_max, + tA_sum, + blk_qv, + thr_id, + exp_sums(_, _, head, l_coord), + max_logits(_, _, head, l_coord), + idx_kv_split, + head_group_q, + sinks_per_kv, + num_kv_splits); + } else { + epilogue( + O(_, _, head, idx_kv_split, l_coord), + tArA, + tA_max, + tA_sum, + blk_qv, + thr_id, + exp_sums(_, _, head, l_coord), + max_logits(_, _, head, l_coord), + idx_kv_split, + head_group_q, + sinks, + num_kv_splits); + } + } + } +}; + } // namespace cutlass::fmha::kernel From a732778cfa1410092904d8625592fab8c28b73b9 Mon Sep 17 00:00:00 2001 From: "jiwei1.sun" Date: Fri, 13 Mar 2026 20:28:36 +0800 Subject: [PATCH 08/23] const tensor for Q --- include/sgl_flash_kernel_ops.h | 2 +- src/sycl/flash_attention.cpp | 2 +- src/sycl/kernels/chunk_prefill/chunk_prefill_runner.hpp | 2 +- src/torch_extension_sycl.cc | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/include/sgl_flash_kernel_ops.h b/include/sgl_flash_kernel_ops.h index 75d9a18f..9b1f357f 100644 --- a/include/sgl_flash_kernel_ops.h +++ b/include/sgl_flash_kernel_ops.h @@ -43,7 +43,7 @@ limitations under the License. * From flash-attention */ std::vector mha_fwd( - at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + const at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, // h_k, d) if there is page_table. const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, diff --git a/src/sycl/flash_attention.cpp b/src/sycl/flash_attention.cpp index 729f0d65..89553dc2 100644 --- a/src/sycl/flash_attention.cpp +++ b/src/sycl/flash_attention.cpp @@ -409,7 +409,7 @@ std::vector mha_fwd( } // namespace decode std::vector mha_fwd( - at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + const at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, // h_k, d) if there is page_table. const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, diff --git a/src/sycl/kernels/chunk_prefill/chunk_prefill_runner.hpp b/src/sycl/kernels/chunk_prefill/chunk_prefill_runner.hpp index 523f1d1a..52c37a6d 100644 --- a/src/sycl/kernels/chunk_prefill/chunk_prefill_runner.hpp +++ b/src/sycl/kernels/chunk_prefill/chunk_prefill_runner.hpp @@ -431,7 +431,7 @@ inline int round_up_headdim(int head_size) { } std::vector mha_fwd( - at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + const at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, // h_k, d) if there is page_table. const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, diff --git a/src/torch_extension_sycl.cc b/src/torch_extension_sycl.cc index 4e561fef..4d52a1ab 100644 --- a/src/torch_extension_sycl.cc +++ b/src/torch_extension_sycl.cc @@ -93,7 +93,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { * From cutlass attention */ m.def( - "fwd(Tensor! q," + "fwd(Tensor q," " Tensor k," " Tensor v," " Tensor? q_v," From 151122d219874f675f9dfe2c9857eef760b7c9af Mon Sep 17 00:00:00 2001 From: "jiwei1.sun" Date: Tue, 17 Mar 2026 13:04:13 +0800 Subject: [PATCH 09/23] add split kernel --- include/sgl_flash_kernel_ops.h | 4 +- python/sgl_kernel/flash_attn.py | 4 +- src/sycl/flash_attention.cpp | 8 +- .../chunk_prefill/chunk_prefill_runner.hpp | 6 +- .../collective/xe_fmha_fwd_epilogue.hpp | 323 ++++++++++++- .../collective/xe_fmha_fwd_mainloop.hpp | 450 ++++++++++++++++++ .../kernel/xe_fhma_fwd_kernel.hpp | 332 ------------- ...reduce_split_k.h => xe_reduce_split_k.hpp} | 1 - .../kernel/xe_tile_scheduler.hpp | 124 ++++- .../xe_fmha_fwd_decode_runner.hpp | 366 +++++++++++++- src/torch_extension_sycl.cc | 1 - tests/test_flash_attention.py | 6 +- 12 files changed, 1264 insertions(+), 361 deletions(-) rename src/sycl/kernels/flash_attention_v2/kernel/{xe_reduce_split_k.h => xe_reduce_split_k.hpp} (99%) diff --git a/include/sgl_flash_kernel_ops.h b/include/sgl_flash_kernel_ops.h index 9b1f357f..a1b57703 100644 --- a/include/sgl_flash_kernel_ops.h +++ b/include/sgl_flash_kernel_ops.h @@ -43,7 +43,7 @@ limitations under the License. * From flash-attention */ std::vector mha_fwd( - const at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + const at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, // h_k, d) if there is page_table. const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, @@ -70,6 +70,6 @@ std::vector mha_fwd( float const softcap, bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 std::optional& scheduler_metadata_, // (b + 1) - int num_splits, + // int num_kv_splits, std::optional pack_gqa_, int const sm_margin); diff --git a/python/sgl_kernel/flash_attn.py b/python/sgl_kernel/flash_attn.py index c645248d..a6aa4198 100644 --- a/python/sgl_kernel/flash_attn.py +++ b/python/sgl_kernel/flash_attn.py @@ -281,7 +281,7 @@ def flash_attn_with_kvcache( softcap, rotary_interleaved, scheduler_metadata, - num_splits, + # num_splits, pack_gqa, sm_margin, ) @@ -354,7 +354,7 @@ def flash_attn_varlen_func( softcap, False, # rotary_interleaved None, # scheduler_metadata - num_splits, + # num_splits, pack_gqa, sm_margin, ) diff --git a/src/sycl/flash_attention.cpp b/src/sycl/flash_attention.cpp index 89553dc2..e087df0e 100644 --- a/src/sycl/flash_attention.cpp +++ b/src/sycl/flash_attention.cpp @@ -409,7 +409,7 @@ std::vector mha_fwd( } // namespace decode std::vector mha_fwd( - const at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + const at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, // h_k, d) if there is page_table. const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, @@ -436,7 +436,7 @@ std::vector mha_fwd( float const softcap, bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 std::optional& scheduler_metadata_, // (b + 1) - int num_splits, + // int num_kv_splits, std::optional pack_gqa_, int const sm_margin) { if (max_seqlen_q == 1 && page_table.has_value()) { @@ -466,7 +466,7 @@ std::vector mha_fwd( softcap, is_rotary_interleaved, scheduler_metadata_, - num_splits, + // num_kv_splits, pack_gqa_, sm_margin); } else { @@ -496,7 +496,7 @@ std::vector mha_fwd( softcap, is_rotary_interleaved, scheduler_metadata_, - num_splits, + // num_kv_splits, pack_gqa_, sm_margin); } diff --git a/src/sycl/kernels/chunk_prefill/chunk_prefill_runner.hpp b/src/sycl/kernels/chunk_prefill/chunk_prefill_runner.hpp index 52c37a6d..191bb202 100644 --- a/src/sycl/kernels/chunk_prefill/chunk_prefill_runner.hpp +++ b/src/sycl/kernels/chunk_prefill/chunk_prefill_runner.hpp @@ -142,7 +142,7 @@ struct Flash_fwd_params { bool is_rotary_interleaved; - int num_splits; // For split-KV version + int num_kv_splits; // For split-KV version bool pack_gqa; int* __restrict__ tile_count_semaphore; @@ -431,7 +431,7 @@ inline int round_up_headdim(int head_size) { } std::vector mha_fwd( - const at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + const at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, // h_k, d) if there is page_table. const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, @@ -458,7 +458,7 @@ std::vector mha_fwd( float const softcap, bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 std::optional& scheduler_metadata_, // (b + 1) - int num_splits, + // int num_kv_splits, std::optional pack_gqa_, int const sm_margin) { // TODO: check GPU support diff --git a/src/sycl/kernels/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp b/src/sycl/kernels/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp index a7dffec9..44681d39 100644 --- a/src/sycl/kernels/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp +++ b/src/sycl/kernels/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp @@ -139,7 +139,6 @@ class FMHAFwdEpilogue { FragARow& tA_sum, // Softmax row-wise sum accumulator QVCoord blk_qv, // WG tile indices: (q,v) int thr_id) { // Work-item ID - using namespace cute; using ElementA = typename FragA::element_type; @@ -282,4 +281,326 @@ class FMHAFwdEpilogue { } }; +template < + class CollectiveMainloop, // Attention mainloop + class TileShapeO_, // Shape of output tile, may be larger than P*V GEMM + class TensorO_, // 2D slice of global output tensor + class TensorLSE_ = void, // Optional tensor for storing intermediate exp + // sums and max logits + class TiledCopyO_ = void, // Optional TiledCopy for loading O + bool Sink_ = false> // Whether to sink softmax into epilogue +class DecodeFwdEpilogue { + public: + // + // Type Aliases + // + using TiledMMAPV = typename CollectiveMainloop::TiledMMAPV; + using TileShapePV = decltype(TiledMMAPV{}.tile_mnk()); + using TileShapeO = TileShapeO_; + using SGPerWG = decltype(product(take<1, 4>(shape(typename TiledMMAPV::ThrLayoutVMNK{})))); + + using TensorO = TensorO_; + using TensorO2D = decltype(TensorO_{}(append>(make_coord(_, _), 0))); + using ElementO = typename TensorO_::value_type; + + using TensorLSE = TensorLSE_; + using TensorLSE2D = conditional_t< + is_void_v, + void, + decltype(TensorLSE_{}(append>(make_coord(_, _), 0)))>; + using ElementLSE = conditional_t, void, typename TensorLSE_::value_type>; + + using FragA = typename CollectiveMainloop::FragA; + using FragARow = typename CollectiveMainloop::FragARow; + using ElementA = typename FragA::value_type; + + // softmax sink, same dtype + static constexpr bool Sink = Sink_; + using ElementSink = typename CollectiveMainloop::TensorQ::element_type; + + // Split k-reduced tiles between participating subgroups. + // Assumption: the A tile is contiguous. + using ReduceK = decltype(size<3>(typename TiledMMAPV::ThrLayoutVMNK{})); + + static auto reduce_sg_v_helper() { + constexpr auto v_total_sg = get<1>(SGTileShapeA{}) / intel::_SGSize{}; + constexpr auto v_avail_sg = ReduceK{} / ReduceSGQ{}; + return Int < (v_total_sg > v_avail_sg) ? cute::gcd(v_total_sg, v_avail_sg) : v_total_sg > {}; + } + + using SGTileShapeA = decltype(atuple_coshape(FragA{}.tv_layout())); + using ReduceSGQ = decltype(cute::gcd(get<0>(SGTileShapeA{}), ReduceK{})); + using ReduceSGV = decltype(reduce_sg_v_helper()); + using ReduceSGLayout = decltype(make_identity_layout(Shape{})); + + using SGTileShapeO = decltype(shape_div(take<0, 2>(SGTileShapeA{}), shape(ReduceSGLayout{}))); + + using ReduceFragA = + decltype(make_subgroup_tensor(make_layout(select<1, 0>(SGTileShapeO{}), Stride, E<0>>{}))); + using ReduceFragARow = decltype(reduce<1>(ReduceFragA{}, sycl::plus{})); + + static auto default_tiled_copy_O_helper() { + if constexpr (ReduceK{} == _1{}) + return make_block_2d_copy_D(TiledMMAPV{}, TensorO2D{}); + else + return make_block_2d_copy_D_subtiled(TiledMMAPV{}, ReduceFragA{}.tv_layout(), ReduceSGLayout{}, TensorO2D{}); + } + + using DefaultTiledCopyO = decltype(default_tiled_copy_O_helper()); + using TiledCopyO = conditional_t, DefaultTiledCopyO, TiledCopyO_>; + + // Stateless design -- no arguments or parameters. + struct Arguments {}; + struct Params {}; + + // Shared memory storage + // Note sum/max tiles are padded to 16 elements, due to limitations in CuTe + // block load infrastructure. + using AlignedSGTileA_Q = C<((size<0>(SGTileShapeA{}) + intel::sg_size - 1) / intel::sg_size) * intel::sg_size>; + + struct SharedStorageNone {}; + struct SharedStorageReduceK { + cute::array a_data; + cute::array a_sum_data, a_max_data; + }; + + using SharedStorage = conditional_t<(ReduceK{} > _1{}), SharedStorageReduceK, SharedStorageNone>; + + private: + SharedStorage& shared; + + public: + static constexpr Params to_underlying_arguments(Arguments const& args, void* /* workspace */) { + return {}; + } + + CUTLASS_HOST_DEVICE static bool can_implement(Arguments const&) { + return true; + } + + CUTLASS_HOST_DEVICE + DecodeFwdEpilogue(Params const&, SharedStorage& shared_) : shared(shared_) {} + + template + CUTLASS_DEVICE void operator()( + TensorO2D const& O, // Global O tensor: (q,v) + FragA& tArA, // O accumulator: (q,v) + FragARow& tA_max, // Softmax row-wise max accumulator + FragARow& tA_sum, // Softmax row-wise sum accumulator + QVCoord blk_qv, // WG tile indices: (q,v) + int thr_id) { // Work-item ID + using namespace cute; + using ElementA = typename FragA::element_type; + + // Reduce k-blocks of A and A_sum across WG, if needed. + auto [rA, rA_max_unused, rA_sum, active] = reduce_A(tArA, tA_max, tA_sum, thr_id); + + /* Some subgroups may not have any work to do; if so, quit early. */ + if (!active) return; + + /* Complete softmax, dividing out sums. */ + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < rA_sum.size(); i++) + rA_sum(i) = ElementA(1) / rA_sum(i); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < rA.size(); i++) + rA(i) *= broadcast<0>(rA_sum, rA, i); + + /* Tile output */ + Tensor cO = make_identity_tensor(O.shape()); // (q,v) + Tensor gO = local_tile(cO, TileShapeO{}, blk_qv); // (q,v) + + /* Prepare slices */ + TiledCopyO copy_o{O}; + auto thr_copy_o = copy_o.get_slice(thr_id); + + auto tOrO = thr_copy_o.partition_sg_fragment_S(gO); + auto tOgO = thr_copy_o.partition_D(gO); + + /* Reorder tile and write out */ + reorder(rA, tOrO); + copy(copy_o, tOrO, tOgO); + } + + // splitK version + template + CUTLASS_DEVICE void operator()( + TensorO2D const& O, // Global O tensor: (q,v) + FragA& tArA, // O accumulator: (q,v) + FragARow& tA_max, // Softmax row-wise max accumulator + FragARow& tA_sum, // Softmax row-wise sum accumulator + QVCoord blk_qv, // WG tile indices: (q,v) + int thr_id, // Work-item ID + const TensorLSE2D& exp_sums, // Global exp sum tensor + const TensorLSE2D& max_logits, // Global max logits tensor + int idx_kv_split, + int head_group_q, + TensorSink& tSink, // Sink for current head + int num_kv_splits) { + using namespace cute; + using ElementA = typename FragA::element_type; + + // Reduce k-blocks of A and A_sum across WG, if needed. + int sg_id = thr_id / intel::sg_size; + if constexpr (Sink) { + constexpr double kLog2e = 1.4426950408889634074; + if (idx_kv_split == 0 && sg_id == 0 && thr_id < head_group_q) { + tA_sum(0) += sycl::native::exp2(static_cast(tSink(thr_id) * kLog2e) - tA_max(0)); + } + } + + auto [rA, rA_max, rA_sum, active] = reduce_A(tArA, tA_max, tA_sum, thr_id); + + // store exp sum and max logits for current KV split + // assume seq_len_qo == 1 + if (thr_id < head_group_q && num_kv_splits > 1) { + exp_sums(thr_id, idx_kv_split) = rA_sum(0); + max_logits(thr_id, idx_kv_split) = rA_max(0); + } + + /* Some subgroups may not have any work to do; if so, quit early. */ + if (!active) return; + + /* Complete softmax, dividing out sums. */ + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < rA_sum.size(); i++) { + rA_sum(i) = ElementA(1) / rA_sum(i); + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < rA.size(); i++) { + rA(i) *= broadcast<0>(rA_sum, rA, i); + } + + /* Tile output */ + Tensor cO = make_identity_tensor(O.shape()); // (q,v) + Tensor gO = local_tile(cO, TileShapeO{}, blk_qv); // (q,v) + + /* Prepare slices */ + TiledCopyO copy_o{O}; + auto thr_copy_o = copy_o.get_slice(thr_id); + + auto tOrO = thr_copy_o.partition_sg_fragment_S(gO); + auto tOgO = thr_copy_o.partition_D(gO); + + /* Reorder tile and write out */ + reorder(rA, tOrO); + copy(copy_o, tOrO, tOgO); + } + + // Reduce k-blocks of A and A_sum across WG, if needed. + // Note that each k block has its own scale factor based on A_max, + // so A/A_sum contributions need to be rescaled to match. + template + CUTLASS_DEVICE decltype(auto) reduce_A( + FragA& tArA, // O accumulator: (q,v) + FragARow& tA_max, // Softmax row-wise max accumulator + FragARow& tA_sum, // Softmax row-wise sum accumulator + int thr_id) { // Work-item ID + + using namespace sycl::ext::oneapi::this_work_item; + + if constexpr (ReduceK{} == _1{}) { + ReduceFragARow rA_max; + return std::make_tuple(tArA, rA_max, tA_sum, true); + } else { + /* Identify A tile ID and k block for this subgroup. */ + auto thr_vak = group<1, 3>(TiledMMAPV{}.get_thr_layout_vmnk()).get_flat_coord(assert_uniform(thr_id)); + auto a_tile = get<1>(thr_vak); + auto k_blk = get<2>(thr_vak); + + /* Set up SLM tensors and partition A tiles among participating subgroups + */ + auto shape_A = append(append(SGTileShapeA{}, ReduceK{}), SGPerWG{} / ReduceK{}); + auto shape_A_row = make_shape(get<0>(SGTileShapeO{}), shape(ReduceSGLayout{}), ReduceK{}, SGPerWG{} / ReduceK{}); + + auto sA_layout = group<2, 4>(flat_divide(make_ordered_layout(shape_A, Step<_1, _0, _2, _3>{}), SGTileShapeO{})); + auto sA_row_stride = + make_stride(_1{}, make_stride(get<0>(shape_A_row), _0{}), AlignedSGTileA_Q{}, AlignedSGTileA_Q{} * ReduceK{}); + auto sA_row_layout = make_layout(shape_A_row, sA_row_stride); + + auto basis2 = make_basis_like(SGTileShapeO{}); + auto sA_coords = make_layout( + append(SGTileShapeO{}, shape(ReduceSGLayout{})), append(basis2, product_each(zip(SGTileShapeO{}, basis2)))); + + auto sA = make_tensor(make_smem_ptr(&shared.a_data), + sA_layout); // (q,v,rblk_dst,rblk_src,a_tile) + auto sA_max = make_tensor( + make_smem_ptr(&shared.a_max_data), + sA_row_layout); // (q,rblk_dst,rblk_src,a_tile) + auto sA_sum = make_tensor( + make_smem_ptr(&shared.a_sum_data), + sA_row_layout); // (q,rblk_dst,rblk_src,a_tile) + + /* Write my contributions to SLM. */ + copy_block_r2s(tA_max, sA_max(_, _, k_blk, a_tile)); + barrier_arrive(ScopeWorkgroup, SemanticsRelease | SemanticsWGMemory); + copy_block_r2s(tA_sum, sA_sum(_, _, k_blk, a_tile)); + copy_block_r2s(tArA, sA(_, _, _, k_blk, a_tile), sA_coords); + + bool active = (k_blk < size(ReduceSGLayout{})) || (ReduceK{} == size(ReduceSGLayout{})); // help compiler out + + /* Wait for maxima to be available, signal other data available */ + barrier_wait(ScopeWorkgroup, SemanticsAcquire | SemanticsWGMemory); + barrier_arrive(ScopeWorkgroup, SemanticsRelease | SemanticsWGMemory); + + ReduceFragA rA; + ReduceFragARow rA_sum, rA_max, rA_kmax[ReduceK{}]; + + if (active) { + /* Read A_max back from SLM and reduce. */ + CUTLASS_PRAGMA_UNROLL + for (int kr = 0; kr < ReduceK{}; kr++) { + copy_block_s2r(sA_max(_, k_blk, kr, a_tile), rA_kmax[kr]); + } + + rA_max = rA_kmax[0]; + for (int kr = 1; kr < ReduceK{}; kr++) + cute::transform(rA_max, rA_kmax[kr], rA_max, cute::max_fn{}); + + /* Calculate scale factors for aligning per-block maxima. */ + for (int kr = 0; kr < ReduceK{}; kr++) { + cute::transform( + rA_max, rA_kmax[kr], rA_kmax[kr], [](auto gmax, auto kmax) { return sycl::native::exp2(kmax - gmax); }); + } + } + + /* Wait for A/A_sum data to be available */ + barrier_wait(ScopeWorkgroup, SemanticsAcquire | SemanticsWGMemory); + + if (active) { + /* Read A/A_sum back from SLM, align scaling to new maxima, and reduce. + */ + clear(rA_sum); + + CUTLASS_PRAGMA_UNROLL + for (int kr = 0; kr < ReduceK{}; kr++) { + ReduceFragARow rA_sum_read; + copy_block_s2r(sA_sum(_, k_blk, kr, a_tile), rA_sum_read); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < rA_sum_read.size(); i++) { + rA_sum(i) += rA_sum_read(i) * rA_kmax[kr](i); + } + } + + clear(rA); + + CUTLASS_PRAGMA_UNROLL + for (int kr = 0; kr < ReduceK{}; kr++) { + ReduceFragA rA_read; + copy_block_s2r(sA(_, _, k_blk, kr, a_tile), sA_coords(_, _, 0), rA_read); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < rA_read.size(); i++) { + rA(i) += rA_read(i) * broadcast<0>(rA_kmax[kr], rA, i); + } + } + } + return std::make_tuple(rA, rA_max, rA_sum, active); + } + } +}; + } // namespace cutlass::fmha::collective diff --git a/src/sycl/kernels/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp b/src/sycl/kernels/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp index 94ddddaf..daf85a7c 100644 --- a/src/sycl/kernels/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp +++ b/src/sycl/kernels/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp @@ -546,6 +546,456 @@ struct FMHAFwdMainloop< } }; +template < + class DispatchPolicy_, + bool PagedKV_, + bool CausalMask_, + class TiledMMAQK_, // Tiling for Q*K GEMM + class TiledMMAPV_, // Tiling for P*V GEMM + int VTiles_, // # of tiles in V dimension + class TensorQ_, // Global Q/K/V tensors + class TensorK_, + class TensorV_, + class TiledCopyQ_ = void, // Optional TiledCopy for loading Q + class TiledCopyK_ = void, // Optional TiledCopy for loading K + class TiledCopyV_ = void, // Optional TiledCopy for loading V + bool LocalMask_ = false> +struct DecodeFwdMainloop { + static_assert(cutlass::detail::dependent_false, "Could not find a mainloop specialization."); +}; + +template < + int Stages, + bool PagedKV_, + bool CausalMask_, + class TiledMMAQK_, + class TiledMMAPV_, + int VTiles_, + class TensorQ_, + class TensorK_, + class TensorV_, + class TiledCopyQ_, + class TiledCopyK_, + class TiledCopyV_, + bool LocalMask_> +struct DecodeFwdMainloop< + XeDefault, + PagedKV_, + CausalMask_, + TiledMMAQK_, + TiledMMAPV_, + VTiles_, + TensorQ_, + TensorK_, + TensorV_, + TiledCopyQ_, + TiledCopyK_, + TiledCopyV_, + LocalMask_> { + // + // Type Aliases + // + using TiledMMAQK = TiledMMAQK_; + using TiledMMAPV = TiledMMAPV_; + using TileShapeQK = decltype(TiledMMAQK{}.tile_mnk()); + using TileShapePV = decltype(TiledMMAPV{}.tile_mnk()); + static constexpr int VTiles = VTiles_; + using SubgroupLayoutQK = decltype(TiledMMAQK{}.get_atom_layout_mnk()); + using SGPerWG = decltype(product(take<1, 4>(shape(typename TiledMMAQK::ThrLayoutVMNK{})))); + + using TensorQ = TensorQ_; + using TensorK = TensorK_; + using TensorV = TensorV_; + + using ElementQ = typename TensorQ::engine_type::value_type; + using ElementK = typename TensorK::engine_type::value_type; + + using TensorQ2D = decltype(TensorQ_{}(append>(make_coord(_, _), 0))); + using TensorK2D = decltype(TensorK_{}(append>(make_coord(_, _), 0))); + using TensorV2D = decltype(TensorV_{}(append>(make_coord(_, _), 0))); + + using TiledCopyQ = + conditional_t, decltype(make_block_2d_copy_A(TiledMMAQK{}, TensorQ2D{})), TiledCopyQ_>; + using TiledCopyK = + conditional_t, decltype(make_block_2d_copy_B(TiledMMAQK{}, TensorK2D{})), TiledCopyK_>; + using TiledCopyV = + conditional_t, decltype(make_block_2d_copy_B(TiledMMAPV{}, TensorV2D{})), TiledCopyV_>; + + // TODO: static_asserts on TiledMMAPV here... + + // + // Accumulator types + // + // FragS: accumulator for Q*K MMA + // FragO: accumulator for P*V MMAs. + // Note: v mode may be split into multiple pieces + // to reduce register pressure. + // Frag*Row types are reductions of the corresponding Frag* types + // over rows. + // + template + using FragC = decltype(TiledMMA{}.get_slice(0).partition_sg_fragment_C( + make_identity_tensor(select<0, 1>(TiledMMA{}.tile_mnk())))); + + using FragS = FragC; + using FragSRow = decltype(reduce<1>(FragS{}, sycl::plus{})); + using FragSCol = decltype(reduce<0>(FragS{}, sycl::plus{})); + using ElementS = typename TiledMMAQK::ValTypeD; + + using SingleFragA = FragC; // (atom val,q',v') + using FragA = expand_sg_fragment_t; // (atom val,q',v',VV) + using FragARow = decltype(reduce<1>(FragA{}, sycl::plus{})); + // static_assert(is_same_v, "dtype + // mismatched"); + using ElementA = typename TiledMMAPV::ValTypeD; + + static constexpr bool PagedKV = PagedKV_; + static constexpr bool CausalMask = CausalMask_; + static constexpr bool Fp8KV = is_any_of_v; + static constexpr bool LocalMask = LocalMask_; + + // User-facing arguments + struct Arguments { + ElementS const scale; + void* const scale_k; + void* const scale_v; + // Paged KV Cache + int const* ptr_page_table; + int page_size; + int max_pages_per_seq; + int total_seqlen_kv; + // Local Mask + int window_size_left; + int window_size_right; + }; + + // Kernel-facing parameters + using Params = Arguments; + + // SLM data + struct SharedStorage {}; + + Params params; + + // + // Methods + // + + DecodeFwdMainloop(Params const& params_, SharedStorage&) : params(params_) {} + + static constexpr Params to_underlying_arguments(Arguments const& args, void* /* workspace */) { + constexpr double kLog2e = 1.4426950408889634074; // log_2(e) + ElementS val = args.scale * static_cast(kLog2e); + return Params{ + val, + args.scale_k, + args.scale_v, + args.ptr_page_table, + args.page_size, + args.max_pages_per_seq, + args.total_seqlen_kv, + args.window_size_left, + args.window_size_right}; + } + + CUTLASS_HOST_DEVICE static bool can_implement(Arguments const&) { + return true; + } + + template + CUTLASS_DEVICE void operator()( + TensorQ2D const& Q_2D, // (q,d) + TensorK2D const& K_2D, // (k,d) + TensorV2D const& V_2D, // (d,k) + FragA& tArA, // Output accumulator (q,v) + FragARow& tA_max, // Softmax row-wise max accumulator + FragARow& tA_sum, // Softmax row-wise sum accumulator + QVCoord blk_qv, // WG tile indices: (Q,V) + int const& idx_b, // WG tile indices: (B) + int blk_k0, // K block range: [K0,K1) + int blk_k1, + int total_blk, // Total # of K blocks + int thr_id, + int seq_len, + int full_tile_offset, + int discard_seq_coord) { + using namespace sycl::ext::oneapi::this_work_item; + + // Short dimension names: + // q = sequence len dimension for Q + // k = sequence len dimension for K + // d = head size dimension for K/Q + // v = head size dimension for V + // VV = MMA tile indices for V + // Capital letters (Q, K, ...) refer to WG block indices. + // Primed letters (q', k', ...) refer to atom block indices. + + auto tile_shape_v = make_shape(get<1>(TileShapePV{}) * C{}, get<2>(TileShapePV{})); + + /* Create proxy coordinate tensors for Q/K/P/V */ + Tensor cQ = make_identity_tensor(Q_2D.shape()); // (q,d) + Tensor cK = make_identity_tensor(K_2D.shape()); // (k,d) + Tensor cV = make_identity_tensor(V_2D.shape()); // (v,k) + Tensor cP = make_identity_tensor(take<0, 2>(TileShapeQK{})); // (q,k) + + /* Partition global tensors into workgroup tiles */ + Tensor gQ = local_tile(cQ, TileShapeQK{}, append(blk_qv, _), Step<_1, X, _1>{}); // (q,d,D) + Tensor gK = local_tile(cK, TileShapeQK{}, make_coord(_, _, _), Step{}); // (k,d,K,D) + Tensor gV = local_tile(cV, tile_shape_v, make_coord(get<1>(blk_qv), _)); // (v,k,K) + Tensor gV_split = local_tile(gV, TileShapePV{}, make_coord(_, _, 0), Step{}); // (v,k,VV,K) + + /* Create global -> register copies */ + TiledCopyQ copy_q{Q_2D}; + TiledCopyK copy_k{K_2D}; + TiledCopyV copy_v{V_2D}; + + /* Create MMAs */ + TiledMMAQK mma_qk{}; + TiledMMAPV mma_pv{}; + + auto copyQ = make_block_2d_copy_A(TiledMMAQK{}, TensorQ2D{}); + + /* Slice TiledCopy/TiledMMA operations down to to work-item level */ + auto thr_copy_q = copy_q.get_slice(thr_id); + auto thr_copy_k = copy_k.get_slice(thr_id); + auto thr_copy_v = copy_v.get_slice(thr_id); + auto thr_mma_qk = mma_qk.get_slice(thr_id); + auto thr_mma_pv = mma_pv.get_slice(thr_id); + + /* Partition coordinate tensors for copy */ + auto tQgQ = thr_copy_q.partition_S(gQ); // (atom_val,q',d',D) + auto tKgK = thr_copy_k.partition_S(gK); // (atom_val,k',d',K,D) + auto tVgV = thr_copy_v.partition_S(gV_split); // (atom_val,v',k',VV,K) + + /* Create register fragments for MMA and copies */ + auto tQrQ = thr_copy_q.partition_sg_fragment_D(gQ(_, _, 0)); + auto tSrQ = thr_mma_qk.partition_sg_fragment_A(gQ(_, _, 0)); + + auto tKrK = thr_copy_k.partition_sg_fragment_D(gK(_, _, 0, 0)); + auto tSrK = thr_mma_qk.partition_sg_fragment_B(gK(_, _, 0, 0)); + + auto tSrS = thr_mma_qk.partition_sg_fragment_C(cP); + auto tArP = thr_mma_pv.partition_sg_fragment_A(cP); + + auto tVrV = thr_copy_v.partition_sg_fragment_D(gV_split(_, _, 0, 0)); + auto tArV = thr_mma_pv.partition_sg_fragment_B(gV_split(_, _, 0, 0)); + + /* Create TiledCopy objects for prefetches */ + auto prefetch_q = make_block_2d_prefetch(copy_q); + auto prefetch_k = make_block_2d_prefetch(copy_k); + auto prefetch_v = make_block_2d_prefetch(tile_shape_v, V_2D); + + /* Partition global tensors for prefetch */ + auto pQgQ = prefetch_q.get_slice(thr_id).partition_S(gQ); + auto pKgK = prefetch_k.get_slice(thr_id).partition_S(gK); + auto pVgV = prefetch_v.get_slice(thr_id).partition_S(gV); + + // ------ + // Kernel + // ------ + + // PagedKV + int tiles_per_page = params.page_size / get<1>(TileShapeQK{}); + int tile_idx = blk_k0; + int b_offset = idx_b * params.max_pages_per_seq; + if constexpr (PagedKV) { + int page_local_idx = tile_idx * get<1>(TileShapeQK{}) / params.page_size; + tile_idx = params.ptr_page_table[b_offset + page_local_idx] * tiles_per_page + tile_idx % tiles_per_page; + } + + /* Initialization steps for first block: Q/K prefetch, O init */ + /* TODO: limit D prefetch for large head size, and reorder K prefetches */ + for (int D = 0; D < size<3>(pQgQ); D++) { + prefetch(prefetch_q, pQgQ(_, _, _, D)); + } + + for (int D = 0; D < size<4>(pKgK); D++) { + prefetch(prefetch_k, pKgK(_, _, _, tile_idx, D)); + } + + clear(tArA); + fill(tA_max, cutlass::platform::numeric_limits::lowest()); + clear(tA_sum); + + /* Check if */ + bool check_remainder_k = (seq_len % get<1>(TileShapeQK{}) != 0); + + // FP8 KV Scale: Currently we only support per-tensor scale for KV + float scale_k = 1.f, scale_v = 1.f; + if constexpr (Fp8KV) { + scale_k = *static_cast(params.scale_k); + scale_v = *static_cast(params.scale_v); + } + + /* Main loop, blocked in k. */ + int next_tile_idx; + for (int K = blk_k0; K < blk_k1; K++) { + /* Split barrier to keep threads together */ + // barrier_arrive(ScopeWorkgroup); + + auto tKgK_cache = PagedKV ? tKgK(_, _, _, tile_idx, _) : tKgK(_, _, _, K, _); + auto tVgV_cache = PagedKV ? tVgV(_, _, _, _, tile_idx) : tVgV(_, _, _, _, K); + + /* GEMM 1: S = K * Q */ + clear(tSrS); /* TODO: fuse w/ initial gemm call */ + for (int D = 0; D < size<4>(tKgK); D++) { + copy(copy_q, tQgQ(_, _, _, D), tQrQ); + copy(copy_k, tKgK_cache(_, _, _, D), tKrK); + + reorder(tQrQ, tSrQ); + reorder(tKrK, tSrK); + if constexpr (Fp8KV) { + for (int i = 0; i < tSrK.size(); ++i) { + tSrK(i) = static_cast(scale_k * static_cast(tSrK(i))); + } + } + + cute::gemm(mma_qk, tSrQ, tSrK, tSrS); + } + /* V prefetch for GEMM 2 */ + prefetch(prefetch_v, pVgV(_, _, _, tile_idx)); + + /* Causal masking */ + // No Causal masking in decoding + // if constexpr (CausalMask) { + // if (K == blk_k1 - 1) { + // // Need to get global col and row indices to mask the elements + // Tensor cPgP = make_identity_tensor(make_shape(seq_len, seq_len)); + // Tensor gP = local_tile(cPgP, take<0,2>(TileShapeQK{}), + // make_coord(get<0>(blk_qv), K)); auto cS_thread = + // thr_mma_qk.partition_C(gP); CUTLASS_PRAGMA_UNROLL for (int i = 0; i + // < tSrS.size(); ++i) { + // int row_idx = get<0>(cS_thread(i)); + // int col_idx = get<1>(cS_thread(i)); + // if (col_idx - full_tile_offset > row_idx - discard_seq_coord) { + // tSrS(i) = ElementS(-INFINITY); + // } + // } + // } + // } + + /* Local/sliding window masking */ + if constexpr (LocalMask) { + // For decode, all packed GQA heads share the same KV position + // (seq_len_kv - 1). Use a fixed decode row for all elements. + int decode_row = seq_len - 1 - full_tile_offset; + Tensor cPgP = make_identity_tensor(make_shape(seq_len, seq_len)); + Tensor gP = local_tile(cPgP, take<0, 2>(TileShapeQK{}), make_coord(get<0>(blk_qv), K)); + auto cS_thread = thr_mma_qk.partition_C(gP); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < tSrS.size(); ++i) { + int col_idx = get<1>(cS_thread(i)) - full_tile_offset; + bool left_mask = col_idx < decode_row - params.window_size_left; + bool right_mask = col_idx > decode_row + params.window_size_right; + if (left_mask || right_mask) { + tSrS(i) = ElementS(-INFINITY); + } + } + } + + /* k masking for remainder tiles */ + if (check_remainder_k && K == blk_k1 - 1) { + FragSCol k_rem_mask; + int k = get<0>(tKgK(0, 0, 0, K, 0)) + get_sub_group().get_local_id()[0]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < k_rem_mask.size(); i++, k += intel::sg_size) { + k_rem_mask(i) = (k < seq_len) ? ElementS(sycl::nan(0u)) : ElementS(-INFINITY); + } + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < tSrS.size(); i++) { + tSrS(i) = sycl::fmin(tSrS(i), broadcast<1>(k_rem_mask, tSrS, i)); + } + } + + /* Apply softmax and scaling */ + softmax(K == 0, tSrS, tA_max, tA_sum, tArA); + reorder(tSrS, tArP); + + /* GEMM 2: A += P * V, split in v dimension */ + CUTLASS_PRAGMA_UNROLL + for (int VV = 0; VV < VTiles; VV++) { + copy(copy_v, tVgV_cache(_, _, _, VV), tVrV); + reorder(tVrV, tArV); + if constexpr (Fp8KV) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < tArV.size(); ++i) { + tArV(i) = static_cast(scale_v * static_cast(tArV(i))); + } + } + cute::gemm(mma_pv, tArP, tArV, tArA(_, _, _, VV)); + } + + // barrier(); + + // next tile_idx + next_tile_idx = K + 1; + if constexpr (PagedKV) { + int next_page_local_idx = next_tile_idx * get<1>(TileShapeQK{}) / params.page_size; + if (next_page_local_idx < params.max_pages_per_seq) { + next_tile_idx = + params.ptr_page_table[b_offset + next_page_local_idx] * tiles_per_page + next_tile_idx % tiles_per_page; + } else { + // set to last page + next_tile_idx = params.max_pages_per_seq * tiles_per_page - 1; + } + } + tile_idx = next_tile_idx; + + /* K prefetch */ + for (int D = 0; D < size<4>(pKgK); D++) { + prefetch(prefetch_k, pKgK(_, _, _, tile_idx, D)); + } + + // barrier_wait(ScopeWorkgroup); + } + } + + // Single step of blocked softmax. + CUTLASS_DEVICE + void softmax( + bool first_block, // First softmax block? + FragS& tS, // Softmax src/dst block + FragSRow& tS_max, // Softmax row-wise max accumulator + FragSRow& tS_sum, // Softmax row-wise sum accumulator + FragA& tA) { // O accumulator (for rescaling) + + /* Compute row-wise maxima for this block */ + auto tS_bmax = reduce<1>(tS, sycl::maximum{}); + + /* Update (scaled) maxima */ + auto tS_prev_max = tS_max; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < tS_max.size(); i++) { + tS_max(i) = sycl::max(tS_max(i), params.scale * tS_bmax(i)); + } + + /* Scale S and subtract maxima, then exponentiate */ + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < tS.size(); i++) + tS(i) = sycl::native::exp2(params.scale * tS(i) - broadcast<0>(tS_max, tS, i)); + + /* Rescale existing S sums and O accumulator */ + if (!first_block) { + FragSRow rescale; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < tS_max.size(); i++) { + rescale(i) = sycl::native::exp2(tS_prev_max(i) - tS_max(i)); + tS_sum(i) *= rescale(i); + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < tA.size(); i++) + tA(i) *= broadcast<0>(rescale, tA, i); + } + + /* Update sums */ + auto tS_bsum = reduce<1>(tS, sycl::plus{}); + for (int i = 0; i < tS_sum.size(); i++) + tS_sum(i) += tS_bsum(i); + } +}; + template CUTLASS_HOST_DEVICE constexpr auto get_sg_layout_pv(SGLayoutQK const&) { return make_layout(get<0>(SGLayoutQK{}), Layout<_1, _0>{}, get<1>(SGLayoutQK{})); diff --git a/src/sycl/kernels/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp b/src/sycl/kernels/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp index c83d3d19..2763a7e4 100644 --- a/src/sycl/kernels/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp +++ b/src/sycl/kernels/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp @@ -1096,336 +1096,4 @@ class XeFMHAFwdSplitKVKernel { } } }; - -template -class XeFMHAFwdSplitKVKernel { - public: - // - // Type Aliases - // - using ProblemShape = ProblemShape_; - using VariableLength = cutlass::fmha::collective::VariableLength; - static constexpr bool is_var_len = cutlass::fmha::collective::is_variable_length_v; - using CollectiveMainloop = CollectiveMainloop_; - using MainloopArguments = typename CollectiveMainloop::Arguments; - using MainloopParams = typename CollectiveMainloop::Params; - - using TiledMMAQK = typename CollectiveMainloop::TiledMMAQK; - using TiledMMAPV = typename CollectiveMainloop::TiledMMAPV; - using TileShapeQK = typename CollectiveMainloop::TileShapeQK; - using TileShapePV = typename CollectiveMainloop::TileShapePV; - using SubgroupLayoutQK = typename CollectiveMainloop::SubgroupLayoutQK; - using ElementQ = typename CollectiveMainloop::TensorQ::element_type; - using ElementK = typename CollectiveMainloop::TensorK::element_type; - using ElementV = typename CollectiveMainloop::TensorV::element_type; - - using StrideQ = decltype(stride(typename CollectiveMainloop::TensorQ{})); - using StrideK = decltype(stride(typename CollectiveMainloop::TensorK{})); - using StrideV = decltype(stride(typename CollectiveMainloop::TensorV{})); - - using SGPerWG = typename CollectiveMainloop::SGPerWG; - - using FragA = typename CollectiveMainloop::FragA; - using FragARow = typename CollectiveMainloop::FragARow; - - // Tile scheduler derived types - using TileScheduler = TileScheduler_; - using TileSchedulerParams = typename TileScheduler::Params; - - // Epilogue derived types - using CollectiveEpilogue = CollectiveEpilogue_; - using EpilogueArguments = typename CollectiveEpilogue::Arguments; - using EpilogueParams = typename CollectiveEpilogue::Params; - - using TileShapeO = typename CollectiveEpilogue::TileShapeO; - using ElementO = typename CollectiveEpilogue::TensorO::element_type; - using ElementLSE = typename CollectiveEpilogue::ElementLSE; - using StrideO = decltype(stride(typename CollectiveEpilogue::TensorO{})); - - // Kernel level shared memory storage - using MainloopSharedStorage = typename CollectiveMainloop::SharedStorage; - using EpilogueSharedStorage = typename CollectiveEpilogue::SharedStorage; - union SharedStorage { - MainloopSharedStorage mainloop; - EpilogueSharedStorage epilogue; - }; - - static constexpr int SharedStorageSize = is_empty_v ? size_t(0) : sizeof(SharedStorage); - - static constexpr int max_num_kv_splits = SGPerWG::value * intel::sg_size; - static constexpr int dpas_max_repeat_count = 8; - static constexpr bool Sink = CollectiveEpilogue::Sink; - using ElementSink = typename CollectiveEpilogue::ElementSink; - - // Device side arguments - struct KernelArguments { - ProblemShape shape; - const ElementQ* Q; - StrideQ dQ; - const ElementK* K; - StrideK dK; - const ElementV* V; - StrideV dV; - ElementO* Oaccum; - StrideO dOaccum; - ElementLSE* exp_sums; - StrideO dExp_sums; - ElementLSE* max_logits; - StrideO dMax_logits; - - const ElementSink* sm_sink; - }; - using KernelParams = KernelArguments; - - struct Arguments { - KernelArguments kernel{}; - MainloopArguments mainloop{}; - EpilogueArguments epilogue{}; - KernelHardwareInfo hw_info{}; - int num_kv_splits = -1; // no split by default - }; - - // Kernel entry point API - struct Params { - KernelParams kernel; - MainloopParams mainloop; - EpilogueParams epilogue; - TileSchedulerParams scheduler; - }; - - // - // Methods - // - - static Params to_underlying_arguments(Arguments const& args, void* workspace) { - return { - args.kernel, - CollectiveMainloop::to_underlying_arguments(args.mainloop, workspace), - CollectiveEpilogue::to_underlying_arguments(args.epilogue, workspace), - TileScheduler::to_underlying_arguments(args.kernel.shape, args.hw_info, TileShapeO{}, args.num_kv_splits)}; - } - - static bool can_implement(Arguments const& args) { - if (!is_var_len && args.kernel.shape.seq_len_qo != 1) { - // decode only - return false; - } - - if (args.num_kv_splits > max_num_kv_splits) { - return false; - } - - return CollectiveMainloop::can_implement(args.mainloop) && CollectiveEpilogue::can_implement(args.epilogue); - } - - static int get_workspace_size(Arguments const& args) { - return 0; - } - - static cutlass::Status initialize_workspace( - Arguments const& args, - void* workspace = nullptr, - cudaStream_t stream = nullptr, - CudaHostAdapter* cuda_adapter = nullptr) { - return Status::kSuccess; - } - - static dim3 get_grid_shape(Params const& params) { - return TileScheduler::template get_grid_shape(params.scheduler); - } - - static dim3 get_block_shape() { - return dim3(SGPerWG::value * intel::sg_size, 1, 1); - } - - CUTLASS_DEVICE - Shape get_sequence_length_shape(ProblemShape const& problem_shape, int const& batch) { - if constexpr (is_var_len) { - auto q_len = - cutlass::fmha::collective::apply_variable_length(Shape{problem_shape.seq_len_qo}, batch); - return Shape{get<0>(q_len), problem_shape.seq_len_kv.cumulative_length[batch]}; - } else { - return Shape{problem_shape.seq_len_qo, problem_shape.seq_len_kv}; - } - } - - CUTLASS_DEVICE - void operator()(Params const& params, char* smem_buf) { - using namespace sycl::ext::oneapi::this_work_item; - - SharedStorage& shared_storage = *reinterpret_cast(smem_buf); - - auto& p = params.kernel; - ProblemShape const& s = p.shape; - int head_group_q = s.num_heads_q / s.num_heads_kv; - - int thr_id = int(ThreadIdxX()); - int sub_group_id = thr_id / intel::sg_size; - int q_sg_tile = get<0>(shape_div(TileShapeQK{}, shape(SubgroupLayoutQK{}))); - - auto cS = make_identity_tensor(take<0, 2>(TiledMMAQK{}.tile_mnk())); - auto tScS = TiledMMAQK{}.get_slice(thr_id).partition_C(cS); - auto q_offset_wi = get<0>(tScS(0)); - auto q_offset_sg = group_broadcast(sycl::ext::oneapi::this_work_item::get_sub_group(), q_offset_wi, 0); - - TileScheduler tile_scheduler{params.scheduler}; - auto num_kv_splits = params.scheduler.num_kv_splits_; - - CUTLASS_PRAGMA_NO_UNROLL - for (; tile_scheduler.is_valid(); ++tile_scheduler) { - auto [blk_q, blk_v, head, idx_b, idx_kv_split] = tile_scheduler.get_block_coord(); // (Q,V,h,b,id_split) - auto blk_qv = make_coord(blk_q, blk_v); - int head_q_start = head * head_group_q; - - auto sequence_length_shape = get_sequence_length_shape(s, idx_b); - auto [seq_len_qo, seq_len_kv] = sequence_length_shape; - if (blk_q * get<0>(TileShapeQK{}) >= seq_len_qo) continue; - - auto offset = cute::min(seq_len_qo, seq_len_kv); - auto discard_seq_coord = seq_len_qo - offset; - auto full_tile_offset = seq_len_kv - offset; - int seq_coord = cute::min(seq_len_qo, (blk_q * get<0>(TileShapeQK{}) + q_offset_sg)); - - if (CollectiveMainloop::CausalMask && seq_coord < discard_seq_coord) continue; - // For decode window_size_right doesn't have effect - const int seq_len = seq_len_kv; - // For decode, all packed GQA heads are at position seq_len_kv - 1. - // Use seq_len - 1 (= seq_len_kv - 1) as the decode position for - // k_block0 to match ReduceSplitK's computation. - const int k_block0 = CollectiveMainloop::LocalMask - ? cute::max(seq_len - 1 - params.mainloop.window_size_left, 0) / get<1>(TileShapeQK{}) - : 0; - const int k_blocks = cute::ceil_div(seq_len, get<1>(TileShapeQK{})); - const int windowed_k_blocks = k_blocks - k_block0; - - int offset_q = 0, offset_k = 0, offset_v = 0, offset_o = 0; - int offset_exp_sums = 0, offset_max_logits = 0; - if constexpr (is_var_len) { - auto qo_cumulative = s.seq_len_qo.cumulative_length; - - offset_q = s.num_heads_q * s.head_size_qk * qo_cumulative[idx_b]; - offset_o = s.num_heads_q * s.head_size_vo * num_kv_splits * qo_cumulative[idx_b]; - offset_exp_sums = s.num_heads_q * num_kv_splits * qo_cumulative[idx_b]; - offset_max_logits = s.num_heads_q * num_kv_splits * qo_cumulative[idx_b]; - - // for gqa packing, seq_len_qo must be 1 - seq_len_qo = 1; - } - - // neglect seq_len_qo since it's always 1 for decode - auto batch_dim = is_var_len ? 1 : s.batch; - auto shape_Q = make_shape(head_group_q, s.head_size_qk, s.num_heads_kv, batch_dim); - // shape - auto total_seqlen_kv = params.mainloop.total_seqlen_kv; - auto shape_K = make_shape(total_seqlen_kv, s.head_size_qk, s.num_heads_kv, batch_dim); - auto shape_V = make_shape(s.head_size_vo, total_seqlen_kv, s.num_heads_kv, batch_dim); - - auto shape_O = make_shape(head_group_q, s.head_size_vo, s.num_heads_kv, num_kv_splits, batch_dim); - auto shape_exp_sums = make_shape(head_group_q, num_kv_splits, s.num_heads_kv, batch_dim); - auto shape_max_logits = make_shape(head_group_q, num_kv_splits, s.num_heads_kv, batch_dim); - auto shape_sink = make_shape(s.num_heads_kv, head_group_q); - - int num_blocks_per_split = cute::ceil_div(windowed_k_blocks, num_kv_splits); - int kv_split_offset = k_block0 + idx_kv_split * num_blocks_per_split; - int num_effective_kv_blocks = - cute::min(windowed_k_blocks - idx_kv_split * num_blocks_per_split, num_blocks_per_split); - - if (num_effective_kv_blocks <= 0) { - // no need computation - continue; - } - - auto dcQ = const_cast(p.Q + offset_q); - auto dcK = const_cast(p.K); - auto dcV = const_cast(p.V); - auto ptrO = p.Oaccum + offset_o; - auto ptrExp_sums = p.exp_sums + offset_exp_sums; - auto ptrMax_logits = p.max_logits + offset_max_logits; - - auto layout_q = make_ordered_layout(shape_Q, Step<_1, _0, _2, _3>{}); - auto layout_k = make_ordered_layout(shape_K, Step<_2, _0, _1, _3>{}); - auto layout_v = make_ordered_layout(shape_V, Step<_0, _2, _1, _3>{}); - - auto layout_o = make_ordered_layout(shape_O, Step<_1, _0, _2, _3, _4>{}); - auto layout_exp_sums = make_ordered_layout(shape_exp_sums, Step<_1, _0, _2, _3>{}); - auto layout_max_logits = make_ordered_layout(shape_max_logits, Step<_1, _0, _2, _3>{}); - auto layout_sink = make_ordered_layout(shape_sink, Step<_1, _0>{}); - - Tensor Q = make_tensor(make_gmem_ptr(dcQ), layout_q); - Tensor K = make_tensor(make_gmem_ptr(dcK), layout_k); - Tensor V = make_tensor(make_gmem_ptr(dcV), layout_v); - Tensor O = make_tensor(make_gmem_ptr(ptrO), layout_o); - Tensor exp_sums = make_tensor(make_gmem_ptr(ptrExp_sums), layout_exp_sums); - Tensor max_logits = make_tensor(make_gmem_ptr(ptrMax_logits), layout_max_logits); - Tensor sinks = make_tensor(make_gmem_ptr(const_cast(p.sm_sink)), layout_sink); - - // O accumulator types - FragA tArA; - FragARow tA_max, tA_sum; - - // Main loop - int l_coord = is_var_len ? 0 : idx_b; - - int start_blk = kv_split_offset; - int end_blk = kv_split_offset + num_effective_kv_blocks; - - CollectiveMainloop mainloop(params.mainloop, shared_storage.mainloop); - - mainloop( - Q(_, _, head, l_coord), - K(_, _, head, l_coord), - V(_, _, head, l_coord), - tArA, - tA_max, - tA_sum, - blk_qv, - idx_b, - start_blk, - end_blk, - k_blocks, - thr_id, - seq_len, - full_tile_offset, - discard_seq_coord); - - if constexpr (!is_empty_v && !is_empty_v) { - sycl::group_barrier(get_work_group<3>()); - } - - // Epilogue - CollectiveEpilogue epilogue{params.epilogue, shared_storage.epilogue}; - if constexpr (Sink) { - auto sinks_per_kv = sinks(head, _); - epilogue( - O(_, _, head, idx_kv_split, l_coord), - tArA, - tA_max, - tA_sum, - blk_qv, - thr_id, - exp_sums(_, _, head, l_coord), - max_logits(_, _, head, l_coord), - idx_kv_split, - head_group_q, - sinks_per_kv, - num_kv_splits); - } else { - epilogue( - O(_, _, head, idx_kv_split, l_coord), - tArA, - tA_max, - tA_sum, - blk_qv, - thr_id, - exp_sums(_, _, head, l_coord), - max_logits(_, _, head, l_coord), - idx_kv_split, - head_group_q, - sinks, - num_kv_splits); - } - } - } -}; - } // namespace cutlass::fmha::kernel diff --git a/src/sycl/kernels/flash_attention_v2/kernel/xe_reduce_split_k.h b/src/sycl/kernels/flash_attention_v2/kernel/xe_reduce_split_k.hpp similarity index 99% rename from src/sycl/kernels/flash_attention_v2/kernel/xe_reduce_split_k.h rename to src/sycl/kernels/flash_attention_v2/kernel/xe_reduce_split_k.hpp index 94452a2e..7b8ec38f 100644 --- a/src/sycl/kernels/flash_attention_v2/kernel/xe_reduce_split_k.h +++ b/src/sycl/kernels/flash_attention_v2/kernel/xe_reduce_split_k.hpp @@ -240,7 +240,6 @@ class ReduceSplitK { Tensor exp_sums = make_tensor(make_gmem_ptr(ptrExp_sums), make_layout(shape_exp_sums, stride_exp_sums)); Tensor max_logits = make_tensor(make_gmem_ptr(ptrMax_logits), make_layout(shape_max_logits, stride_max_logits)); - int l_coord = is_var_len ? 0 : idx_b; // Step 1: reduce max logits across different partitions diff --git a/src/sycl/kernels/flash_attention_v2/kernel/xe_tile_scheduler.hpp b/src/sycl/kernels/flash_attention_v2/kernel/xe_tile_scheduler.hpp index edef3695..3a901d70 100644 --- a/src/sycl/kernels/flash_attention_v2/kernel/xe_tile_scheduler.hpp +++ b/src/sycl/kernels/flash_attention_v2/kernel/xe_tile_scheduler.hpp @@ -56,9 +56,9 @@ struct XeFHMAIndividualTileScheduler { using namespace cute; dim3 grid( - size(ceil_div(shape.head_size_vo, get<1>(tile_shape))), // V - size(ceil_div((int)shape.seq_len_qo, get<0>(tile_shape))), // Q - size(shape.batch * shape.num_heads_q)); // (h,b) -- split later + size(ceil_div(shape.head_size_vo, get<1>(tile_shape))), // V + size(ceil_div(shape.seq_len_qo, get<0>(tile_shape))), // Q + size(shape.batch * shape.num_heads_q)); // (h,b) -- split later return Params{grid, {shape.num_heads_q}}; } @@ -157,4 +157,122 @@ struct XeFHMAIndividualPersistentTileScheduler { } }; +struct DecodeTileScheduler { + struct Params { + dim3 grid; + FastDivmod divmod_num_heads; + FastDivmod divmod_batch; + int num_kv_splits_ = -1; + }; + + bool valid_ = true; + Params params; + + CUTLASS_DEVICE + DecodeTileScheduler(Params const& params) : params(params) {} + + template + static Params to_underlying_arguments( + ProblemShape const& shape, + KernelHardwareInfo hw_info, + TileShape const& tile_shape, + const int& num_kv_splits = -1) { + using namespace cute; + + dim3 grid( + size(ceil_div(shape.head_size_vo, get<1>(tile_shape))), // V + size(ceil_div(shape.seq_len_qo, get<0>(tile_shape))), // Q + size(shape.batch * shape.num_heads_q)); // (h,b) -- split later + int num_head = shape.num_heads_q; + if (num_kv_splits >= 1) { + // for splitKV, each wg handles group query heads + grid.z = size(shape.batch * shape.num_heads_kv); + grid.z *= num_kv_splits; + num_head = shape.num_heads_kv; + } + return Params{grid, {num_head}, {shape.batch * num_head}, num_kv_splits}; + } + + template + static dim3 get_grid_shape(Params const& params) { + return params.grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return valid_; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + int idx_kv_split = BlockIdxZ(); + int head, idx_b; + + if (params.num_kv_splits_ >= 1) { + params.divmod_batch(idx_kv_split, idx_b, idx_kv_split); + params.divmod_num_heads(idx_b, head, idx_b); + return make_coord(BlockIdxY(), BlockIdxX(), head, idx_b, idx_kv_split); + } + + idx_b = idx_kv_split; + params.divmod_num_heads(idx_b, head, idx_b); + return make_coord(BlockIdxY(), BlockIdxX(), head, idx_b, (int)-1); + } + + CUTLASS_DEVICE + DecodeTileScheduler& operator++() { + valid_ = false; + return *this; + } +}; + +struct XeReduceSplitKTileScheduler { + struct Params { + dim3 grid; + FastDivmod divmod_num_heads; + int num_kv_splits = -1; + }; + + bool valid_ = true; + Params params; + + CUTLASS_DEVICE + XeReduceSplitKTileScheduler(Params const& params) : params(params) {} + + template + static Params to_underlying_arguments( + ProblemShape const& shape, + KernelHardwareInfo hw_info, + TileShape const& tile_shape, + const int& num_kv_splits = -1) { + using namespace cute; + + dim3 grid(shape.seq_len_qo, shape.num_heads_q, shape.batch); + return Params{grid, {shape.num_heads_q}, num_kv_splits}; + } + + template + static dim3 get_grid_shape(Params const& params) { + return params.grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return valid_; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + + return make_coord(BlockIdxX(), BlockIdxY(), BlockIdxZ()); + } + + CUTLASS_DEVICE + XeReduceSplitKTileScheduler& operator++() { + valid_ = false; + return *this; + } +}; } // namespace cutlass::fmha::kernel diff --git a/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp b/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp index 4c9cae96..25c6acbd 100644 --- a/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp +++ b/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp @@ -45,8 +45,8 @@ #include "sycl/comm/common.h" #include "sycl/kernels/flash_attention_v2/collective/fmha_fusion.hpp" #include "sycl/kernels/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp" +#include "sycl/kernels/flash_attention_v2/kernel/xe_reduce_split_k.hpp" #include "sycl/kernels/flash_attention_v2/kernel/xe_tile_scheduler.hpp" - using namespace cute; namespace decode { struct Arguments { @@ -57,6 +57,12 @@ struct Arguments { void* __restrict__ k_ptr; void* __restrict__ v_ptr; + void* __restrict__ k_scale_ptr = nullptr; + void* __restrict__ v_scale_ptr = nullptr; + + void* __restrict__ temp_out_ptr = nullptr; + void* __restrict__ exp_sums_ptr = nullptr; + void* __restrict__ max_logits_ptr = nullptr; // The stride between rows of the Q, K and V matrices. index_t q_batch_stride; index_t k_batch_stride; @@ -72,6 +78,7 @@ struct Arguments { // The number of heads. int h, h_k; int q_group_size = 1; + bool use_split_kv_decode = false; // The O matrix (output). void* __restrict__ o_ptr; @@ -144,7 +151,7 @@ struct Arguments { // The indices to index into the KV cache. int* __restrict__ kv_batch_idx; - // Paged KV cache + // PagedKV KV cache int* __restrict__ page_table; int max_num_pages_per_seq; index_t page_table_batch_stride; @@ -161,11 +168,13 @@ struct Arguments { // Scale factor of 1 / (1 - p_dropout). float rp_dropout; - // Local window size - int window_size_left, window_size_right; + // LocalMask window size + int window_size_left = -1; + int window_size_right = -1; // Pointer to the RNG seed (idx 0) and offset (idx 1). uint64_t* rng_state; + int num_kv_splits; // For split-KV version bool is_bf16; bool is_fp32; @@ -222,16 +231,21 @@ struct DecodeRunner { auto initialize_varlen(const Arguments& params, const ProblemShape& problem_size) { ProblemShape problem_size_for_init = problem_size; get<0>(problem_size_for_init) = 1; // concentrated batch - get<1>(problem_size_for_init) = params.h / params.q_group_size; - get<3>(problem_size_for_init) = params.total_q * params.q_group_size; + get<1>(problem_size_for_init) = params.use_split_kv_decode ? params.h : params.h_k; + get<3>(problem_size_for_init) = params.use_split_kv_decode ? params.total_q : params.total_q * params.q_group_size; get<4>(problem_size_for_init) = params.total_knew; get<5>(problem_size_for_init) = params.total_k; ProblemShapeType problem_size_for_launch{ .batch = get<0>(problem_size), - .num_heads_q = get<1>(problem_size) / params.q_group_size, + .num_heads_q = params.use_split_kv_decode ? get<1>(problem_size) : get<2>(problem_size), .num_heads_kv = get<2>(problem_size), - .seq_len_qo = {params.seqlen_q, params.total_q * params.q_group_size, nullptr, params.q_group_size}, + .seq_len_qo = + {params.use_split_kv_decode ? params.seqlen_q * params.q_group_size : params.seqlen_q, + params.use_split_kv_decode ? params.total_q : params.total_q * params.q_group_size, + nullptr, + params.use_split_kv_decode ? 1 : params.q_group_size}, + .seq_len_kv = {params.seqlen_knew, params.total_knew}, .seq_len_kv_cache = {params.seqlen_k, params.total_k}, .head_size_qk = get<6>(problem_size), @@ -326,6 +340,243 @@ struct DecodeRunner { return cutlass::Status::kSuccess; } }; + +template +struct DecodeKernelLauncher { + using StrideQ = typename FMHAKernel::StrideQ; + using StrideK = typename FMHAKernel::StrideK; + using StrideV = typename FMHAKernel::StrideV; + using StrideO = typename FMHAKernel::StrideO; + + using ElementQ = typename FMHAKernel::ElementQ; + using ElementK = typename FMHAKernel::ElementK; + using ElementV = typename FMHAKernel::ElementV; + using ElementO = typename FMHAKernel::ElementO; + using ElementLSE = typename FMHAKernel::ElementLSE; + + using CollectiveMainloop = typename FMHAKernel::CollectiveMainloop; + using ElementS = typename CollectiveMainloop::ElementS; + + using ProblemShapeType = cutlass::fmha::kernel::FMHAProblemShape; + using ProblemShapeTypeInit = cutlass::fmha::kernel::FMHAProblemShape; + + /// Initialization + StrideQ stride_Q; + StrideK stride_K; + StrideV stride_V; + StrideO stride_O; + StrideO stride_Oaccum; + StrideO stride_exp_sums; + StrideO stride_max_logits; + + int num_kv_splits; + + ProblemShapeType initialize(const Arguments& params) { + ProblemShapeType shape; + ProblemShapeTypeInit shape_init; + auto batch = shape.batch = shape_init.batch = params.b; + auto num_heads_q = shape.num_heads_q = shape_init.num_heads_q = params.h; + auto num_heads_kv = shape.num_heads_kv = shape_init.num_heads_kv = params.h_k; + auto head_size_qk = shape.head_size_qk = shape_init.head_size_qk = params.d; + auto head_size_vo = shape.head_size_vo = shape_init.head_size_vo = params.d; + + if constexpr (isVarLen) { + batch = shape_init.batch = 1; + shape_init.seq_len_qo = params.total_q; + shape_init.seq_len_kv = params.total_k; + + shape.seq_len_qo = cutlass::fmha::collective::VariableLength{params.seqlen_q}; + shape.seq_len_qo.cumulative_length = reinterpret_cast(params.cu_seqlens_q); + shape.seq_len_kv = cutlass::fmha::collective::VariableLength{params.seqlen_k}; + shape.seq_len_kv.cumulative_length = reinterpret_cast(params.cu_seqlens_k); + } else { + shape.seq_len_qo = shape_init.seq_len_qo = params.seqlen_q; + shape.seq_len_kv = shape_init.seq_len_kv = params.seqlen_k; + } + + auto seq_len_qo = shape_init.seq_len_qo; + auto seq_len_kv = shape_init.seq_len_kv; + + num_kv_splits = params.num_kv_splits; + + stride_Q = + cutlass::make_cute_packed_stride(StrideQ{}, cute::make_shape(seq_len_qo, head_size_qk, num_heads_q, batch)); + stride_K = + cutlass::make_cute_packed_stride(StrideK{}, cute::make_shape(seq_len_kv, head_size_qk, num_heads_kv, batch)); + stride_V = + cutlass::make_cute_packed_stride(StrideV{}, cute::make_shape(head_size_vo, seq_len_kv, num_heads_kv, batch)); + stride_O = + cutlass::make_cute_packed_stride(StrideO{}, cute::make_shape(seq_len_qo, head_size_vo, num_heads_q, batch)); + stride_Oaccum = cutlass::make_cute_packed_stride( + StrideO{}, cute::make_shape(seq_len_qo, head_size_vo, num_heads_q * num_kv_splits, batch)); + + stride_exp_sums = + cutlass::make_cute_packed_stride(StrideO{}, cute::make_shape(seq_len_qo, num_kv_splits, num_heads_q, batch)); + + stride_max_logits = + cutlass::make_cute_packed_stride(StrideO{}, cute::make_shape(seq_len_qo, num_kv_splits, num_heads_q, batch)); + + return shape; + } + + cutlass::Status run(const Arguments& params, const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType shape = initialize(params); + + typename FMHAKernel::Arguments arguments{ + { + shape, + reinterpret_cast(params.q_ptr), + stride_Q, + reinterpret_cast(params.k_ptr), + stride_K, + reinterpret_cast(params.v_ptr), + stride_V, + reinterpret_cast(params.temp_out_ptr), + stride_Oaccum, + reinterpret_cast(params.exp_sums_ptr), + stride_exp_sums, + reinterpret_cast(params.max_logits_ptr), + stride_max_logits, + reinterpret_cast(params.softmax_sink_ptr), + }, + {params.softmax_scale, + params.k_scale_ptr, + params.v_scale_ptr, + static_cast(params.page_table), + params.page_size, + params.max_num_pages_per_seq, + params.total_k, + params.window_size_left, + params.window_size_right}, + {}, + hw_info, + params.num_kv_splits}; + + typename ReductionSplitKernel::Arguments reduce_arg{ + {shape, + reinterpret_cast(params.o_ptr), + stride_O, + reinterpret_cast(params.temp_out_ptr), + stride_Oaccum, + reinterpret_cast(params.exp_sums_ptr), + stride_exp_sums, + reinterpret_cast(params.max_logits_ptr), + stride_max_logits, + params.window_size_left}, + hw_info, + params.num_kv_splits}; + + // Define device-global scratch memory + size_t workspace_size = FMHAKernel::get_workspace_size(arguments); + size_t reduce_workspace_size = ReductionSplitKernel::get_workspace_size(reduce_arg); + cutlass::device_memory::allocation workspace(workspace_size + reduce_workspace_size); + + if (!FMHAKernel::can_implement(arguments)) { + // std::cout << "Invalid Problem Size: " << params.batch_size << 'x' + // << params.num_heads_q << 'x' << params.max_queries << 'x' + // << params.max_keys << 'x' << params.head_size << 'x' + // << params.head_size << std::endl; + return cutlass::Status::kErrorInvalidProblem; + } + + // Initialize the workspace + FMHAKernel::initialize_workspace(arguments, workspace.get()); + + // Convert host-side arguments to device-side arguments to be passed to the + // kernel + auto kernel_params = FMHAKernel::to_underlying_arguments(arguments, workspace.get()); + auto reduce_params = ReductionSplitKernel::to_underlying_arguments(reduce_arg, workspace.get() + workspace_size); + + ReductionSplitKernel::initialize_workspace(reduce_arg, workspace.get() + workspace_size); + run(kernel_params, reduce_params, params.num_kv_splits > 1); + + return cutlass::Status::kSuccess; + } + + static void + run(typename FMHAKernel::Params params, typename ReductionSplitKernel::Params reduce_params, bool need_reduce) { + auto stream = at::xpu::getCurrentXPUStream(); + auto q = stream.queue(); + + namespace syclex = sycl::ext::oneapi::experimental; + namespace intelex = sycl::ext::intel::experimental; + + dim3 const block = FMHAKernel::get_block_shape(); + dim3 const grid = FMHAKernel::get_grid_shape(params); + + // cute::print("Launching FMHAKernel with grid: "); cute::print("%d x %d x + // %d ", grid.x, grid.y, grid.z); cute::print("and block: "); + // cute::print("%d x %d x %d\n", block.x, block.y, block.z); + + // configure smem size and carveout + int smem_size = FMHAKernel::SharedStorageSize; + + const auto sycl_block = compat::dim3(block.x, block.y, block.z); + const auto sycl_grid = compat::dim3(grid.x, grid.y, grid.z); + + // Launch parameters depend on whether SYCL compiler supports work-group + // scratch memory extension + compat::experimental::launch_properties launch_props{ + syclex::work_group_scratch_size(smem_size), + }; + compat::experimental::kernel_properties kernel_props{ + syclex::sub_group_size, intelex::grf_size<256>}; + compat::experimental::launch_policy policy{sycl_grid, sycl_block, launch_props, kernel_props}; + + sycl::ext::oneapi::experimental::launch_config config(policy.get_range(), policy.get_launch_properties()); + auto cgf = [&](::sycl::handler& cgh) { + auto KernelFunctor = + compat::experimental::detail::build_kernel_functor>(cgh, policy, params); + sycl::ext::oneapi::experimental::detail:: + LaunchConfigAccess, decltype(policy.get_launch_properties())> + ConfigAccess(config); + cgh.parallel_for>(ConfigAccess.getRange(), ConfigAccess.getProperties(), KernelFunctor); + }; + + q.submit(cgf); + // auto event = + // compat::experimental::launch>( + // policy, queue, params); + // EventManager::getInstance().addEvent(event); + + // event.wait(); + + if (need_reduce) { + dim3 const reduce_grid = ReductionSplitKernel::get_grid_shape(reduce_params); + int reduce_smem_size = ReductionSplitKernel::SharedStorageSize; + const auto reduce_sycl_block = compat::dim3(block.x, block.y, block.z); + const auto reduce_sycl_grid = compat::dim3(reduce_grid.x, reduce_grid.y, reduce_grid.z); + compat::experimental::launch_properties launch_props_reduce{ + syclex::work_group_scratch_size(reduce_smem_size), + }; + compat::experimental::launch_policy reduce_policy{ + reduce_sycl_grid, reduce_sycl_block, launch_props_reduce, kernel_props}; + + sycl::ext::oneapi::experimental::launch_config reduce_config( + reduce_policy.get_range(), reduce_policy.get_launch_properties()); + auto cgf = [&](::sycl::handler& cgh) { + auto KernelFunctor = + compat::experimental::detail::build_kernel_functor>( + cgh, reduce_policy, reduce_params); + sycl::ext::oneapi::experimental::detail:: + LaunchConfigAccess, decltype(reduce_policy.get_launch_properties())> + ConfigAccess(reduce_config); + cgh.parallel_for>( + ConfigAccess.getRange(), ConfigAccess.getProperties(), KernelFunctor); + }; + q.submit(cgf); + + // auto reduce_event = compat::experimental::launch< + // cutlass::device_kernel>( + // reduce_policy, queue, reduce_params); + + // // reduce_event.wait(); + + // EventManager::getInstance().addEvent(reduce_event); + } + } +}; + template < bool Causal, bool LocalMask, @@ -350,7 +601,7 @@ template < typename GmemTiledCopyK = void, typename GmemTiledCopyV = void, typename GmemTiledCopyO = void> -struct FMHAConfig { +struct DecodeConfig { static constexpr int SGTileQ = get<0>(shape_div(TileShapeQK{}, shape(SubgroupLayoutQK{})))(); using MMAOperation = cute::conditional_t< is_void_v, @@ -441,4 +692,101 @@ struct FMHAConfig { } } }; + +template < + bool Causal, + bool LocalMask, + bool Sink, + typename TileShapeQK, + typename TileShapePV, + typename TileShapeOutput, + typename SubgroupLayoutQK, + typename SubgroupLayoutPV_ = void /* void -> default */, + int PipelineStages = 1, + typename ElementQ = bfloat16_t, + typename ElementK = bfloat16_t, + typename ElementV = bfloat16_t, + typename ElementO = bfloat16_t, + typename MMAOperation_ = void, /* void -> default */ + typename StrideQ = Stride, + typename StrideK = Stride, + typename StrideV = Stride<_1, int, int, int>, + typename StrideO = Stride, + typename StrideOaccum = Stride, + typename GmemTiledCopyQ = void, /* void -> default block 2D */ + typename GmemTiledCopyK = void, + typename GmemTiledCopyV = void, + typename GmemTiledCopyO = void> +struct SplitDeodeConfig { + static constexpr int SGTileQ = get<0>(shape_div(TileShapeQK{}, shape(SubgroupLayoutQK{})))(); + using MMAOperation = + cute::conditional_t, XE_DPAS_TT, MMAOperation_>; + using SubgroupLayoutPV = cute::conditional_t< + is_void_v, + decltype(cutlass::fmha::collective::get_sg_layout_pv(SubgroupLayoutQK{})), + SubgroupLayoutPV_>; + + template + static void run(const Arguments& params) { + // constexpr bool isVarLen = true; + // constexpr bool PagedKV = true; + cutlass::KernelHardwareInfo hw_info; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + using ProblemShapeType = cutlass::fmha::kernel::FMHAProblemShape; + + using TiledMMAQK = typename TiledMMAHelper, Layout, SubgroupLayoutQK>::TiledMMA; + using TiledMMAPV = typename TiledMMAHelper, Layout, SubgroupLayoutPV>::TiledMMA; + + static_assert( + get<0>(TileShapeOutput{}) == get<0>(TileShapePV{}), + "Output tile and P*V tile have different sizes in Q dimension"); + constexpr int VTiles = get<1>(TileShapeOutput{}) / get<1>(TileShapePV{}); + + auto make_dummy_tensor = [&](auto val, auto stride) { + return make_tensor(make_gmem_ptr(&val), make_layout(repeat>(1), stride)); + }; + + using TensorQ = decltype(make_dummy_tensor(ElementQ{}, StrideQ{})); + using TensorK = decltype(make_dummy_tensor(ElementK{}, StrideK{})); + using TensorV = decltype(make_dummy_tensor(ElementV{}, StrideV{})); + using TensorO = decltype(make_dummy_tensor(ElementO{}, StrideOaccum{})); + using TensorLSE = decltype(make_dummy_tensor(float{}, StrideO{})); + + // Mainloop + using MainloopDispatchPolicy = cutlass::fmha::XeDefault; + using CollectiveMainloop = cutlass::fmha::collective::DecodeFwdMainloop< + MainloopDispatchPolicy, + PagedKV, + Causal, + TiledMMAQK, + TiledMMAPV, + VTiles, + TensorQ, + TensorK, + TensorV, + GmemTiledCopyQ, + GmemTiledCopyK, + GmemTiledCopyV, + LocalMask>; + + // Epilogue + using CollectiveEpilogue = cutlass::fmha::collective:: + DecodeFwdEpilogue; + + using FMHAKernel = cutlass::fmha::kernel:: + XeFMHAFwdSplitKVKernel; + + using ReduceSplitKernel = cutlass::reduction::kernel:: + ReduceSplitK; + + DecodeKernelLauncher launcher; + + launcher.run(params, hw_info); + } + + static void kernel_dispatch(const Arguments& params) { + return run(params); + } +}; } // namespace decode diff --git a/src/torch_extension_sycl.cc b/src/torch_extension_sycl.cc index 4d52a1ab..7d68caf2 100644 --- a/src/torch_extension_sycl.cc +++ b/src/torch_extension_sycl.cc @@ -118,7 +118,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { " float softcap," " bool is_rotary_interleaved," " Tensor? scheduler_metadata," - " int num_splits," " bool? pack_gqa," " int sm_margin) -> Tensor[]"); m.impl("fwd", torch::kXPU, make_pytorch_shim(&mha_fwd)); diff --git a/tests/test_flash_attention.py b/tests/test_flash_attention.py index 25acc6d2..7ec19d71 100644 --- a/tests/test_flash_attention.py +++ b/tests/test_flash_attention.py @@ -1004,7 +1004,7 @@ def test_flash_attn_kvcache( @pytest.mark.parametrize("has_leftpad", [False]) @pytest.mark.parametrize("has_batch_idx", [False]) @pytest.mark.parametrize("varlen_q", [True]) -@pytest.mark.parametrize("d", [64, 128]) +@pytest.mark.parametrize("d", [64]) @pytest.mark.parametrize("seqlen_q", [1]) @pytest.mark.parametrize( "seqlen_k", @@ -1189,9 +1189,9 @@ def test_flash_attn_decode_kvcache( dtype_ref, ) cache_seqlens = torch.randint( - seqlen_q, - # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough seqlen_k, + # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough + seqlen_k + 1, (batch_size,), dtype=torch.int32, device=device, From 2bb49aeabb4c0a4c3cb5913f68837a674b40e3f2 Mon Sep 17 00:00:00 2001 From: "jiwei1.sun" Date: Tue, 17 Mar 2026 13:36:34 +0800 Subject: [PATCH 10/23] save --- tests/test_flash_attention.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_flash_attention.py b/tests/test_flash_attention.py index 7ec19d71..7cc8178b 100644 --- a/tests/test_flash_attention.py +++ b/tests/test_flash_attention.py @@ -1006,6 +1006,7 @@ def test_flash_attn_kvcache( @pytest.mark.parametrize("varlen_q", [True]) @pytest.mark.parametrize("d", [64]) @pytest.mark.parametrize("seqlen_q", [1]) +@pytest.mark.parametrize("batch_size", [1, 16, 32]) @pytest.mark.parametrize( "seqlen_k", [ @@ -1021,6 +1022,7 @@ def test_flash_attn_kvcache( ], ) def test_flash_attn_decode_kvcache( + batch_size, seqlen_q, seqlen_k, d, @@ -1051,7 +1053,6 @@ def test_flash_attn_decode_kvcache( pytest.skip() # set seed torch.random.manual_seed(0) - batch_size = 16 batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 nheads = 16 nheads_k = 4 # nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) From ea443ce4cafeb30487e2de5057f1a61ab52c3ed8 Mon Sep 17 00:00:00 2001 From: "jiwei1.sun" Date: Tue, 17 Mar 2026 13:39:53 +0800 Subject: [PATCH 11/23] cache_seqlens --- tests/test_flash_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_flash_attention.py b/tests/test_flash_attention.py index 7cc8178b..5e7ce7ca 100644 --- a/tests/test_flash_attention.py +++ b/tests/test_flash_attention.py @@ -1190,7 +1190,7 @@ def test_flash_attn_decode_kvcache( dtype_ref, ) cache_seqlens = torch.randint( - seqlen_k, + seqlen_q, # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough seqlen_k + 1, (batch_size,), From e073c7a5dd7219e4a47adcb2c12d4f4e86e3b140 Mon Sep 17 00:00:00 2001 From: "jiwei1.sun" Date: Tue, 17 Mar 2026 16:57:18 +0800 Subject: [PATCH 12/23] head_dim =128 --- tests/test_flash_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_flash_attention.py b/tests/test_flash_attention.py index 5e7ce7ca..563490f5 100644 --- a/tests/test_flash_attention.py +++ b/tests/test_flash_attention.py @@ -1004,7 +1004,7 @@ def test_flash_attn_kvcache( @pytest.mark.parametrize("has_leftpad", [False]) @pytest.mark.parametrize("has_batch_idx", [False]) @pytest.mark.parametrize("varlen_q", [True]) -@pytest.mark.parametrize("d", [64]) +@pytest.mark.parametrize("d", [128]) @pytest.mark.parametrize("seqlen_q", [1]) @pytest.mark.parametrize("batch_size", [1, 16, 32]) @pytest.mark.parametrize( From 294cdcff5478d19ed37771197e9e422017eed27c Mon Sep 17 00:00:00 2001 From: "jiwei1.sun" Date: Wed, 18 Mar 2026 15:55:44 +0800 Subject: [PATCH 13/23] 2026 --- include/sgl_flash_kernel_ops.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/sgl_flash_kernel_ops.h b/include/sgl_flash_kernel_ops.h index a1b57703..d0a31d76 100644 --- a/include/sgl_flash_kernel_ops.h +++ b/include/sgl_flash_kernel_ops.h @@ -1,4 +1,4 @@ -/* Copyright 2025 SGLang Team. All Rights Reserved. +/* Copyright 2025-2026 SGLang Team. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. From b4ec20bc6661b6d8181009a5a4618b021b496e7f Mon Sep 17 00:00:00 2001 From: "jiwei1.sun" Date: Mon, 23 Mar 2026 13:17:05 +0800 Subject: [PATCH 14/23] test for mingxu --- python/sgl_kernel/flash_attn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/sgl_kernel/flash_attn.py b/python/sgl_kernel/flash_attn.py index a6aa4198..6da7c4db 100644 --- a/python/sgl_kernel/flash_attn.py +++ b/python/sgl_kernel/flash_attn.py @@ -254,7 +254,7 @@ def flash_attn_with_kvcache( if cache_seqlens is not None: assert cache_seqlens.size(0) + 1 == cu_seqlens_q.size(0) cu_seqlens_k = cache_seqlens - max_seqlen_k = int(cache_seqlens.max().item()) + # max_seqlen_k = int(cache_seqlens.max().item()) out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default( q, k_cache, @@ -263,7 +263,7 @@ def flash_attn_with_kvcache( cu_seqlens_q, cu_seqlens_k, max_seqlen_q, - max_seqlen_k, + 1, page_table, cache_batch_idx, cache_leftpad, From 40ba59ff4132cf35e0668fcbcbe8b143e5c5e779 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 23 Mar 2026 07:17:34 +0000 Subject: [PATCH 15/23] Initial plan From 64e9a551f3589e1b5779aef8bf177800cd203abd Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 23 Mar 2026 07:30:42 +0000 Subject: [PATCH 16/23] Rebase onto main: integrate split-KV changes into flash_attention.cpp and fix SplitDeodeConfig typo --- src/sycl/flash_attention.cpp | 228 +++++++++--------- .../xe_fmha_fwd_decode_runner.hpp | 2 +- 2 files changed, 114 insertions(+), 116 deletions(-) diff --git a/src/sycl/flash_attention.cpp b/src/sycl/flash_attention.cpp index e087df0e..a1741270 100644 --- a/src/sycl/flash_attention.cpp +++ b/src/sycl/flash_attention.cpp @@ -38,107 +38,12 @@ #include #include "kernels/chunk_prefill/chunk_prefill_runner.hpp" -#include "kernels/flash_attention_v2/xe_fmha_fwd_decode_dispatch.hpp" #include "kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp" namespace decode { -namespace { - -using launch_fn_t = void (*)(bool use_sink, const Arguments& params); - -#define LAUNCH_FN_ENTRY(QG, HD, PS) &launch_fmha_decode_##QG##_##HD##_##PS - -launch_fn_t get_launch_fn(int qg_sz, int head_dim, int page_size) { - // Dispatch table indexed by (qg_sz, head_dim, page_size). - // qg_sz index: {1->0, 2->1, 4->2, 8->3, 16->4, 32->5} - // head_dim index: {64->0, 96->1, 128->2, 192->3} - // page_size index: {32->0, 64->1, 128->2} - -#define PAGE_ENTRIES(QG, HD) \ - { LAUNCH_FN_ENTRY(QG, HD, 32), LAUNCH_FN_ENTRY(QG, HD, 64), LAUNCH_FN_ENTRY(QG, HD, 128) } - -#define HD_ENTRIES(QG) \ - { PAGE_ENTRIES(QG, 64), PAGE_ENTRIES(QG, 96), PAGE_ENTRIES(QG, 128), PAGE_ENTRIES(QG, 192) } - - static const launch_fn_t table[6][4][3] = { - HD_ENTRIES(1), - HD_ENTRIES(2), - HD_ENTRIES(4), - HD_ENTRIES(8), - HD_ENTRIES(16), - HD_ENTRIES(32), - }; - -#undef HD_ENTRIES -#undef PAGE_ENTRIES - - int qg_idx = -1; - switch (qg_sz) { - case 1: - qg_idx = 0; - break; - case 2: - qg_idx = 1; - break; - case 4: - qg_idx = 2; - break; - case 8: - qg_idx = 3; - break; - case 16: - qg_idx = 4; - break; - case 32: - qg_idx = 5; - break; - default: - return nullptr; - } - - int hd_idx = -1; - switch (head_dim) { - case 64: - hd_idx = 0; - break; - case 96: - hd_idx = 1; - break; - case 128: - hd_idx = 2; - break; - case 192: - hd_idx = 3; - break; - default: - return nullptr; - } - - int ps_idx = -1; - switch (page_size) { - case 32: - ps_idx = 0; - break; - case 64: - ps_idx = 1; - break; - case 128: - ps_idx = 2; - break; - default: - return nullptr; - } - - return table[qg_idx][hd_idx][ps_idx]; -} - -#undef LAUNCH_FN_ENTRY - -} // namespace - std::vector mha_fwd( - at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + const at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, // h_k, d) if there is page_table. const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, @@ -165,7 +70,7 @@ std::vector mha_fwd( float const softcap, bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 std::optional& scheduler_metadata_, // (b + 1) - int num_splits, + // int num_kv_splits, std::optional pack_gqa_, int const sm_margin) { auto q_type = q.scalar_type(); @@ -202,10 +107,9 @@ std::vector mha_fwd( int const max_num_pages_per_seq = page_table.value().size(1); int const num_pages = k.size(0); int const page_size = k.size(1); - int const seqlen_k = max_num_pages_per_seq * page_size; + int const seqlen_k = page_table.has_value() ? max_num_pages_per_seq * page_size : max_seqlen_k; int const total_k = num_pages * page_size; int const num_heads_k = k.size(-2); - int q_group_size = num_heads / num_heads_k; int const batch_size_k = page_table.value().size(0); float softmax_scale = softmax_scale_; @@ -248,8 +152,43 @@ std::vector mha_fwd( auto opts = q.options(); at::Tensor out; + at::Tensor temp_out; // [batch, num_kv_splits, num_head_q, seq_q, head_size] + at::Tensor exp_sums; // [batch, num_head_q, seq_q, num_kv_splits] + at::Tensor max_logits; // [batch, num_head_q, seq_q, num_kv_splits] + int num_kv_splits = 1; out = torch::empty({total_q, num_heads, head_size_v}, opts); - + Arguments params; + params.use_split_kv_decode = true; + if (params.use_split_kv_decode) { + auto get_num_splits = [](int batch_size, int num_heads_kv, int max_seqlen_k, int block_size) { + auto stream = at::xpu::getCurrentXPUStream(); + auto queue = stream.queue(); + auto device = queue.get_device(); + int num_xe_cores = device.get_info() * + device.get_info(); + int parallel_ = num_xe_cores; + int parallel_2 = num_xe_cores * 2; + int cur_parallel_d = batch_size * num_heads_kv; + int num_splits = (parallel_ + cur_parallel_d - 1) / cur_parallel_d; + if (cur_parallel_d * num_splits > parallel_ && num_splits > 1) { + num_splits = std::ceil(parallel_2 / static_cast(cur_parallel_d)) - 1; + } + + int max_splits = (max_seqlen_k + block_size - 1) / block_size; + max_splits = std::min(max_splits, parallel_); + return std::min(num_splits, max_splits); + }; + num_kv_splits = get_num_splits(batch_size, num_heads_k, seqlen_k, page_size); + temp_out = num_kv_splits == 1 + ? out + : torch::empty({total_q, num_kv_splits * num_heads, head_size_v}, q.options().device(q.device())); + + exp_sums = torch::empty({total_q, num_heads, num_kv_splits}, q.options().dtype(at::kFloat).device(q.device())); + max_logits = torch::empty({total_q, num_heads, num_kv_splits}, q.options().dtype(at::kFloat).device(q.device())); + params.temp_out_ptr = temp_out.data_ptr(); + params.exp_sums_ptr = exp_sums.data_ptr(); + params.max_logits_ptr = max_logits.data_ptr(); + } int const head_size_rounded = round_up_headdim(head_size); int const head_size_v_rounded = head_size_v == head_size ? head_size_rounded : round_up_headdim(head_size_v); @@ -261,7 +200,7 @@ std::vector mha_fwd( softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); // align with FA3 - Arguments params; + params.is_bf16 = q.dtype() == torch::kBFloat16; // Set the pointers and strides. @@ -282,6 +221,7 @@ std::vector mha_fwd( params.cu_seqlens_q = cu_seqlens_q.data_ptr(); params.cu_seqlens_k = cu_seqlens_k.data_ptr(); + params.num_kv_splits = num_kv_splits; // Softmax sum params.softmax_lse_ptr = softmax_lse.data_ptr(); @@ -291,7 +231,7 @@ std::vector mha_fwd( params.h = num_heads; params.h_k = num_heads_k; params.q_group_size = num_heads / num_heads_k; - params.seqlen_q = seqlen_q * q_group_size; + params.seqlen_q = seqlen_q; params.seqlen_k = seqlen_k; params.d = head_size; params.d_rounded = head_size_rounded; @@ -388,21 +328,79 @@ std::vector mha_fwd( at::Tensor out_accum, softmax_lse_accum; auto outaccum_type = at::ScalarType::Float; - int qg_sz = nextPowerOf2(max_seqlen_q); - TORCH_CHECK(qg_sz >= 1 && qg_sz <= 32, "Unsupported qgroup_size for decode attention: ", max_seqlen_q); - TORCH_CHECK( - params.d == 64 || params.d == 96 || params.d == 128 || params.d == 192, - "Unsupported head size for decode attention: ", - params.d); - TORCH_CHECK( - params.page_size == 32 || params.page_size == 64 || params.page_size == 128, - "Unsupported page size for decode attention: ", - params.page_size); + constexpr bool Causal = false; // The decode kernel does not support causal mode. It must be set to false. + + auto launch_kernel = [&](auto _QG_SZ, auto _HEAD_DIM, auto _PAGE_SIZE, auto _NUM_SG) { + using TileShapeQK = cute::Shape; + using TileShapePV = cute::Shape; + using TileShapeOutput = cute::Shape; + using SubgroupLayoutQK = cute::Layout>; + + AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { + AT_DISPATCH_BOOL_NO_RETURN(params.is_local, LocalMask, { + SplitDecodeConfig:: + kernel_dispatch(params); + }); + }); + }; - auto fn = get_launch_fn(qg_sz, params.d, params.page_size); - TORCH_CHECK(fn != nullptr, "No FMHA decode kernel for qg=", qg_sz, " hd=", params.d, " ps=", params.page_size); - fn(use_sink, params); + auto dispatch_page_size = [&](auto _QG_SZ, auto _HEAD_DIM) { + switch (params.page_size) { + case 32: + launch_kernel(_QG_SZ, _HEAD_DIM, _32{}, _2{}); + break; + case 64: + launch_kernel(_QG_SZ, _HEAD_DIM, _64{}, _4{}); + break; + case 128: + launch_kernel(_QG_SZ, _HEAD_DIM, _128{}, _8{}); + break; + default: + TORCH_CHECK(false, "Unsupported page size for decode attention: ", params.page_size); + } + }; + auto dispatch_q_group = [&](auto _HEAD_DIM) { + switch (nextPowerOf2(params.q_group_size)) { + case 1: + dispatch_page_size(_1{}, _HEAD_DIM); + break; + case 2: + dispatch_page_size(_2{}, _HEAD_DIM); + break; + case 4: + dispatch_page_size(_4{}, _HEAD_DIM); + break; + case 8: + dispatch_page_size(_8{}, _HEAD_DIM); + break; + case 16: + dispatch_page_size(_16{}, _HEAD_DIM); + break; + default: + TORCH_CHECK(false, "Unsupported q_group_size for decode attention: ", params.q_group_size); + } + }; + + switch (params.d) { + case 64: + dispatch_q_group(_64{}); + break; + case 96: + dispatch_q_group(_96{}); + break; + case 128: + dispatch_q_group(_128{}); + break; + case 192: + dispatch_q_group(_192{}); + break; + case 256: + dispatch_q_group(_256{}); + break; + default: + TORCH_CHECK(false, "Unsupported head size for decode attention: ", params.d); + } return {out, softmax_lse, out_accum, softmax_lse_accum}; } diff --git a/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp b/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp index 25c6acbd..afc1ee3b 100644 --- a/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp +++ b/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp @@ -717,7 +717,7 @@ template < typename GmemTiledCopyK = void, typename GmemTiledCopyV = void, typename GmemTiledCopyO = void> -struct SplitDeodeConfig { +struct SplitDecodeConfig { static constexpr int SGTileQ = get<0>(shape_div(TileShapeQK{}, shape(SubgroupLayoutQK{})))(); using MMAOperation = cute::conditional_t, XE_DPAS_TT, MMAOperation_>; From 8ce31707abcd896004e9db554a17bc2ae1888b29 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 23 Mar 2026 09:13:09 +0000 Subject: [PATCH 17/23] Rebase onto updated split_kv_decode: fix FMHAConfig undefined, add conditional dispatch between DecodeConfig and SplitDecodeConfig, comment out page_size 32 Co-authored-by: sunjiweiswift <16934286+sunjiweiswift@users.noreply.github.com> Agent-Logs-Url: https://github.com/sgl-project/sgl-kernel-xpu/sessions/82c107d1-3f61-4ce0-9444-d7f19f27a292 --- src/sycl/flash_attention.cpp | 15 ++++++++++----- src/sycl/xe_fmha_fwd_decode_kernel.cpp.in | 2 +- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/sycl/flash_attention.cpp b/src/sycl/flash_attention.cpp index a1741270..607eea6b 100644 --- a/src/sycl/flash_attention.cpp +++ b/src/sycl/flash_attention.cpp @@ -338,17 +338,22 @@ std::vector mha_fwd( AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { AT_DISPATCH_BOOL_NO_RETURN(params.is_local, LocalMask, { - SplitDecodeConfig:: - kernel_dispatch(params); + if (params.use_split_kv_decode) { + SplitDecodeConfig:: + kernel_dispatch(params); + } else { + DecodeConfig::run( + params); + } }); }); }; auto dispatch_page_size = [&](auto _QG_SZ, auto _HEAD_DIM) { switch (params.page_size) { - case 32: - launch_kernel(_QG_SZ, _HEAD_DIM, _32{}, _2{}); - break; + // case 32: + // launch_kernel(_QG_SZ, _HEAD_DIM, _32{}, _2{}); + // break; case 64: launch_kernel(_QG_SZ, _HEAD_DIM, _64{}, _4{}); break; diff --git a/src/sycl/xe_fmha_fwd_decode_kernel.cpp.in b/src/sycl/xe_fmha_fwd_decode_kernel.cpp.in index c2ac6ba2..a667e1da 100644 --- a/src/sycl/xe_fmha_fwd_decode_kernel.cpp.in +++ b/src/sycl/xe_fmha_fwd_decode_kernel.cpp.in @@ -48,7 +48,7 @@ void launch_fmha_decode_@QG_SZ@_@HEAD_DIM@_@PAGE_SIZE@(bool use_sink, const Argu AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { AT_DISPATCH_BOOL_NO_RETURN(params.is_local, LocalMask, { - FMHAConfig::run(params); + DecodeConfig::run(params); }); }); } From 6f57253b8a99de279f163dea4a48c92590863880 Mon Sep 17 00:00:00 2001 From: "jiwei1.sun" Date: Tue, 24 Mar 2026 10:50:57 +0800 Subject: [PATCH 18/23] bugfix --- src/FMHADecodeXe20.cmake | 6 +++--- src/sycl/flash_attention.cpp | 10 +++++----- .../xe_fmha_fwd_decode_dispatch.hpp | 11 +++++------ src/sycl/xe_fmha_fwd_decode_kernel.cpp.in | 3 ++- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/FMHADecodeXe20.cmake b/src/FMHADecodeXe20.cmake index fe996ee5..0c466946 100644 --- a/src/FMHADecodeXe20.cmake +++ b/src/FMHADecodeXe20.cmake @@ -2,9 +2,9 @@ # Each (QG_SZ, HEAD_DIM, PAGE_SIZE) combination is compiled as a separate # library to parallelize and speed up compilation. -set(FMHA_DECODE_QG_SIZES 1 2 4 8 16 32) -set(FMHA_DECODE_HEAD_DIMS 64 96 128 192) -set(FMHA_DECODE_PAGE_SIZES 32 64 128) +set(FMHA_DECODE_QG_SIZES 1 2 4 8 16) +set(FMHA_DECODE_HEAD_DIMS 64 96 128 192 256) +set(FMHA_DECODE_PAGE_SIZES 64 128) set(FMHA_DECODE_TEMPLATE "${CMAKE_CURRENT_SOURCE_DIR}/sycl/xe_fmha_fwd_decode_kernel.cpp.in") diff --git a/src/sycl/flash_attention.cpp b/src/sycl/flash_attention.cpp index e2c8f2ab..4d59e760 100644 --- a/src/sycl/flash_attention.cpp +++ b/src/sycl/flash_attention.cpp @@ -338,13 +338,13 @@ std::vector mha_fwd( AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { AT_DISPATCH_BOOL_NO_RETURN(params.is_local, LocalMask, { - if (params.use_split_kv_decode) { + // if (params.use_split_kv_decode) { SplitDecodeConfig::run( params); - } else { - DecodeConfig::run( - params); - } + // } else { + // DecodeConfig::run( + // params); + // } }); }); }; diff --git a/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_dispatch.hpp b/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_dispatch.hpp index 113b287b..95c91ba8 100644 --- a/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_dispatch.hpp +++ b/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_dispatch.hpp @@ -41,15 +41,14 @@ struct Arguments; // // Naming: launch_fmha_decode___ // Parameters: -// QG_SZ in {1, 2, 4, 8, 16, 32} -// HEAD_DIM in {64, 96, 128, 192} -// PAGE_SIZE in {32, 64, 128} (with NUM_SG = PAGE_SIZE / 16) +// QG_SZ in {1, 2, 4, 8, 16} +// HEAD_DIM in {64, 96, 128, 192, 256} +// PAGE_SIZE in {64, 128} (with NUM_SG = PAGE_SIZE / 16) #define DECLARE_LAUNCH_FMHA_DECODE(QG, HD, PS) \ void launch_fmha_decode_##QG##_##HD##_##PS(bool use_sink, const Arguments& params); #define DECLARE_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES(QG, HD) \ - DECLARE_LAUNCH_FMHA_DECODE(QG, HD, 32) \ DECLARE_LAUNCH_FMHA_DECODE(QG, HD, 64) \ DECLARE_LAUNCH_FMHA_DECODE(QG, HD, 128) @@ -58,13 +57,13 @@ struct Arguments; DECLARE_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES(2, HD) \ DECLARE_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES(4, HD) \ DECLARE_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES(8, HD) \ - DECLARE_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES(16, HD) \ - DECLARE_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES(32, HD) + DECLARE_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES(16, HD) DECLARE_LAUNCH_FMHA_DECODE_ALL_QG(64) DECLARE_LAUNCH_FMHA_DECODE_ALL_QG(96) DECLARE_LAUNCH_FMHA_DECODE_ALL_QG(128) DECLARE_LAUNCH_FMHA_DECODE_ALL_QG(192) +DECLARE_LAUNCH_FMHA_DECODE_ALL_QG(256) #undef DECLARE_LAUNCH_FMHA_DECODE #undef DECLARE_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES diff --git a/src/sycl/xe_fmha_fwd_decode_kernel.cpp.in b/src/sycl/xe_fmha_fwd_decode_kernel.cpp.in index a667e1da..bf1885a3 100644 --- a/src/sycl/xe_fmha_fwd_decode_kernel.cpp.in +++ b/src/sycl/xe_fmha_fwd_decode_kernel.cpp.in @@ -48,7 +48,8 @@ void launch_fmha_decode_@QG_SZ@_@HEAD_DIM@_@PAGE_SIZE@(bool use_sink, const Argu AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { AT_DISPATCH_BOOL_NO_RETURN(params.is_local, LocalMask, { - DecodeConfig::run(params); + SplitDecodeConfig::run( + params); }); }); } From 1614a4c10db67d1c002fcf8b59f06e806bd21e94 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 24 Mar 2026 03:09:31 +0000 Subject: [PATCH 19/23] Refactor dispatch to function pointer tables following GroupGemmXe20 pattern Co-authored-by: sunjiweiswift <16934286+sunjiweiswift@users.noreply.github.com> Agent-Logs-Url: https://github.com/sgl-project/sgl-kernel-xpu/sessions/5c058f18-d408-4752-aba6-2625a6e5e638 --- src/FMHADecodeXe20.cmake | 10 +- src/sycl/flash_attention.cpp | 203 +++++++++++------- .../xe_fmha_fwd_decode_dispatch.hpp | 94 +++++--- .../xe_fmha_fwd_decode_runner.hpp | 36 ++++ src/sycl/xe_fmha_fwd_decode_kernel.cpp.in | 19 +- .../xe_fmha_fwd_split_decode_kernel.cpp.in | 42 ++++ 6 files changed, 274 insertions(+), 130 deletions(-) create mode 100644 src/sycl/xe_fmha_fwd_split_decode_kernel.cpp.in diff --git a/src/FMHADecodeXe20.cmake b/src/FMHADecodeXe20.cmake index 0c466946..f8e0dd8f 100644 --- a/src/FMHADecodeXe20.cmake +++ b/src/FMHADecodeXe20.cmake @@ -9,15 +9,21 @@ set(FMHA_DECODE_PAGE_SIZES 64 128) set(FMHA_DECODE_TEMPLATE "${CMAKE_CURRENT_SOURCE_DIR}/sycl/xe_fmha_fwd_decode_kernel.cpp.in") +set(FMHA_SPLIT_DECODE_TEMPLATE + "${CMAKE_CURRENT_SOURCE_DIR}/sycl/xe_fmha_fwd_split_decode_kernel.cpp.in") + foreach(QG_SZ ${FMHA_DECODE_QG_SIZES}) foreach(HEAD_DIM ${FMHA_DECODE_HEAD_DIMS}) foreach(PAGE_SIZE ${FMHA_DECODE_PAGE_SIZES}) - math(EXPR NUM_SG "${PAGE_SIZE} / 16") - set(GENERATED_FILE "${CMAKE_CURRENT_BINARY_DIR}/sycl/xe_fmha_fwd_decode_kernel_${QG_SZ}_${HEAD_DIM}_${PAGE_SIZE}.cpp") configure_file(${FMHA_DECODE_TEMPLATE} ${GENERATED_FILE} @ONLY) list(APPEND device_cpp_common ${GENERATED_FILE}) + + set(GENERATED_SPLIT_FILE + "${CMAKE_CURRENT_BINARY_DIR}/sycl/xe_fmha_fwd_split_decode_kernel_${QG_SZ}_${HEAD_DIM}_${PAGE_SIZE}.cpp") + configure_file(${FMHA_SPLIT_DECODE_TEMPLATE} ${GENERATED_SPLIT_FILE} @ONLY) + list(APPEND device_cpp_common ${GENERATED_SPLIT_FILE}) endforeach() endforeach() endforeach() diff --git a/src/sycl/flash_attention.cpp b/src/sycl/flash_attention.cpp index 4d59e760..f34bb6b4 100644 --- a/src/sycl/flash_attention.cpp +++ b/src/sycl/flash_attention.cpp @@ -35,13 +35,122 @@ #include #include -#include - #include "kernels/chunk_prefill/chunk_prefill_runner.hpp" +#include "kernels/flash_attention_v2/xe_fmha_fwd_decode_dispatch.hpp" #include "kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp" namespace decode { +namespace { + +using launch_fn_t = void (*)(bool use_sink, const Arguments& params); + +#define LAUNCH_FN_ENTRY(QG, HD, PS) &launch_fmha_decode +#define LAUNCH_SPLIT_FN_ENTRY(QG, HD, PS) &launch_fmha_split_decode + +launch_fn_t get_launch_fn(int qg_sz, int head_dim, int page_size, bool use_split) { + // Dispatch tables indexed by (qg_sz, head_dim, page_size). + // qg_sz index: {1->0, 2->1, 4->2, 8->3, 16->4} + // head_dim index: {64->0, 96->1, 128->2, 192->3, 256->4} + // page_size index: {64->0, 128->1} + +#define PAGE_ENTRIES(QG, HD) \ + { LAUNCH_FN_ENTRY(QG, HD, 64), LAUNCH_FN_ENTRY(QG, HD, 128) } + +#define HD_ENTRIES(QG) \ + { PAGE_ENTRIES(QG, 64), PAGE_ENTRIES(QG, 96), PAGE_ENTRIES(QG, 128), PAGE_ENTRIES(QG, 192), PAGE_ENTRIES(QG, 256) } + + static const launch_fn_t decode_table[5][5][2] = { + HD_ENTRIES(1), + HD_ENTRIES(2), + HD_ENTRIES(4), + HD_ENTRIES(8), + HD_ENTRIES(16), + }; + +#undef HD_ENTRIES +#undef PAGE_ENTRIES + +#define PAGE_ENTRIES(QG, HD) \ + { LAUNCH_SPLIT_FN_ENTRY(QG, HD, 64), LAUNCH_SPLIT_FN_ENTRY(QG, HD, 128) } + +#define HD_ENTRIES(QG) \ + { PAGE_ENTRIES(QG, 64), PAGE_ENTRIES(QG, 96), PAGE_ENTRIES(QG, 128), PAGE_ENTRIES(QG, 192), PAGE_ENTRIES(QG, 256) } + + static const launch_fn_t split_decode_table[5][5][2] = { + HD_ENTRIES(1), + HD_ENTRIES(2), + HD_ENTRIES(4), + HD_ENTRIES(8), + HD_ENTRIES(16), + }; + +#undef HD_ENTRIES +#undef PAGE_ENTRIES + + int qg_idx = -1; + switch (qg_sz) { + case 1: + qg_idx = 0; + break; + case 2: + qg_idx = 1; + break; + case 4: + qg_idx = 2; + break; + case 8: + qg_idx = 3; + break; + case 16: + qg_idx = 4; + break; + default: + return nullptr; + } + + int hd_idx = -1; + switch (head_dim) { + case 64: + hd_idx = 0; + break; + case 96: + hd_idx = 1; + break; + case 128: + hd_idx = 2; + break; + case 192: + hd_idx = 3; + break; + case 256: + hd_idx = 4; + break; + default: + return nullptr; + } + + int ps_idx = -1; + switch (page_size) { + case 64: + ps_idx = 0; + break; + case 128: + ps_idx = 1; + break; + default: + return nullptr; + } + + const auto& table = use_split ? split_decode_table : decode_table; + return table[qg_idx][hd_idx][ps_idx]; +} + +#undef LAUNCH_FN_ENTRY +#undef LAUNCH_SPLIT_FN_ENTRY + +} // namespace + std::vector mha_fwd( const at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, @@ -326,86 +435,22 @@ std::vector mha_fwd( params.tensor_opts = torch::TensorOptions().dtype(torch::kUInt8).device(q.device()); at::Tensor out_accum, softmax_lse_accum; - auto outaccum_type = at::ScalarType::Float; - - constexpr bool Causal = false; // The decode kernel does not support causal mode. It must be set to false. - - auto launch_kernel = [&](auto _QG_SZ, auto _HEAD_DIM, auto _PAGE_SIZE, auto _NUM_SG) { - using TileShapeQK = cute::Shape; - using TileShapePV = cute::Shape; - using TileShapeOutput = cute::Shape; - using SubgroupLayoutQK = cute::Layout>; - - AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { - AT_DISPATCH_BOOL_NO_RETURN(params.is_local, LocalMask, { - // if (params.use_split_kv_decode) { - SplitDecodeConfig::run( - params); - // } else { - // DecodeConfig::run( - // params); - // } - }); - }); - }; - auto dispatch_page_size = [&](auto _QG_SZ, auto _HEAD_DIM) { - switch (params.page_size) { - // case 32: - // launch_kernel(_QG_SZ, _HEAD_DIM, _32{}, _2{}); - // break; - case 64: - launch_kernel(_QG_SZ, _HEAD_DIM, _64{}, _4{}); - break; - case 128: - launch_kernel(_QG_SZ, _HEAD_DIM, _128{}, _8{}); - break; - default: - TORCH_CHECK(false, "Unsupported page size for decode attention: ", params.page_size); - } - }; + int qg_sz = nextPowerOf2(params.q_group_size); + TORCH_CHECK(qg_sz >= 1 && qg_sz <= 16, "Unsupported q_group_size for decode attention: ", params.q_group_size); + TORCH_CHECK( + params.d == 64 || params.d == 96 || params.d == 128 || params.d == 192 || params.d == 256, + "Unsupported head size for decode attention: ", + params.d); + TORCH_CHECK( + params.page_size == 64 || params.page_size == 128, + "Unsupported page size for decode attention: ", + params.page_size); - auto dispatch_q_group = [&](auto _HEAD_DIM) { - switch (nextPowerOf2(params.q_group_size)) { - case 1: - dispatch_page_size(_1{}, _HEAD_DIM); - break; - case 2: - dispatch_page_size(_2{}, _HEAD_DIM); - break; - case 4: - dispatch_page_size(_4{}, _HEAD_DIM); - break; - case 8: - dispatch_page_size(_8{}, _HEAD_DIM); - break; - case 16: - dispatch_page_size(_16{}, _HEAD_DIM); - break; - default: - TORCH_CHECK(false, "Unsupported q_group_size for decode attention: ", params.q_group_size); - } - }; + auto fn = get_launch_fn(qg_sz, params.d, params.page_size, params.use_split_kv_decode); + TORCH_CHECK(fn != nullptr, "No FMHA decode kernel for qg=", qg_sz, " hd=", params.d, " ps=", params.page_size); + fn(use_sink, params); - switch (params.d) { - case 64: - dispatch_q_group(_64{}); - break; - case 96: - dispatch_q_group(_96{}); - break; - case 128: - dispatch_q_group(_128{}); - break; - case 192: - dispatch_q_group(_192{}); - break; - case 256: - dispatch_q_group(_256{}); - break; - default: - TORCH_CHECK(false, "Unsupported head size for decode attention: ", params.d); - } return {out, softmax_lse, out_accum, softmax_lse_accum}; } diff --git a/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_dispatch.hpp b/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_dispatch.hpp index 95c91ba8..9920896e 100644 --- a/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_dispatch.hpp +++ b/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_dispatch.hpp @@ -35,38 +35,68 @@ namespace decode { struct Arguments; -// Declarations for generated FMHA decode kernel launch functions. -// Each function is defined in a separate generated .cpp file from -// xe_fmha_fwd_decode_kernel.cpp.in, compiled as its own library. +// Template function declarations for FMHA decode kernel launchers. +// Each template is explicitly instantiated in a separate generated .cpp file +// (from xe_fmha_fwd_decode_kernel.cpp.in / xe_fmha_fwd_split_decode_kernel.cpp.in). // -// Naming: launch_fmha_decode___ -// Parameters: -// QG_SZ in {1, 2, 4, 8, 16} -// HEAD_DIM in {64, 96, 128, 192, 256} -// PAGE_SIZE in {64, 128} (with NUM_SG = PAGE_SIZE / 16) - -#define DECLARE_LAUNCH_FMHA_DECODE(QG, HD, PS) \ - void launch_fmha_decode_##QG##_##HD##_##PS(bool use_sink, const Arguments& params); - -#define DECLARE_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES(QG, HD) \ - DECLARE_LAUNCH_FMHA_DECODE(QG, HD, 64) \ - DECLARE_LAUNCH_FMHA_DECODE(QG, HD, 128) - -#define DECLARE_LAUNCH_FMHA_DECODE_ALL_QG(HD) \ - DECLARE_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES(1, HD) \ - DECLARE_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES(2, HD) \ - DECLARE_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES(4, HD) \ - DECLARE_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES(8, HD) \ - DECLARE_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES(16, HD) - -DECLARE_LAUNCH_FMHA_DECODE_ALL_QG(64) -DECLARE_LAUNCH_FMHA_DECODE_ALL_QG(96) -DECLARE_LAUNCH_FMHA_DECODE_ALL_QG(128) -DECLARE_LAUNCH_FMHA_DECODE_ALL_QG(192) -DECLARE_LAUNCH_FMHA_DECODE_ALL_QG(256) - -#undef DECLARE_LAUNCH_FMHA_DECODE -#undef DECLARE_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES -#undef DECLARE_LAUNCH_FMHA_DECODE_ALL_QG +// QG_SZ in {1, 2, 4, 8, 16} +// HEAD_DIM in {64, 96, 128, 192, 256} +// PAGE_SIZE in {64, 128} + +template +void launch_fmha_decode(bool use_sink, const Arguments& params); + +template +void launch_fmha_split_decode(bool use_sink, const Arguments& params); + +// Explicit instantiation declarations — tell the compiler these are compiled +// in separate translation units (generated from the .cpp.in templates). + +#define EXTERN_LAUNCH_FMHA_DECODE(QG, HD, PS) \ + extern template void launch_fmha_decode(bool, const Arguments&); + +#define EXTERN_LAUNCH_FMHA_SPLIT_DECODE(QG, HD, PS) \ + extern template void launch_fmha_split_decode(bool, const Arguments&); + +#define EXTERN_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES(QG, HD) \ + EXTERN_LAUNCH_FMHA_DECODE(QG, HD, 64) \ + EXTERN_LAUNCH_FMHA_DECODE(QG, HD, 128) + +#define EXTERN_LAUNCH_FMHA_SPLIT_DECODE_ALL_PAGE_SIZES(QG, HD) \ + EXTERN_LAUNCH_FMHA_SPLIT_DECODE(QG, HD, 64) \ + EXTERN_LAUNCH_FMHA_SPLIT_DECODE(QG, HD, 128) + +#define EXTERN_LAUNCH_FMHA_DECODE_ALL_QG(HD) \ + EXTERN_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES(1, HD) \ + EXTERN_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES(2, HD) \ + EXTERN_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES(4, HD) \ + EXTERN_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES(8, HD) \ + EXTERN_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES(16, HD) + +#define EXTERN_LAUNCH_FMHA_SPLIT_DECODE_ALL_QG(HD) \ + EXTERN_LAUNCH_FMHA_SPLIT_DECODE_ALL_PAGE_SIZES(1, HD) \ + EXTERN_LAUNCH_FMHA_SPLIT_DECODE_ALL_PAGE_SIZES(2, HD) \ + EXTERN_LAUNCH_FMHA_SPLIT_DECODE_ALL_PAGE_SIZES(4, HD) \ + EXTERN_LAUNCH_FMHA_SPLIT_DECODE_ALL_PAGE_SIZES(8, HD) \ + EXTERN_LAUNCH_FMHA_SPLIT_DECODE_ALL_PAGE_SIZES(16, HD) + +EXTERN_LAUNCH_FMHA_DECODE_ALL_QG(64) +EXTERN_LAUNCH_FMHA_DECODE_ALL_QG(96) +EXTERN_LAUNCH_FMHA_DECODE_ALL_QG(128) +EXTERN_LAUNCH_FMHA_DECODE_ALL_QG(192) +EXTERN_LAUNCH_FMHA_DECODE_ALL_QG(256) + +EXTERN_LAUNCH_FMHA_SPLIT_DECODE_ALL_QG(64) +EXTERN_LAUNCH_FMHA_SPLIT_DECODE_ALL_QG(96) +EXTERN_LAUNCH_FMHA_SPLIT_DECODE_ALL_QG(128) +EXTERN_LAUNCH_FMHA_SPLIT_DECODE_ALL_QG(192) +EXTERN_LAUNCH_FMHA_SPLIT_DECODE_ALL_QG(256) + +#undef EXTERN_LAUNCH_FMHA_DECODE +#undef EXTERN_LAUNCH_FMHA_SPLIT_DECODE +#undef EXTERN_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES +#undef EXTERN_LAUNCH_FMHA_SPLIT_DECODE_ALL_PAGE_SIZES +#undef EXTERN_LAUNCH_FMHA_DECODE_ALL_QG +#undef EXTERN_LAUNCH_FMHA_SPLIT_DECODE_ALL_QG } // namespace decode diff --git a/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp b/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp index 9d03ba0f..aa3bd568 100644 --- a/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp +++ b/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp @@ -789,4 +789,40 @@ struct SplitDecodeConfig { return run(params); } }; + +// Free-function templates for use with the function-pointer dispatch table. +// Each template is explicitly instantiated in a generated .cpp file so the +// compiler only emits code for the combinations that are actually needed. + +template +void launch_fmha_decode(bool use_sink, const Arguments& params) { + constexpr bool Causal = false; + using TileShapeQK = cute::Shape, cute::Int, cute::_64>; + using TileShapePV = cute::Shape, cute::_32, cute::Int>; + using TileShapeOutput = cute::Shape, cute::Int>; + using SubgroupLayoutQK = cute::Layout, cute::_1>>; + + AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { + AT_DISPATCH_BOOL_NO_RETURN(params.is_local, LocalMask, { + DecodeConfig::run(params); + }); + }); +} + +template +void launch_fmha_split_decode(bool use_sink, const Arguments& params) { + constexpr bool Causal = false; + using TileShapeQK = cute::Shape, cute::Int, cute::_64>; + using TileShapePV = cute::Shape, cute::_32, cute::Int>; + using TileShapeOutput = cute::Shape, cute::Int>; + using SubgroupLayoutQK = cute::Layout, cute::_1>>; + + AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { + AT_DISPATCH_BOOL_NO_RETURN(params.is_local, LocalMask, { + SplitDecodeConfig::run( + params); + }); + }); +} + } // namespace decode diff --git a/src/sycl/xe_fmha_fwd_decode_kernel.cpp.in b/src/sycl/xe_fmha_fwd_decode_kernel.cpp.in index bf1885a3..259a152e 100644 --- a/src/sycl/xe_fmha_fwd_decode_kernel.cpp.in +++ b/src/sycl/xe_fmha_fwd_decode_kernel.cpp.in @@ -30,28 +30,13 @@ * **************************************************************************************************/ // Auto-generated from xe_fmha_fwd_decode_kernel.cpp.in -// Template parameters: QG_SZ=@QG_SZ@, HEAD_DIM=@HEAD_DIM@, PAGE_SIZE=@PAGE_SIZE@, NUM_SG=@NUM_SG@ +// Template parameters: QG_SZ=@QG_SZ@, HEAD_DIM=@HEAD_DIM@, PAGE_SIZE=@PAGE_SIZE@ #define SYCL_INTEL_TARGET 20 #include "sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp" namespace decode { -void launch_fmha_decode_@QG_SZ@_@HEAD_DIM@_@PAGE_SIZE@(bool use_sink, const Arguments& params) { - using namespace cute; - - constexpr bool Causal = false; - using TileShapeQK = cute::Shape, cute::Int<@PAGE_SIZE@>, cute::_64>; - using TileShapePV = cute::Shape, cute::_32, cute::Int<@PAGE_SIZE@>>; - using TileShapeOutput = cute::Shape, cute::Int<@HEAD_DIM@>>; - using SubgroupLayoutQK = cute::Layout, cute::_1>>; - - AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { - AT_DISPATCH_BOOL_NO_RETURN(params.is_local, LocalMask, { - SplitDecodeConfig::run( - params); - }); - }); -} +template void launch_fmha_decode<@QG_SZ@, @HEAD_DIM@, @PAGE_SIZE@>(bool, const Arguments&); } // namespace decode diff --git a/src/sycl/xe_fmha_fwd_split_decode_kernel.cpp.in b/src/sycl/xe_fmha_fwd_split_decode_kernel.cpp.in new file mode 100644 index 00000000..22e57582 --- /dev/null +++ b/src/sycl/xe_fmha_fwd_split_decode_kernel.cpp.in @@ -0,0 +1,42 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +// Auto-generated from xe_fmha_fwd_split_decode_kernel.cpp.in +// Template parameters: QG_SZ=@QG_SZ@, HEAD_DIM=@HEAD_DIM@, PAGE_SIZE=@PAGE_SIZE@ +#define SYCL_INTEL_TARGET 20 + +#include "sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp" + +namespace decode { + +template void launch_fmha_split_decode<@QG_SZ@, @HEAD_DIM@, @PAGE_SIZE@>(bool, const Arguments&); + +} // namespace decode From 25d22d1665278bec1cce7dce5bca68c8dc5562b7 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 24 Mar 2026 03:19:56 +0000 Subject: [PATCH 20/23] Refactor decode dispatch to struct operator() following GroupGemmXe20 pattern Co-authored-by: sunjiweiswift <16934286+sunjiweiswift@users.noreply.github.com> Agent-Logs-Url: https://github.com/sgl-project/sgl-kernel-xpu/sessions/871c6171-1314-44d4-b4f4-007ddffc6ed4 --- src/sycl/flash_attention.cpp | 5 +- .../xe_fmha_fwd_decode_dispatch.hpp | 109 +++++++++--------- .../xe_fmha_fwd_decode_runner.hpp | 47 +++----- src/sycl/xe_fmha_fwd_decode_kernel.cpp.in | 17 ++- .../xe_fmha_fwd_split_decode_kernel.cpp.in | 18 ++- 5 files changed, 104 insertions(+), 92 deletions(-) diff --git a/src/sycl/flash_attention.cpp b/src/sycl/flash_attention.cpp index f34bb6b4..eef878b7 100644 --- a/src/sycl/flash_attention.cpp +++ b/src/sycl/flash_attention.cpp @@ -37,7 +37,6 @@ #include "kernels/chunk_prefill/chunk_prefill_runner.hpp" #include "kernels/flash_attention_v2/xe_fmha_fwd_decode_dispatch.hpp" -#include "kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp" namespace decode { @@ -45,8 +44,8 @@ namespace { using launch_fn_t = void (*)(bool use_sink, const Arguments& params); -#define LAUNCH_FN_ENTRY(QG, HD, PS) &launch_fmha_decode -#define LAUNCH_SPLIT_FN_ENTRY(QG, HD, PS) &launch_fmha_split_decode +#define LAUNCH_FN_ENTRY(QG, HD, PS) &FmhaDecodeRunner::call +#define LAUNCH_SPLIT_FN_ENTRY(QG, HD, PS) &FmhaSplitDecodeRunner::call launch_fn_t get_launch_fn(int qg_sz, int head_dim, int page_size, bool use_split) { // Dispatch tables indexed by (qg_sz, head_dim, page_size). diff --git a/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_dispatch.hpp b/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_dispatch.hpp index 9920896e..19d7f3cc 100644 --- a/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_dispatch.hpp +++ b/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_dispatch.hpp @@ -31,72 +31,67 @@ **************************************************************************************************/ #pragma once -namespace decode { +#include "xe_fmha_fwd_decode_runner.hpp" -struct Arguments; +namespace decode { -// Template function declarations for FMHA decode kernel launchers. -// Each template is explicitly instantiated in a separate generated .cpp file -// (from xe_fmha_fwd_decode_kernel.cpp.in / xe_fmha_fwd_split_decode_kernel.cpp.in). +// Struct functor declarations for FMHA decode kernel launchers. +// Each template specialization is explicitly instantiated in a separate +// generated .cpp file (from xe_fmha_fwd_decode_kernel.cpp.in / +// xe_fmha_fwd_split_decode_kernel.cpp.in). // // QG_SZ in {1, 2, 4, 8, 16} // HEAD_DIM in {64, 96, 128, 192, 256} // PAGE_SIZE in {64, 128} -template -void launch_fmha_decode(bool use_sink, const Arguments& params); - -template -void launch_fmha_split_decode(bool use_sink, const Arguments& params); - // Explicit instantiation declarations — tell the compiler these are compiled // in separate translation units (generated from the .cpp.in templates). -#define EXTERN_LAUNCH_FMHA_DECODE(QG, HD, PS) \ - extern template void launch_fmha_decode(bool, const Arguments&); - -#define EXTERN_LAUNCH_FMHA_SPLIT_DECODE(QG, HD, PS) \ - extern template void launch_fmha_split_decode(bool, const Arguments&); - -#define EXTERN_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES(QG, HD) \ - EXTERN_LAUNCH_FMHA_DECODE(QG, HD, 64) \ - EXTERN_LAUNCH_FMHA_DECODE(QG, HD, 128) - -#define EXTERN_LAUNCH_FMHA_SPLIT_DECODE_ALL_PAGE_SIZES(QG, HD) \ - EXTERN_LAUNCH_FMHA_SPLIT_DECODE(QG, HD, 64) \ - EXTERN_LAUNCH_FMHA_SPLIT_DECODE(QG, HD, 128) - -#define EXTERN_LAUNCH_FMHA_DECODE_ALL_QG(HD) \ - EXTERN_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES(1, HD) \ - EXTERN_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES(2, HD) \ - EXTERN_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES(4, HD) \ - EXTERN_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES(8, HD) \ - EXTERN_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES(16, HD) - -#define EXTERN_LAUNCH_FMHA_SPLIT_DECODE_ALL_QG(HD) \ - EXTERN_LAUNCH_FMHA_SPLIT_DECODE_ALL_PAGE_SIZES(1, HD) \ - EXTERN_LAUNCH_FMHA_SPLIT_DECODE_ALL_PAGE_SIZES(2, HD) \ - EXTERN_LAUNCH_FMHA_SPLIT_DECODE_ALL_PAGE_SIZES(4, HD) \ - EXTERN_LAUNCH_FMHA_SPLIT_DECODE_ALL_PAGE_SIZES(8, HD) \ - EXTERN_LAUNCH_FMHA_SPLIT_DECODE_ALL_PAGE_SIZES(16, HD) - -EXTERN_LAUNCH_FMHA_DECODE_ALL_QG(64) -EXTERN_LAUNCH_FMHA_DECODE_ALL_QG(96) -EXTERN_LAUNCH_FMHA_DECODE_ALL_QG(128) -EXTERN_LAUNCH_FMHA_DECODE_ALL_QG(192) -EXTERN_LAUNCH_FMHA_DECODE_ALL_QG(256) - -EXTERN_LAUNCH_FMHA_SPLIT_DECODE_ALL_QG(64) -EXTERN_LAUNCH_FMHA_SPLIT_DECODE_ALL_QG(96) -EXTERN_LAUNCH_FMHA_SPLIT_DECODE_ALL_QG(128) -EXTERN_LAUNCH_FMHA_SPLIT_DECODE_ALL_QG(192) -EXTERN_LAUNCH_FMHA_SPLIT_DECODE_ALL_QG(256) - -#undef EXTERN_LAUNCH_FMHA_DECODE -#undef EXTERN_LAUNCH_FMHA_SPLIT_DECODE -#undef EXTERN_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES -#undef EXTERN_LAUNCH_FMHA_SPLIT_DECODE_ALL_PAGE_SIZES -#undef EXTERN_LAUNCH_FMHA_DECODE_ALL_QG -#undef EXTERN_LAUNCH_FMHA_SPLIT_DECODE_ALL_QG +#define EXTERN_FMHA_DECODE_RUNNER(QG, HD, PS) \ + extern template struct FmhaDecodeRunner; + +#define EXTERN_FMHA_SPLIT_DECODE_RUNNER(QG, HD, PS) \ + extern template struct FmhaSplitDecodeRunner; + +#define EXTERN_FMHA_DECODE_RUNNER_ALL_PAGE_SIZES(QG, HD) \ + EXTERN_FMHA_DECODE_RUNNER(QG, HD, 64) \ + EXTERN_FMHA_DECODE_RUNNER(QG, HD, 128) + +#define EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_PAGE_SIZES(QG, HD) \ + EXTERN_FMHA_SPLIT_DECODE_RUNNER(QG, HD, 64) \ + EXTERN_FMHA_SPLIT_DECODE_RUNNER(QG, HD, 128) + +#define EXTERN_FMHA_DECODE_RUNNER_ALL_QG(HD) \ + EXTERN_FMHA_DECODE_RUNNER_ALL_PAGE_SIZES(1, HD) \ + EXTERN_FMHA_DECODE_RUNNER_ALL_PAGE_SIZES(2, HD) \ + EXTERN_FMHA_DECODE_RUNNER_ALL_PAGE_SIZES(4, HD) \ + EXTERN_FMHA_DECODE_RUNNER_ALL_PAGE_SIZES(8, HD) \ + EXTERN_FMHA_DECODE_RUNNER_ALL_PAGE_SIZES(16, HD) + +#define EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_QG(HD) \ + EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_PAGE_SIZES(1, HD) \ + EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_PAGE_SIZES(2, HD) \ + EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_PAGE_SIZES(4, HD) \ + EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_PAGE_SIZES(8, HD) \ + EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_PAGE_SIZES(16, HD) + +EXTERN_FMHA_DECODE_RUNNER_ALL_QG(64) +EXTERN_FMHA_DECODE_RUNNER_ALL_QG(96) +EXTERN_FMHA_DECODE_RUNNER_ALL_QG(128) +EXTERN_FMHA_DECODE_RUNNER_ALL_QG(192) +EXTERN_FMHA_DECODE_RUNNER_ALL_QG(256) + +EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_QG(64) +EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_QG(96) +EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_QG(128) +EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_QG(192) +EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_QG(256) + +#undef EXTERN_FMHA_DECODE_RUNNER +#undef EXTERN_FMHA_SPLIT_DECODE_RUNNER +#undef EXTERN_FMHA_DECODE_RUNNER_ALL_PAGE_SIZES +#undef EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_PAGE_SIZES +#undef EXTERN_FMHA_DECODE_RUNNER_ALL_QG +#undef EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_QG } // namespace decode diff --git a/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp b/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp index aa3bd568..cb80521f 100644 --- a/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp +++ b/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp @@ -790,39 +790,26 @@ struct SplitDecodeConfig { } }; -// Free-function templates for use with the function-pointer dispatch table. -// Each template is explicitly instantiated in a generated .cpp file so the -// compiler only emits code for the combinations that are actually needed. +// Struct functors for use with the function-pointer dispatch table. +// operator() is declared here; each specialization's body is defined in a +// generated .cpp file (from xe_fmha_fwd_decode_kernel.cpp.in / +// xe_fmha_fwd_split_decode_kernel.cpp.in) so the compiler only emits code +// for the combinations that are actually needed. template -void launch_fmha_decode(bool use_sink, const Arguments& params) { - constexpr bool Causal = false; - using TileShapeQK = cute::Shape, cute::Int, cute::_64>; - using TileShapePV = cute::Shape, cute::_32, cute::Int>; - using TileShapeOutput = cute::Shape, cute::Int>; - using SubgroupLayoutQK = cute::Layout, cute::_1>>; - - AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { - AT_DISPATCH_BOOL_NO_RETURN(params.is_local, LocalMask, { - DecodeConfig::run(params); - }); - }); -} +struct FmhaDecodeRunner { + void operator()(bool use_sink, const Arguments& params) const; + static void call(bool use_sink, const Arguments& params) { + FmhaDecodeRunner{}(use_sink, params); + } +}; template -void launch_fmha_split_decode(bool use_sink, const Arguments& params) { - constexpr bool Causal = false; - using TileShapeQK = cute::Shape, cute::Int, cute::_64>; - using TileShapePV = cute::Shape, cute::_32, cute::Int>; - using TileShapeOutput = cute::Shape, cute::Int>; - using SubgroupLayoutQK = cute::Layout, cute::_1>>; - - AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { - AT_DISPATCH_BOOL_NO_RETURN(params.is_local, LocalMask, { - SplitDecodeConfig::run( - params); - }); - }); -} +struct FmhaSplitDecodeRunner { + void operator()(bool use_sink, const Arguments& params) const; + static void call(bool use_sink, const Arguments& params) { + FmhaSplitDecodeRunner{}(use_sink, params); + } +}; } // namespace decode diff --git a/src/sycl/xe_fmha_fwd_decode_kernel.cpp.in b/src/sycl/xe_fmha_fwd_decode_kernel.cpp.in index 259a152e..eac58cd2 100644 --- a/src/sycl/xe_fmha_fwd_decode_kernel.cpp.in +++ b/src/sycl/xe_fmha_fwd_decode_kernel.cpp.in @@ -37,6 +37,21 @@ namespace decode { -template void launch_fmha_decode<@QG_SZ@, @HEAD_DIM@, @PAGE_SIZE@>(bool, const Arguments&); +template <> +void FmhaDecodeRunner<@QG_SZ@, @HEAD_DIM@, @PAGE_SIZE@>::operator()(bool use_sink, const Arguments& params) const { + constexpr bool Causal = false; + using TileShapeQK = cute::Shape, cute::Int<@PAGE_SIZE@>, cute::_64>; + using TileShapePV = cute::Shape, cute::_32, cute::Int<@PAGE_SIZE@>>; + using TileShapeOutput = cute::Shape, cute::Int<@HEAD_DIM@>>; + using SubgroupLayoutQK = cute::Layout, cute::_1>>; + + AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { + AT_DISPATCH_BOOL_NO_RETURN(params.is_local, LocalMask, { + DecodeConfig::run(params); + }); + }); +} + +template struct FmhaDecodeRunner<@QG_SZ@, @HEAD_DIM@, @PAGE_SIZE@>; } // namespace decode diff --git a/src/sycl/xe_fmha_fwd_split_decode_kernel.cpp.in b/src/sycl/xe_fmha_fwd_split_decode_kernel.cpp.in index 22e57582..7ed51235 100644 --- a/src/sycl/xe_fmha_fwd_split_decode_kernel.cpp.in +++ b/src/sycl/xe_fmha_fwd_split_decode_kernel.cpp.in @@ -37,6 +37,22 @@ namespace decode { -template void launch_fmha_split_decode<@QG_SZ@, @HEAD_DIM@, @PAGE_SIZE@>(bool, const Arguments&); +template <> +void FmhaSplitDecodeRunner<@QG_SZ@, @HEAD_DIM@, @PAGE_SIZE@>::operator()(bool use_sink, const Arguments& params) const { + constexpr bool Causal = false; + using TileShapeQK = cute::Shape, cute::Int<@PAGE_SIZE@>, cute::_64>; + using TileShapePV = cute::Shape, cute::_32, cute::Int<@PAGE_SIZE@>>; + using TileShapeOutput = cute::Shape, cute::Int<@HEAD_DIM@>>; + using SubgroupLayoutQK = cute::Layout, cute::_1>>; + + AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { + AT_DISPATCH_BOOL_NO_RETURN(params.is_local, LocalMask, { + SplitDecodeConfig::run( + params); + }); + }); +} + +template struct FmhaSplitDecodeRunner<@QG_SZ@, @HEAD_DIM@, @PAGE_SIZE@>; } // namespace decode From d2157bb5fb1dbadc2060b3ddacddd2d9c9fb93e9 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 24 Mar 2026 03:34:03 +0000 Subject: [PATCH 21/23] Replace function pointer table with direct struct operator() dispatch (GroupGemmXe20 pattern) Co-authored-by: sunjiweiswift <16934286+sunjiweiswift@users.noreply.github.com> Agent-Logs-Url: https://github.com/sgl-project/sgl-kernel-xpu/sessions/85877884-e91a-400f-adcc-5b9a55692df9 --- src/sycl/flash_attention.cpp | 189 +++++++----------- .../xe_fmha_fwd_decode_runner.hpp | 8 +- 2 files changed, 78 insertions(+), 119 deletions(-) diff --git a/src/sycl/flash_attention.cpp b/src/sycl/flash_attention.cpp index eef878b7..5cef56c0 100644 --- a/src/sycl/flash_attention.cpp +++ b/src/sycl/flash_attention.cpp @@ -40,115 +40,77 @@ namespace decode { -namespace { - -using launch_fn_t = void (*)(bool use_sink, const Arguments& params); - -#define LAUNCH_FN_ENTRY(QG, HD, PS) &FmhaDecodeRunner::call -#define LAUNCH_SPLIT_FN_ENTRY(QG, HD, PS) &FmhaSplitDecodeRunner::call - -launch_fn_t get_launch_fn(int qg_sz, int head_dim, int page_size, bool use_split) { - // Dispatch tables indexed by (qg_sz, head_dim, page_size). - // qg_sz index: {1->0, 2->1, 4->2, 8->3, 16->4} - // head_dim index: {64->0, 96->1, 128->2, 192->3, 256->4} - // page_size index: {64->0, 128->1} - -#define PAGE_ENTRIES(QG, HD) \ - { LAUNCH_FN_ENTRY(QG, HD, 64), LAUNCH_FN_ENTRY(QG, HD, 128) } - -#define HD_ENTRIES(QG) \ - { PAGE_ENTRIES(QG, 64), PAGE_ENTRIES(QG, 96), PAGE_ENTRIES(QG, 128), PAGE_ENTRIES(QG, 192), PAGE_ENTRIES(QG, 256) } - - static const launch_fn_t decode_table[5][5][2] = { - HD_ENTRIES(1), - HD_ENTRIES(2), - HD_ENTRIES(4), - HD_ENTRIES(8), - HD_ENTRIES(16), - }; - -#undef HD_ENTRIES -#undef PAGE_ENTRIES - -#define PAGE_ENTRIES(QG, HD) \ - { LAUNCH_SPLIT_FN_ENTRY(QG, HD, 64), LAUNCH_SPLIT_FN_ENTRY(QG, HD, 128) } - -#define HD_ENTRIES(QG) \ - { PAGE_ENTRIES(QG, 64), PAGE_ENTRIES(QG, 96), PAGE_ENTRIES(QG, 128), PAGE_ENTRIES(QG, 192), PAGE_ENTRIES(QG, 256) } - - static const launch_fn_t split_decode_table[5][5][2] = { - HD_ENTRIES(1), - HD_ENTRIES(2), - HD_ENTRIES(4), - HD_ENTRIES(8), - HD_ENTRIES(16), - }; - -#undef HD_ENTRIES -#undef PAGE_ENTRIES - - int qg_idx = -1; - switch (qg_sz) { - case 1: - qg_idx = 0; - break; - case 2: - qg_idx = 1; - break; - case 4: - qg_idx = 2; - break; - case 8: - qg_idx = 3; - break; - case 16: - qg_idx = 4; - break; - default: - return nullptr; - } - - int hd_idx = -1; - switch (head_dim) { - case 64: - hd_idx = 0; - break; - case 96: - hd_idx = 1; - break; - case 128: - hd_idx = 2; - break; - case 192: - hd_idx = 3; - break; - case 256: - hd_idx = 4; - break; - default: - return nullptr; - } - - int ps_idx = -1; - switch (page_size) { - case 64: - ps_idx = 0; - break; - case 128: - ps_idx = 1; - break; - default: - return nullptr; - } - - const auto& table = use_split ? split_decode_table : decode_table; - return table[qg_idx][hd_idx][ps_idx]; -} - -#undef LAUNCH_FN_ENTRY -#undef LAUNCH_SPLIT_FN_ENTRY - -} // namespace +// Dispatch macros following the GroupGemmXe20.cpp pattern. +// Directly call struct operator() — no function pointers. + +#define DISPATCH_DECODE_KERNEL(QG, HD, PS) \ + do { \ + if (params.use_split_kv_decode) { \ + FmhaSplitDecodeRunner{}(use_sink, params); \ + } else { \ + FmhaDecodeRunner{}(use_sink, params); \ + } \ + } while (0) + +#define DISPATCH_DECODE_PAGE_SIZE(QG, HD) \ + do { \ + switch (params.page_size) { \ + case 64: \ + DISPATCH_DECODE_KERNEL(QG, HD, 64); \ + break; \ + case 128: \ + DISPATCH_DECODE_KERNEL(QG, HD, 128); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported page_size for decode attention: ", params.page_size); \ + } \ + } while (0) + +#define DISPATCH_DECODE_HEAD_DIM(QG) \ + do { \ + switch (params.d) { \ + case 64: \ + DISPATCH_DECODE_PAGE_SIZE(QG, 64); \ + break; \ + case 96: \ + DISPATCH_DECODE_PAGE_SIZE(QG, 96); \ + break; \ + case 128: \ + DISPATCH_DECODE_PAGE_SIZE(QG, 128); \ + break; \ + case 192: \ + DISPATCH_DECODE_PAGE_SIZE(QG, 192); \ + break; \ + case 256: \ + DISPATCH_DECODE_PAGE_SIZE(QG, 256); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported head size for decode attention: ", params.d); \ + } \ + } while (0) + +#define DISPATCH_DECODE(qg_sz) \ + do { \ + switch (qg_sz) { \ + case 1: \ + DISPATCH_DECODE_HEAD_DIM(1); \ + break; \ + case 2: \ + DISPATCH_DECODE_HEAD_DIM(2); \ + break; \ + case 4: \ + DISPATCH_DECODE_HEAD_DIM(4); \ + break; \ + case 8: \ + DISPATCH_DECODE_HEAD_DIM(8); \ + break; \ + case 16: \ + DISPATCH_DECODE_HEAD_DIM(16); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported q_group_size for decode attention: ", params.q_group_size); \ + } \ + } while (0) std::vector mha_fwd( const at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q @@ -446,13 +408,16 @@ std::vector mha_fwd( "Unsupported page size for decode attention: ", params.page_size); - auto fn = get_launch_fn(qg_sz, params.d, params.page_size, params.use_split_kv_decode); - TORCH_CHECK(fn != nullptr, "No FMHA decode kernel for qg=", qg_sz, " hd=", params.d, " ps=", params.page_size); - fn(use_sink, params); + DISPATCH_DECODE(qg_sz); return {out, softmax_lse, out_accum, softmax_lse_accum}; } +#undef DISPATCH_DECODE_KERNEL +#undef DISPATCH_DECODE_PAGE_SIZE +#undef DISPATCH_DECODE_HEAD_DIM +#undef DISPATCH_DECODE + } // namespace decode std::vector mha_fwd( diff --git a/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp b/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp index cb80521f..f644b00f 100644 --- a/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp +++ b/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp @@ -790,7 +790,7 @@ struct SplitDecodeConfig { } }; -// Struct functors for use with the function-pointer dispatch table. +// Struct functors for decode kernel dispatch. // operator() is declared here; each specialization's body is defined in a // generated .cpp file (from xe_fmha_fwd_decode_kernel.cpp.in / // xe_fmha_fwd_split_decode_kernel.cpp.in) so the compiler only emits code @@ -799,17 +799,11 @@ struct SplitDecodeConfig { template struct FmhaDecodeRunner { void operator()(bool use_sink, const Arguments& params) const; - static void call(bool use_sink, const Arguments& params) { - FmhaDecodeRunner{}(use_sink, params); - } }; template struct FmhaSplitDecodeRunner { void operator()(bool use_sink, const Arguments& params) const; - static void call(bool use_sink, const Arguments& params) { - FmhaSplitDecodeRunner{}(use_sink, params); - } }; } // namespace decode From 5c294805ec1324523aa6d448e2d98af615c5cb35 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 24 Mar 2026 03:47:25 +0000 Subject: [PATCH 22/23] Add use_sink and use_causal_mask to Arguments; remove bool use_sink from operator() signature Co-authored-by: sunjiweiswift <16934286+sunjiweiswift@users.noreply.github.com> Agent-Logs-Url: https://github.com/sgl-project/sgl-kernel-xpu/sessions/f88d0fec-8f93-4f48-99b8-714d68fd14f4 --- src/sycl/flash_attention.cpp | 6 ++++-- .../xe_fmha_fwd_decode_runner.hpp | 7 +++++-- src/sycl/xe_fmha_fwd_decode_kernel.cpp.in | 11 ++++++----- src/sycl/xe_fmha_fwd_split_decode_kernel.cpp.in | 13 +++++++------ 4 files changed, 22 insertions(+), 15 deletions(-) diff --git a/src/sycl/flash_attention.cpp b/src/sycl/flash_attention.cpp index 5cef56c0..6340657f 100644 --- a/src/sycl/flash_attention.cpp +++ b/src/sycl/flash_attention.cpp @@ -46,9 +46,9 @@ namespace decode { #define DISPATCH_DECODE_KERNEL(QG, HD, PS) \ do { \ if (params.use_split_kv_decode) { \ - FmhaSplitDecodeRunner{}(use_sink, params); \ + FmhaSplitDecodeRunner{}(params); \ } else { \ - FmhaDecodeRunner{}(use_sink, params); \ + FmhaDecodeRunner{}(params); \ } \ } while (0) @@ -310,6 +310,7 @@ std::vector mha_fwd( params.softmax_scale = softmax_scale; bool use_sink = sinks_.has_value(); params.softmax_sink_ptr = use_sink ? sinks_.value().data_ptr() : nullptr; + params.use_sink = use_sink; params.softcap = softcap; @@ -319,6 +320,7 @@ std::vector mha_fwd( // Causal is the special case where window_size_right == 0 and window_size_left < 0. // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. params.is_causal = window_size_left < 0 && window_size_right == 0; + params.use_causal_mask = params.is_causal; params.is_local = (window_size_left >= 0 || window_size_right >= 0) && !params.is_causal; // TODO: check this diff --git a/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp b/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp index f644b00f..f889e26a 100644 --- a/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp +++ b/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp @@ -182,6 +182,9 @@ struct Arguments { bool is_causal; bool is_local; + bool use_sink = false; + bool use_causal_mask = false; + bool is_rotary_interleaved; torch::TensorOptions tensor_opts; @@ -798,12 +801,12 @@ struct SplitDecodeConfig { template struct FmhaDecodeRunner { - void operator()(bool use_sink, const Arguments& params) const; + void operator()(const Arguments& params) const; }; template struct FmhaSplitDecodeRunner { - void operator()(bool use_sink, const Arguments& params) const; + void operator()(const Arguments& params) const; }; } // namespace decode diff --git a/src/sycl/xe_fmha_fwd_decode_kernel.cpp.in b/src/sycl/xe_fmha_fwd_decode_kernel.cpp.in index eac58cd2..d8414d2b 100644 --- a/src/sycl/xe_fmha_fwd_decode_kernel.cpp.in +++ b/src/sycl/xe_fmha_fwd_decode_kernel.cpp.in @@ -38,16 +38,17 @@ namespace decode { template <> -void FmhaDecodeRunner<@QG_SZ@, @HEAD_DIM@, @PAGE_SIZE@>::operator()(bool use_sink, const Arguments& params) const { - constexpr bool Causal = false; +void FmhaDecodeRunner<@QG_SZ@, @HEAD_DIM@, @PAGE_SIZE@>::operator()(const Arguments& params) const { using TileShapeQK = cute::Shape, cute::Int<@PAGE_SIZE@>, cute::_64>; using TileShapePV = cute::Shape, cute::_32, cute::Int<@PAGE_SIZE@>>; using TileShapeOutput = cute::Shape, cute::Int<@HEAD_DIM@>>; using SubgroupLayoutQK = cute::Layout, cute::_1>>; - AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { - AT_DISPATCH_BOOL_NO_RETURN(params.is_local, LocalMask, { - DecodeConfig::run(params); + AT_DISPATCH_BOOL_NO_RETURN(params.use_causal_mask, Causal, { + AT_DISPATCH_BOOL_NO_RETURN(params.use_sink, Sink, { + AT_DISPATCH_BOOL_NO_RETURN(params.is_local, LocalMask, { + DecodeConfig::run(params); + }); }); }); } diff --git a/src/sycl/xe_fmha_fwd_split_decode_kernel.cpp.in b/src/sycl/xe_fmha_fwd_split_decode_kernel.cpp.in index 7ed51235..9bfc40f3 100644 --- a/src/sycl/xe_fmha_fwd_split_decode_kernel.cpp.in +++ b/src/sycl/xe_fmha_fwd_split_decode_kernel.cpp.in @@ -38,17 +38,18 @@ namespace decode { template <> -void FmhaSplitDecodeRunner<@QG_SZ@, @HEAD_DIM@, @PAGE_SIZE@>::operator()(bool use_sink, const Arguments& params) const { - constexpr bool Causal = false; +void FmhaSplitDecodeRunner<@QG_SZ@, @HEAD_DIM@, @PAGE_SIZE@>::operator()(const Arguments& params) const { using TileShapeQK = cute::Shape, cute::Int<@PAGE_SIZE@>, cute::_64>; using TileShapePV = cute::Shape, cute::_32, cute::Int<@PAGE_SIZE@>>; using TileShapeOutput = cute::Shape, cute::Int<@HEAD_DIM@>>; using SubgroupLayoutQK = cute::Layout, cute::_1>>; - AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { - AT_DISPATCH_BOOL_NO_RETURN(params.is_local, LocalMask, { - SplitDecodeConfig::run( - params); + AT_DISPATCH_BOOL_NO_RETURN(params.use_causal_mask, Causal, { + AT_DISPATCH_BOOL_NO_RETURN(params.use_sink, Sink, { + AT_DISPATCH_BOOL_NO_RETURN(params.is_local, LocalMask, { + SplitDecodeConfig::run( + params); + }); }); }); } From a7eaeeef8f537ef938c799f94d4462d3450ca4aa Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 24 Mar 2026 04:54:01 +0000 Subject: [PATCH 23/23] Replace non-ASCII em-dash in flash_attention.cpp comment with ASCII hyphen Co-authored-by: sunjiweiswift <16934286+sunjiweiswift@users.noreply.github.com> Agent-Logs-Url: https://github.com/sgl-project/sgl-kernel-xpu/sessions/8cc49274-a20f-42e3-aad5-39043ba2eefa --- src/sycl/flash_attention.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sycl/flash_attention.cpp b/src/sycl/flash_attention.cpp index 6340657f..0fc8639c 100644 --- a/src/sycl/flash_attention.cpp +++ b/src/sycl/flash_attention.cpp @@ -41,7 +41,7 @@ namespace decode { // Dispatch macros following the GroupGemmXe20.cpp pattern. -// Directly call struct operator() — no function pointers. +// Directly call struct operator() - no function pointers. #define DISPATCH_DECODE_KERNEL(QG, HD, PS) \ do { \