Skip to content

Commit 950da3d

Browse files
Refactor FMHAPrefillXe20.cmake to match FMHADecodeXe20.cmake structure
Agent-Logs-Url: https://github.com/sgl-project/sgl-kernel-xpu/sessions/483ec9d0-5189-453e-8b80-7bafa0ec9939 Co-authored-by: sunjiweiswift <16934286+sunjiweiswift@users.noreply.github.com>
1 parent 1c3030b commit 950da3d

1 file changed

Lines changed: 31 additions & 24 deletions

File tree

src/FMHAPrefillXe20.cmake

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,41 @@
11
# Generate FMHA prefill kernel instantiation files.
22
# Each HEAD_DIM is compiled as a separate translation unit to parallelize
33
# and speed up compilation.
4-
#
5-
# Tile shape mapping (HEAD_DIM -> TILED_Q, TILED_KV, NUM_SG):
6-
# 64 -> 128, 64, 8
7-
# 96 -> 128, 64, 8
8-
# 128 -> 256, 32, 16
9-
# 192 -> 256, 64, 32
10-
# 256 -> 256, 64, 32
11-
# 512 -> 256, 64, 32
4+
5+
set(FMHA_PREFILL_HEAD_DIMS 64 96 128 192 256 512)
126

137
set(FMHA_PREFILL_TEMPLATE
148
"${CMAKE_CURRENT_SOURCE_DIR}/sycl/xe_fmha_fwd_prefill_kernel.cpp.in")
159

16-
# Define the per-HEAD_DIM tile configurations
17-
# Format: HEAD_DIM;TILED_Q;TILED_KV;NUM_SG
18-
set(FMHA_PREFILL_CONFIGS
19-
"64;128;64;8"
20-
"96;128;64;8"
21-
"128;256;32;16"
22-
"192;256;64;32"
23-
"256;256;64;32"
24-
"512;256;64;32"
25-
)
26-
27-
foreach(CONFIG ${FMHA_PREFILL_CONFIGS})
28-
list(GET CONFIG 0 HEAD_DIM)
29-
list(GET CONFIG 1 TILED_Q)
30-
list(GET CONFIG 2 TILED_KV)
31-
list(GET CONFIG 3 NUM_SG)
10+
# Per-HEAD_DIM tile shape parameters (TILED_Q, TILED_KV, NUM_SG)
11+
set(FMHA_PREFILL_TILED_Q_64 128)
12+
set(FMHA_PREFILL_TILED_KV_64 64)
13+
set(FMHA_PREFILL_NUM_SG_64 8)
14+
15+
set(FMHA_PREFILL_TILED_Q_96 128)
16+
set(FMHA_PREFILL_TILED_KV_96 64)
17+
set(FMHA_PREFILL_NUM_SG_96 8)
18+
19+
set(FMHA_PREFILL_TILED_Q_128 256)
20+
set(FMHA_PREFILL_TILED_KV_128 32)
21+
set(FMHA_PREFILL_NUM_SG_128 16)
22+
23+
set(FMHA_PREFILL_TILED_Q_192 256)
24+
set(FMHA_PREFILL_TILED_KV_192 64)
25+
set(FMHA_PREFILL_NUM_SG_192 32)
26+
27+
set(FMHA_PREFILL_TILED_Q_256 256)
28+
set(FMHA_PREFILL_TILED_KV_256 64)
29+
set(FMHA_PREFILL_NUM_SG_256 32)
30+
31+
set(FMHA_PREFILL_TILED_Q_512 256)
32+
set(FMHA_PREFILL_TILED_KV_512 64)
33+
set(FMHA_PREFILL_NUM_SG_512 32)
34+
35+
foreach(HEAD_DIM ${FMHA_PREFILL_HEAD_DIMS})
36+
set(TILED_Q ${FMHA_PREFILL_TILED_Q_${HEAD_DIM}})
37+
set(TILED_KV ${FMHA_PREFILL_TILED_KV_${HEAD_DIM}})
38+
set(NUM_SG ${FMHA_PREFILL_NUM_SG_${HEAD_DIM}})
3239

3340
set(GENERATED_FILE
3441
"${CMAKE_CURRENT_BINARY_DIR}/sycl/xe_fmha_fwd_prefill_kernel_${HEAD_DIM}.cpp")

0 commit comments

Comments
 (0)