Skip to content

Split-KV decode, refactor prefill instantiation, and add flash_attn CI benchmarking#145

Merged
airMeng merged 37 commits intomainfrom
split_kv_decode
Apr 8, 2026
Merged

Split-KV decode, refactor prefill instantiation, and add flash_attn CI benchmarking#145
airMeng merged 37 commits intomainfrom
split_kv_decode

Conversation

@sunjiweiswift
Copy link
Copy Markdown
Collaborator

@sunjiweiswift sunjiweiswift commented Mar 23, 2026

Small batch / large seq_kv decode suffers from insufficient workgroup parallelism. This PR adds split-KV partitioning to the decode path (2–4× speedup on decode, prefill unchanged) and refactors prefill instantiation to match the decode/split-decode pattern.

Split-KV decode

  • Partition KV sequence across num_kv_splits workgroups via grid.z = batch * num_heads_kv * num_kv_splits
  • Added FmhaSplitDecodeRunner, ReduceSplitK reduction kernel, and DecodeTileScheduler with split-aware block coordinate decomposition
  • num_kv_splits determined automatically (matching vLLM's DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE) approach)
  • num_splits removed from sgl_kernel.fwd op schema; now computed internally

Prefill instantiation refactor

  • Added FmhaPrefillRunner<HEAD_DIM> struct functor (matching FmhaDecodeRunner/FmhaSplitDecodeRunner)
  • CMake template xe_fmha_fwd_prefill_kernel.cpp.in generates per-HEAD_DIM translation units
  • FMHAPrefillXe20.cmake uses per-dim variable sets matching FMHADecodeXe20.cmake structure
  • DISPATCH_PREFILL_KERNEL macro in flash_attention.cpp replaces inline function table

CI benchmark support for flash_attn

  • update_baseline_from_log.py extended with parse_flash_attn_log() parser; handles multiple benchmark types with flash_attn:-prefixed baseline keys
  • pr-test-xpu.yml copies flash.log alongside fused_moe.log
  • bench_flash_attn.py print order fixed to emit table after "Benchmark finished!" marker

Benchmark results (decode, kv_seq_length=4096, head_dim=128)

batch num_heads_kv main (ms) split-KV (ms) speedup
1 2 0.057 0.022 2.6×
2 2 0.064 0.031 2.0×
4 2 0.129 0.052 2.5×
8 2 0.200 0.092 2.2×

Prefill (q_seq_length=128) is unchanged.

Copilot AI review requested due to automatic review settings March 23, 2026 07:01
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces a split-KV implementation for FlashAttention decode on XPU, aiming to improve decode throughput by partitioning the KV sequence into multiple splits and reducing partial results into the final output.

Changes:

  • Adds split-KV decode kernel path (new split-KV FMHA kernel + ReduceSplitK reduction + new decode tile scheduler).
  • Updates the sgl_kernel.fwd operator schema to remove the num_splits argument and adjusts call sites accordingly.
  • Adjusts FlashAttention decode tests to vary batch size and tweak cache length sampling.

Reviewed changes

Copilot reviewed 12 out of 12 changed files in this pull request and generated 9 comments.

Show a summary per file
File Description
tests/test_flash_attention.py Updates decode test parametrization (batch sizes, head dim coverage) and cache length sampling.
src/torch_extension_sycl.cc Updates Torch op schema for fwd (removes num_splits, changes q mutability annotation).
src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp Wires split-KV decode runner path, allocates split buffers, and dispatches the new kernel config.
src/sycl/kernels/flash_attention_v2/kernel/xe_tile_scheduler.hpp Adds DecodeTileScheduler and XeReduceSplitKTileScheduler.
src/sycl/kernels/flash_attention_v2/kernel/xe_reduce_split_k.hpp Adds a new ReduceSplitK kernel implementation for combining split outputs.
src/sycl/kernels/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp Adds XeFMHAFwdSplitKVKernel to compute per-split partial outputs.
src/sycl/kernels/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp Adds DecodeFwdMainloop specialization used by split-KV decode.
src/sycl/kernels/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp Adds DecodeFwdEpilogue with logic to store per-split exp sums / max logits.
src/sycl/kernels/chunk_prefill/chunk_prefill_runner.hpp Renames a params field (num_splitsnum_kv_splits) and adjusts decode-related signatures.
src/sycl/flash_attention.cpp Removes the num_splits argument forwarding to decode/prefill implementations.
python/sgl_kernel/flash_attn.py Stops passing num_splits to the op; modifies max_seqlen_k handling in the KV-cache wrapper.
include/sgl_flash_kernel_ops.h Updates the mha_fwd declaration to match the removed num_splits parameter and const correctness.

Comment thread python/sgl_kernel/flash_attn.py
Comment thread src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp
Comment thread src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp Outdated
Comment thread src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp Outdated
Comment thread src/sycl/kernels/flash_attention_v2/kernel/xe_tile_scheduler.hpp Outdated
Comment thread tests/test_flash_attention.py Outdated
Comment thread python/sgl_kernel/flash_attn.py
Comment thread python/sgl_kernel/flash_attn.py
Comment thread src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp Outdated
@sunjiweiswift
Copy link
Copy Markdown
Collaborator Author

@copilot please rebase the main branch

Copy link
Copy Markdown
Contributor

Copilot AI commented Mar 23, 2026

@sunjiweiswift I've opened a new pull request, #146, to work on those changes. Once the pull request is ready, I'll request review from you.

Copilot AI mentioned this pull request Mar 23, 2026
5 tasks
@airMeng
Copy link
Copy Markdown
Collaborator

airMeng commented Mar 23, 2026

@copilot cherry-pick #143

Copy link
Copy Markdown
Contributor

Copilot AI commented Mar 23, 2026

@airMeng I've opened a new pull request, #147, to work on those changes. Once the pull request is ready, I'll request review from you.

Copy link
Copy Markdown
Collaborator

@airMeng airMeng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do not forget

- name: Auto PR for baseline.json update

Comment thread src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp Outdated
Comment thread src/sycl/kernels/flash_attention_v2/kernel/xe_reduce_split_k.hpp
Comment thread python/sgl_kernel/flash_attn.py Outdated
Comment thread src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp Outdated
Comment thread src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp Outdated
* Split FMHA decode and GroupGemm template instantiations into per-kernel compilation units (#140)

Split the monolithic template instantiation of xe_fmha_fwd_decode_runner.hpp
into 72 separate .cpp files (one per QG_SZ × HEAD_DIM × PAGE_SIZE combination),
each compiled as its own library. This enables parallel compilation and
significantly speeds up build times.

Changes:
- Create xe_fmha_fwd_decode_kernel.cpp.in template for per-combination compilation
- Create xe_fmha_fwd_decode_dispatch.hpp with function declarations for all 72 kernels
- Move decode::mha_fwd() from header to flash_attention.cpp with dispatch table
- Update src/CMakeLists.txt to generate .cpp files via configure_file()
- Remove mha_fwd() definition from xe_fmha_fwd_decode_runner.hpp header

Co-authored-by: airMeng <39229107+airMeng@users.noreply.github.com>
Co-authored-by: jiwei1.sun <jiwei1.sun@intel.com>

* Fix noncontiguous input for rmsnorm (#117)

* fix norm with noncontiguous input

* remove comment out test

* support in kernel

* Add MXFP4 Per Token Group Quant kernel and tests (#106)

* Add MXFP4 Per Token Group Quant kernel and tests

Remove commented out fp8 blockwise group gemm registration

* Add benchmarking for per token group quant mxfp4

* Add test to run_suite.py

* Fix group size constraint for mxfp4; Add benchmark test to CI flow

* Remove reference provider from the benchmark script

- Add check for quantized and scale values separately
- Include eps value in ref quant function call

* Fix MXFP4 quantization to match OCP MX spec

- Replace ceil(log2(max/6.0)) scale computation with floor(log2(max)) -
  E2M1_EMAX per OCP MX spec
- Fix roundTiesToEven at midpoints in SYCL kernel (change <= to < at
  odd-mantissa boundaries)
- Replace naive argmin-based quantize_to_e2m1 reference with
  microxcaling _quantize_elemwise_core algorithm
- Normalize signed zeros (+0.0 vs -0.0) before packed byte comparison in
  tests and benchmark

* Fix lint issues

* Remove unsupported group sizes

* Apply formatting check

* Add TODO for quantize_to_e2m1

* Update tests/test_per_token_group_quant_mxfp4.py

Co-authored-by: Meng, Hengyu <airdldl@163.com>

* trigger CI

---------

Co-authored-by: Meng, Hengyu <airdldl@163.com>

* add page 64

* Initial plan

* add reduce.h

* add XeFMHAFwdSplitKVKernel

* const tensor for Q

* add split kernel

* save

* cache_seqlens

* head_dim =128

* 2026

* test for mingxu

* Initial plan

* Rebase onto main: integrate split-KV changes into flash_attention.cpp and fix SplitDeodeConfig typo

* Rebase onto updated split_kv_decode: fix FMHAConfig undefined, add conditional dispatch between DecodeConfig and SplitDecodeConfig, comment out page_size 32

Co-authored-by: sunjiweiswift <16934286+sunjiweiswift@users.noreply.github.com>
Agent-Logs-Url: https://github.com/sgl-project/sgl-kernel-xpu/sessions/82c107d1-3f61-4ce0-9444-d7f19f27a292

* bugfix

* Refactor dispatch to function pointer tables following GroupGemmXe20 pattern

Co-authored-by: sunjiweiswift <16934286+sunjiweiswift@users.noreply.github.com>
Agent-Logs-Url: https://github.com/sgl-project/sgl-kernel-xpu/sessions/5c058f18-d408-4752-aba6-2625a6e5e638

* Refactor decode dispatch to struct operator() following GroupGemmXe20 pattern

Co-authored-by: sunjiweiswift <16934286+sunjiweiswift@users.noreply.github.com>
Agent-Logs-Url: https://github.com/sgl-project/sgl-kernel-xpu/sessions/871c6171-1314-44d4-b4f4-007ddffc6ed4

* Replace function pointer table with direct struct operator() dispatch (GroupGemmXe20 pattern)

Co-authored-by: sunjiweiswift <16934286+sunjiweiswift@users.noreply.github.com>
Agent-Logs-Url: https://github.com/sgl-project/sgl-kernel-xpu/sessions/85877884-e91a-400f-adcc-5b9a55692df9

* Add use_sink and use_causal_mask to Arguments; remove bool use_sink from operator() signature

Co-authored-by: sunjiweiswift <16934286+sunjiweiswift@users.noreply.github.com>
Agent-Logs-Url: https://github.com/sgl-project/sgl-kernel-xpu/sessions/f88d0fec-8f93-4f48-99b8-714d68fd14f4

* Replace non-ASCII em-dash in flash_attention.cpp comment with ASCII hyphen

Co-authored-by: sunjiweiswift <16934286+sunjiweiswift@users.noreply.github.com>
Agent-Logs-Url: https://github.com/sgl-project/sgl-kernel-xpu/sessions/8cc49274-a20f-42e3-aad5-39043ba2eefa

---------

Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
Co-authored-by: airMeng <39229107+airMeng@users.noreply.github.com>
Co-authored-by: jiwei1.sun <jiwei1.sun@intel.com>
Co-authored-by: Yang Wang <yang3.wang@intel.com>
Co-authored-by: Suryaprakash Shanmugam <suryaprakash.shanmugam@intel.com>
Co-authored-by: Meng, Hengyu <airdldl@163.com>
Co-authored-by: sunjiweiswift <16934286+sunjiweiswift@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

Copilot AI commented Apr 3, 2026

@copilot modify sgl-kernel-xpu/src/FMHAPrefillXe20.cmake ref sgl-kernel-xpu/src/FMHADecodeXe20.cmake

Done in commit 950da3d. FMHAPrefillXe20.cmake now follows the same structure as FMHADecodeXe20.cmake:

  • Declares FMHA_PREFILL_HEAD_DIMS list at the top
  • Sets per-HEAD_DIM tile shape variables (FMHA_PREFILL_TILED_Q_<HD>, FMHA_PREFILL_TILED_KV_<HD>, FMHA_PREFILL_NUM_SG_<HD>)
  • Uses a single foreach(HEAD_DIM ...) loop with configure_file(@ONLY) and list(APPEND device_cpp_common ...)

This also fixes the broken CMake list(GET CONFIG ...) usage in the previous implementation, which incorrectly treated tuple-strings as individual list elements.

@mingfeima
Copy link
Copy Markdown
Collaborator

mingfeima commented Apr 8, 2026

do we have benchmark results to show the results?

@sunjiweiswift
Copy link
Copy Markdown
Collaborator Author

do we have benchmark results to show the results?

Main branch

batch q_seq_length kv_seq_length num_heads_q num_heads_kv head_dim causal local use_sinks page_size provider tflops bandwidth ms
0 1 1 4096 16 2 128 False False False 128 flash_attn 0.593086 74.2805 0.056576
1 1 1 4096 16 4 128 False False False 128 flash_attn 0.562578 140.782 0.059644
2 1 1 4096 16 8 128 False False False 128 flash_attn 0.57409 287.185 0.058448
3 1 128 4096 16 2 128 False False False 128 flash_attn 26.704 32.5977 0.160836
4 1 128 4096 16 4 128 False False False 128 flash_attn 26.1378 57.4317 0.16432
5 1 128 4096 16 8 128 False False False 128 flash_attn 25.6987 106.66 0.167128
6 2 1 4096 16 2 128 False False False 128 flash_attn 1.05351 131.946 0.0637
7 2 1 4096 16 4 128 False False False 128 flash_attn 0.967433 242.094 0.069368
8 2 1 4096 16 8 128 False False False 128 flash_attn 0.749016 374.691 0.089596
9 2 128 4096 16 2 128 False False False 128 flash_attn 43.5745 53.1916 0.197132
10 2 128 4096 16 4 128 False False False 128 flash_attn 41.0005 90.089 0.209508
11 2 128 4096 16 8 128 False False False 128 flash_attn 39.5904 164.316 0.21697
12 4 1 4096 16 2 128 False False False 128 flash_attn 1.03951 130.193 0.129116
13 4 1 4096 16 4 128 False False False 128 flash_attn 1.15177 288.223 0.116532
14 4 1 4096 16 8 128 False False False 128 flash_attn 0.793211 396.799 0.169208
15 4 128 4096 16 2 128 False False False 128 flash_attn 44.0803 53.809 0.38974
16 4 128 4096 16 4 128 False False False 128 flash_attn 42.4383 93.2482 0.40482
17 4 128 4096 16 8 128 False False False 128 flash_attn 40.431 167.805 0.424918
18 8 1 4096 16 2 128 False False False 128 flash_attn 1.34503 168.457 0.199576
19 8 1 4096 16 4 128 False False False 128 flash_attn 1.28286 321.028 0.209248
20 8 1 4096 16 8 128 False False False 128 flash_attn 0.838704 419.557 0.32006
21 8 128 4096 16 2 128 False False False 128 flash_attn 44.9622 54.8855 0.764192
22 8 128 4096 16 4 128 False False False 128 flash_attn 43.4456 95.4615 0.790868
23 8 128 4096 16 8 128 False False False 128 flash_attn 41.9373 174.056 0.819312
Benchmark finished!

Split Kernel

batch q_seq_length kv_seq_length num_heads_q num_heads_kv head_dim causal local use_sinks page_size provider tflops bandwidth ms
0 1 1 4096 16 2 128 False False False 128 flash_attn 1.51474 189.712 0.022152
1 1 1 4096 16 4 128 False False False 128 flash_attn 1.07011 267.789 0.031356
2 1 1 4096 16 8 128 False False False 128 flash_attn 0.664549 332.437 0.050492
3 1 128 4096 16 2 128 False False False 128 flash_attn 26.7256 32.624 0.160706
4 1 128 4096 16 4 128 False False False 128 flash_attn 26.1792 57.5228 0.16406
5 1 128 4096 16 8 128 False False False 128 flash_attn 25.6987 106.66 0.167128
6 2 1 4096 16 2 128 False False False 128 flash_attn 2.13315 267.164 0.03146
7 2 1 4096 16 4 128 False False False 128 flash_attn 1.30623 326.876 0.051376
8 2 1 4096 16 8 128 False False False 128 flash_attn 0.742125 371.244 0.090428
9 2 128 4096 16 2 128 False False False 128 flash_attn 43.609 53.2337 0.196976
10 2 128 4096 16 4 128 False False False 128 flash_attn 41.0413 90.1785 0.2093
11 2 128 4096 16 8 128 False False False 128 flash_attn 39.5762 164.257 0.217048
12 4 1 4096 16 2 128 False False False 128 flash_attn 2.56063 320.703 0.052416
13 4 1 4096 16 4 128 False False False 128 flash_attn 1.47114 368.143 0.091234
14 4 1 4096 16 8 128 False False False 128 flash_attn 0.818101 409.25 0.16406
15 4 128 4096 16 2 128 False False False 128 flash_attn 44.0833 53.8126 0.389714
16 4 128 4096 16 4 128 False False False 128 flash_attn 42.5038 93.3922 0.404196
17 4 128 4096 16 8 128 False False False 128 flash_attn 40.5004 168.093 0.42419
18 8 1 4096 16 2 128 False False False 128 flash_attn 2.90502 363.837 0.092404
19 8 1 4096 16 4 128 False False False 128 flash_attn 1.63594 409.385 0.164086
20 8 1 4096 16 8 128 False False False 128 flash_attn 0.860944 430.682 0.311792
21 8 128 4096 16 2 128 False False False 128 flash_attn 44.8736 54.7774 0.7657
22 8 128 4096 16 4 128 False False False 128 flash_attn 43.4356 95.4396 0.79105
23 8 128 4096 16 8 128 False False False 128 flash_attn 41.8656 173.758 0.820716
Benchmark finished!

Copy link
Copy Markdown
Collaborator

@mingfeima mingfeima left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try to update commit message with clear results:

  • decoding achieves 2-4x speedup with kv splits
  • prefill (seqlen_q 128) stays unchanged

@sunjiweiswift
Copy link
Copy Markdown
Collaborator Author

Try to update commit message with clear results:

  • decoding achieves 2-4x speedup with kv splits
  • prefill (seqlen_q 128) stays unchanged

Don't apply split kernel for prefill now. Only decode. So prefill is the same as main.
I will apply for prefill in next pr

@sunjiweiswift
Copy link
Copy Markdown
Collaborator Author

@copilot modify sgl-kernel-xpu/benchmark/update_baseline_from_log.py and pr-test-xpu.yml#L78 to support flash_attention's benchmark

…rom_log.py

Agent-Logs-Url: https://github.com/sgl-project/sgl-kernel-xpu/sessions/7877113f-361e-4c42-b0e6-96e7dfef2ad3

Co-authored-by: sunjiweiswift <16934286+sunjiweiswift@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

Copilot AI commented Apr 8, 2026

@copilot modify sgl-kernel-xpu/benchmark/update_baseline_from_log.py and pr-test-xpu.yml#L78 to support flash_attention's benchmark

Done in commit 9a03181. Changes:

  • benchmark/update_baseline_from_log.py: Added parse_flash_attn_log() parser, process_log() helper, and separate formatting sections for fused_moe and flash_attn benchmarks. Missing log files are skipped gracefully with a warning.
  • .github/workflows/pr-test-xpu.yml: Added docker cp for flash.log alongside fused_moe.log.
  • benchmark/bench_flash_attn.py: Moved print("Benchmark finished!") before the table output to match the fused_moe pattern (table printed after marker).

Copilot AI changed the title Split kv decode and refactor prefill instantiation Split-KV decode, refactor prefill instantiation, and add flash_attn CI benchmarking Apr 8, 2026
@sunjiweiswift sunjiweiswift added the perf Performance-related; enables baseline recording. label Apr 8, 2026
@airMeng airMeng merged commit f7975d2 into main Apr 8, 2026
5 of 6 checks passed
@sunjiweiswift sunjiweiswift deleted the split_kv_decode branch April 8, 2026 05:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

perf Performance-related; enables baseline recording. run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants