|
1 | 1 | # Generate FMHA prefill kernel instantiation files. |
2 | 2 | # Each HEAD_DIM is compiled as a separate translation unit to parallelize |
3 | 3 | # and speed up compilation. |
4 | | -# |
5 | | -# Tile shape mapping (HEAD_DIM -> TILED_Q, TILED_KV, NUM_SG): |
6 | | -# 64 -> 128, 64, 8 |
7 | | -# 96 -> 128, 64, 8 |
8 | | -# 128 -> 256, 32, 16 |
9 | | -# 192 -> 256, 64, 32 |
10 | | -# 256 -> 256, 64, 32 |
11 | | -# 512 -> 256, 64, 32 |
| 4 | + |
| 5 | +set(FMHA_PREFILL_HEAD_DIMS 64 96 128 192 256 512) |
12 | 6 |
|
13 | 7 | set(FMHA_PREFILL_TEMPLATE |
14 | 8 | "${CMAKE_CURRENT_SOURCE_DIR}/sycl/xe_fmha_fwd_prefill_kernel.cpp.in") |
15 | 9 |
|
16 | | -# Define the per-HEAD_DIM tile configurations |
17 | | -# Format: HEAD_DIM;TILED_Q;TILED_KV;NUM_SG |
18 | | -set(FMHA_PREFILL_CONFIGS |
19 | | - "64;128;64;8" |
20 | | - "96;128;64;8" |
21 | | - "128;256;32;16" |
22 | | - "192;256;64;32" |
23 | | - "256;256;64;32" |
24 | | - "512;256;64;32" |
25 | | -) |
26 | | - |
27 | | -foreach(CONFIG ${FMHA_PREFILL_CONFIGS}) |
28 | | - list(GET CONFIG 0 HEAD_DIM) |
29 | | - list(GET CONFIG 1 TILED_Q) |
30 | | - list(GET CONFIG 2 TILED_KV) |
31 | | - list(GET CONFIG 3 NUM_SG) |
| 10 | +# Per-HEAD_DIM tile shape parameters (TILED_Q, TILED_KV, NUM_SG) |
| 11 | +set(FMHA_PREFILL_TILED_Q_64 128) |
| 12 | +set(FMHA_PREFILL_TILED_KV_64 64) |
| 13 | +set(FMHA_PREFILL_NUM_SG_64 8) |
| 14 | + |
| 15 | +set(FMHA_PREFILL_TILED_Q_96 128) |
| 16 | +set(FMHA_PREFILL_TILED_KV_96 64) |
| 17 | +set(FMHA_PREFILL_NUM_SG_96 8) |
| 18 | + |
| 19 | +set(FMHA_PREFILL_TILED_Q_128 256) |
| 20 | +set(FMHA_PREFILL_TILED_KV_128 32) |
| 21 | +set(FMHA_PREFILL_NUM_SG_128 16) |
| 22 | + |
| 23 | +set(FMHA_PREFILL_TILED_Q_192 256) |
| 24 | +set(FMHA_PREFILL_TILED_KV_192 64) |
| 25 | +set(FMHA_PREFILL_NUM_SG_192 32) |
| 26 | + |
| 27 | +set(FMHA_PREFILL_TILED_Q_256 256) |
| 28 | +set(FMHA_PREFILL_TILED_KV_256 64) |
| 29 | +set(FMHA_PREFILL_NUM_SG_256 32) |
| 30 | + |
| 31 | +set(FMHA_PREFILL_TILED_Q_512 256) |
| 32 | +set(FMHA_PREFILL_TILED_KV_512 64) |
| 33 | +set(FMHA_PREFILL_NUM_SG_512 32) |
| 34 | + |
| 35 | +foreach(HEAD_DIM ${FMHA_PREFILL_HEAD_DIMS}) |
| 36 | + set(TILED_Q ${FMHA_PREFILL_TILED_Q_${HEAD_DIM}}) |
| 37 | + set(TILED_KV ${FMHA_PREFILL_TILED_KV_${HEAD_DIM}}) |
| 38 | + set(NUM_SG ${FMHA_PREFILL_NUM_SG_${HEAD_DIM}}) |
32 | 39 |
|
33 | 40 | set(GENERATED_FILE |
34 | 41 | "${CMAKE_CURRENT_BINARY_DIR}/sycl/xe_fmha_fwd_prefill_kernel_${HEAD_DIM}.cpp") |
|
0 commit comments