Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
a20df4d
add reduce.h
sunjiweiswift Mar 13, 2026
5ac747d
add XeFMHAFwdSplitKVKernel
sunjiweiswift Mar 13, 2026
902a29c
const tensor for Q
sunjiweiswift Mar 13, 2026
23bfd0c
add split kernel
sunjiweiswift Mar 17, 2026
b7dab66
save
sunjiweiswift Mar 17, 2026
09f27ce
cache_seqlens
sunjiweiswift Mar 17, 2026
121b736
head_dim =128
sunjiweiswift Mar 17, 2026
f47daf2
2026
sunjiweiswift Mar 18, 2026
61c3ddc
test for mingxu
sunjiweiswift Mar 23, 2026
6f76a31
add seqlen_k
sunjiweiswift Mar 23, 2026
25a95a3
add page 64
sunjiweiswift Mar 23, 2026
b6257f2
add
sunjiweiswift Mar 23, 2026
48bcf2e
Split kv decode (#146)
Copilot Mar 24, 2026
10d7dfe
lint and bench
sunjiweiswift Mar 24, 2026
cb20d54
lint
sunjiweiswift Mar 24, 2026
368ff64
Merge branch 'main' into split_kv_decode
sunjiweiswift Mar 24, 2026
d002cf0
fix
sunjiweiswift Mar 24, 2026
b1e152e
add HD 256
sunjiweiswift Mar 24, 2026
056dd99
code opt
sunjiweiswift Mar 31, 2026
38e2efa
lint
sunjiweiswift Mar 31, 2026
b0b8f6b
Merge branch 'main' into split_kv_decode
sunjiweiswift Apr 2, 2026
153b637
Merge branch 'main' into split_kv_decode
sunjiweiswift Apr 3, 2026
4971500
add 512
sunjiweiswift Apr 3, 2026
84f43de
Merge branch 'main' into split_kv_decode
sunjiweiswift Apr 3, 2026
bca059a
lint
sunjiweiswift Apr 3, 2026
63165db
add 256 and 512 prefill
sunjiweiswift Apr 3, 2026
1c3030b
Refactor prefill instantiation to match decode/splitdecode pattern
Copilot Apr 3, 2026
0831deb
Refactor FMHAPrefillXe20.cmake to match FMHADecodeXe20.cmake structure
sunjiweiswift Apr 7, 2026
952184d
delete template page 32 and QZ 32
sunjiweiswift Apr 7, 2026
af91f16
Merge branch 'main' into split_kv_decode
sunjiweiswift Apr 7, 2026
c392fc6
fix oom
sunjiweiswift Apr 7, 2026
9ad6f4d
lint
sunjiweiswift Apr 8, 2026
b23cc66
delete some case
sunjiweiswift Apr 8, 2026
9130e35
fix ci benchmark scripts for mla_decode
pralay-das Apr 8, 2026
9a03181
Add flash_attn benchmark support to CI pipeline and update_baseline_f…
Copilot Apr 8, 2026
81498c1
update baseline
sunjiweiswift Apr 8, 2026
b540bc4
add baseline and line
sunjiweiswift Apr 8, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/pr-test-xpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ jobs:
/miniforge3/envs/py3.10/bin/python3 -m pip install tabulate && \
cd /root/sglang/sgl-kernel-xpu/benchmark && \
python3 bench_flash_attn.py 2>&1 | tee flash.log && \
python3 bench_flash_mla_decode.py.py 2>&1 | tee mla.log && \
python3 bench_flash_mla_decode.py 2>&1 | tee mla.log && \
python3 bench_moe_topk_softmax.py 2>&1 | tee moe.log && \
python3 bench_fused_moe.py 2>&1 | tee fused_moe.log && \
python3 bench_moe_sum_reduce.py 2>&1 | tee moe_sum_reduce.log \
Expand All @@ -77,6 +77,7 @@ jobs:
timeout-minutes: 20
run: |
docker cp ci_sglang_xpu:/root/sglang/sgl-kernel-xpu/benchmark/fused_moe.log ./fused_moe.log
docker cp ci_sglang_xpu:/root/sglang/sgl-kernel-xpu/benchmark/flash.log ./flash.log
python3 benchmark/update_baseline_from_log.py

- name: Install GitHub CLI
Expand Down
744 changes: 716 additions & 28 deletions benchmark/baseline.json

Large diffs are not rendered by default.

12 changes: 7 additions & 5 deletions benchmark/bench_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,18 +59,19 @@ def flash_attn_baseline(
causal = [True, False]
local = [True, False]
use_sinks = [True, False]
batch_size = [16, 32]
batch_size = [1, 8, 16]
q_seq_length_range = [1, 128]
head_dim = [64, 128, 256, 512]
num_heads_q = [16]
num_heads_kv = [2, 4, 8]
kv_seq_length_range = [4096, 16384]
num_heads_kv = [4, 8]
kv_seq_length_range = [1024, 4096]
page_size_range = [0, 128]
configs = list(
filter(
lambda cfg: not (cfg[0] and cfg[1])
and (cfg[4] != 1 or (not cfg[0] and not cfg[1] and not cfg[2]))
and (cfg[6] % cfg[7] == 0),
and (cfg[6] % cfg[7] == 0)
and (cfg[8] >= cfg[9]),
product(
causal,
local,
Expand Down Expand Up @@ -240,8 +241,9 @@ def benchmark(

if __name__ == "__main__":
benchmark.run(print_data=False)
print("Benchmark finished!")

import pandas as pd

df = pd.DataFrame(all_results)
print(df.to_markdown())
print("Benchmark finished!")
158 changes: 132 additions & 26 deletions benchmark/update_baseline_from_log.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import json
import os
import re


def parse_benchmark_log(log_text: str) -> dict:
def parse_fused_moe_log(log_text: str) -> dict:
lines = log_text.splitlines()
start_idx = None
for i, line in enumerate(lines):
Expand All @@ -11,7 +12,7 @@ def parse_benchmark_log(log_text: str) -> dict:
break

if start_idx is None:
raise ValueError("Benchmark finished! not found")
raise ValueError("Benchmark finished! not found in fused_moe log")

result = {}

Expand All @@ -34,22 +35,72 @@ def parse_benchmark_log(log_text: str) -> dict:
shard_intermediate_size = cols[5]
ms = float(cols[-1])

key = f"fused_moe:{num_tokens}-{num_experts}-{topk}-{hidden_size}-{shard_intermediate_size}"
result[key] = ms

return result


def parse_flash_attn_log(log_text: str) -> dict:
lines = log_text.splitlines()
start_idx = None
for i, line in enumerate(lines):
if "Benchmark finished!" in line:
start_idx = i
break

if start_idx is None:
raise ValueError("Benchmark finished! not found in flash_attn log")

result = {}

for line in lines[start_idx + 1 :]:
line = line.strip()

if not line.startswith("|"):
continue
if re.match(r"\|\s*-+", line):
continue
if "batch" in line:
continue

cols = [c.strip() for c in line.strip("|").split("|")]

batch = cols[1]
q_seq_length = cols[2]
kv_seq_length = cols[3]
num_heads_q = cols[4]
num_heads_kv = cols[5]
head_dim = cols[6]
causal = cols[7]
local = cols[8]
use_sinks = cols[9]
page_size = cols[10]
ms = float(cols[-1])

key = (
f"{num_tokens}-{num_experts}-{topk}-{hidden_size}-{shard_intermediate_size}"
f"flash_attn:{batch}-{q_seq_length}-{kv_seq_length}"
f"-{num_heads_q}-{num_heads_kv}-{head_dim}"
f"-{causal}-{local}-{use_sinks}-{page_size}"
)
result[key] = ms

return result


def format_section(title, data):
def format_section(title, data, benchmark_type="fused_moe"):
if not data:
return f"### {title}\n\nNone\n"

if benchmark_type == "flash_attn":
header = "| config | log | baseline | ratio |"
else:
header = "| num_tokens - num_experts - topk - hidden_size - shard_intermediate_size | log | baseline | ratio |"

lines = [
f"### {title}",
"",
"| num_tokens - num_experts - topk - hidden_size - shard_intermediate_size | log | baseline | ratio |",
header,
"|---|---:|---:|---:|",
]
for k, (l, b) in sorted(data.items()):
Expand Down Expand Up @@ -81,18 +132,18 @@ def compare(log_data: dict, baseline: dict):
return lower, higher, equal


def main():
def process_log(log_file, parser, benchmark_type, baseline):
if not os.path.exists(log_file):
print(f"Warning: {log_file} not found, skipping {benchmark_type} benchmark")
return {}, {}, {}

with open("fused_moe.log") as f:
with open(log_file) as f:
log_text = f.read()

data = parse_benchmark_log(log_text)

with open("benchmark/baseline.json") as f:
baseline = json.load(f)

data = parser(log_text)
lower, higher, equal = compare(data, baseline)

print(f"\n=== {benchmark_type} ===")
print("=== LOWER (log < baseline) ===")
for k, (l, b) in lower.items():
ratio = l / b
Expand All @@ -114,20 +165,75 @@ def main():
print("Collected benchmark data:")
print(data)

pr_body = "\n".join(
[
"## Benchmark Comparison",
"",
"_Ratio = log / baseline (lower is better)_",
"",
format_section("LOWER (log < baseline)", lower),
format_section("HIGHER (log > baseline)", higher),
format_section("EQUAL", equal),
]
)

if lower:
for k, (l, _) in lower.items():
return lower, higher, equal


def main():
with open("benchmark/baseline.json") as f:
baseline = json.load(f)

benchmarks = [
("fused_moe.log", parse_fused_moe_log, "fused_moe"),
("flash.log", parse_flash_attn_log, "flash_attn"),
]

all_lower = {}
all_higher = {}
all_equal = {}

for log_file, parser, benchmark_type in benchmarks:
lower, higher, equal = process_log(log_file, parser, benchmark_type, baseline)
all_lower.update(lower)
all_higher.update(higher)
all_equal.update(equal)

# Separate results by type for formatting
fused_moe_lower = {
k: v for k, v in all_lower.items() if not k.startswith("flash_attn:")
}
fused_moe_higher = {
k: v for k, v in all_higher.items() if not k.startswith("flash_attn:")
}
fused_moe_equal = {
k: v for k, v in all_equal.items() if not k.startswith("flash_attn:")
}
flash_attn_lower = {
k: v for k, v in all_lower.items() if k.startswith("flash_attn:")
}
flash_attn_higher = {
k: v for k, v in all_higher.items() if k.startswith("flash_attn:")
}
flash_attn_equal = {
k: v for k, v in all_equal.items() if k.startswith("flash_attn:")
}

sections = []
if fused_moe_lower or fused_moe_higher or fused_moe_equal:
sections.append("## Fused MoE Benchmark Comparison\n")
sections.append("_Ratio = log / baseline (lower is better)_\n")
sections.append(
format_section("LOWER (log < baseline)", fused_moe_lower, "fused_moe")
)
sections.append(
format_section("HIGHER (log > baseline)", fused_moe_higher, "fused_moe")
)
sections.append(format_section("EQUAL", fused_moe_equal, "fused_moe"))

if flash_attn_lower or flash_attn_higher or flash_attn_equal:
sections.append("## Flash Attention Benchmark Comparison\n")
sections.append("_Ratio = log / baseline (lower is better)_\n")
sections.append(
format_section("LOWER (log < baseline)", flash_attn_lower, "flash_attn")
)
sections.append(
format_section("HIGHER (log > baseline)", flash_attn_higher, "flash_attn")
)
sections.append(format_section("EQUAL", flash_attn_equal, "flash_attn"))

pr_body = "\n".join(sections) if sections else "## Benchmark Comparison\n\nNo data."

if all_lower:
for k, (l, _) in all_lower.items():
baseline[k] = l
with open("benchmark/baseline.json", "w") as f:
json.dump(baseline, f, indent=4)
Expand Down
6 changes: 3 additions & 3 deletions include/sgl_flash_kernel_ops.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2025 SGLang Team. All Rights Reserved.
/* Copyright 2025-2026 SGLang Team. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -43,7 +43,7 @@ limitations under the License.
* From flash-attention
*/
std::vector<at::Tensor> mha_fwd(
at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
const at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q
const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size,
// h_k, d) if there is page_table.
const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages,
Expand All @@ -70,7 +70,7 @@ std::vector<at::Tensor> mha_fwd(
float const softcap,
bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
std::optional<at::Tensor>& scheduler_metadata_, // (b + 1)
int num_splits,
int num_kv_splits,
std::optional<bool> pack_gqa_,
int const sm_margin);

Expand Down
2 changes: 1 addition & 1 deletion python/sgl_kernel/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def flash_attn_with_kvcache(
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
1,
Comment thread
sunjiweiswift marked this conversation as resolved.
page_table,
cache_batch_idx,
cache_leftpad,
Expand Down
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ foreach(file ${device_cpp})
endforeach()

include(FMHADecodeXe20.cmake)
include(FMHAPrefillXe20.cmake)
include(MlaDecodeXe20.cmake)

message(STATUS "BMG files: ${device_cpp_xe20}")
Expand Down
14 changes: 10 additions & 4 deletions src/FMHADecodeXe20.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,28 @@
# Each (QG_SZ, HEAD_DIM, PAGE_SIZE) combination is compiled as a separate
# library to parallelize and speed up compilation.

set(FMHA_DECODE_QG_SIZES 1 2 4 8 16 32)
set(FMHA_DECODE_QG_SIZES 1 2 4 8 16)
set(FMHA_DECODE_HEAD_DIMS 64 96 128 192 256 512)
set(FMHA_DECODE_PAGE_SIZES 32 64 128)
set(FMHA_DECODE_PAGE_SIZES 64 128)

set(FMHA_DECODE_TEMPLATE
"${CMAKE_CURRENT_SOURCE_DIR}/sycl/xe_fmha_fwd_decode_kernel.cpp.in")

set(FMHA_SPLIT_DECODE_TEMPLATE
"${CMAKE_CURRENT_SOURCE_DIR}/sycl/xe_fmha_fwd_split_decode_kernel.cpp.in")

foreach(QG_SZ ${FMHA_DECODE_QG_SIZES})
foreach(HEAD_DIM ${FMHA_DECODE_HEAD_DIMS})
foreach(PAGE_SIZE ${FMHA_DECODE_PAGE_SIZES})
math(EXPR NUM_SG "${PAGE_SIZE} / 16")

set(GENERATED_FILE
"${CMAKE_CURRENT_BINARY_DIR}/sycl/xe_fmha_fwd_decode_kernel_${QG_SZ}_${HEAD_DIM}_${PAGE_SIZE}.cpp")
configure_file(${FMHA_DECODE_TEMPLATE} ${GENERATED_FILE} @ONLY)
list(APPEND device_cpp_common ${GENERATED_FILE})

set(GENERATED_SPLIT_FILE
"${CMAKE_CURRENT_BINARY_DIR}/sycl/xe_fmha_fwd_split_decode_kernel_${QG_SZ}_${HEAD_DIM}_${PAGE_SIZE}.cpp")
configure_file(${FMHA_SPLIT_DECODE_TEMPLATE} ${GENERATED_SPLIT_FILE} @ONLY)
list(APPEND device_cpp_common ${GENERATED_SPLIT_FILE})
endforeach()
endforeach()
endforeach()
44 changes: 44 additions & 0 deletions src/FMHAPrefillXe20.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Generate FMHA prefill kernel instantiation files.
# Each HEAD_DIM is compiled as a separate translation unit to parallelize
# and speed up compilation.

set(FMHA_PREFILL_HEAD_DIMS 64 96 128 192 256 512)

set(FMHA_PREFILL_TEMPLATE
"${CMAKE_CURRENT_SOURCE_DIR}/sycl/xe_fmha_fwd_prefill_kernel.cpp.in")

# Per-HEAD_DIM tile shape parameters (TILED_Q, TILED_KV, NUM_SG)
set(FMHA_PREFILL_TILED_Q_64 128)
set(FMHA_PREFILL_TILED_KV_64 64)
set(FMHA_PREFILL_NUM_SG_64 8)

set(FMHA_PREFILL_TILED_Q_96 128)
set(FMHA_PREFILL_TILED_KV_96 64)
set(FMHA_PREFILL_NUM_SG_96 8)

set(FMHA_PREFILL_TILED_Q_128 256)
set(FMHA_PREFILL_TILED_KV_128 32)
set(FMHA_PREFILL_NUM_SG_128 16)

set(FMHA_PREFILL_TILED_Q_192 256)
set(FMHA_PREFILL_TILED_KV_192 64)
set(FMHA_PREFILL_NUM_SG_192 32)

set(FMHA_PREFILL_TILED_Q_256 256)
set(FMHA_PREFILL_TILED_KV_256 64)
set(FMHA_PREFILL_NUM_SG_256 32)

set(FMHA_PREFILL_TILED_Q_512 256)
set(FMHA_PREFILL_TILED_KV_512 64)
set(FMHA_PREFILL_NUM_SG_512 32)

foreach(HEAD_DIM ${FMHA_PREFILL_HEAD_DIMS})
set(TILED_Q ${FMHA_PREFILL_TILED_Q_${HEAD_DIM}})
set(TILED_KV ${FMHA_PREFILL_TILED_KV_${HEAD_DIM}})
set(NUM_SG ${FMHA_PREFILL_NUM_SG_${HEAD_DIM}})

set(GENERATED_FILE
"${CMAKE_CURRENT_BINARY_DIR}/sycl/xe_fmha_fwd_prefill_kernel_${HEAD_DIM}.cpp")
configure_file(${FMHA_PREFILL_TEMPLATE} ${GENERATED_FILE} @ONLY)
list(APPEND device_cpp_common ${GENERATED_FILE})
endforeach()
Loading
Loading