Split-KV decode, refactor prefill instantiation, and add flash_attn CI benchmarking#145
Split-KV decode, refactor prefill instantiation, and add flash_attn CI benchmarking#145
Conversation
There was a problem hiding this comment.
Pull request overview
This PR introduces a split-KV implementation for FlashAttention decode on XPU, aiming to improve decode throughput by partitioning the KV sequence into multiple splits and reducing partial results into the final output.
Changes:
- Adds split-KV decode kernel path (new split-KV FMHA kernel + ReduceSplitK reduction + new decode tile scheduler).
- Updates the
sgl_kernel.fwdoperator schema to remove thenum_splitsargument and adjusts call sites accordingly. - Adjusts FlashAttention decode tests to vary batch size and tweak cache length sampling.
Reviewed changes
Copilot reviewed 12 out of 12 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
tests/test_flash_attention.py |
Updates decode test parametrization (batch sizes, head dim coverage) and cache length sampling. |
src/torch_extension_sycl.cc |
Updates Torch op schema for fwd (removes num_splits, changes q mutability annotation). |
src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp |
Wires split-KV decode runner path, allocates split buffers, and dispatches the new kernel config. |
src/sycl/kernels/flash_attention_v2/kernel/xe_tile_scheduler.hpp |
Adds DecodeTileScheduler and XeReduceSplitKTileScheduler. |
src/sycl/kernels/flash_attention_v2/kernel/xe_reduce_split_k.hpp |
Adds a new ReduceSplitK kernel implementation for combining split outputs. |
src/sycl/kernels/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp |
Adds XeFMHAFwdSplitKVKernel to compute per-split partial outputs. |
src/sycl/kernels/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp |
Adds DecodeFwdMainloop specialization used by split-KV decode. |
src/sycl/kernels/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp |
Adds DecodeFwdEpilogue with logic to store per-split exp sums / max logits. |
src/sycl/kernels/chunk_prefill/chunk_prefill_runner.hpp |
Renames a params field (num_splits → num_kv_splits) and adjusts decode-related signatures. |
src/sycl/flash_attention.cpp |
Removes the num_splits argument forwarding to decode/prefill implementations. |
python/sgl_kernel/flash_attn.py |
Stops passing num_splits to the op; modifies max_seqlen_k handling in the KV-cache wrapper. |
include/sgl_flash_kernel_ops.h |
Updates the mha_fwd declaration to match the removed num_splits parameter and const correctness. |
|
@copilot please rebase the main branch |
|
@sunjiweiswift I've opened a new pull request, #146, to work on those changes. Once the pull request is ready, I'll request review from you. |
9bdad0b to
25a95a3
Compare
404bd3f to
b6257f2
Compare
* Split FMHA decode and GroupGemm template instantiations into per-kernel compilation units (#140) Split the monolithic template instantiation of xe_fmha_fwd_decode_runner.hpp into 72 separate .cpp files (one per QG_SZ × HEAD_DIM × PAGE_SIZE combination), each compiled as its own library. This enables parallel compilation and significantly speeds up build times. Changes: - Create xe_fmha_fwd_decode_kernel.cpp.in template for per-combination compilation - Create xe_fmha_fwd_decode_dispatch.hpp with function declarations for all 72 kernels - Move decode::mha_fwd() from header to flash_attention.cpp with dispatch table - Update src/CMakeLists.txt to generate .cpp files via configure_file() - Remove mha_fwd() definition from xe_fmha_fwd_decode_runner.hpp header Co-authored-by: airMeng <39229107+airMeng@users.noreply.github.com> Co-authored-by: jiwei1.sun <jiwei1.sun@intel.com> * Fix noncontiguous input for rmsnorm (#117) * fix norm with noncontiguous input * remove comment out test * support in kernel * Add MXFP4 Per Token Group Quant kernel and tests (#106) * Add MXFP4 Per Token Group Quant kernel and tests Remove commented out fp8 blockwise group gemm registration * Add benchmarking for per token group quant mxfp4 * Add test to run_suite.py * Fix group size constraint for mxfp4; Add benchmark test to CI flow * Remove reference provider from the benchmark script - Add check for quantized and scale values separately - Include eps value in ref quant function call * Fix MXFP4 quantization to match OCP MX spec - Replace ceil(log2(max/6.0)) scale computation with floor(log2(max)) - E2M1_EMAX per OCP MX spec - Fix roundTiesToEven at midpoints in SYCL kernel (change <= to < at odd-mantissa boundaries) - Replace naive argmin-based quantize_to_e2m1 reference with microxcaling _quantize_elemwise_core algorithm - Normalize signed zeros (+0.0 vs -0.0) before packed byte comparison in tests and benchmark * Fix lint issues * Remove unsupported group sizes * Apply formatting check * Add TODO for quantize_to_e2m1 * Update tests/test_per_token_group_quant_mxfp4.py Co-authored-by: Meng, Hengyu <airdldl@163.com> * trigger CI --------- Co-authored-by: Meng, Hengyu <airdldl@163.com> * add page 64 * Initial plan * add reduce.h * add XeFMHAFwdSplitKVKernel * const tensor for Q * add split kernel * save * cache_seqlens * head_dim =128 * 2026 * test for mingxu * Initial plan * Rebase onto main: integrate split-KV changes into flash_attention.cpp and fix SplitDeodeConfig typo * Rebase onto updated split_kv_decode: fix FMHAConfig undefined, add conditional dispatch between DecodeConfig and SplitDecodeConfig, comment out page_size 32 Co-authored-by: sunjiweiswift <16934286+sunjiweiswift@users.noreply.github.com> Agent-Logs-Url: https://github.com/sgl-project/sgl-kernel-xpu/sessions/82c107d1-3f61-4ce0-9444-d7f19f27a292 * bugfix * Refactor dispatch to function pointer tables following GroupGemmXe20 pattern Co-authored-by: sunjiweiswift <16934286+sunjiweiswift@users.noreply.github.com> Agent-Logs-Url: https://github.com/sgl-project/sgl-kernel-xpu/sessions/5c058f18-d408-4752-aba6-2625a6e5e638 * Refactor decode dispatch to struct operator() following GroupGemmXe20 pattern Co-authored-by: sunjiweiswift <16934286+sunjiweiswift@users.noreply.github.com> Agent-Logs-Url: https://github.com/sgl-project/sgl-kernel-xpu/sessions/871c6171-1314-44d4-b4f4-007ddffc6ed4 * Replace function pointer table with direct struct operator() dispatch (GroupGemmXe20 pattern) Co-authored-by: sunjiweiswift <16934286+sunjiweiswift@users.noreply.github.com> Agent-Logs-Url: https://github.com/sgl-project/sgl-kernel-xpu/sessions/85877884-e91a-400f-adcc-5b9a55692df9 * Add use_sink and use_causal_mask to Arguments; remove bool use_sink from operator() signature Co-authored-by: sunjiweiswift <16934286+sunjiweiswift@users.noreply.github.com> Agent-Logs-Url: https://github.com/sgl-project/sgl-kernel-xpu/sessions/f88d0fec-8f93-4f48-99b8-714d68fd14f4 * Replace non-ASCII em-dash in flash_attention.cpp comment with ASCII hyphen Co-authored-by: sunjiweiswift <16934286+sunjiweiswift@users.noreply.github.com> Agent-Logs-Url: https://github.com/sgl-project/sgl-kernel-xpu/sessions/8cc49274-a20f-42e3-aad5-39043ba2eefa --------- Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com> Co-authored-by: airMeng <39229107+airMeng@users.noreply.github.com> Co-authored-by: jiwei1.sun <jiwei1.sun@intel.com> Co-authored-by: Yang Wang <yang3.wang@intel.com> Co-authored-by: Suryaprakash Shanmugam <suryaprakash.shanmugam@intel.com> Co-authored-by: Meng, Hengyu <airdldl@163.com> Co-authored-by: sunjiweiswift <16934286+sunjiweiswift@users.noreply.github.com>
Done in commit
This also fixes the broken CMake |
Agent-Logs-Url: https://github.com/sgl-project/sgl-kernel-xpu/sessions/483ec9d0-5189-453e-8b80-7bafa0ec9939 Co-authored-by: sunjiweiswift <16934286+sunjiweiswift@users.noreply.github.com>
950da3d to
0831deb
Compare
6d6e733 to
952184d
Compare
|
do we have benchmark results to show the results? |
Main branch
Split Kernel
|
mingfeima
left a comment
There was a problem hiding this comment.
Try to update commit message with clear results:
- decoding achieves 2-4x speedup with kv splits
- prefill (seqlen_q 128) stays unchanged
Don't apply split kernel for prefill now. Only decode. So prefill is the same as main. |
4451c17 to
5da6096
Compare
5da6096 to
b23cc66
Compare
|
@copilot modify sgl-kernel-xpu/benchmark/update_baseline_from_log.py and pr-test-xpu.yml#L78 to support flash_attention's benchmark |
…rom_log.py Agent-Logs-Url: https://github.com/sgl-project/sgl-kernel-xpu/sessions/7877113f-361e-4c42-b0e6-96e7dfef2ad3 Co-authored-by: sunjiweiswift <16934286+sunjiweiswift@users.noreply.github.com>
Done in commit
|
Small batch / large
seq_kvdecode suffers from insufficient workgroup parallelism. This PR adds split-KV partitioning to the decode path (2–4× speedup on decode, prefill unchanged) and refactors prefill instantiation to match the decode/split-decode pattern.Split-KV decode
num_kv_splitsworkgroups viagrid.z = batch * num_heads_kv * num_kv_splitsFmhaSplitDecodeRunner,ReduceSplitKreduction kernel, andDecodeTileSchedulerwith split-aware block coordinate decompositionnum_kv_splitsdetermined automatically (matching vLLM'sDIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE)approach)num_splitsremoved fromsgl_kernel.fwdop schema; now computed internallyPrefill instantiation refactor
FmhaPrefillRunner<HEAD_DIM>struct functor (matchingFmhaDecodeRunner/FmhaSplitDecodeRunner)xe_fmha_fwd_prefill_kernel.cpp.ingenerates per-HEAD_DIM translation unitsFMHAPrefillXe20.cmakeuses per-dim variable sets matchingFMHADecodeXe20.cmakestructureDISPATCH_PREFILL_KERNELmacro inflash_attention.cppreplaces inline function tableCI benchmark support for flash_attn
update_baseline_from_log.pyextended withparse_flash_attn_log()parser; handles multiple benchmark types withflash_attn:-prefixed baseline keyspr-test-xpu.ymlcopiesflash.logalongsidefused_moe.logbench_flash_attn.pyprint order fixed to emit table after"Benchmark finished!"markerBenchmark results (decode,
kv_seq_length=4096,head_dim=128)Prefill (
q_seq_length=128) is unchanged.