diff --git a/.github/workflows/pr-test-xpu.yml b/.github/workflows/pr-test-xpu.yml index db3e6fdc..861f856c 100644 --- a/.github/workflows/pr-test-xpu.yml +++ b/.github/workflows/pr-test-xpu.yml @@ -60,7 +60,7 @@ jobs: /miniforge3/envs/py3.10/bin/python3 -m pip install tabulate && \ cd /root/sglang/sgl-kernel-xpu/benchmark && \ python3 bench_flash_attn.py 2>&1 | tee flash.log && \ - python3 bench_flash_mla_decode.py.py 2>&1 | tee mla.log && \ + python3 bench_flash_mla_decode.py 2>&1 | tee mla.log && \ python3 bench_moe_topk_softmax.py 2>&1 | tee moe.log && \ python3 bench_fused_moe.py 2>&1 | tee fused_moe.log && \ python3 bench_moe_sum_reduce.py 2>&1 | tee moe_sum_reduce.log \ @@ -77,6 +77,7 @@ jobs: timeout-minutes: 20 run: | docker cp ci_sglang_xpu:/root/sglang/sgl-kernel-xpu/benchmark/fused_moe.log ./fused_moe.log + docker cp ci_sglang_xpu:/root/sglang/sgl-kernel-xpu/benchmark/flash.log ./flash.log python3 benchmark/update_baseline_from_log.py - name: Install GitHub CLI diff --git a/benchmark/baseline.json b/benchmark/baseline.json index 12e77825..41b18e30 100644 --- a/benchmark/baseline.json +++ b/benchmark/baseline.json @@ -1,30 +1,718 @@ { - "1-64-8-3584-1280": 0.35152, - "1-64-8-3584-640": 0.23036, - "1-8-2-4096-7168": 0.480792, - "1-8-2-4096-3584": 0.282464, - "128-64-8-3584-1280": 2.29752, - "128-64-8-3584-640": 1.22236, - "128-8-2-4096-7168": 2.06066, - "128-8-2-4096-3584": 1.11192, - "512-64-8-3584-1280": 2.56888, - "512-64-8-3584-640": 1.53722, - "512-8-2-4096-7168": 2.076, - "512-8-2-4096-3584": 1.224, - "1024-64-8-3584-1280": 3.13924, - "1024-64-8-3584-640": 2.04922, - "1024-8-2-4096-7168": 2.70195, - "1024-8-2-4096-3584": 1.59234, - "2048-64-8-3584-1280": 4.72836, - "2048-64-8-3584-640": 3.16337, - "2048-8-2-4096-7168": 5.62411, - "2048-8-2-4096-3584": 2.78034, - "4096-64-8-3584-1280": 8.42426, - "4096-64-8-3584-640": 5.51372, - "4096-8-2-4096-7168": 9.76456, - "4096-8-2-4096-3584": 5.2883, - "8192-64-8-3584-1280": 15.1513, - "8192-64-8-3584-640": 10.3734, - "8192-8-2-4096-7168": 18.5414, - "8192-8-2-4096-3584": 9.43483 + "fused_moe:1-64-8-3584-1280": 0.325364, + "fused_moe:1-64-8-3584-640": 0.193336, + "fused_moe:1-8-2-4096-7168": 0.46358, + "fused_moe:1-8-2-4096-3584": 0.254384, + "fused_moe:128-64-8-3584-1280": 2.38386, + "fused_moe:128-64-8-3584-640": 1.24264, + "fused_moe:128-8-2-4096-7168": 2.03247, + "fused_moe:128-8-2-4096-3584": 1.11264, + "fused_moe:512-64-8-3584-1280": 2.65361, + "fused_moe:512-64-8-3584-640": 1.57986, + "fused_moe:512-8-2-4096-7168": 2.08676, + "fused_moe:512-8-2-4096-3584": 1.28131, + "fused_moe:1024-64-8-3584-1280": 3.18755, + "fused_moe:1024-64-8-3584-640": 2.06929, + "fused_moe:1024-8-2-4096-7168": 2.73554, + "fused_moe:1024-8-2-4096-3584": 1.58153, + "fused_moe:2048-64-8-3584-1280": 4.74167, + "fused_moe:2048-64-8-3584-640": 3.137, + "fused_moe:2048-8-2-4096-7168": 5.73903, + "fused_moe:2048-8-2-4096-3584": 2.79734, + "fused_moe:4096-64-8-3584-1280": 8.46916, + "fused_moe:4096-64-8-3584-640": 5.45971, + "fused_moe:4096-8-2-4096-7168": 9.96138, + "fused_moe:4096-8-2-4096-3584": 5.38288, + "fused_moe:8192-64-8-3584-1280": 15.1295, + "fused_moe:8192-64-8-3584-640": 10.183, + "fused_moe:8192-8-2-4096-7168": 18.9691, + "fused_moe:8192-8-2-4096-3584": 9.577, + "fused_moe:8-64-8-3584-1280": 0.0, + "fused_moe:16-64-8-3584-1280": 0.0, + "fused_moe:1-64-6-1280-1792": 0.146484, + "fused_moe:1-32-4-2880-2880": 0.289172, + "fused_moe:128-64-6-1280-1792": 1.11264, + "fused_moe:128-32-4-2880-2880": 2.16336, + "fused_moe:512-64-6-1280-1792": 1.2675, + "fused_moe:512-32-4-2880-2880": 2.24455, + "fused_moe:1024-64-6-1280-1792": 1.44979, + "fused_moe:1024-32-4-2880-2880": 2.67636, + "fused_moe:2048-64-6-1280-1792": 1.8596, + "fused_moe:2048-32-4-2880-2880": 3.96713, + "fused_moe:4096-64-6-1280-1792": 3.25395, + "fused_moe:4096-32-4-2880-2880": 6.87757, + "fused_moe:8192-64-6-1280-1792": 5.86966, + "fused_moe:8192-32-4-2880-2880": 12.4467, + "flash_attn:1-128-1024-16-4-64-True-False-True-0": 0.058968, + "flash_attn:1-128-1024-16-4-64-True-False-True-128": 0.062868, + "flash_attn:1-128-4096-16-4-64-True-False-True-0": 0.211848, + "flash_attn:1-128-4096-16-4-64-True-False-True-128": 0.228072, + "flash_attn:1-128-1024-16-8-64-True-False-True-0": 0.059072, + "flash_attn:1-128-1024-16-8-64-True-False-True-128": 0.06266, + "flash_attn:1-128-4096-16-8-64-True-False-True-0": 0.211848, + "flash_attn:1-128-4096-16-8-64-True-False-True-128": 0.226668, + "flash_attn:1-128-1024-16-4-128-True-False-True-0": 0.049348, + "flash_attn:1-128-1024-16-4-128-True-False-True-128": 0.053196, + "flash_attn:1-128-4096-16-4-128-True-False-True-0": 0.171028, + "flash_attn:1-128-4096-16-4-128-True-False-True-128": 0.185432, + "flash_attn:1-128-1024-16-8-128-True-False-True-0": 0.050076, + "flash_attn:1-128-1024-16-8-128-True-False-True-128": 0.0533, + "flash_attn:1-128-4096-16-8-128-True-False-True-0": 0.174252, + "flash_attn:1-128-4096-16-8-128-True-False-True-128": 0.187252, + "flash_attn:1-128-1024-16-4-256-True-False-True-0": 0.110292, + "flash_attn:1-128-1024-16-4-256-True-False-True-128": 0.110214, + "flash_attn:1-128-4096-16-4-256-True-False-True-0": 0.409084, + "flash_attn:1-128-4096-16-4-256-True-False-True-128": 0.407628, + "flash_attn:1-128-1024-16-8-256-True-False-True-0": 0.112424, + "flash_attn:1-128-1024-16-8-256-True-False-True-128": 0.109044, + "flash_attn:1-128-4096-16-8-256-True-False-True-0": 0.413452, + "flash_attn:1-128-4096-16-8-256-True-False-True-128": 0.402246, + "flash_attn:1-128-1024-16-4-512-True-False-True-0": 2.22115, + "flash_attn:1-128-1024-16-4-512-True-False-True-128": 2.22583, + "flash_attn:1-128-4096-16-4-512-True-False-True-0": 9.1188, + "flash_attn:1-128-4096-16-4-512-True-False-True-128": 9.1014, + "flash_attn:1-128-1024-16-8-512-True-False-True-0": 2.23813, + "flash_attn:1-128-1024-16-8-512-True-False-True-128": 2.25009, + "flash_attn:1-128-4096-16-8-512-True-False-True-0": 9.13793, + "flash_attn:1-128-4096-16-8-512-True-False-True-128": 9.23302, + "flash_attn:8-128-1024-16-4-64-True-False-True-0": 0.130624, + "flash_attn:8-128-1024-16-4-64-True-False-True-128": 0.138216, + "flash_attn:8-128-4096-16-4-64-True-False-True-0": 0.48308, + "flash_attn:8-128-4096-16-4-64-True-False-True-128": 0.512928, + "flash_attn:8-128-1024-16-8-64-True-False-True-0": 0.132756, + "flash_attn:8-128-1024-16-8-64-True-False-True-128": 0.140036, + "flash_attn:8-128-4096-16-8-64-True-False-True-0": 0.486434, + "flash_attn:8-128-4096-16-8-64-True-False-True-128": 0.517712, + "flash_attn:8-128-1024-16-4-128-True-False-True-0": 0.21645, + "flash_attn:8-128-1024-16-4-128-True-False-True-128": 0.232232, + "flash_attn:8-128-4096-16-4-128-True-False-True-0": 0.830336, + "flash_attn:8-128-4096-16-4-128-True-False-True-128": 0.876928, + "flash_attn:8-128-1024-16-8-128-True-False-True-0": 0.22698, + "flash_attn:8-128-1024-16-8-128-True-False-True-128": 0.240058, + "flash_attn:8-128-4096-16-8-128-True-False-True-0": 0.849264, + "flash_attn:8-128-4096-16-8-128-True-False-True-128": 0.904878, + "flash_attn:8-128-1024-16-4-256-True-False-True-0": 0.732212, + "flash_attn:8-128-1024-16-4-256-True-False-True-128": 0.730548, + "flash_attn:8-128-4096-16-4-256-True-False-True-0": 2.80615, + "flash_attn:8-128-4096-16-4-256-True-False-True-128": 2.80574, + "flash_attn:8-128-1024-16-8-256-True-False-True-0": 0.73606, + "flash_attn:8-128-1024-16-8-256-True-False-True-128": 0.730184, + "flash_attn:8-128-4096-16-8-256-True-False-True-0": 2.82831, + "flash_attn:8-128-4096-16-8-256-True-False-True-128": 2.81094, + "flash_attn:8-128-1024-16-4-512-True-False-True-0": 17.0003, + "flash_attn:8-128-1024-16-4-512-True-False-True-128": 17.0292, + "flash_attn:8-128-4096-16-4-512-True-False-True-0": 70.581, + "flash_attn:8-128-4096-16-4-512-True-False-True-128": 71.2612, + "flash_attn:8-128-1024-16-8-512-True-False-True-0": 17.0133, + "flash_attn:8-128-1024-16-8-512-True-False-True-128": 17.2585, + "flash_attn:8-128-4096-16-8-512-True-False-True-0": 71.3855, + "flash_attn:8-128-4096-16-8-512-True-False-True-128": 71.4294, + "flash_attn:16-128-1024-16-4-64-True-False-True-0": 0.2509, + "flash_attn:16-128-1024-16-4-64-True-False-True-128": 0.264888, + "flash_attn:16-128-4096-16-4-64-True-False-True-0": 0.91767, + "flash_attn:16-128-4096-16-4-64-True-False-True-128": 0.98488, + "flash_attn:16-128-1024-16-8-64-True-False-True-0": 0.256932, + "flash_attn:16-128-1024-16-8-64-True-False-True-128": 0.272948, + "flash_attn:16-128-4096-16-8-64-True-False-True-0": 0.93587, + "flash_attn:16-128-4096-16-8-64-True-False-True-128": 1.00812, + "flash_attn:16-128-1024-16-4-128-True-False-True-0": 0.378118, + "flash_attn:16-128-1024-16-4-128-True-False-True-128": 0.403364, + "flash_attn:16-128-4096-16-4-128-True-False-True-0": 1.46411, + "flash_attn:16-128-4096-16-4-128-True-False-True-128": 1.55717, + "flash_attn:16-128-1024-16-8-128-True-False-True-0": 0.39793, + "flash_attn:16-128-1024-16-8-128-True-False-True-128": 0.417924, + "flash_attn:16-128-4096-16-8-128-True-False-True-0": 1.51528, + "flash_attn:16-128-4096-16-8-128-True-False-True-128": 1.61889, + "flash_attn:16-128-1024-16-4-256-True-False-True-0": 1.35439, + "flash_attn:16-128-1024-16-4-256-True-False-True-128": 1.35535, + "flash_attn:16-128-4096-16-4-256-True-False-True-0": 5.27277, + "flash_attn:16-128-4096-16-4-256-True-False-True-128": 5.24641, + "flash_attn:16-128-1024-16-8-256-True-False-True-0": 1.36726, + "flash_attn:16-128-1024-16-8-256-True-False-True-128": 1.34862, + "flash_attn:16-128-4096-16-8-256-True-False-True-0": 5.35813, + "flash_attn:16-128-4096-16-8-256-True-False-True-128": 5.26518, + "flash_attn:16-128-1024-16-4-512-True-False-True-0": 33.1104, + "flash_attn:16-128-1024-16-4-512-True-False-True-128": 33.3107, + "flash_attn:16-128-4096-16-4-512-True-False-True-0": 138.487, + "flash_attn:16-128-4096-16-4-512-True-False-True-128": 138.583, + "flash_attn:16-128-1024-16-8-512-True-False-True-0": 33.3434, + "flash_attn:16-128-1024-16-8-512-True-False-True-128": 33.5314, + "flash_attn:16-128-4096-16-8-512-True-False-True-0": 141.029, + "flash_attn:16-128-4096-16-8-512-True-False-True-128": 139.389, + "flash_attn:1-128-1024-16-4-64-True-False-False-0": 0.058448, + "flash_attn:1-128-1024-16-4-64-True-False-False-128": 0.061724, + "flash_attn:1-128-4096-16-4-64-True-False-False-0": 0.210392, + "flash_attn:1-128-4096-16-4-64-True-False-False-128": 0.224692, + "flash_attn:1-128-1024-16-8-64-True-False-False-0": 0.058396, + "flash_attn:1-128-1024-16-8-64-True-False-False-128": 0.061932, + "flash_attn:1-128-4096-16-8-64-True-False-False-0": 0.209144, + "flash_attn:1-128-4096-16-8-64-True-False-False-128": 0.224042, + "flash_attn:1-128-1024-16-4-128-True-False-False-0": 0.04836, + "flash_attn:1-128-1024-16-4-128-True-False-False-128": 0.052546, + "flash_attn:1-128-4096-16-4-128-True-False-False-0": 0.169208, + "flash_attn:1-128-4096-16-4-128-True-False-False-128": 0.184236, + "flash_attn:1-128-1024-16-8-128-True-False-False-0": 0.04914, + "flash_attn:1-128-1024-16-8-128-True-False-False-128": 0.053092, + "flash_attn:1-128-4096-16-8-128-True-False-False-0": 0.172094, + "flash_attn:1-128-4096-16-8-128-True-False-False-128": 0.186524, + "flash_attn:1-128-1024-16-4-256-True-False-False-0": 0.106964, + "flash_attn:1-128-1024-16-4-256-True-False-False-128": 0.109876, + "flash_attn:1-128-4096-16-4-256-True-False-False-0": 0.396812, + "flash_attn:1-128-4096-16-4-256-True-False-False-128": 0.407004, + "flash_attn:1-128-1024-16-8-256-True-False-False-0": 0.108004, + "flash_attn:1-128-1024-16-8-256-True-False-False-128": 0.108576, + "flash_attn:1-128-4096-16-8-256-True-False-False-0": 0.395928, + "flash_attn:1-128-4096-16-8-256-True-False-False-128": 0.40209, + "flash_attn:1-128-1024-16-4-512-True-False-False-0": 2.2327, + "flash_attn:1-128-1024-16-4-512-True-False-False-128": 2.23917, + "flash_attn:1-128-4096-16-4-512-True-False-False-0": 9.09579, + "flash_attn:1-128-4096-16-4-512-True-False-False-128": 9.14326, + "flash_attn:1-128-1024-16-8-512-True-False-False-0": 2.24224, + "flash_attn:1-128-1024-16-8-512-True-False-False-128": 2.25295, + "flash_attn:1-128-4096-16-8-512-True-False-False-0": 9.12072, + "flash_attn:1-128-4096-16-8-512-True-False-False-128": 9.23268, + "flash_attn:8-128-1024-16-4-64-True-False-False-0": 0.129506, + "flash_attn:8-128-1024-16-4-64-True-False-False-128": 0.135928, + "flash_attn:8-128-4096-16-4-64-True-False-False-0": 0.478816, + "flash_attn:8-128-4096-16-4-64-True-False-False-128": 0.508014, + "flash_attn:8-128-1024-16-8-64-True-False-False-0": 0.131508, + "flash_attn:8-128-1024-16-8-64-True-False-False-128": 0.137748, + "flash_attn:8-128-4096-16-8-64-True-False-False-0": 0.483132, + "flash_attn:8-128-4096-16-8-64-True-False-False-128": 0.513734, + "flash_attn:8-128-1024-16-4-128-True-False-False-0": 0.216268, + "flash_attn:8-128-1024-16-4-128-True-False-False-128": 0.232388, + "flash_attn:8-128-4096-16-4-128-True-False-False-0": 0.807248, + "flash_attn:8-128-4096-16-4-128-True-False-False-128": 0.872092, + "flash_attn:8-128-1024-16-8-128-True-False-False-0": 0.22698, + "flash_attn:8-128-1024-16-8-128-True-False-False-128": 0.241228, + "flash_attn:8-128-4096-16-8-128-True-False-False-0": 0.841984, + "flash_attn:8-128-4096-16-8-128-True-False-False-128": 0.903292, + "flash_attn:8-128-1024-16-4-256-True-False-False-0": 0.713908, + "flash_attn:8-128-1024-16-4-256-True-False-False-128": 0.730002, + "flash_attn:8-128-4096-16-4-256-True-False-False-0": 2.71913, + "flash_attn:8-128-4096-16-4-256-True-False-False-128": 2.80067, + "flash_attn:8-128-1024-16-8-256-True-False-False-0": 0.710866, + "flash_attn:8-128-1024-16-8-256-True-False-False-128": 0.727376, + "flash_attn:8-128-4096-16-8-256-True-False-False-0": 2.72964, + "flash_attn:8-128-4096-16-8-256-True-False-False-128": 2.7963, + "flash_attn:8-128-1024-16-4-512-True-False-False-0": 16.9799, + "flash_attn:8-128-1024-16-4-512-True-False-False-128": 17.2656, + "flash_attn:8-128-4096-16-4-512-True-False-False-0": 71.4872, + "flash_attn:8-128-4096-16-4-512-True-False-False-128": 71.8024, + "flash_attn:8-128-1024-16-8-512-True-False-False-0": 17.175, + "flash_attn:8-128-1024-16-8-512-True-False-False-128": 17.3152, + "flash_attn:8-128-4096-16-8-512-True-False-False-0": 71.2, + "flash_attn:8-128-4096-16-8-512-True-False-False-128": 71.8972, + "flash_attn:16-128-1024-16-4-64-True-False-False-0": 0.24518, + "flash_attn:16-128-1024-16-4-64-True-False-False-128": 0.2626, + "flash_attn:16-128-4096-16-4-64-True-False-False-0": 0.908336, + "flash_attn:16-128-4096-16-4-64-True-False-False-128": 0.976508, + "flash_attn:16-128-1024-16-8-64-True-False-False-0": 0.251316, + "flash_attn:16-128-1024-16-8-64-True-False-False-128": 0.269282, + "flash_attn:16-128-4096-16-8-64-True-False-False-0": 0.931268, + "flash_attn:16-128-4096-16-8-64-True-False-False-128": 0.99775, + "flash_attn:16-128-1024-16-4-128-True-False-False-0": 0.378664, + "flash_attn:16-128-1024-16-4-128-True-False-False-128": 0.403416, + "flash_attn:16-128-4096-16-4-128-True-False-False-0": 1.4572, + "flash_attn:16-128-4096-16-4-128-True-False-False-128": 1.55938, + "flash_attn:16-128-1024-16-8-128-True-False-False-0": 0.398632, + "flash_attn:16-128-1024-16-8-128-True-False-False-128": 0.419016, + "flash_attn:16-128-4096-16-8-128-True-False-False-0": 1.5159, + "flash_attn:16-128-4096-16-8-128-True-False-False-128": 1.6185, + "flash_attn:16-128-1024-16-4-256-True-False-False-0": 1.32101, + "flash_attn:16-128-1024-16-4-256-True-False-False-128": 1.35143, + "flash_attn:16-128-4096-16-4-256-True-False-False-0": 5.11085, + "flash_attn:16-128-4096-16-4-256-True-False-False-128": 5.2469, + "flash_attn:16-128-1024-16-8-256-True-False-False-0": 1.31851, + "flash_attn:16-128-1024-16-8-256-True-False-False-128": 1.34217, + "flash_attn:16-128-4096-16-8-256-True-False-False-0": 5.14748, + "flash_attn:16-128-4096-16-8-256-True-False-False-128": 5.27839, + "flash_attn:16-128-1024-16-4-512-True-False-False-0": 33.1918, + "flash_attn:16-128-1024-16-4-512-True-False-False-128": 33.3258, + "flash_attn:16-128-4096-16-4-512-True-False-False-0": 139.11, + "flash_attn:16-128-4096-16-4-512-True-False-False-128": 139.44, + "flash_attn:16-128-1024-16-8-512-True-False-False-0": 33.4072, + "flash_attn:16-128-1024-16-8-512-True-False-False-128": 33.9094, + "flash_attn:16-128-4096-16-8-512-True-False-False-0": 140.38, + "flash_attn:16-128-4096-16-8-512-True-False-False-128": 139.696, + "flash_attn:1-128-1024-16-4-64-False-True-True-0": 0.073164, + "flash_attn:1-128-1024-16-4-64-False-True-True-128": 0.076804, + "flash_attn:1-128-4096-16-4-64-False-True-True-0": 0.268164, + "flash_attn:1-128-4096-16-4-64-False-True-True-128": 0.286312, + "flash_attn:1-128-1024-16-8-64-False-True-True-0": 0.073216, + "flash_attn:1-128-1024-16-8-64-False-True-True-128": 0.077584, + "flash_attn:1-128-4096-16-8-64-False-True-True-0": 0.266864, + "flash_attn:1-128-4096-16-8-64-False-True-True-128": 0.286156, + "flash_attn:1-128-1024-16-4-128-False-True-True-0": 0.055016, + "flash_attn:1-128-1024-16-4-128-False-True-True-128": 0.058448, + "flash_attn:1-128-4096-16-4-128-False-True-True-0": 0.1976, + "flash_attn:1-128-4096-16-4-128-False-True-True-128": 0.206336, + "flash_attn:1-128-1024-16-8-128-False-True-True-0": 0.057148, + "flash_attn:1-128-1024-16-8-128-False-True-True-128": 0.060268, + "flash_attn:1-128-4096-16-8-128-False-True-True-0": 0.201864, + "flash_attn:1-128-4096-16-8-128-False-True-True-128": 0.21476, + "flash_attn:1-128-1024-16-4-256-False-True-True-0": 0.112008, + "flash_attn:1-128-1024-16-4-256-False-True-True-128": 0.112944, + "flash_attn:1-128-4096-16-4-256-False-True-True-0": 0.416052, + "flash_attn:1-128-4096-16-4-256-False-True-True-128": 0.422396, + "flash_attn:1-128-1024-16-8-256-False-True-True-0": 0.111384, + "flash_attn:1-128-1024-16-8-256-False-True-True-128": 0.111904, + "flash_attn:1-128-4096-16-8-256-False-True-True-0": 0.410124, + "flash_attn:1-128-4096-16-8-256-False-True-True-128": 0.413816, + "flash_attn:1-128-1024-16-4-512-False-True-True-0": 2.17698, + "flash_attn:1-128-1024-16-4-512-False-True-True-128": 2.18059, + "flash_attn:1-128-4096-16-4-512-False-True-True-0": 8.55988, + "flash_attn:1-128-4096-16-4-512-False-True-True-128": 8.563, + "flash_attn:1-128-1024-16-8-512-False-True-True-0": 2.20077, + "flash_attn:1-128-1024-16-8-512-False-True-True-128": 2.21244, + "flash_attn:1-128-4096-16-8-512-False-True-True-0": 8.67906, + "flash_attn:1-128-4096-16-8-512-False-True-True-128": 8.64942, + "flash_attn:8-128-1024-16-4-64-False-True-True-0": 0.157716, + "flash_attn:8-128-1024-16-4-64-False-True-True-128": 0.167076, + "flash_attn:8-128-4096-16-4-64-False-True-True-0": 0.585052, + "flash_attn:8-128-4096-16-4-64-False-True-True-128": 0.612898, + "flash_attn:8-128-1024-16-8-64-False-True-True-0": 0.16276, + "flash_attn:8-128-1024-16-8-64-False-True-True-128": 0.169156, + "flash_attn:8-128-4096-16-8-64-False-True-True-0": 0.5941, + "flash_attn:8-128-4096-16-8-64-False-True-True-128": 0.622284, + "flash_attn:8-128-1024-16-4-128-False-True-True-0": 0.242138, + "flash_attn:8-128-1024-16-4-128-False-True-True-128": 0.252642, + "flash_attn:8-128-4096-16-4-128-False-True-True-0": 0.882102, + "flash_attn:8-128-4096-16-4-128-False-True-True-128": 0.92144, + "flash_attn:8-128-1024-16-8-128-False-True-True-0": 0.248404, + "flash_attn:8-128-1024-16-8-128-False-True-True-128": 0.262288, + "flash_attn:8-128-4096-16-8-128-False-True-True-0": 0.922376, + "flash_attn:8-128-4096-16-8-128-False-True-True-128": 0.973232, + "flash_attn:8-128-1024-16-4-256-False-True-True-0": 0.749944, + "flash_attn:8-128-1024-16-4-256-False-True-True-128": 0.753428, + "flash_attn:8-128-4096-16-4-256-False-True-True-0": 2.86793, + "flash_attn:8-128-4096-16-4-256-False-True-True-128": 2.89219, + "flash_attn:8-128-1024-16-8-256-False-True-True-0": 0.744796, + "flash_attn:8-128-1024-16-8-256-False-True-True-128": 0.746668, + "flash_attn:8-128-4096-16-8-256-False-True-True-0": 2.84011, + "flash_attn:8-128-4096-16-8-256-False-True-True-128": 2.85332, + "flash_attn:8-128-1024-16-4-512-False-True-True-0": 16.5562, + "flash_attn:8-128-1024-16-4-512-False-True-True-128": 16.6114, + "flash_attn:8-128-4096-16-4-512-False-True-True-0": 66.8192, + "flash_attn:8-128-4096-16-4-512-False-True-True-128": 66.7389, + "flash_attn:8-128-1024-16-8-512-False-True-True-0": 16.7078, + "flash_attn:8-128-1024-16-8-512-False-True-True-128": 16.6765, + "flash_attn:8-128-4096-16-8-512-False-True-True-0": 66.6175, + "flash_attn:8-128-4096-16-8-512-False-True-True-128": 67.839, + "flash_attn:16-128-1024-16-4-64-False-True-True-0": 0.312078, + "flash_attn:16-128-1024-16-4-64-False-True-True-128": 0.324376, + "flash_attn:16-128-4096-16-4-64-False-True-True-0": 1.16966, + "flash_attn:16-128-4096-16-4-64-False-True-True-128": 1.20806, + "flash_attn:16-128-1024-16-8-64-False-True-True-0": 0.314704, + "flash_attn:16-128-1024-16-8-64-False-True-True-128": 0.329992, + "flash_attn:16-128-4096-16-8-64-False-True-True-0": 1.18503, + "flash_attn:16-128-4096-16-8-64-False-True-True-128": 1.23079, + "flash_attn:16-128-1024-16-4-128-False-True-True-0": 0.41132, + "flash_attn:16-128-1024-16-4-128-False-True-True-128": 0.43446, + "flash_attn:16-128-4096-16-4-128-False-True-True-0": 1.58444, + "flash_attn:16-128-4096-16-4-128-False-True-True-128": 1.63836, + "flash_attn:16-128-1024-16-8-128-False-True-True-0": 0.433056, + "flash_attn:16-128-1024-16-8-128-False-True-True-128": 0.451308, + "flash_attn:16-128-4096-16-8-128-False-True-True-0": 1.65989, + "flash_attn:16-128-4096-16-8-128-False-True-True-128": 1.73207, + "flash_attn:16-128-1024-16-4-256-False-True-True-0": 1.38824, + "flash_attn:16-128-1024-16-4-256-False-True-True-128": 1.3974, + "flash_attn:16-128-4096-16-4-256-False-True-True-0": 5.38122, + "flash_attn:16-128-4096-16-4-256-False-True-True-128": 5.36734, + "flash_attn:16-128-1024-16-8-256-False-True-True-0": 1.38632, + "flash_attn:16-128-1024-16-8-256-False-True-True-128": 1.38902, + "flash_attn:16-128-4096-16-8-256-False-True-True-0": 5.33083, + "flash_attn:16-128-4096-16-8-256-False-True-True-128": 5.40368, + "flash_attn:16-128-1024-16-4-512-False-True-True-0": 32.6949, + "flash_attn:16-128-1024-16-4-512-False-True-True-128": 33.0126, + "flash_attn:16-128-4096-16-4-512-False-True-True-0": 130.693, + "flash_attn:16-128-4096-16-4-512-False-True-True-128": 131.444, + "flash_attn:16-128-1024-16-8-512-False-True-True-0": 32.7242, + "flash_attn:16-128-1024-16-8-512-False-True-True-128": 33.1765, + "flash_attn:16-128-4096-16-8-512-False-True-True-0": 132.025, + "flash_attn:16-128-4096-16-8-512-False-True-True-128": 131.914, + "flash_attn:1-128-1024-16-4-64-False-True-False-0": 0.072696, + "flash_attn:1-128-1024-16-4-64-False-True-False-128": 0.077012, + "flash_attn:1-128-4096-16-4-64-False-True-False-0": 0.268112, + "flash_attn:1-128-4096-16-4-64-False-True-False-128": 0.282152, + "flash_attn:1-128-1024-16-8-64-False-True-False-0": 0.072904, + "flash_attn:1-128-1024-16-8-64-False-True-False-128": 0.076856, + "flash_attn:1-128-4096-16-8-64-False-True-False-0": 0.267384, + "flash_attn:1-128-4096-16-8-64-False-True-False-128": 0.28444, + "flash_attn:1-128-1024-16-4-128-False-True-False-0": 0.054652, + "flash_attn:1-128-1024-16-4-128-False-True-False-128": 0.05798, + "flash_attn:1-128-4096-16-4-128-False-True-False-0": 0.194142, + "flash_attn:1-128-4096-16-4-128-False-True-False-128": 0.208572, + "flash_attn:1-128-1024-16-8-128-False-True-False-0": 0.055224, + "flash_attn:1-128-1024-16-8-128-False-True-False-128": 0.058344, + "flash_attn:1-128-4096-16-8-128-False-True-False-0": 0.194506, + "flash_attn:1-128-4096-16-8-128-False-True-False-128": 0.20696, + "flash_attn:1-128-1024-16-4-256-False-True-False-0": 0.110396, + "flash_attn:1-128-1024-16-4-256-False-True-False-128": 0.1131, + "flash_attn:1-128-4096-16-4-256-False-True-False-0": 0.40924, + "flash_attn:1-128-4096-16-4-256-False-True-False-128": 0.422994, + "flash_attn:1-128-1024-16-8-256-False-True-False-0": 0.109668, + "flash_attn:1-128-1024-16-8-256-False-True-False-128": 0.112008, + "flash_attn:1-128-4096-16-8-256-False-True-False-0": 0.405704, + "flash_attn:1-128-4096-16-8-256-False-True-False-128": 0.41392, + "flash_attn:1-128-1024-16-4-512-False-True-False-0": 2.1737, + "flash_attn:1-128-1024-16-4-512-False-True-False-128": 2.18834, + "flash_attn:1-128-4096-16-4-512-False-True-False-0": 8.53273, + "flash_attn:1-128-4096-16-4-512-False-True-False-128": 8.58515, + "flash_attn:1-128-1024-16-8-512-False-True-False-0": 2.1971, + "flash_attn:1-128-1024-16-8-512-False-True-False-128": 2.20155, + "flash_attn:1-128-4096-16-8-512-False-True-False-0": 8.67308, + "flash_attn:1-128-4096-16-8-512-False-True-False-128": 8.66268, + "flash_attn:8-128-1024-16-4-64-False-True-False-0": 0.158132, + "flash_attn:8-128-1024-16-4-64-False-True-False-128": 0.165308, + "flash_attn:8-128-4096-16-4-64-False-True-False-0": 0.579904, + "flash_attn:8-128-4096-16-4-64-False-True-False-128": 0.612794, + "flash_attn:8-128-1024-16-8-64-False-True-False-0": 0.161928, + "flash_attn:8-128-1024-16-8-64-False-True-False-128": 0.165776, + "flash_attn:8-128-4096-16-8-64-False-True-False-0": 0.593008, + "flash_attn:8-128-4096-16-8-64-False-True-False-128": 0.6162, + "flash_attn:8-128-1024-16-4-128-False-True-False-0": 0.238732, + "flash_attn:8-128-1024-16-4-128-False-True-False-128": 0.253344, + "flash_attn:8-128-4096-16-4-128-False-True-False-0": 0.860002, + "flash_attn:8-128-4096-16-4-128-False-True-False-128": 0.910624, + "flash_attn:8-128-1024-16-8-128-False-True-False-0": 0.250198, + "flash_attn:8-128-1024-16-8-128-False-True-False-128": 0.259428, + "flash_attn:8-128-4096-16-8-128-False-True-False-0": 0.921128, + "flash_attn:8-128-4096-16-8-128-False-True-False-128": 0.961038, + "flash_attn:8-128-1024-16-4-256-False-True-False-0": 0.737984, + "flash_attn:8-128-1024-16-4-256-False-True-False-128": 0.753532, + "flash_attn:8-128-4096-16-4-256-False-True-False-0": 2.80314, + "flash_attn:8-128-4096-16-4-256-False-True-False-128": 2.89687, + "flash_attn:8-128-1024-16-8-256-False-True-False-0": 0.732732, + "flash_attn:8-128-1024-16-8-256-False-True-False-128": 0.747188, + "flash_attn:8-128-4096-16-8-256-False-True-False-0": 2.7937, + "flash_attn:8-128-4096-16-8-256-False-True-False-128": 2.8554, + "flash_attn:8-128-1024-16-4-512-False-True-False-0": 16.5033, + "flash_attn:8-128-1024-16-4-512-False-True-False-128": 16.6317, + "flash_attn:8-128-4096-16-4-512-False-True-False-0": 66.3594, + "flash_attn:8-128-4096-16-4-512-False-True-False-128": 66.9647, + "flash_attn:8-128-1024-16-8-512-False-True-False-0": 16.7118, + "flash_attn:8-128-1024-16-8-512-False-True-False-128": 16.7786, + "flash_attn:8-128-4096-16-8-512-False-True-False-0": 66.3964, + "flash_attn:8-128-4096-16-8-512-False-True-False-128": 67.024, + "flash_attn:16-128-1024-16-4-64-False-True-False-0": 0.311948, + "flash_attn:16-128-1024-16-4-64-False-True-False-128": 0.327184, + "flash_attn:16-128-4096-16-4-64-False-True-False-0": 1.14088, + "flash_attn:16-128-4096-16-4-64-False-True-False-128": 1.20892, + "flash_attn:16-128-1024-16-8-64-False-True-False-0": 0.316316, + "flash_attn:16-128-1024-16-8-64-False-True-False-128": 0.327522, + "flash_attn:16-128-4096-16-8-64-False-True-False-0": 1.17754, + "flash_attn:16-128-4096-16-8-64-False-True-False-128": 1.22213, + "flash_attn:16-128-1024-16-4-128-False-True-False-0": 0.418002, + "flash_attn:16-128-1024-16-4-128-False-True-False-128": 0.433342, + "flash_attn:16-128-4096-16-4-128-False-True-False-0": 1.56832, + "flash_attn:16-128-4096-16-4-128-False-True-False-128": 1.63706, + "flash_attn:16-128-1024-16-8-128-False-True-False-0": 0.434304, + "flash_attn:16-128-1024-16-8-128-False-True-False-128": 0.450346, + "flash_attn:16-128-4096-16-8-128-False-True-False-0": 1.60371, + "flash_attn:16-128-4096-16-8-128-False-True-False-128": 1.71813, + "flash_attn:16-128-1024-16-4-256-False-True-False-0": 1.37249, + "flash_attn:16-128-1024-16-4-256-False-True-False-128": 1.39334, + "flash_attn:16-128-4096-16-4-256-False-True-False-0": 5.25457, + "flash_attn:16-128-4096-16-4-256-False-True-False-128": 5.36754, + "flash_attn:16-128-1024-16-8-256-False-True-False-0": 1.37023, + "flash_attn:16-128-1024-16-8-256-False-True-False-128": 1.38564, + "flash_attn:16-128-4096-16-8-256-False-True-False-0": 5.23965, + "flash_attn:16-128-4096-16-8-256-False-True-False-128": 5.33603, + "flash_attn:16-128-1024-16-4-512-False-True-False-0": 32.5828, + "flash_attn:16-128-1024-16-4-512-False-True-False-128": 32.6596, + "flash_attn:16-128-4096-16-4-512-False-True-False-0": 130.409, + "flash_attn:16-128-4096-16-4-512-False-True-False-128": 131.04, + "flash_attn:16-128-1024-16-8-512-False-True-False-0": 32.8279, + "flash_attn:16-128-1024-16-8-512-False-True-False-128": 32.9852, + "flash_attn:16-128-4096-16-8-512-False-True-False-0": 132.954, + "flash_attn:16-128-4096-16-8-512-False-True-False-128": 132.777, + "flash_attn:1-128-1024-16-4-64-False-False-True-0": 0.053196, + "flash_attn:1-128-1024-16-4-64-False-False-True-128": 0.056784, + "flash_attn:1-128-4096-16-4-64-False-False-True-0": 0.188864, + "flash_attn:1-128-4096-16-4-64-False-False-True-128": 0.204724, + "flash_attn:1-128-1024-16-8-64-False-False-True-0": 0.053404, + "flash_attn:1-128-1024-16-8-64-False-False-True-128": 0.056212, + "flash_attn:1-128-4096-16-8-64-False-False-True-0": 0.187824, + "flash_attn:1-128-4096-16-8-64-False-False-True-128": 0.200538, + "flash_attn:1-128-1024-16-4-128-False-False-True-0": 0.0442, + "flash_attn:1-128-1024-16-4-128-False-False-True-128": 0.047112, + "flash_attn:1-128-4096-16-4-128-False-False-True-0": 0.150696, + "flash_attn:1-128-4096-16-4-128-False-False-True-128": 0.162396, + "flash_attn:1-128-1024-16-8-128-False-False-True-0": 0.04706, + "flash_attn:1-128-1024-16-8-128-False-False-True-128": 0.047684, + "flash_attn:1-128-4096-16-8-128-False-False-True-0": 0.162292, + "flash_attn:1-128-4096-16-8-128-False-False-True-128": 0.165308, + "flash_attn:1-128-1024-16-4-256-False-False-True-0": 0.102804, + "flash_attn:1-128-1024-16-4-256-False-False-True-128": 0.103896, + "flash_attn:1-128-4096-16-4-256-False-False-True-0": 0.377416, + "flash_attn:1-128-4096-16-4-256-False-False-True-128": 0.3796, + "flash_attn:1-128-1024-16-8-256-False-False-True-0": 0.103194, + "flash_attn:1-128-1024-16-8-256-False-False-True-128": 0.102856, + "flash_attn:1-128-4096-16-8-256-False-False-True-0": 0.3757, + "flash_attn:1-128-4096-16-8-256-False-False-True-128": 0.375856, + "flash_attn:1-128-1024-16-4-512-False-False-True-0": 2.16557, + "flash_attn:1-128-1024-16-4-512-False-False-True-128": 2.17682, + "flash_attn:1-128-4096-16-4-512-False-False-True-0": 8.47855, + "flash_attn:1-128-4096-16-4-512-False-False-True-128": 8.49945, + "flash_attn:1-128-1024-16-8-512-False-False-True-0": 2.18312, + "flash_attn:1-128-1024-16-8-512-False-False-True-128": 2.20584, + "flash_attn:1-128-4096-16-8-512-False-False-True-0": 8.55473, + "flash_attn:1-128-4096-16-8-512-False-False-True-128": 8.65358, + "flash_attn:8-128-1024-16-4-64-False-False-True-0": 0.121056, + "flash_attn:8-128-1024-16-4-64-False-False-True-128": 0.12454, + "flash_attn:8-128-4096-16-4-64-False-False-True-0": 0.430456, + "flash_attn:8-128-4096-16-4-64-False-False-True-128": 0.452764, + "flash_attn:8-128-1024-16-8-64-False-False-True-0": 0.12298, + "flash_attn:8-128-1024-16-8-64-False-False-True-128": 0.126568, + "flash_attn:8-128-4096-16-8-64-False-False-True-0": 0.434044, + "flash_attn:8-128-4096-16-8-64-False-False-True-128": 0.455624, + "flash_attn:8-128-1024-16-4-128-False-False-True-0": 0.200096, + "flash_attn:8-128-1024-16-4-128-False-False-True-128": 0.21372, + "flash_attn:8-128-4096-16-4-128-False-False-True-0": 0.723996, + "flash_attn:8-128-4096-16-4-128-False-False-True-128": 0.788892, + "flash_attn:8-128-1024-16-8-128-False-False-True-0": 0.21138, + "flash_attn:8-128-1024-16-8-128-False-False-True-128": 0.221624, + "flash_attn:8-128-4096-16-8-128-False-False-True-0": 0.763308, + "flash_attn:8-128-4096-16-8-128-False-False-True-128": 0.821418, + "flash_attn:8-128-1024-16-4-256-False-False-True-0": 0.684216, + "flash_attn:8-128-1024-16-4-256-False-False-True-128": 0.68796, + "flash_attn:8-128-4096-16-4-256-False-False-True-0": 2.60026, + "flash_attn:8-128-4096-16-4-256-False-False-True-128": 2.62883, + "flash_attn:8-128-1024-16-8-256-False-False-True-0": 0.680784, + "flash_attn:8-128-1024-16-8-256-False-False-True-128": 0.681564, + "flash_attn:8-128-4096-16-8-256-False-False-True-0": 2.60458, + "flash_attn:8-128-4096-16-8-256-False-False-True-128": 2.60354, + "flash_attn:8-128-1024-16-4-512-False-False-True-0": 16.4901, + "flash_attn:8-128-1024-16-4-512-False-False-True-128": 16.5829, + "flash_attn:8-128-4096-16-4-512-False-False-True-0": 66.5631, + "flash_attn:8-128-4096-16-4-512-False-False-True-128": 67.5802, + "flash_attn:8-128-1024-16-8-512-False-False-True-0": 16.5468, + "flash_attn:8-128-1024-16-8-512-False-False-True-128": 16.7509, + "flash_attn:8-128-4096-16-8-512-False-False-True-0": 66.5583, + "flash_attn:8-128-4096-16-8-512-False-False-True-128": 66.7109, + "flash_attn:16-128-1024-16-4-64-False-False-True-0": 0.228644, + "flash_attn:16-128-1024-16-4-64-False-False-True-128": 0.243256, + "flash_attn:16-128-4096-16-4-64-False-False-True-0": 0.834886, + "flash_attn:16-128-4096-16-4-64-False-False-True-128": 0.882882, + "flash_attn:16-128-1024-16-8-64-False-False-True-0": 0.233636, + "flash_attn:16-128-1024-16-8-64-False-False-True-128": 0.250536, + "flash_attn:16-128-4096-16-8-64-False-False-True-0": 0.846508, + "flash_attn:16-128-4096-16-8-64-False-False-True-128": 0.904696, + "flash_attn:16-128-1024-16-4-128-False-False-True-0": 0.349388, + "flash_attn:16-128-1024-16-4-128-False-False-True-128": 0.372788, + "flash_attn:16-128-4096-16-4-128-False-False-True-0": 1.31612, + "flash_attn:16-128-4096-16-4-128-False-False-True-128": 1.4275, + "flash_attn:16-128-1024-16-8-128-False-False-True-0": 0.372112, + "flash_attn:16-128-1024-16-8-128-False-False-True-128": 0.388128, + "flash_attn:16-128-4096-16-8-128-False-False-True-0": 1.38939, + "flash_attn:16-128-4096-16-8-128-False-False-True-128": 1.48268, + "flash_attn:16-128-1024-16-4-256-False-False-True-0": 1.2655, + "flash_attn:16-128-1024-16-4-256-False-False-True-128": 1.27343, + "flash_attn:16-128-4096-16-4-256-False-False-True-0": 4.91618, + "flash_attn:16-128-4096-16-4-256-False-False-True-128": 4.95201, + "flash_attn:16-128-1024-16-8-256-False-False-True-0": 1.2754, + "flash_attn:16-128-1024-16-8-256-False-False-True-128": 1.27218, + "flash_attn:16-128-4096-16-8-256-False-False-True-0": 4.94005, + "flash_attn:16-128-4096-16-8-256-False-False-True-128": 4.9778, + "flash_attn:16-128-1024-16-4-512-False-False-True-0": 32.5958, + "flash_attn:16-128-1024-16-4-512-False-False-True-128": 32.712, + "flash_attn:16-128-4096-16-4-512-False-False-True-0": 130.326, + "flash_attn:16-128-4096-16-4-512-False-False-True-128": 131.142, + "flash_attn:16-128-1024-16-8-512-False-False-True-0": 32.9821, + "flash_attn:16-128-1024-16-8-512-False-False-True-128": 32.9599, + "flash_attn:16-128-4096-16-8-512-False-False-True-0": 132.752, + "flash_attn:16-128-4096-16-8-512-False-False-True-128": 130.963, + "flash_attn:1-1-1024-16-4-64-False-False-False-0": 0.051064, + "flash_attn:1-1-1024-16-4-64-False-False-False-128": 0.020852, + "flash_attn:1-1-4096-16-4-64-False-False-False-0": 0.183066, + "flash_attn:1-1-4096-16-4-64-False-False-False-128": 0.053196, + "flash_attn:1-1-1024-16-8-64-False-False-False-0": 0.052208, + "flash_attn:1-1-1024-16-8-64-False-False-False-128": 0.022932, + "flash_attn:1-1-4096-16-8-64-False-False-False-0": 0.187148, + "flash_attn:1-1-4096-16-8-64-False-False-False-128": 0.060476, + "flash_attn:1-1-1024-16-4-128-False-False-False-0": 0.041912, + "flash_attn:1-1-1024-16-4-128-False-False-False-128": 0.02548, + "flash_attn:1-1-4096-16-4-128-False-False-False-0": 0.150644, + "flash_attn:1-1-4096-16-4-128-False-False-False-128": 0.069732, + "flash_attn:1-1-1024-16-8-128-False-False-False-0": 0.044512, + "flash_attn:1-1-1024-16-8-128-False-False-False-128": 0.027456, + "flash_attn:1-1-4096-16-8-128-False-False-False-0": 0.1599, + "flash_attn:1-1-4096-16-8-128-False-False-False-128": 0.079196, + "flash_attn:1-1-1024-16-4-256-False-False-False-0": 0.099788, + "flash_attn:1-1-1024-16-4-256-False-False-False-128": 0.034736, + "flash_attn:1-1-4096-16-4-256-False-False-False-0": 0.375648, + "flash_attn:1-1-4096-16-4-256-False-False-False-128": 0.108576, + "flash_attn:1-1-1024-16-8-256-False-False-False-0": 0.100984, + "flash_attn:1-1-1024-16-8-256-False-False-False-128": 0.043108, + "flash_attn:1-1-4096-16-8-256-False-False-False-0": 0.377962, + "flash_attn:1-1-4096-16-8-256-False-False-False-128": 0.133952, + "flash_attn:1-1-1024-16-4-512-False-False-False-0": 2.13161, + "flash_attn:1-1-1024-16-4-512-False-False-False-128": 0.060476, + "flash_attn:1-1-4096-16-4-512-False-False-False-0": 8.48344, + "flash_attn:1-1-4096-16-4-512-False-False-False-128": 0.203086, + "flash_attn:1-1-1024-16-8-512-False-False-False-0": 2.15332, + "flash_attn:1-1-1024-16-8-512-False-False-False-128": 0.068172, + "flash_attn:1-1-4096-16-8-512-False-False-False-0": 8.44485, + "flash_attn:1-1-4096-16-8-512-False-False-False-128": 0.20982, + "flash_attn:1-128-1024-16-4-64-False-False-False-0": 0.053248, + "flash_attn:1-128-1024-16-4-64-False-False-False-128": 0.056706, + "flash_attn:1-128-4096-16-4-64-False-False-False-0": 0.190684, + "flash_attn:1-128-4096-16-4-64-False-False-False-128": 0.2054, + "flash_attn:1-128-1024-16-8-64-False-False-False-0": 0.053144, + "flash_attn:1-128-1024-16-8-64-False-False-False-128": 0.056004, + "flash_attn:1-128-4096-16-8-64-False-False-False-0": 0.188656, + "flash_attn:1-128-4096-16-8-64-False-False-False-128": 0.200408, + "flash_attn:1-128-1024-16-4-128-False-False-False-0": 0.044148, + "flash_attn:1-128-1024-16-4-128-False-False-False-128": 0.047164, + "flash_attn:1-128-4096-16-4-128-False-False-False-0": 0.152048, + "flash_attn:1-128-4096-16-4-128-False-False-False-128": 0.16289, + "flash_attn:1-128-1024-16-8-128-False-False-False-0": 0.046436, + "flash_attn:1-128-1024-16-8-128-False-False-False-128": 0.047892, + "flash_attn:1-128-4096-16-8-128-False-False-False-0": 0.163228, + "flash_attn:1-128-4096-16-8-128-False-False-False-128": 0.16536, + "flash_attn:1-128-1024-16-4-256-False-False-False-0": 0.101036, + "flash_attn:1-128-1024-16-4-256-False-False-False-128": 0.103116, + "flash_attn:1-128-4096-16-4-256-False-False-False-0": 0.371956, + "flash_attn:1-128-4096-16-4-256-False-False-False-128": 0.379496, + "flash_attn:1-128-1024-16-8-256-False-False-False-0": 0.102128, + "flash_attn:1-128-1024-16-8-256-False-False-False-128": 0.102648, + "flash_attn:1-128-4096-16-8-256-False-False-False-0": 0.37336, + "flash_attn:1-128-4096-16-8-256-False-False-False-128": 0.375934, + "flash_attn:1-128-1024-16-4-512-False-False-False-0": 2.16377, + "flash_attn:1-128-1024-16-4-512-False-False-False-128": 2.17565, + "flash_attn:1-128-4096-16-4-512-False-False-False-0": 8.60387, + "flash_attn:1-128-4096-16-4-512-False-False-False-128": 8.53216, + "flash_attn:1-128-1024-16-8-512-False-False-False-0": 2.18686, + "flash_attn:1-128-1024-16-8-512-False-False-False-128": 2.18876, + "flash_attn:1-128-4096-16-8-512-False-False-False-0": 8.69586, + "flash_attn:1-128-4096-16-8-512-False-False-False-128": 8.66497, + "flash_attn:8-1-1024-16-4-64-False-False-False-0": 0.111488, + "flash_attn:8-1-1024-16-4-64-False-False-False-128": 0.028912, + "flash_attn:8-1-4096-16-4-64-False-False-False-0": 0.413712, + "flash_attn:8-1-4096-16-4-64-False-False-False-128": 0.090428, + "flash_attn:8-1-1024-16-8-64-False-False-False-0": 0.113776, + "flash_attn:8-1-1024-16-8-64-False-False-False-128": 0.047632, + "flash_attn:8-1-4096-16-8-64-False-False-False-0": 0.420004, + "flash_attn:8-1-4096-16-8-64-False-False-False-128": 0.162084, + "flash_attn:8-1-1024-16-4-128-False-False-False-0": 0.190632, + "flash_attn:8-1-1024-16-4-128-False-False-False-128": 0.04888, + "flash_attn:8-1-4096-16-4-128-False-False-False-0": 0.724984, + "flash_attn:8-1-4096-16-4-128-False-False-False-128": 0.164736, + "flash_attn:8-1-1024-16-8-128-False-False-False-0": 0.20228, + "flash_attn:8-1-1024-16-8-128-False-False-False-128": 0.087256, + "flash_attn:8-1-4096-16-8-128-False-False-False-0": 0.77116, + "flash_attn:8-1-4096-16-8-128-False-False-False-128": 0.311116, + "flash_attn:8-1-1024-16-4-256-False-False-False-0": 0.663364, + "flash_attn:8-1-1024-16-4-256-False-False-False-128": 0.088504, + "flash_attn:8-1-4096-16-4-256-False-False-False-0": 2.5858, + "flash_attn:8-1-4096-16-4-256-False-False-False-128": 0.31356, + "flash_attn:8-1-1024-16-8-256-False-False-False-0": 0.659438, + "flash_attn:8-1-1024-16-8-256-False-False-False-128": 0.16666, + "flash_attn:8-1-4096-16-8-256-False-False-False-0": 2.58206, + "flash_attn:8-1-4096-16-8-256-False-False-False-128": 0.612976, + "flash_attn:8-1-1024-16-4-512-False-False-False-0": 16.1952, + "flash_attn:8-1-1024-16-4-512-False-False-False-128": 0.17394, + "flash_attn:8-1-4096-16-4-512-False-False-False-0": 65.1957, + "flash_attn:8-1-4096-16-4-512-False-False-False-128": 0.623584, + "flash_attn:8-1-1024-16-8-512-False-False-False-0": 16.392, + "flash_attn:8-1-1024-16-8-512-False-False-False-128": 0.328588, + "flash_attn:8-1-4096-16-8-512-False-False-False-0": 65.042, + "flash_attn:8-1-4096-16-8-512-False-False-False-128": 1.25653, + "flash_attn:8-128-1024-16-4-64-False-False-False-0": 0.121004, + "flash_attn:8-128-1024-16-4-64-False-False-False-128": 0.124592, + "flash_attn:8-128-4096-16-4-64-False-False-False-0": 0.442468, + "flash_attn:8-128-4096-16-4-64-False-False-False-128": 0.45344, + "flash_attn:8-128-1024-16-8-64-False-False-False-0": 0.12298, + "flash_attn:8-128-1024-16-8-64-False-False-False-128": 0.126152, + "flash_attn:8-128-4096-16-8-64-False-False-False-0": 0.440804, + "flash_attn:8-128-4096-16-8-64-False-False-False-128": 0.458276, + "flash_attn:8-128-1024-16-4-128-False-False-False-0": 0.198744, + "flash_attn:8-128-1024-16-4-128-False-False-False-128": 0.213616, + "flash_attn:8-128-4096-16-4-128-False-False-False-0": 0.723372, + "flash_attn:8-128-4096-16-4-128-False-False-False-128": 0.793468, + "flash_attn:8-128-1024-16-8-128-False-False-False-0": 0.210496, + "flash_attn:8-128-1024-16-8-128-False-False-False-128": 0.222716, + "flash_attn:8-128-4096-16-8-128-False-False-False-0": 0.764998, + "flash_attn:8-128-4096-16-8-128-False-False-False-128": 0.825188, + "flash_attn:8-128-1024-16-4-256-False-False-False-0": 0.67379, + "flash_attn:8-128-1024-16-4-256-False-False-False-128": 0.686192, + "flash_attn:8-128-4096-16-4-256-False-False-False-0": 2.57234, + "flash_attn:8-128-4096-16-4-256-False-False-False-128": 2.61591, + "flash_attn:8-128-1024-16-8-256-False-False-False-0": 0.67145, + "flash_attn:8-128-1024-16-8-256-False-False-False-128": 0.680316, + "flash_attn:8-128-4096-16-8-256-False-False-False-0": 2.57546, + "flash_attn:8-128-4096-16-8-256-False-False-False-128": 2.60525, + "flash_attn:8-128-1024-16-4-512-False-False-False-0": 16.5026, + "flash_attn:8-128-1024-16-4-512-False-False-False-128": 16.5364, + "flash_attn:8-128-4096-16-4-512-False-False-False-0": 66.0768, + "flash_attn:8-128-4096-16-4-512-False-False-False-128": 67.0548, + "flash_attn:8-128-1024-16-8-512-False-False-False-0": 16.561, + "flash_attn:8-128-1024-16-8-512-False-False-False-128": 16.5769, + "flash_attn:8-128-4096-16-8-512-False-False-False-0": 66.4396, + "flash_attn:8-128-4096-16-8-512-False-False-False-128": 66.6831, + "flash_attn:16-1-1024-16-4-64-False-False-False-0": 0.214604, + "flash_attn:16-1-1024-16-4-64-False-False-False-128": 0.04758, + "flash_attn:16-1-4096-16-4-64-False-False-False-0": 0.81588, + "flash_attn:16-1-4096-16-4-64-False-False-False-128": 0.164164, + "flash_attn:16-1-1024-16-8-64-False-False-False-0": 0.219232, + "flash_attn:16-1-1024-16-8-64-False-False-False-128": 0.088036, + "flash_attn:16-1-4096-16-8-64-False-False-False-0": 0.831428, + "flash_attn:16-1-4096-16-8-64-False-False-False-128": 0.31226, + "flash_attn:16-1-1024-16-4-128-False-False-False-0": 0.335452, + "flash_attn:16-1-1024-16-4-128-False-False-False-128": 0.087776, + "flash_attn:16-1-4096-16-4-128-False-False-False-0": 1.2936, + "flash_attn:16-1-4096-16-4-128-False-False-False-128": 0.311532, + "flash_attn:16-1-1024-16-8-128-False-False-False-0": 0.352404, + "flash_attn:16-1-1024-16-8-128-False-False-False-128": 0.16354, + "flash_attn:16-1-4096-16-8-128-False-False-False-0": 1.36139, + "flash_attn:16-1-4096-16-8-128-False-False-False-128": 0.609024, + "flash_attn:16-1-1024-16-4-256-False-False-False-0": 1.22912, + "flash_attn:16-1-1024-16-4-256-False-False-False-128": 0.169, + "flash_attn:16-1-4096-16-4-256-False-False-False-0": 4.80152, + "flash_attn:16-1-4096-16-4-256-False-False-False-128": 0.627562, + "flash_attn:16-1-1024-16-8-256-False-False-False-0": 1.22078, + "flash_attn:16-1-1024-16-8-256-False-False-False-128": 0.319436, + "flash_attn:16-1-4096-16-8-256-False-False-False-0": 4.79996, + "flash_attn:16-1-4096-16-8-256-False-False-False-128": 1.2219, + "flash_attn:16-1-1024-16-4-512-False-False-False-0": 32.0168, + "flash_attn:16-1-1024-16-4-512-False-False-False-128": 0.33436, + "flash_attn:16-1-4096-16-4-512-False-False-False-0": 128.256, + "flash_attn:16-1-4096-16-4-512-False-False-False-128": 1.24114, + "flash_attn:16-1-1024-16-8-512-False-False-False-0": 32.2096, + "flash_attn:16-1-1024-16-8-512-False-False-False-128": 0.651404, + "flash_attn:16-1-4096-16-8-512-False-False-False-0": 130.912, + "flash_attn:16-1-4096-16-8-512-False-False-False-128": 2.54992, + "flash_attn:16-128-1024-16-4-64-False-False-False-0": 0.2288, + "flash_attn:16-128-1024-16-4-64-False-False-False-128": 0.242788, + "flash_attn:16-128-4096-16-4-64-False-False-False-0": 0.838344, + "flash_attn:16-128-4096-16-4-64-False-False-False-128": 0.883168, + "flash_attn:16-128-1024-16-8-64-False-False-False-0": 0.233948, + "flash_attn:16-128-1024-16-8-64-False-False-False-128": 0.248534, + "flash_attn:16-128-4096-16-8-64-False-False-False-0": 0.851604, + "flash_attn:16-128-4096-16-8-64-False-False-False-128": 0.904956, + "flash_attn:16-128-1024-16-4-128-False-False-False-0": 0.347698, + "flash_attn:16-128-1024-16-4-128-False-False-False-128": 0.376012, + "flash_attn:16-128-4096-16-4-128-False-False-False-0": 1.31674, + "flash_attn:16-128-4096-16-4-128-False-False-False-128": 1.43052, + "flash_attn:16-128-1024-16-8-128-False-False-False-0": 0.3705, + "flash_attn:16-128-1024-16-8-128-False-False-False-128": 0.391352, + "flash_attn:16-128-4096-16-8-128-False-False-False-0": 1.38679, + "flash_attn:16-128-4096-16-8-128-False-False-False-128": 1.49074, + "flash_attn:16-128-1024-16-4-256-False-False-False-0": 1.24987, + "flash_attn:16-128-1024-16-4-256-False-False-False-128": 1.27072, + "flash_attn:16-128-4096-16-4-256-False-False-False-0": 4.8405, + "flash_attn:16-128-4096-16-4-256-False-False-False-128": 4.96272, + "flash_attn:16-128-1024-16-8-256-False-False-False-0": 1.2648, + "flash_attn:16-128-1024-16-8-256-False-False-False-128": 1.26612, + "flash_attn:16-128-4096-16-8-256-False-False-False-0": 4.86148, + "flash_attn:16-128-4096-16-8-256-False-False-False-128": 4.96616, + "flash_attn:16-128-1024-16-4-512-False-False-False-0": 32.761, + "flash_attn:16-128-1024-16-4-512-False-False-False-128": 32.4472, + "flash_attn:16-128-4096-16-4-512-False-False-False-0": 131.23, + "flash_attn:16-128-4096-16-4-512-False-False-False-128": 130.807, + "flash_attn:16-128-1024-16-8-512-False-False-False-0": 32.9504, + "flash_attn:16-128-1024-16-8-512-False-False-False-128": 32.694, + "flash_attn:16-128-4096-16-8-512-False-False-False-0": 130.059, + "flash_attn:16-128-4096-16-8-512-False-False-False-128": 130.668 } diff --git a/benchmark/bench_flash_attn.py b/benchmark/bench_flash_attn.py index 49e26629..bc4a0007 100644 --- a/benchmark/bench_flash_attn.py +++ b/benchmark/bench_flash_attn.py @@ -59,18 +59,19 @@ def flash_attn_baseline( causal = [True, False] local = [True, False] use_sinks = [True, False] -batch_size = [16, 32] +batch_size = [1, 8, 16] q_seq_length_range = [1, 128] head_dim = [64, 128, 256, 512] num_heads_q = [16] -num_heads_kv = [2, 4, 8] -kv_seq_length_range = [4096, 16384] +num_heads_kv = [4, 8] +kv_seq_length_range = [1024, 4096] page_size_range = [0, 128] configs = list( filter( lambda cfg: not (cfg[0] and cfg[1]) and (cfg[4] != 1 or (not cfg[0] and not cfg[1] and not cfg[2])) - and (cfg[6] % cfg[7] == 0), + and (cfg[6] % cfg[7] == 0) + and (cfg[8] >= cfg[9]), product( causal, local, @@ -240,8 +241,9 @@ def benchmark( if __name__ == "__main__": benchmark.run(print_data=False) + print("Benchmark finished!") + import pandas as pd df = pd.DataFrame(all_results) print(df.to_markdown()) - print("Benchmark finished!") diff --git a/benchmark/update_baseline_from_log.py b/benchmark/update_baseline_from_log.py index 668268a6..a316d18a 100644 --- a/benchmark/update_baseline_from_log.py +++ b/benchmark/update_baseline_from_log.py @@ -1,8 +1,9 @@ import json +import os import re -def parse_benchmark_log(log_text: str) -> dict: +def parse_fused_moe_log(log_text: str) -> dict: lines = log_text.splitlines() start_idx = None for i, line in enumerate(lines): @@ -11,7 +12,7 @@ def parse_benchmark_log(log_text: str) -> dict: break if start_idx is None: - raise ValueError("Benchmark finished! not found") + raise ValueError("Benchmark finished! not found in fused_moe log") result = {} @@ -34,22 +35,72 @@ def parse_benchmark_log(log_text: str) -> dict: shard_intermediate_size = cols[5] ms = float(cols[-1]) + key = f"fused_moe:{num_tokens}-{num_experts}-{topk}-{hidden_size}-{shard_intermediate_size}" + result[key] = ms + + return result + + +def parse_flash_attn_log(log_text: str) -> dict: + lines = log_text.splitlines() + start_idx = None + for i, line in enumerate(lines): + if "Benchmark finished!" in line: + start_idx = i + break + + if start_idx is None: + raise ValueError("Benchmark finished! not found in flash_attn log") + + result = {} + + for line in lines[start_idx + 1 :]: + line = line.strip() + + if not line.startswith("|"): + continue + if re.match(r"\|\s*-+", line): + continue + if "batch" in line: + continue + + cols = [c.strip() for c in line.strip("|").split("|")] + + batch = cols[1] + q_seq_length = cols[2] + kv_seq_length = cols[3] + num_heads_q = cols[4] + num_heads_kv = cols[5] + head_dim = cols[6] + causal = cols[7] + local = cols[8] + use_sinks = cols[9] + page_size = cols[10] + ms = float(cols[-1]) + key = ( - f"{num_tokens}-{num_experts}-{topk}-{hidden_size}-{shard_intermediate_size}" + f"flash_attn:{batch}-{q_seq_length}-{kv_seq_length}" + f"-{num_heads_q}-{num_heads_kv}-{head_dim}" + f"-{causal}-{local}-{use_sinks}-{page_size}" ) result[key] = ms return result -def format_section(title, data): +def format_section(title, data, benchmark_type="fused_moe"): if not data: return f"### {title}\n\nNone\n" + if benchmark_type == "flash_attn": + header = "| config | log | baseline | ratio |" + else: + header = "| num_tokens - num_experts - topk - hidden_size - shard_intermediate_size | log | baseline | ratio |" + lines = [ f"### {title}", "", - "| num_tokens - num_experts - topk - hidden_size - shard_intermediate_size | log | baseline | ratio |", + header, "|---|---:|---:|---:|", ] for k, (l, b) in sorted(data.items()): @@ -81,18 +132,18 @@ def compare(log_data: dict, baseline: dict): return lower, higher, equal -def main(): +def process_log(log_file, parser, benchmark_type, baseline): + if not os.path.exists(log_file): + print(f"Warning: {log_file} not found, skipping {benchmark_type} benchmark") + return {}, {}, {} - with open("fused_moe.log") as f: + with open(log_file) as f: log_text = f.read() - data = parse_benchmark_log(log_text) - - with open("benchmark/baseline.json") as f: - baseline = json.load(f) - + data = parser(log_text) lower, higher, equal = compare(data, baseline) + print(f"\n=== {benchmark_type} ===") print("=== LOWER (log < baseline) ===") for k, (l, b) in lower.items(): ratio = l / b @@ -114,20 +165,75 @@ def main(): print("Collected benchmark data:") print(data) - pr_body = "\n".join( - [ - "## Benchmark Comparison", - "", - "_Ratio = log / baseline (lower is better)_", - "", - format_section("LOWER (log < baseline)", lower), - format_section("HIGHER (log > baseline)", higher), - format_section("EQUAL", equal), - ] - ) - - if lower: - for k, (l, _) in lower.items(): + return lower, higher, equal + + +def main(): + with open("benchmark/baseline.json") as f: + baseline = json.load(f) + + benchmarks = [ + ("fused_moe.log", parse_fused_moe_log, "fused_moe"), + ("flash.log", parse_flash_attn_log, "flash_attn"), + ] + + all_lower = {} + all_higher = {} + all_equal = {} + + for log_file, parser, benchmark_type in benchmarks: + lower, higher, equal = process_log(log_file, parser, benchmark_type, baseline) + all_lower.update(lower) + all_higher.update(higher) + all_equal.update(equal) + + # Separate results by type for formatting + fused_moe_lower = { + k: v for k, v in all_lower.items() if not k.startswith("flash_attn:") + } + fused_moe_higher = { + k: v for k, v in all_higher.items() if not k.startswith("flash_attn:") + } + fused_moe_equal = { + k: v for k, v in all_equal.items() if not k.startswith("flash_attn:") + } + flash_attn_lower = { + k: v for k, v in all_lower.items() if k.startswith("flash_attn:") + } + flash_attn_higher = { + k: v for k, v in all_higher.items() if k.startswith("flash_attn:") + } + flash_attn_equal = { + k: v for k, v in all_equal.items() if k.startswith("flash_attn:") + } + + sections = [] + if fused_moe_lower or fused_moe_higher or fused_moe_equal: + sections.append("## Fused MoE Benchmark Comparison\n") + sections.append("_Ratio = log / baseline (lower is better)_\n") + sections.append( + format_section("LOWER (log < baseline)", fused_moe_lower, "fused_moe") + ) + sections.append( + format_section("HIGHER (log > baseline)", fused_moe_higher, "fused_moe") + ) + sections.append(format_section("EQUAL", fused_moe_equal, "fused_moe")) + + if flash_attn_lower or flash_attn_higher or flash_attn_equal: + sections.append("## Flash Attention Benchmark Comparison\n") + sections.append("_Ratio = log / baseline (lower is better)_\n") + sections.append( + format_section("LOWER (log < baseline)", flash_attn_lower, "flash_attn") + ) + sections.append( + format_section("HIGHER (log > baseline)", flash_attn_higher, "flash_attn") + ) + sections.append(format_section("EQUAL", flash_attn_equal, "flash_attn")) + + pr_body = "\n".join(sections) if sections else "## Benchmark Comparison\n\nNo data." + + if all_lower: + for k, (l, _) in all_lower.items(): baseline[k] = l with open("benchmark/baseline.json", "w") as f: json.dump(baseline, f, indent=4) diff --git a/include/sgl_flash_kernel_ops.h b/include/sgl_flash_kernel_ops.h index 9c19853f..9c682245 100644 --- a/include/sgl_flash_kernel_ops.h +++ b/include/sgl_flash_kernel_ops.h @@ -1,4 +1,4 @@ -/* Copyright 2025 SGLang Team. All Rights Reserved. +/* Copyright 2025-2026 SGLang Team. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -43,7 +43,7 @@ limitations under the License. * From flash-attention */ std::vector mha_fwd( - at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + const at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, // h_k, d) if there is page_table. const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, @@ -70,7 +70,7 @@ std::vector mha_fwd( float const softcap, bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 std::optional& scheduler_metadata_, // (b + 1) - int num_splits, + int num_kv_splits, std::optional pack_gqa_, int const sm_margin); diff --git a/python/sgl_kernel/flash_attn.py b/python/sgl_kernel/flash_attn.py index 7b7a9fbe..0880eb9c 100644 --- a/python/sgl_kernel/flash_attn.py +++ b/python/sgl_kernel/flash_attn.py @@ -262,7 +262,7 @@ def flash_attn_with_kvcache( cu_seqlens_q, cu_seqlens_k, max_seqlen_q, - max_seqlen_k, + 1, page_table, cache_batch_idx, cache_leftpad, diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 7add37cb..42a9d835 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -20,6 +20,7 @@ foreach(file ${device_cpp}) endforeach() include(FMHADecodeXe20.cmake) +include(FMHAPrefillXe20.cmake) include(MlaDecodeXe20.cmake) message(STATUS "BMG files: ${device_cpp_xe20}") diff --git a/src/FMHADecodeXe20.cmake b/src/FMHADecodeXe20.cmake index 5f5dd41f..9dc5cf3d 100644 --- a/src/FMHADecodeXe20.cmake +++ b/src/FMHADecodeXe20.cmake @@ -2,22 +2,28 @@ # Each (QG_SZ, HEAD_DIM, PAGE_SIZE) combination is compiled as a separate # library to parallelize and speed up compilation. -set(FMHA_DECODE_QG_SIZES 1 2 4 8 16 32) +set(FMHA_DECODE_QG_SIZES 1 2 4 8 16) set(FMHA_DECODE_HEAD_DIMS 64 96 128 192 256 512) -set(FMHA_DECODE_PAGE_SIZES 32 64 128) +set(FMHA_DECODE_PAGE_SIZES 64 128) set(FMHA_DECODE_TEMPLATE "${CMAKE_CURRENT_SOURCE_DIR}/sycl/xe_fmha_fwd_decode_kernel.cpp.in") +set(FMHA_SPLIT_DECODE_TEMPLATE + "${CMAKE_CURRENT_SOURCE_DIR}/sycl/xe_fmha_fwd_split_decode_kernel.cpp.in") + foreach(QG_SZ ${FMHA_DECODE_QG_SIZES}) foreach(HEAD_DIM ${FMHA_DECODE_HEAD_DIMS}) foreach(PAGE_SIZE ${FMHA_DECODE_PAGE_SIZES}) - math(EXPR NUM_SG "${PAGE_SIZE} / 16") - set(GENERATED_FILE "${CMAKE_CURRENT_BINARY_DIR}/sycl/xe_fmha_fwd_decode_kernel_${QG_SZ}_${HEAD_DIM}_${PAGE_SIZE}.cpp") configure_file(${FMHA_DECODE_TEMPLATE} ${GENERATED_FILE} @ONLY) list(APPEND device_cpp_common ${GENERATED_FILE}) + + set(GENERATED_SPLIT_FILE + "${CMAKE_CURRENT_BINARY_DIR}/sycl/xe_fmha_fwd_split_decode_kernel_${QG_SZ}_${HEAD_DIM}_${PAGE_SIZE}.cpp") + configure_file(${FMHA_SPLIT_DECODE_TEMPLATE} ${GENERATED_SPLIT_FILE} @ONLY) + list(APPEND device_cpp_common ${GENERATED_SPLIT_FILE}) endforeach() endforeach() endforeach() diff --git a/src/FMHAPrefillXe20.cmake b/src/FMHAPrefillXe20.cmake new file mode 100644 index 00000000..1aa25ca0 --- /dev/null +++ b/src/FMHAPrefillXe20.cmake @@ -0,0 +1,44 @@ +# Generate FMHA prefill kernel instantiation files. +# Each HEAD_DIM is compiled as a separate translation unit to parallelize +# and speed up compilation. + +set(FMHA_PREFILL_HEAD_DIMS 64 96 128 192 256 512) + +set(FMHA_PREFILL_TEMPLATE + "${CMAKE_CURRENT_SOURCE_DIR}/sycl/xe_fmha_fwd_prefill_kernel.cpp.in") + +# Per-HEAD_DIM tile shape parameters (TILED_Q, TILED_KV, NUM_SG) +set(FMHA_PREFILL_TILED_Q_64 128) +set(FMHA_PREFILL_TILED_KV_64 64) +set(FMHA_PREFILL_NUM_SG_64 8) + +set(FMHA_PREFILL_TILED_Q_96 128) +set(FMHA_PREFILL_TILED_KV_96 64) +set(FMHA_PREFILL_NUM_SG_96 8) + +set(FMHA_PREFILL_TILED_Q_128 256) +set(FMHA_PREFILL_TILED_KV_128 32) +set(FMHA_PREFILL_NUM_SG_128 16) + +set(FMHA_PREFILL_TILED_Q_192 256) +set(FMHA_PREFILL_TILED_KV_192 64) +set(FMHA_PREFILL_NUM_SG_192 32) + +set(FMHA_PREFILL_TILED_Q_256 256) +set(FMHA_PREFILL_TILED_KV_256 64) +set(FMHA_PREFILL_NUM_SG_256 32) + +set(FMHA_PREFILL_TILED_Q_512 256) +set(FMHA_PREFILL_TILED_KV_512 64) +set(FMHA_PREFILL_NUM_SG_512 32) + +foreach(HEAD_DIM ${FMHA_PREFILL_HEAD_DIMS}) + set(TILED_Q ${FMHA_PREFILL_TILED_Q_${HEAD_DIM}}) + set(TILED_KV ${FMHA_PREFILL_TILED_KV_${HEAD_DIM}}) + set(NUM_SG ${FMHA_PREFILL_NUM_SG_${HEAD_DIM}}) + + set(GENERATED_FILE + "${CMAKE_CURRENT_BINARY_DIR}/sycl/xe_fmha_fwd_prefill_kernel_${HEAD_DIM}.cpp") + configure_file(${FMHA_PREFILL_TEMPLATE} ${GENERATED_FILE} @ONLY) + list(APPEND device_cpp_common ${GENERATED_FILE}) +endforeach() diff --git a/src/sycl/flash_attention.cpp b/src/sycl/flash_attention.cpp index 7949234e..a2ce5d10 100644 --- a/src/sycl/flash_attention.cpp +++ b/src/sycl/flash_attention.cpp @@ -35,120 +35,417 @@ #include #include -#include - #include "kernels/chunk_prefill/chunk_prefill_runner.hpp" #include "kernels/flash_attention_v2/xe_fmha_fwd_decode_dispatch.hpp" -#include "kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp" -#include "kernels/flash_attention_v2/xe_fmha_fwd_prefill_runner.hpp" +#include "kernels/flash_attention_v2/xe_fmha_fwd_prefill_dispatch.hpp" namespace decode { -namespace { +// Dispatch macros following the GroupGemmXe20.cpp pattern. +// Directly call struct operator() - no function pointers. + +#define DISPATCH_DECODE_KERNEL(QG, HD, PS) \ + do { \ + if (params.use_split_kv) { \ + FmhaSplitDecodeRunner{}(params); \ + } else { \ + FmhaDecodeRunner{}(params); \ + } \ + } while (0) + +#define DISPATCH_DECODE_PAGE_SIZE(QG, HD) \ + do { \ + switch (params.page_size) { \ + case 64: \ + DISPATCH_DECODE_KERNEL(QG, HD, 64); \ + break; \ + case 128: \ + DISPATCH_DECODE_KERNEL(QG, HD, 128); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported page_size for decode attention: ", params.page_size); \ + } \ + } while (0) + +#define DISPATCH_DECODE_HEAD_DIM(QG) \ + do { \ + switch (params.d) { \ + case 64: \ + DISPATCH_DECODE_PAGE_SIZE(QG, 64); \ + break; \ + case 96: \ + DISPATCH_DECODE_PAGE_SIZE(QG, 96); \ + break; \ + case 128: \ + DISPATCH_DECODE_PAGE_SIZE(QG, 128); \ + break; \ + case 192: \ + DISPATCH_DECODE_PAGE_SIZE(QG, 192); \ + break; \ + case 256: \ + DISPATCH_DECODE_PAGE_SIZE(QG, 256); \ + break; \ + case 512: \ + DISPATCH_DECODE_PAGE_SIZE(QG, 512); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported head size for decode attention: ", params.d); \ + } \ + } while (0) + +#define DISPATCH_DECODE(qg_sz) \ + do { \ + switch (qg_sz) { \ + case 1: \ + DISPATCH_DECODE_HEAD_DIM(1); \ + break; \ + case 2: \ + DISPATCH_DECODE_HEAD_DIM(2); \ + break; \ + case 4: \ + DISPATCH_DECODE_HEAD_DIM(4); \ + break; \ + case 8: \ + DISPATCH_DECODE_HEAD_DIM(8); \ + break; \ + case 16: \ + DISPATCH_DECODE_HEAD_DIM(16); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported q_group_size for decode attention: ", params.q_group_size); \ + } \ + } while (0) + +std::vector mha_fwd( + const at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, + // h_k, d) if there is page_table. + const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, + // page_size, h_k, dv) if there is page_table. + std::optional& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q + const at::Tensor& cu_seqlens_q, // b+1 + const at::Tensor& cu_seqlens_k, // b+1 + int max_seqlen_q, + int max_seqlen_k, + std::optional& page_table, // (b_k, max_num_pages_per_seq) + std::optional& kv_batch_idx_, // b. indices to index into the KV cache + std::optional& leftpad_k_, // b + std::optional& rotary_cos_, // seqlen_ro x (rotary_dim / 2) + std::optional& rotary_sin_, // seqlen_ro x (rotary_dim / 2) + std::optional& seqlens_rotary_, // b + std::optional& q_descale_, // (b, h_k), not (b, h) + std::optional& k_descale_, // (b, h_k) + std::optional& v_descale_, // (b, h_k) + const float softmax_scale_, + std::optional& sinks_, + bool is_causal, + int window_size_left, + int window_size_right, + float const softcap, + bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 + std::optional& scheduler_metadata_, // (b + 1) + int num_kv_splits, + std::optional pack_gqa_, + int const sm_margin) { + auto q_type = q.scalar_type(); + TORCH_CHECK( + q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, + "mha_fwd only supports Half and BFloat16, got", + q_type); + + TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype"); + TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype"); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(q); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(k); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(v); -using launch_fn_t = void (*)(bool use_sink, const Arguments& params); + TORCH_CHECK(page_table.value().dtype() == torch::kInt32, "page_table must have dtype torch.int32"); + TORCH_CHECK(page_table.value().stride(-1) == 1, "page_table must have contiguous last dimension"); -#define LAUNCH_FN_ENTRY(QG, HD, PS) &launch_fmha_decode_##QG##_##HD##_##PS + TORCH_CHECK(q.dim() == 3, "query must be in ragged format"); + CHECK_INPUT(cu_seqlens_q); + TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32"); -launch_fn_t get_launch_fn(int qg_sz, int head_dim, int page_size) { - // Dispatch table indexed by (qg_sz, head_dim, page_size). - // qg_sz index: {1->0, 2->1, 4->2, 8->3, 16->4, 32->5} - // head_dim index: {64->0, 96->1, 128->2, 192->3, 256->4, 512->5} - // page_size index: {32->0, 64->1, 128->2} + CHECK_INPUT(cu_seqlens_k); + TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32"); -#define PAGE_ENTRIES(QG, HD) \ - { LAUNCH_FN_ENTRY(QG, HD, 32), LAUNCH_FN_ENTRY(QG, HD, 64), LAUNCH_FN_ENTRY(QG, HD, 128) } + auto const sizes = q.sizes(); + const int batch_size = cu_seqlens_q.size(0) - 1; + int seqlen_q = max_seqlen_q; + int total_q = q.size(0); + int num_heads = q.size(-2); + int const head_size = q.size(-1); + int const head_size_v = v.size(-1); + int const max_num_pages_per_seq = page_table.value().size(1); + int const num_pages = k.size(0); + int const page_size = k.size(1); + int const seqlen_k = page_table.has_value() ? max_num_pages_per_seq * page_size : max_seqlen_k; + int const total_k = num_pages * page_size; + int const num_heads_k = k.size(-2); -#define HD_ENTRIES(QG) \ - { \ - PAGE_ENTRIES(QG, 64), PAGE_ENTRIES(QG, 96), PAGE_ENTRIES(QG, 128), PAGE_ENTRIES(QG, 192), PAGE_ENTRIES(QG, 256), \ - PAGE_ENTRIES(QG, 512) \ + int const batch_size_k = page_table.value().size(0); + float softmax_scale = softmax_scale_; + + if (!kv_batch_idx_.has_value()) { + TORCH_CHECK(batch_size == batch_size_k, "batch_size must be equal to batch_size_k"); } - static const launch_fn_t table[6][6][3] = { - HD_ENTRIES(1), - HD_ENTRIES(2), - HD_ENTRIES(4), - HD_ENTRIES(8), - HD_ENTRIES(16), - HD_ENTRIES(32), - }; - -#undef HD_ENTRIES -#undef PAGE_ENTRIES - - int qg_idx = -1; - switch (qg_sz) { - case 1: - qg_idx = 0; - break; - case 2: - qg_idx = 1; - break; - case 4: - qg_idx = 2; - break; - case 8: - qg_idx = 3; - break; - case 16: - qg_idx = 4; - break; - case 32: - qg_idx = 5; - break; - default: - return nullptr; + // Currently only support head dims <= 512 + static constexpr int max_headdim = 512; + TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most ", max_headdim); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM + // TODO: check this + + if (window_size_left >= seqlen_k - 1) { + window_size_left = -1; + } + window_size_right = min(window_size_right, seqlen_q); + // causal=true is the same as causal=false in this case + if (is_causal) { + window_size_right = 0; } - int hd_idx = -1; - switch (head_dim) { - case 64: - hd_idx = 0; - break; - case 96: - hd_idx = 1; - break; - case 128: - hd_idx = 2; - break; - case 192: - hd_idx = 3; - break; - case 256: - hd_idx = 4; - break; - case 512: - hd_idx = 5; - break; - default: - return nullptr; + CHECK_SHAPE(k, num_pages, page_size, num_heads_k, head_size); + CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size_v); + CHECK_SHAPE(page_table.value(), batch_size_k, max_num_pages_per_seq); + + if (leftpad_k_.has_value()) { + auto leftpad_k = leftpad_k_.value(); + TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); + CHECK_INPUT(leftpad_k); + CHECK_SHAPE(leftpad_k, batch_size); } - int ps_idx = -1; - switch (page_size) { - case 32: - ps_idx = 0; - break; - case 64: - ps_idx = 1; - break; - case 128: - ps_idx = 2; - break; - default: - return nullptr; + static constexpr int alignment = 8; + TORCH_CHECK(head_size % alignment == 0, "head_size should be a multiple of " + std::to_string(alignment)); + TORCH_CHECK(head_size_v % alignment == 0, "head_size_v should be a multiple of " + std::to_string(alignment)); + + auto opts = q.options(); + at::Tensor out; + at::Tensor temp_out; // [batch, num_kv_splits, num_head_q, seq_q, head_size] + at::Tensor exp_sums; // [batch, num_head_q, seq_q, num_kv_splits] + at::Tensor max_logits; // [batch, num_head_q, seq_q, num_kv_splits] + num_kv_splits = -1; + out = torch::empty({total_q, num_heads, head_size_v}, opts); + Arguments params; + params.use_split_kv = true; + if (params.use_split_kv) { + auto get_num_splits = [](int batch_size, int num_heads_kv, int max_seqlen_k, int block_size) { + auto stream = at::xpu::getCurrentXPUStream(); + auto queue = stream.queue(); + auto device = queue.get_device(); + int num_xe_cores = device.get_info() * + device.get_info(); + int parallel_ = num_xe_cores; + int parallel_2 = num_xe_cores * 2; + int cur_parallel_d = batch_size * num_heads_kv; + int num_splits = (parallel_ + cur_parallel_d - 1) / cur_parallel_d; + if (cur_parallel_d * num_splits > parallel_ && num_splits > 1) { + num_splits = std::ceil(parallel_2 / static_cast(cur_parallel_d)) - 1; + } + + int max_splits = (max_seqlen_k + block_size - 1) / block_size; + max_splits = std::min(max_splits, parallel_); + return std::min(num_splits, max_splits); + }; + num_kv_splits = get_num_splits(batch_size, num_heads_k, seqlen_k, page_size); + temp_out = num_kv_splits == 1 + ? out + : torch::empty({total_q, num_kv_splits * num_heads, head_size_v}, q.options().device(q.device())); + + max_logits = torch::full( + {total_q, num_heads, num_kv_splits}, + -std::numeric_limits::infinity(), + q.options().dtype(at::kFloat).device(q.device())); + + exp_sums = torch::zeros({total_q, num_heads, num_kv_splits}, q.options().dtype(at::kFloat).device(q.device())); + + params.temp_out_ptr = temp_out.data_ptr(); + params.exp_sums_ptr = exp_sums.data_ptr(); + params.max_logits_ptr = max_logits.data_ptr(); + } + int const head_size_rounded = round_up_headdim(head_size); + int const head_size_v_rounded = head_size_v == head_size ? head_size_rounded : round_up_headdim(head_size_v); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + c10::DeviceGuard device_guard(q.device()); + + at::Tensor softmax_lse; + softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); + + // align with FA3 + + params.is_bf16 = q.dtype() == torch::kBFloat16; + + // Set the pointers and strides. + params.q_ptr = q.data_ptr(); + params.k_ptr = k.data_ptr(); + params.v_ptr = v.data_ptr(); + // All stride are in elements, not bytes. + params.q_row_stride = q.stride(-3); + params.k_row_stride = k.stride(-3); + params.v_row_stride = v.stride(-3); + params.q_head_stride = q.stride(-2); + + params.k_head_stride = k.stride(-2); + params.v_head_stride = v.stride(-2); + + params.k_stride_page = k.stride(0); + params.k_stride_seq = k.stride(1); + params.k_stride_heads = k.stride(2); + params.v_stride_page = v.stride(0); + params.v_stride_seq = v.stride(1); + params.v_stride_heads = v.stride(2); + + params.v_dim_stride = v.stride(-1); + params.o_ptr = out.data_ptr(); + params.o_row_stride = out.stride(-3); + params.o_head_stride = out.stride(-2); + + params.cu_seqlens_q = cu_seqlens_q.data_ptr(); + params.cu_seqlens_k = cu_seqlens_k.data_ptr(); + params.num_kv_splits = num_kv_splits; + + // Softmax sum + params.softmax_lse_ptr = softmax_lse.data_ptr(); + + // Set the dimensions. + params.b = batch_size; + params.h = num_heads; + params.h_k = num_heads_k; + params.q_group_size = num_heads / num_heads_k; + params.seqlen_q = seqlen_q; + params.seqlen_k = seqlen_k; + params.d = head_size; + params.d_rounded = head_size_rounded; + + // Set the different scale values. + params.softmax_scale = softmax_scale; + params.use_sink = sinks_.has_value(); + params.softmax_sink_ptr = params.use_sink ? sinks_.value().data_ptr() : nullptr; + + params.softcap = softcap; + + // Set this to probability of keeping an element to simplify things. + params.p_dropout = 1.f; + + // Causal is the special case where window_size_right == 0 and window_size_left < 0. + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + params.is_causal = false; // Decode don't need causal mask since we only compute attention for the current token, but + // this kernel can also be used for local attention in the future + params.is_local = (window_size_left >= 0 || window_size_right >= 0) && !params.is_causal; + + // TODO: check this + if (window_size_left < 0) { + window_size_left = seqlen_k - 1; + } + if (window_size_right < 0) { + window_size_right = seqlen_q - 1; + } + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + params.total_q = total_q; + params.total_k = total_k; + params.b_k = batch_size_k; + params.dv = head_size_v; + params.page_table = page_table.value().data_ptr(); + params.page_table_batch_stride = page_table.value().stride(0); + params.max_num_pages_per_seq = max_num_pages_per_seq; + params.page_size = page_size; + params.num_pages = num_pages; + + if (q_v_.has_value()) { + TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64"); + TORCH_CHECK( + q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, + "q_v is only supported for fp16 and bf16 data type"); + TORCH_CHECK(false, "q_v is not supported yet"); + at::Tensor q_v = q_v_.value(); + TORCH_CHECK(q_v.dtype() == q_type, "q_v must have the same dtype as query"); + TORCH_CHECK(q_v.stride(-1) == 1, "q_v tensor must have contiguous last dimension"); + CHECK_SHAPE(q_v, total_q, num_heads, head_size_v); + params.qv_ptr = q_v.data_ptr(); + // All stride are in elements, not bytes. + params.qv_row_stride = q_v.stride(-3); + params.qv_head_stride = q_v.stride(-2); + } + + if (rotary_cos_.has_value()) { + auto rotary_cos = rotary_cos_.value(); + CHECK_INPUT(rotary_cos); + params.rotary_dim = rotary_cos.size(1) * 2; + TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim"); + TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported"); + const int seqlen_ro = rotary_cos.size(0); + TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache"); + CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2); + TORCH_CHECK(rotary_cos.scalar_type() == q_type, "rotary_cos must have the same dtype as query"); + + TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided"); + auto rotary_sin = rotary_sin_.value(); + CHECK_INPUT(rotary_sin); + CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2); + TORCH_CHECK(rotary_sin.scalar_type() == q_type, "rotary_cos must have the same dtype as query"); + params.rotary_cos_ptr = rotary_cos.data_ptr(); + params.rotary_sin_ptr = rotary_sin.data_ptr(); + params.is_rotary_interleaved = is_rotary_interleaved; + if (seqlens_rotary_.has_value()) { + at::Tensor seqlens_rotary = seqlens_rotary_.value(); + CHECK_INPUT(seqlens_rotary); + TORCH_CHECK(seqlens_rotary.dtype() == torch::kInt32, "seqlens_rotary must have dtype torch.int32"); + CHECK_SHAPE(seqlens_rotary, batch_size); + params.seqlens_rotary = seqlens_rotary.data_ptr(); + } + } else { + params.rotary_dim = 0; + } + + if (kv_batch_idx_.has_value()) { + auto kv_batch_idx = kv_batch_idx_.value(); + CHECK_INPUT(kv_batch_idx); + TORCH_CHECK(kv_batch_idx.scalar_type() == torch::kInt32, "kv_batch_idx must have dtype int32"); + params.kv_batch_idx = reinterpret_cast(kv_batch_idx.data_ptr()); } - return table[qg_idx][hd_idx][ps_idx]; + params.tensor_opts = torch::TensorOptions().dtype(torch::kUInt8).device(q.device()); + + at::Tensor out_accum, softmax_lse_accum; + + int qg_sz = nextPowerOf2(params.q_group_size); + TORCH_CHECK(qg_sz >= 1 && qg_sz <= 16, "Unsupported q_group_size for decode attention: ", params.q_group_size); + TORCH_CHECK( + params.d == 64 || params.d == 96 || params.d == 128 || params.d == 192 || params.d == 256 || params.d == 512, + "Unsupported head size for decode attention: ", + params.d); + TORCH_CHECK( + params.page_size == 64 || params.page_size == 128, + "Unsupported page size for decode attention: ", + params.page_size); + + DISPATCH_DECODE(qg_sz); + + return {out, softmax_lse, out_accum, softmax_lse_accum}; } -#undef LAUNCH_FN_ENTRY +#undef DISPATCH_DECODE_KERNEL +#undef DISPATCH_DECODE_PAGE_SIZE +#undef DISPATCH_DECODE_HEAD_DIM +#undef DISPATCH_DECODE -} // namespace +} // namespace decode + +namespace prefill { + +// Dispatch macro following the same pattern as decode. +// Directly call struct operator() - no function pointers. + +#define DISPATCH_PREFILL_KERNEL(HD) FmhaPrefillRunner{}(params) std::vector mha_fwd( - at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + const at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, // h_k, d) if there is page_table. const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, @@ -213,7 +510,6 @@ std::vector mha_fwd( int const seqlen_k = max_num_pages_per_seq * page_size; int const total_k = num_pages * page_size; int const num_heads_k = k.size(-2); - int q_group_size = num_heads / num_heads_k; int const batch_size_k = page_table.value().size(0); float softmax_scale = softmax_scale_; @@ -225,7 +521,7 @@ std::vector mha_fwd( // Currently only support head dims <= 512 static constexpr int max_headdim = 512; TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most ", max_headdim); - TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + TORCH_CHECK(num_heads == num_heads_k, "Only support number of heads in key/value equals to number of heads in query"); // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM // TODO: check this @@ -298,16 +594,15 @@ std::vector mha_fwd( params.b = batch_size; params.h = num_heads; params.h_k = num_heads_k; - params.q_group_size = num_heads / num_heads_k; - params.seqlen_q = seqlen_q * q_group_size; + params.q_group_size = 1; + params.seqlen_q = seqlen_q; params.seqlen_k = seqlen_k; params.d = head_size; params.d_rounded = head_size_rounded; // Set the different scale values. params.softmax_scale = softmax_scale; - bool use_sink = sinks_.has_value(); - params.softmax_sink_ptr = use_sink ? sinks_.value().data_ptr() : nullptr; + params.softmax_sink_ptr = sinks_.has_value() ? sinks_.value().data_ptr() : nullptr; params.softcap = softcap; @@ -394,30 +689,44 @@ std::vector mha_fwd( params.tensor_opts = torch::TensorOptions().dtype(torch::kUInt8).device(q.device()); at::Tensor out_accum, softmax_lse_accum; - auto outaccum_type = at::ScalarType::Float; - int qg_sz = nextPowerOf2(max_seqlen_q); - TORCH_CHECK(qg_sz >= 1 && qg_sz <= 32, "Unsupported qgroup_size for decode attention: ", max_seqlen_q); TORCH_CHECK( params.d == 64 || params.d == 96 || params.d == 128 || params.d == 192 || params.d == 256 || params.d == 512, - "Unsupported head size for decode attention: ", + "Unsupported head size for prefill attention: ", params.d); - TORCH_CHECK( - params.page_size == 32 || params.page_size == 64 || params.page_size == 128, - "Unsupported page size for decode attention: ", - params.page_size); - auto fn = get_launch_fn(qg_sz, params.d, params.page_size); - TORCH_CHECK(fn != nullptr, "No FMHA decode kernel for qg=", qg_sz, " hd=", params.d, " ps=", params.page_size); - fn(use_sink, params); + switch (params.d) { + case 64: + DISPATCH_PREFILL_KERNEL(64); + break; + case 96: + DISPATCH_PREFILL_KERNEL(96); + break; + case 128: + DISPATCH_PREFILL_KERNEL(128); + break; + case 192: + DISPATCH_PREFILL_KERNEL(192); + break; + case 256: + DISPATCH_PREFILL_KERNEL(256); + break; + case 512: + DISPATCH_PREFILL_KERNEL(512); + break; + default: + TORCH_CHECK(false, "Unsupported head size for prefill attention: ", params.d); + } return {out, softmax_lse, out_accum, softmax_lse_accum}; } -} // namespace decode +#undef DISPATCH_PREFILL_KERNEL + +} // namespace prefill std::vector mha_fwd( - at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + const at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, // h_k, d) if there is page_table. const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, @@ -444,7 +753,7 @@ std::vector mha_fwd( float const softcap, bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 std::optional& scheduler_metadata_, // (b + 1) - int num_splits, + int num_kv_splits, std::optional pack_gqa_, int const sm_margin) { TORCH_CHECK(cu_seqlens_k.data_ptr() != nullptr, "cu_seqlens_k is not valid."); @@ -477,7 +786,7 @@ std::vector mha_fwd( softcap, is_rotary_interleaved, scheduler_metadata_, - num_splits, + num_kv_splits, pack_gqa_, sm_margin); } else if ( @@ -511,7 +820,7 @@ std::vector mha_fwd( softcap, is_rotary_interleaved, scheduler_metadata_, - num_splits, + num_kv_splits, pack_gqa_, sm_margin); } else { @@ -542,7 +851,7 @@ std::vector mha_fwd( softcap, is_rotary_interleaved, scheduler_metadata_, - num_splits, + num_kv_splits, pack_gqa_, sm_margin); } diff --git a/src/sycl/kernels/chunk_prefill/chunk_prefill_runner.hpp b/src/sycl/kernels/chunk_prefill/chunk_prefill_runner.hpp index c5469ade..b4c54e7c 100644 --- a/src/sycl/kernels/chunk_prefill/chunk_prefill_runner.hpp +++ b/src/sycl/kernels/chunk_prefill/chunk_prefill_runner.hpp @@ -37,6 +37,8 @@ struct Flash_fwd_params { // The number of heads. int h, h_k; + bool use_sink = false; + bool use_causal_mask = false; // The O matrix (output). void* __restrict__ o_ptr; @@ -140,7 +142,7 @@ struct Flash_fwd_params { bool is_rotary_interleaved; - int num_splits; // For split-KV version + int num_kv_splits; // For split-KV version bool pack_gqa; int* __restrict__ tile_count_semaphore; @@ -427,7 +429,7 @@ inline int round_up_headdim(int head_size) { } std::vector mha_fwd( - at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + const at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, // h_k, d) if there is page_table. const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, @@ -454,7 +456,7 @@ std::vector mha_fwd( float const softcap, bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 std::optional& scheduler_metadata_, // (b + 1) - int num_splits, + int num_kv_splits, std::optional pack_gqa_, int const sm_margin) { // TODO: check GPU support @@ -599,8 +601,8 @@ std::vector mha_fwd( // Set the different scale values. params.scale_softmax = softmax_scale; - bool use_sink = sinks_.has_value(); - params.sink_softmax = use_sink ? sinks_.value().data_ptr() : nullptr; + params.use_sink = sinks_.has_value(); + params.sink_softmax = params.use_sink ? sinks_.value().data_ptr() : nullptr; params.softcap = softcap; @@ -697,7 +699,7 @@ std::vector mha_fwd( constexpr int PipelineStages = 2; switch (params.d) { case 64: - AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { + AT_DISPATCH_BOOL_NO_RETURN(params.use_sink, Sink, { if (params.is_causal) { ChunkPrefillConfig< cute::Shape<_128, _64, _64>, @@ -725,7 +727,7 @@ std::vector mha_fwd( }) break; case 96: - AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { + AT_DISPATCH_BOOL_NO_RETURN(params.use_sink, Sink, { if (params.is_causal) { ChunkPrefillConfig< cute::Shape<_128, _64, _32>, @@ -754,7 +756,7 @@ std::vector mha_fwd( }) break; case 128: - AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { + AT_DISPATCH_BOOL_NO_RETURN(params.use_sink, Sink, { if (params.is_causal) { ChunkPrefillConfig< cute::Shape<_128, _64, _64>, @@ -782,7 +784,7 @@ std::vector mha_fwd( }) break; case 192: - AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { + AT_DISPATCH_BOOL_NO_RETURN(params.use_sink, Sink, { if (params.is_causal) { ChunkPrefillConfig< cute::Shape<_256, _64, _64>, @@ -810,7 +812,7 @@ std::vector mha_fwd( }) break; case 256: - AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { + AT_DISPATCH_BOOL_NO_RETURN(params.use_sink, Sink, { if (params.is_causal) { ChunkPrefillConfig< cute::Shape<_256, _64, _64>, @@ -838,7 +840,7 @@ std::vector mha_fwd( }) break; case 512: - AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { + AT_DISPATCH_BOOL_NO_RETURN(params.use_sink, Sink, { if (params.is_causal) { ChunkPrefillConfig< cute::Shape, _64, _64>, diff --git a/src/sycl/kernels/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp b/src/sycl/kernels/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp index 26729f76..dba97ee0 100644 --- a/src/sycl/kernels/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp +++ b/src/sycl/kernels/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp @@ -139,7 +139,6 @@ class FMHAFwdEpilogue { FragARow& tA_sum, // Softmax row-wise sum accumulator QVCoord blk_qv, // WG tile indices: (q,v) int thr_id) { // Work-item ID - using namespace cute; using ElementA = typename FragA::element_type; @@ -282,4 +281,339 @@ class FMHAFwdEpilogue { } }; +template < + class CollectiveMainloop, // Attention mainloop + class TileShapeO_, // Shape of output tile, may be larger than P*V GEMM + class TensorO_, // 2D slice of global output tensor + class TensorLSE_ = void, // Optional tensor for storing intermediate exp + // sums and max logits + class TiledCopyO_ = void, // Optional TiledCopy for loading O + bool Sink_ = false> // Whether to sink softmax into epilogue +class DecodeFwdEpilogue { + public: + // + // Type Aliases + // + using TiledMMAPV = typename CollectiveMainloop::TiledMMAPV; + using TileShapePV = decltype(TiledMMAPV{}.tile_mnk()); + using TileShapeO = TileShapeO_; + using SGPerWG = decltype(product(take<1, 4>(shape(typename TiledMMAPV::ThrLayoutVMNK{})))); + + using TensorO = TensorO_; + using TensorO2D = decltype(TensorO_{}(append>(make_coord(_, _), 0))); + using ElementO = typename TensorO_::value_type; + + using TensorLSE = TensorLSE_; + using TensorLSE2D = conditional_t< + is_void_v, + void, + decltype(TensorLSE_{}(append>(make_coord(_, _), 0)))>; + using ElementLSE = conditional_t, void, typename TensorLSE_::value_type>; + + using FragA = typename CollectiveMainloop::FragA; + using FragARow = typename CollectiveMainloop::FragARow; + using ElementA = typename FragA::value_type; + + // softmax sink, same dtype + static constexpr bool Sink = Sink_; + using ElementSink = typename CollectiveMainloop::TensorQ::element_type; + + // Split k-reduced tiles between participating subgroups. + // Assumption: the A tile is contiguous. + using ReduceK = decltype(size<3>(typename TiledMMAPV::ThrLayoutVMNK{})); + + static auto reduce_sg_v_helper() { + constexpr auto v_total_sg = get<1>(SGTileShapeA{}) / intel::_SGSize{}; + constexpr auto v_avail_sg = ReduceK{} / ReduceSGQ{}; + return Int < (v_total_sg > v_avail_sg) ? cute::gcd(v_total_sg, v_avail_sg) : v_total_sg > {}; + } + + using SGTileShapeA = decltype(atuple_coshape(FragA{}.tv_layout())); + using ReduceSGQ = decltype(cute::gcd(get<0>(SGTileShapeA{}), ReduceK{})); + using ReduceSGV = decltype(reduce_sg_v_helper()); + using ReduceSGLayout = decltype(make_identity_layout(Shape{})); + + using SGTileShapeO = decltype(shape_div(take<0, 2>(SGTileShapeA{}), shape(ReduceSGLayout{}))); + + using ReduceFragA = + decltype(make_subgroup_tensor(make_layout(select<1, 0>(SGTileShapeO{}), Stride, E<0>>{}))); + using ReduceFragARow = decltype(reduce<1>(ReduceFragA{}, sycl::plus{})); + + static auto default_tiled_copy_O_helper() { + if constexpr (ReduceK{} == _1{}) + return make_block_2d_copy_D(TiledMMAPV{}, TensorO2D{}); + else + return make_block_2d_copy_D_subtiled(TiledMMAPV{}, ReduceFragA{}.tv_layout(), ReduceSGLayout{}, TensorO2D{}); + } + + using DefaultTiledCopyO = decltype(default_tiled_copy_O_helper()); + using TiledCopyO = conditional_t, DefaultTiledCopyO, TiledCopyO_>; + + // Stateless design -- no arguments or parameters. + struct Arguments {}; + struct Params {}; + + // Shared memory storage + // Note sum/max tiles are padded to 16 elements, due to limitations in CuTe + // block load infrastructure. + using AlignedSGTileA_Q = C<((size<0>(SGTileShapeA{}) + intel::sg_size - 1) / intel::sg_size) * intel::sg_size>; + + struct SharedStorageNone {}; + struct SharedStorageReduceK { + cute::array a_data; + cute::array a_sum_data, a_max_data; + }; + + using SharedStorage = conditional_t<(ReduceK{} > _1{}), SharedStorageReduceK, SharedStorageNone>; + + private: + SharedStorage& shared; + + public: + static constexpr Params to_underlying_arguments(Arguments const& args, void* /* workspace */) { + return {}; + } + + CUTLASS_HOST_DEVICE static bool can_implement(Arguments const&) { + return true; + } + + CUTLASS_HOST_DEVICE + DecodeFwdEpilogue(Params const&, SharedStorage& shared_) : shared(shared_) {} + + template + CUTLASS_DEVICE void operator()( + TensorO2D const& O, // Global O tensor: (q,v) + FragA& tArA, // O accumulator: (q,v) + FragARow& tA_max, // Softmax row-wise max accumulator + FragARow& tA_sum, // Softmax row-wise sum accumulator + QVCoord blk_qv, // WG tile indices: (q,v) + int thr_id) { // Work-item ID + + using namespace cute; + using ElementA = typename FragA::element_type; + + // Reduce k-blocks of A and A_sum across WG, if needed. + auto [rA, rA_max_unused, rA_sum, active] = reduce_A(tArA, tA_max, tA_sum, thr_id); + + /* Some subgroups may not have any work to do; if so, quit early. */ + if (!active) return; + + /* Complete softmax, dividing out sums. */ + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < rA_sum.size(); i++) + rA_sum(i) = ElementA(1) / rA_sum(i); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < rA.size(); i++) + rA(i) *= broadcast<0>(rA_sum, rA, i); + + /* Tile output */ + Tensor cO = make_identity_tensor(O.shape()); // (q,v) + Tensor gO = local_tile(cO, TileShapeO{}, blk_qv); // (q,v) + + /* Prepare slices */ + TiledCopyO copy_o{O}; + auto thr_copy_o = copy_o.get_slice(thr_id); + + auto tOrO = thr_copy_o.partition_sg_fragment_S(gO); + auto tOgO = thr_copy_o.partition_D(gO); + + /* Reorder tile and write out */ + reorder(rA, tOrO); + copy(copy_o, tOrO, tOgO); + } + + // splitK version + template + CUTLASS_DEVICE void operator()( + TensorO2D const& O, // Global O tensor: (q,v) + FragA& tArA, // O accumulator: (q,v) + FragARow& tA_max, // Softmax row-wise max accumulator + FragARow& tA_sum, // Softmax row-wise sum accumulator + QVCoord blk_qv, // WG tile indices: (q,v) + int thr_id, // Work-item ID + const TensorLSE2D& exp_sums, // Global exp sum tensor + const TensorLSE2D& max_logits, // Global max logits tensor + int idx_kv_split, + int head_group_q, + TensorSink& tSink, // Sink for current head + int num_kv_splits, + bool is_single_split) { + using namespace cute; + using ElementA = typename FragA::element_type; + + // Reduce k-blocks of A and A_sum across WG, if needed. + int sg_id = thr_id / intel::sg_size; + if constexpr (Sink) { + constexpr double kLog2e = 1.4426950408889634074; + if (idx_kv_split == 0 && sg_id == 0 && thr_id < head_group_q) { + tA_sum(0) += sycl::native::exp2(static_cast(tSink(thr_id) * kLog2e) - tA_max(0)); + } + } + + auto [rA, rA_max, rA_sum, active] = reduce_A(tArA, tA_max, tA_sum, thr_id); + + // Always store exp sum and max logits for current KV split. + // assume seq_len_qo == 1 + if (thr_id < head_group_q) { + if (is_single_split) { + // Sentinel values: make ReduceSplitK a pass-through copy. + exp_sums(thr_id, idx_kv_split) = ElementA(1); + max_logits(thr_id, idx_kv_split) = ElementA(0); + } else if (num_kv_splits > 1) { + exp_sums(thr_id, idx_kv_split) = rA_sum(0); + max_logits(thr_id, idx_kv_split) = rA_max(0); + } + } + + /* Some subgroups may not have any work to do; if so, quit early. */ + if (!active) return; + + /* Complete softmax: normalize output for single-split sequences + (so ReduceSplitK pass-through gives correct result). + For multi-split, store unnormalized to avoid divide-multiply + precision loss in the reduce roundtrip. */ + if (is_single_split || num_kv_splits <= 1) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < rA_sum.size(); i++) { + rA_sum(i) = ElementA(1) / rA_sum(i); + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < rA.size(); i++) { + rA(i) *= broadcast<0>(rA_sum, rA, i); + } + } + + /* Tile output */ + Tensor cO = make_identity_tensor(O.shape()); // (q,v) + Tensor gO = local_tile(cO, TileShapeO{}, blk_qv); // (q,v) + + /* Prepare slices */ + TiledCopyO copy_o{O}; + auto thr_copy_o = copy_o.get_slice(thr_id); + + auto tOrO = thr_copy_o.partition_sg_fragment_S(gO); + auto tOgO = thr_copy_o.partition_D(gO); + + /* Reorder tile and write out */ + reorder(rA, tOrO); + copy(copy_o, tOrO, tOgO); + } + + // Reduce k-blocks of A and A_sum across WG, if needed. + // Note that each k block has its own scale factor based on A_max, + // so A/A_sum contributions need to be rescaled to match. + template + CUTLASS_DEVICE decltype(auto) reduce_A( + FragA& tArA, // O accumulator: (q,v) + FragARow& tA_max, // Softmax row-wise max accumulator + FragARow& tA_sum, // Softmax row-wise sum accumulator + int thr_id) { // Work-item ID + + using namespace sycl::ext::oneapi::this_work_item; + + if constexpr (ReduceK{} == _1{}) { + return std::make_tuple(tArA, tA_max, tA_sum, true); + } else { + /* Identify A tile ID and k block for this subgroup. */ + auto thr_vak = group<1, 3>(TiledMMAPV{}.get_thr_layout_vmnk()).get_flat_coord(assert_uniform(thr_id)); + auto a_tile = get<1>(thr_vak); + auto k_blk = get<2>(thr_vak); + + /* Set up SLM tensors and partition A tiles among participating subgroups + */ + auto shape_A = append(append(SGTileShapeA{}, ReduceK{}), SGPerWG{} / ReduceK{}); + auto shape_A_row = make_shape(get<0>(SGTileShapeO{}), shape(ReduceSGLayout{}), ReduceK{}, SGPerWG{} / ReduceK{}); + + /* Physical layouts, with sub_tile modes broken out */ + auto sA_layout = group<2, 4>(flat_divide(make_ordered_layout(shape_A, Step<_1, _0, _2, _3>{}), SGTileShapeO{})); + auto sA_row_stride = + make_stride(_1{}, make_stride(get<0>(shape_A_row), _0{}), AlignedSGTileA_Q{}, AlignedSGTileA_Q{} * ReduceK{}); + auto sA_row_layout = make_layout(shape_A_row, sA_row_stride); + + /* Coordinate layouts, with sub_tile modes broken out */ + auto basis2 = make_basis_like(SGTileShapeO{}); + auto sA_coords = make_layout( + append(SGTileShapeO{}, shape(ReduceSGLayout{})), append(basis2, product_each(zip(SGTileShapeO{}, basis2)))); + + auto sA = make_tensor(make_smem_ptr(&shared.a_data), + sA_layout); // (q,v,rblk_dst,rblk_src,a_tile) + auto sA_max = make_tensor( + make_smem_ptr(&shared.a_max_data), + sA_row_layout); // (q,rblk_dst,rblk_src,a_tile) + auto sA_sum = make_tensor( + make_smem_ptr(&shared.a_sum_data), + sA_row_layout); // (q,rblk_dst,rblk_src,a_tile) + + /* Write my contributions to SLM. */ + copy_block_r2s(tA_max, sA_max(_, _, k_blk, a_tile)); + barrier_arrive(ScopeWorkgroup, SemanticsRelease | SemanticsWGMemory); + copy_block_r2s(tA_sum, sA_sum(_, _, k_blk, a_tile)); + copy_block_r2s(tArA, sA(_, _, _, k_blk, a_tile), sA_coords); + + bool active = (k_blk < size(ReduceSGLayout{})) || (ReduceK{} == size(ReduceSGLayout{})); // help compiler out + + /* Wait for maxima to be available, signal other data available */ + barrier_wait(ScopeWorkgroup, SemanticsAcquire | SemanticsWGMemory); + barrier_arrive(ScopeWorkgroup, SemanticsRelease | SemanticsWGMemory); + + ReduceFragA rA; + ReduceFragARow rA_sum, rA_max, rA_kmax[ReduceK{}]; + + if (active) { + /* Read A_max back from SLM and reduce. */ + CUTLASS_PRAGMA_UNROLL + for (int kr = 0; kr < ReduceK{}; kr++) { + copy_block_s2r(sA_max(_, k_blk, kr, a_tile), rA_kmax[kr]); + } + + rA_max = rA_kmax[0]; + for (int kr = 1; kr < ReduceK{}; kr++) + cute::transform(rA_max, rA_kmax[kr], rA_max, cute::max_fn{}); + + /* Calculate scale factors for aligning per-block maxima. */ + for (int kr = 0; kr < ReduceK{}; kr++) { + cute::transform( + rA_max, rA_kmax[kr], rA_kmax[kr], [](auto gmax, auto kmax) { return sycl::native::exp2(kmax - gmax); }); + } + } + + /* Wait for A/A_sum data to be available */ + barrier_wait(ScopeWorkgroup, SemanticsAcquire | SemanticsWGMemory); + + if (active) { + /* Read A/A_sum back from SLM, align scaling to new maxima, and reduce. + */ + clear(rA_sum); + + CUTLASS_PRAGMA_UNROLL + for (int kr = 0; kr < ReduceK{}; kr++) { + ReduceFragARow rA_sum_read; + copy_block_s2r(sA_sum(_, k_blk, kr, a_tile), rA_sum_read); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < rA_sum_read.size(); i++) { + rA_sum(i) += rA_sum_read(i) * rA_kmax[kr](i); + } + } + + clear(rA); + + CUTLASS_PRAGMA_UNROLL + for (int kr = 0; kr < ReduceK{}; kr++) { + ReduceFragA rA_read; + copy_block_s2r(sA(_, _, k_blk, kr, a_tile), sA_coords(_, _, 0), rA_read); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < rA_read.size(); i++) { + rA(i) += rA_read(i) * broadcast<0>(rA_kmax[kr], rA, i); + } + } + } + return std::make_tuple(rA, rA_max, rA_sum, active); + } + } +}; } // namespace cutlass::fmha::collective diff --git a/src/sycl/kernels/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp b/src/sycl/kernels/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp index c4f15491..6062e729 100644 --- a/src/sycl/kernels/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp +++ b/src/sycl/kernels/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp @@ -550,6 +550,456 @@ struct FMHAFwdMainloop< } }; +template < + class DispatchPolicy_, + bool PagedKV_, + bool CausalMask_, + class TiledMMAQK_, // Tiling for Q*K GEMM + class TiledMMAPV_, // Tiling for P*V GEMM + int VTiles_, // # of tiles in V dimension + class TensorQ_, // Global Q/K/V tensors + class TensorK_, + class TensorV_, + class TiledCopyQ_ = void, // Optional TiledCopy for loading Q + class TiledCopyK_ = void, // Optional TiledCopy for loading K + class TiledCopyV_ = void, // Optional TiledCopy for loading V + bool LocalMask_ = false> +struct DecodeFwdMainloop { + static_assert(cutlass::detail::dependent_false, "Could not find a mainloop specialization."); +}; + +template < + int Stages, + bool PagedKV_, + bool CausalMask_, + class TiledMMAQK_, + class TiledMMAPV_, + int VTiles_, + class TensorQ_, + class TensorK_, + class TensorV_, + class TiledCopyQ_, + class TiledCopyK_, + class TiledCopyV_, + bool LocalMask_> +struct DecodeFwdMainloop< + XeDefault, + PagedKV_, + CausalMask_, + TiledMMAQK_, + TiledMMAPV_, + VTiles_, + TensorQ_, + TensorK_, + TensorV_, + TiledCopyQ_, + TiledCopyK_, + TiledCopyV_, + LocalMask_> { + // + // Type Aliases + // + using TiledMMAQK = TiledMMAQK_; + using TiledMMAPV = TiledMMAPV_; + using TileShapeQK = decltype(TiledMMAQK{}.tile_mnk()); + using TileShapePV = decltype(TiledMMAPV{}.tile_mnk()); + static constexpr int VTiles = VTiles_; + using SubgroupLayoutQK = decltype(TiledMMAQK{}.get_atom_layout_mnk()); + using SGPerWG = decltype(product(take<1, 4>(shape(typename TiledMMAQK::ThrLayoutVMNK{})))); + + using TensorQ = TensorQ_; + using TensorK = TensorK_; + using TensorV = TensorV_; + + using ElementQ = typename TensorQ::engine_type::value_type; + using ElementK = typename TensorK::engine_type::value_type; + + using TensorQ2D = decltype(TensorQ_{}(append>(make_coord(_, _), 0))); + using TensorK2D = decltype(TensorK_{}(append>(make_coord(_, _), 0))); + using TensorV2D = decltype(TensorV_{}(append>(make_coord(_, _), 0))); + + using TiledCopyQ = + conditional_t, decltype(make_block_2d_copy_A(TiledMMAQK{}, TensorQ2D{})), TiledCopyQ_>; + using TiledCopyK = + conditional_t, decltype(make_block_2d_copy_B(TiledMMAQK{}, TensorK2D{})), TiledCopyK_>; + using TiledCopyV = + conditional_t, decltype(make_block_2d_copy_B(TiledMMAPV{}, TensorV2D{})), TiledCopyV_>; + + // TODO: static_asserts on TiledMMAPV here... + + // + // Accumulator types + // + // FragS: accumulator for Q*K MMA + // FragO: accumulator for P*V MMAs. + // Note: v mode may be split into multiple pieces + // to reduce register pressure. + // Frag*Row types are reductions of the corresponding Frag* types + // over rows. + // + template + using FragC = decltype(TiledMMA{}.get_slice(0).partition_sg_fragment_C( + make_identity_tensor(select<0, 1>(TiledMMA{}.tile_mnk())))); + + using FragS = FragC; + using FragSRow = decltype(reduce<1>(FragS{}, sycl::plus{})); + using FragSCol = decltype(reduce<0>(FragS{}, sycl::plus{})); + using ElementS = typename TiledMMAQK::ValTypeD; + + using SingleFragA = FragC; // (atom val,q',v') + using FragA = expand_sg_fragment_t; // (atom val,q',v',VV) + using FragARow = decltype(reduce<1>(FragA{}, sycl::plus{})); + // static_assert(is_same_v, "dtype + // mismatched"); + using ElementA = typename TiledMMAPV::ValTypeD; + + static constexpr bool PagedKV = PagedKV_; + static constexpr bool CausalMask = CausalMask_; + static constexpr bool Fp8KV = is_any_of_v; + static constexpr bool LocalMask = LocalMask_; + + // User-facing arguments + struct Arguments { + ElementS const scale; + void* const scale_k; + void* const scale_v; + // Paged KV Cache + int const* ptr_page_table; + int page_size; + int max_pages_per_seq; + int total_seqlen_kv; + // Local Mask + int window_size_left; + int window_size_right; + }; + + // Kernel-facing parameters + using Params = Arguments; + + // SLM data + struct SharedStorage {}; + + Params params; + + // + // Methods + // + + DecodeFwdMainloop(Params const& params_, SharedStorage&) : params(params_) {} + + static constexpr Params to_underlying_arguments(Arguments const& args, void* /* workspace */) { + constexpr double kLog2e = 1.4426950408889634074; // log_2(e) + ElementS val = args.scale * static_cast(kLog2e); + return Params{ + val, + args.scale_k, + args.scale_v, + args.ptr_page_table, + args.page_size, + args.max_pages_per_seq, + args.total_seqlen_kv, + args.window_size_left, + args.window_size_right}; + } + + CUTLASS_HOST_DEVICE static bool can_implement(Arguments const&) { + return true; + } + + template + CUTLASS_DEVICE void operator()( + TensorQ2D const& Q_2D, // (q,d) + TensorK2D const& K_2D, // (k,d) + TensorV2D const& V_2D, // (d,k) + FragA& tArA, // Output accumulator (q,v) + FragARow& tA_max, // Softmax row-wise max accumulator + FragARow& tA_sum, // Softmax row-wise sum accumulator + QVCoord blk_qv, // WG tile indices: (Q,V) + int const& idx_b, // WG tile indices: (B) + int blk_k0, // K block range: [K0,K1) + int blk_k1, + int total_blk, // Total # of K blocks + int thr_id, + int seq_len, + int full_tile_offset, + int discard_seq_coord) { + using namespace sycl::ext::oneapi::this_work_item; + + // Short dimension names: + // q = sequence len dimension for Q + // k = sequence len dimension for K + // d = head size dimension for K/Q + // v = head size dimension for V + // VV = MMA tile indices for V + // Capital letters (Q, K, ...) refer to WG block indices. + // Primed letters (q', k', ...) refer to atom block indices. + + auto tile_shape_v = make_shape(get<1>(TileShapePV{}) * C{}, get<2>(TileShapePV{})); + + /* Create proxy coordinate tensors for Q/K/P/V */ + Tensor cQ = make_identity_tensor(Q_2D.shape()); // (q,d) + Tensor cK = make_identity_tensor(K_2D.shape()); // (k,d) + Tensor cV = make_identity_tensor(V_2D.shape()); // (v,k) + Tensor cP = make_identity_tensor(take<0, 2>(TileShapeQK{})); // (q,k) + + /* Partition global tensors into workgroup tiles */ + Tensor gQ = local_tile(cQ, TileShapeQK{}, append(blk_qv, _), Step<_1, X, _1>{}); // (q,d,D) + Tensor gK = local_tile(cK, TileShapeQK{}, make_coord(_, _, _), Step{}); // (k,d,K,D) + Tensor gV = local_tile(cV, tile_shape_v, make_coord(get<1>(blk_qv), _)); // (v,k,K) + Tensor gV_split = local_tile(gV, TileShapePV{}, make_coord(_, _, 0), Step{}); // (v,k,VV,K) + + /* Create global -> register copies */ + TiledCopyQ copy_q{Q_2D}; + TiledCopyK copy_k{K_2D}; + TiledCopyV copy_v{V_2D}; + + /* Create MMAs */ + TiledMMAQK mma_qk{}; + TiledMMAPV mma_pv{}; + + auto copyQ = make_block_2d_copy_A(TiledMMAQK{}, TensorQ2D{}); + + /* Slice TiledCopy/TiledMMA operations down to to work-item level */ + auto thr_copy_q = copy_q.get_slice(thr_id); + auto thr_copy_k = copy_k.get_slice(thr_id); + auto thr_copy_v = copy_v.get_slice(thr_id); + auto thr_mma_qk = mma_qk.get_slice(thr_id); + auto thr_mma_pv = mma_pv.get_slice(thr_id); + + /* Partition coordinate tensors for copy */ + auto tQgQ = thr_copy_q.partition_S(gQ); // (atom_val,q',d',D) + auto tKgK = thr_copy_k.partition_S(gK); // (atom_val,k',d',K,D) + auto tVgV = thr_copy_v.partition_S(gV_split); // (atom_val,v',k',VV,K) + + /* Create register fragments for MMA and copies */ + auto tQrQ = thr_copy_q.partition_sg_fragment_D(gQ(_, _, 0)); + auto tSrQ = thr_mma_qk.partition_sg_fragment_A(gQ(_, _, 0)); + + auto tKrK = thr_copy_k.partition_sg_fragment_D(gK(_, _, 0, 0)); + auto tSrK = thr_mma_qk.partition_sg_fragment_B(gK(_, _, 0, 0)); + + auto tSrS = thr_mma_qk.partition_sg_fragment_C(cP); + auto tArP = thr_mma_pv.partition_sg_fragment_A(cP); + + auto tVrV = thr_copy_v.partition_sg_fragment_D(gV_split(_, _, 0, 0)); + auto tArV = thr_mma_pv.partition_sg_fragment_B(gV_split(_, _, 0, 0)); + + /* Create TiledCopy objects for prefetches */ + auto prefetch_q = make_block_2d_prefetch(copy_q); + auto prefetch_k = make_block_2d_prefetch(copy_k); + auto prefetch_v = make_block_2d_prefetch(tile_shape_v, V_2D); + + /* Partition global tensors for prefetch */ + auto pQgQ = prefetch_q.get_slice(thr_id).partition_S(gQ); + auto pKgK = prefetch_k.get_slice(thr_id).partition_S(gK); + auto pVgV = prefetch_v.get_slice(thr_id).partition_S(gV); + + // ------ + // Kernel + // ------ + + // PagedKV + int tiles_per_page = params.page_size / get<1>(TileShapeQK{}); + int tile_idx = blk_k0; + int b_offset = idx_b * params.max_pages_per_seq; + if constexpr (PagedKV) { + int page_local_idx = tile_idx * get<1>(TileShapeQK{}) / params.page_size; + tile_idx = params.ptr_page_table[b_offset + page_local_idx] * tiles_per_page + tile_idx % tiles_per_page; + } + + /* Initialization steps for first block: Q/K prefetch, O init */ + /* TODO: limit D prefetch for large head size, and reorder K prefetches */ + for (int D = 0; D < size<3>(pQgQ); D++) { + prefetch(prefetch_q, pQgQ(_, _, _, D)); + } + + for (int D = 0; D < size<4>(pKgK); D++) { + prefetch(prefetch_k, pKgK(_, _, _, tile_idx, D)); + } + + clear(tArA); + fill(tA_max, cutlass::platform::numeric_limits::lowest()); + clear(tA_sum); + + /* Check if */ + bool check_remainder_k = (seq_len % get<1>(TileShapeQK{}) != 0); + + // FP8 KV Scale: Currently we only support per-tensor scale for KV + float scale_k = 1.f, scale_v = 1.f; + if constexpr (Fp8KV) { + scale_k = *static_cast(params.scale_k); + scale_v = *static_cast(params.scale_v); + } + + /* Main loop, blocked in k. */ + int next_tile_idx; + for (int K = blk_k0; K < blk_k1; K++) { + /* Split barrier to keep threads together */ + // barrier_arrive(ScopeWorkgroup); + + auto tKgK_cache = PagedKV ? tKgK(_, _, _, tile_idx, _) : tKgK(_, _, _, K, _); + auto tVgV_cache = PagedKV ? tVgV(_, _, _, _, tile_idx) : tVgV(_, _, _, _, K); + + /* GEMM 1: S = K * Q */ + clear(tSrS); /* TODO: fuse w/ initial gemm call */ + for (int D = 0; D < size<4>(tKgK); D++) { + copy(copy_q, tQgQ(_, _, _, D), tQrQ); + copy(copy_k, tKgK_cache(_, _, _, D), tKrK); + + reorder(tQrQ, tSrQ); + reorder(tKrK, tSrK); + if constexpr (Fp8KV) { + for (int i = 0; i < tSrK.size(); ++i) { + tSrK(i) = static_cast(scale_k * static_cast(tSrK(i))); + } + } + + cute::gemm(mma_qk, tSrQ, tSrK, tSrS); + } + /* V prefetch for GEMM 2 */ + prefetch(prefetch_v, pVgV(_, _, _, tile_idx)); + + /* Causal masking */ + // No Causal masking in decoding + // if constexpr (CausalMask) { + // if (K == blk_k1 - 1) { + // // Need to get global col and row indices to mask the elements + // Tensor cPgP = make_identity_tensor(make_shape(seq_len, seq_len)); + // Tensor gP = local_tile(cPgP, take<0,2>(TileShapeQK{}), + // make_coord(get<0>(blk_qv), K)); auto cS_thread = + // thr_mma_qk.partition_C(gP); CUTLASS_PRAGMA_UNROLL for (int i = 0; i + // < tSrS.size(); ++i) { + // int row_idx = get<0>(cS_thread(i)); + // int col_idx = get<1>(cS_thread(i)); + // if (col_idx - full_tile_offset > row_idx - discard_seq_coord) { + // tSrS(i) = ElementS(-INFINITY); + // } + // } + // } + // } + + /* Local/sliding window masking */ + if constexpr (LocalMask) { + // For decode, all packed GQA heads share the same KV position + // (seq_len_kv - 1). Use a fixed decode row for all elements. + int decode_row = seq_len - 1 - full_tile_offset; + Tensor cPgP = make_identity_tensor(make_shape(seq_len, seq_len)); + Tensor gP = local_tile(cPgP, take<0, 2>(TileShapeQK{}), make_coord(get<0>(blk_qv), K)); + auto cS_thread = thr_mma_qk.partition_C(gP); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < tSrS.size(); ++i) { + int col_idx = get<1>(cS_thread(i)) - full_tile_offset; + bool left_mask = col_idx < decode_row - params.window_size_left; + bool right_mask = col_idx > decode_row + params.window_size_right; + if (left_mask || right_mask) { + tSrS(i) = ElementS(-INFINITY); + } + } + } + + /* k masking for remainder tiles */ + if (check_remainder_k && K == blk_k1 - 1) { + FragSCol k_rem_mask; + int k = get<0>(tKgK(0, 0, 0, K, 0)) + get_sub_group().get_local_id()[0]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < k_rem_mask.size(); i++, k += intel::sg_size) { + k_rem_mask(i) = (k < seq_len) ? ElementS(sycl::nan(0u)) : ElementS(-INFINITY); + } + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < tSrS.size(); i++) { + tSrS(i) = sycl::fmin(tSrS(i), broadcast<1>(k_rem_mask, tSrS, i)); + } + } + + /* Apply softmax and scaling */ + softmax(K == 0, tSrS, tA_max, tA_sum, tArA); + reorder(tSrS, tArP); + + /* GEMM 2: A += P * V, split in v dimension */ + CUTLASS_PRAGMA_UNROLL + for (int VV = 0; VV < VTiles; VV++) { + copy(copy_v, tVgV_cache(_, _, _, VV), tVrV); + reorder(tVrV, tArV); + if constexpr (Fp8KV) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < tArV.size(); ++i) { + tArV(i) = static_cast(scale_v * static_cast(tArV(i))); + } + } + cute::gemm(mma_pv, tArP, tArV, tArA(_, _, _, VV)); + } + + barrier(); + + // next tile_idx + next_tile_idx = K + 1; + if constexpr (PagedKV) { + int next_page_local_idx = next_tile_idx * get<1>(TileShapeQK{}) / params.page_size; + if (next_page_local_idx < params.max_pages_per_seq) { + next_tile_idx = + params.ptr_page_table[b_offset + next_page_local_idx] * tiles_per_page + next_tile_idx % tiles_per_page; + } else { + // set to last page + next_tile_idx = params.max_pages_per_seq * tiles_per_page - 1; + } + } + tile_idx = next_tile_idx; + + /* K prefetch */ + for (int D = 0; D < size<4>(pKgK); D++) { + prefetch(prefetch_k, pKgK(_, _, _, tile_idx, D)); + } + + // barrier_wait(ScopeWorkgroup); + } + } + + // Single step of blocked softmax. + CUTLASS_DEVICE + void softmax( + bool first_block, // First softmax block? + FragS& tS, // Softmax src/dst block + FragSRow& tS_max, // Softmax row-wise max accumulator + FragSRow& tS_sum, // Softmax row-wise sum accumulator + FragA& tA) { // O accumulator (for rescaling) + + /* Compute row-wise maxima for this block */ + auto tS_bmax = reduce<1>(tS, sycl::maximum{}); + + /* Update (scaled) maxima */ + auto tS_prev_max = tS_max; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < tS_max.size(); i++) { + tS_max(i) = sycl::max(tS_max(i), params.scale * tS_bmax(i)); + } + + /* Scale S and subtract maxima, then exponentiate */ + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < tS.size(); i++) + tS(i) = sycl::native::exp2(params.scale * tS(i) - broadcast<0>(tS_max, tS, i)); + + /* Rescale existing S sums and O accumulator */ + if (!first_block) { + FragSRow rescale; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < tS_max.size(); i++) { + rescale(i) = sycl::native::exp2(tS_prev_max(i) - tS_max(i)); + tS_sum(i) *= rescale(i); + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < tA.size(); i++) + tA(i) *= broadcast<0>(rescale, tA, i); + } + + /* Update sums */ + auto tS_bsum = reduce<1>(tS, sycl::plus{}); + for (int i = 0; i < tS_sum.size(); i++) + tS_sum(i) += tS_bsum(i); + } +}; + template CUTLASS_HOST_DEVICE constexpr auto get_sg_layout_pv(SGLayoutQK const&) { return make_layout(get<0>(SGLayoutQK{}), Layout<_1, _0>{}, get<1>(SGLayoutQK{})); diff --git a/src/sycl/kernels/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp b/src/sycl/kernels/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp index 5c964a57..d2c1bf39 100644 --- a/src/sycl/kernels/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp +++ b/src/sycl/kernels/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp @@ -110,6 +110,8 @@ class XeFMHAFwdKernel { using ElementO = typename CollectiveEpilogue::TensorO::element_type; using StrideO = decltype(stride(typename CollectiveEpilogue::TensorO{})); + using ElementLSE = void; + // Kernel level shared memory storage using MainloopSharedStorage = typename CollectiveMainloop::SharedStorage; using EpilogueSharedStorage = typename CollectiveEpilogue::SharedStorage; @@ -143,6 +145,7 @@ class XeFMHAFwdKernel { MainloopArguments mainloop{}; EpilogueArguments epilogue{}; KernelHardwareInfo hw_info{}; + int num_kv_splits = -1; }; // Kernel entry point API @@ -231,7 +234,7 @@ class XeFMHAFwdKernel { CUTLASS_PRAGMA_NO_UNROLL for (; tile_scheduler.is_valid(); ++tile_scheduler) { - auto [blk_q, blk_v, head_q, idx_b] = tile_scheduler.get_block_coord(); // (Q,V,h,b) + auto [blk_q, blk_v, head_q, idx_b, unused] = tile_scheduler.get_block_coord(); // (Q,V,h,b) auto blk_qv = make_coord(blk_q, blk_v); int head = head_q / head_group_q; @@ -770,4 +773,357 @@ class XeFMHAFwdDynamicSplitKernel { } }; +template +class XeFMHAFwdSplitKVKernel { + public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + using VariableLength = cutlass::fmha::collective::VariableLength; + static constexpr bool is_var_len = cutlass::fmha::collective::is_variable_length_v; + using CollectiveMainloop = CollectiveMainloop_; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + + using TiledMMAQK = typename CollectiveMainloop::TiledMMAQK; + using TiledMMAPV = typename CollectiveMainloop::TiledMMAPV; + using TileShapeQK = typename CollectiveMainloop::TileShapeQK; + using TileShapePV = typename CollectiveMainloop::TileShapePV; + using SubgroupLayoutQK = typename CollectiveMainloop::SubgroupLayoutQK; + using ElementQ = typename CollectiveMainloop::TensorQ::element_type; + using ElementK = typename CollectiveMainloop::TensorK::element_type; + using ElementV = typename CollectiveMainloop::TensorV::element_type; + + using StrideQ = decltype(stride(typename CollectiveMainloop::TensorQ{})); + using StrideK = decltype(stride(typename CollectiveMainloop::TensorK{})); + using StrideV = decltype(stride(typename CollectiveMainloop::TensorV{})); + + using SGPerWG = typename CollectiveMainloop::SGPerWG; + + using FragA = typename CollectiveMainloop::FragA; + using FragARow = typename CollectiveMainloop::FragARow; + + // Tile scheduler derived types + using TileScheduler = TileScheduler_; + using TileSchedulerParams = typename TileScheduler::Params; + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + using TileShapeO = typename CollectiveEpilogue::TileShapeO; + using ElementO = typename CollectiveEpilogue::TensorO::element_type; + using ElementLSE = typename CollectiveEpilogue::ElementLSE; + using StrideO = decltype(stride(typename CollectiveEpilogue::TensorO{})); + + // Kernel level shared memory storage + using MainloopSharedStorage = typename CollectiveMainloop::SharedStorage; + using EpilogueSharedStorage = typename CollectiveEpilogue::SharedStorage; + union SharedStorage { + MainloopSharedStorage mainloop; + EpilogueSharedStorage epilogue; + }; + + static constexpr int SharedStorageSize = is_empty_v ? size_t(0) : sizeof(SharedStorage); + + static constexpr int max_num_kv_splits = SGPerWG::value * intel::sg_size; + static constexpr bool Sink = CollectiveEpilogue::Sink; + using ElementSink = typename CollectiveEpilogue::ElementSink; + + // Device side arguments + struct KernelArguments { + ProblemShape shape; + const ElementQ* Q; + StrideQ dQ; + const ElementK* K; + StrideK dK; + const ElementV* V; + StrideV dV; + ElementO* Oaccum; + StrideO dOaccum; + ElementLSE* exp_sums; + StrideO dExp_sums; + ElementLSE* max_logits; + StrideO dMax_logits; + + const ElementSink* sm_sink; + }; + using KernelParams = KernelArguments; + + struct Arguments { + KernelArguments kernel{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + int num_kv_splits = -1; // no split by default + }; + + // Kernel entry point API + struct Params { + KernelParams kernel; + MainloopParams mainloop; + EpilogueParams epilogue; + TileSchedulerParams scheduler; + }; + + // + // Methods + // + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return { + args.kernel, + CollectiveMainloop::to_underlying_arguments(args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.epilogue, workspace), + TileScheduler::to_underlying_arguments(args.kernel.shape, args.hw_info, TileShapeO{}, args.num_kv_splits)}; + } + + static bool can_implement(Arguments const& args) { + if (!is_var_len && args.kernel.shape.seq_len_qo != 1) { + // decode only + return false; + } + + if (args.num_kv_splits > max_num_kv_splits) { + return false; + } + + return CollectiveMainloop::can_implement(args.mainloop) && CollectiveEpilogue::can_implement(args.epilogue); + } + + static int get_workspace_size(Arguments const& args) { + return 0; + } + + static cutlass::Status initialize_workspace( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + return Status::kSuccess; + } + + static dim3 get_grid_shape(Params const& params) { + return TileScheduler::template get_grid_shape(params.scheduler); + } + + static dim3 get_block_shape() { + return dim3(SGPerWG::value * intel::sg_size, 1, 1); + } + + CUTLASS_DEVICE + Shape get_sequence_length_shape(ProblemShape const& problem_shape, int const& batch) { + if constexpr (is_var_len) { + auto q_len = + cutlass::fmha::collective::apply_variable_length(Shape{problem_shape.seq_len_qo}, batch); + return Shape{get<0>(q_len), problem_shape.seq_len_kv.cumulative_length[batch]}; + } else { + return Shape{problem_shape.seq_len_qo, problem_shape.seq_len_kv}; + } + } + + CUTLASS_DEVICE + void operator()(Params const& params, char* smem_buf) { + using namespace sycl::ext::oneapi::this_work_item; + + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + auto& p = params.kernel; + ProblemShape const& s = p.shape; + int head_group_q = s.num_heads_q / s.num_heads_kv; + + int thr_id = int(ThreadIdxX()); + int sub_group_id = thr_id / intel::sg_size; + int q_sg_tile = get<0>(shape_div(TileShapeQK{}, shape(SubgroupLayoutQK{}))); + + auto cS = make_identity_tensor(take<0, 2>(TiledMMAQK{}.tile_mnk())); + auto tScS = TiledMMAQK{}.get_slice(thr_id).partition_C(cS); + auto q_offset_wi = get<0>(tScS(0)); + auto q_offset_sg = group_broadcast(sycl::ext::oneapi::this_work_item::get_sub_group(), q_offset_wi, 0); + + TileScheduler tile_scheduler{params.scheduler}; + auto num_kv_splits = params.scheduler.num_kv_splits_; + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto [blk_q, blk_v, head, idx_b, idx_kv_split] = tile_scheduler.get_block_coord(); // (Q,V,h,b,id_split) + auto blk_qv = make_coord(blk_q, blk_v); + int head_q_start = head * head_group_q; + + auto sequence_length_shape = get_sequence_length_shape(s, idx_b); + auto [seq_len_qo, seq_len_kv] = sequence_length_shape; + if (blk_q * get<0>(TileShapeQK{}) >= seq_len_qo) continue; + + auto offset = cute::min(seq_len_qo, seq_len_kv); + auto discard_seq_coord = seq_len_qo - offset; + auto full_tile_offset = seq_len_kv - offset; + int seq_coord = cute::min(seq_len_qo, (blk_q * get<0>(TileShapeQK{}) + q_offset_sg)); + + if (CollectiveMainloop::CausalMask && seq_coord < discard_seq_coord) continue; + // For decode window_size_right doesn't have effect + const int seq_len = seq_len_kv; + // For decode, all packed GQA heads are at position seq_len_kv - 1. + // Use seq_len - 1 (= seq_len_kv - 1) as the decode position for + // k_block0 to match ReduceSplitK's computation. + const int k_block0 = CollectiveMainloop::LocalMask + ? cute::max(seq_len - 1 - params.mainloop.window_size_left, 0) / get<1>(TileShapeQK{}) + : 0; + const int k_blocks = cute::ceil_div(seq_len, get<1>(TileShapeQK{})); + const int windowed_k_blocks = k_blocks - k_block0; + + int offset_q = 0, offset_k = 0, offset_v = 0, offset_o = 0; + int offset_exp_sums = 0, offset_max_logits = 0; + if constexpr (is_var_len) { + auto qo_cumulative = s.seq_len_qo.cumulative_length; + + offset_q = s.num_heads_q * s.head_size_qk * qo_cumulative[idx_b]; + offset_o = s.num_heads_q * s.head_size_vo * num_kv_splits * qo_cumulative[idx_b]; + offset_exp_sums = s.num_heads_q * num_kv_splits * qo_cumulative[idx_b]; + offset_max_logits = s.num_heads_q * num_kv_splits * qo_cumulative[idx_b]; + + // for gqa packing, seq_len_qo must be 1 + seq_len_qo = 1; + } + + // neglect seq_len_qo since it's always 1 for decode + auto batch_dim = is_var_len ? 1 : s.batch; + auto shape_Q = make_shape(head_group_q, s.head_size_qk, s.num_heads_kv, batch_dim); + // shape + auto total_seqlen_kv = params.mainloop.total_seqlen_kv; + auto shape_K = make_shape(total_seqlen_kv, s.head_size_qk, s.num_heads_kv, batch_dim); + auto shape_V = make_shape(s.head_size_vo, total_seqlen_kv, s.num_heads_kv, batch_dim); + + auto shape_O = make_shape(head_group_q, s.head_size_vo, s.num_heads_kv, num_kv_splits, batch_dim); + auto shape_exp_sums = make_shape(head_group_q, num_kv_splits, s.num_heads_kv, batch_dim); + auto shape_max_logits = make_shape(head_group_q, num_kv_splits, s.num_heads_kv, batch_dim); + auto shape_sink = make_shape(s.num_heads_kv, head_group_q); + + int num_blocks_per_split = cute::ceil_div(windowed_k_blocks, num_kv_splits); + + // Per-sequence split decision: short sequences are treated as + // single-split even when num_kv_splits > 1, avoiding precision + // loss from the split-reduce roundtrip. + constexpr int kMinBlocksForSplit = 128; + bool is_single_split = (num_kv_splits > 1) && (windowed_k_blocks < kMinBlocksForSplit); + + int kv_split_offset; + int num_effective_kv_blocks; + if (is_single_split) { + // Split 0 processes all blocks; splits 1+ skip entirely. + if (idx_kv_split > 0) { + continue; + } + kv_split_offset = k_block0; + num_effective_kv_blocks = windowed_k_blocks; + } else { + kv_split_offset = k_block0 + idx_kv_split * num_blocks_per_split; + num_effective_kv_blocks = + cute::min(windowed_k_blocks - idx_kv_split * num_blocks_per_split, num_blocks_per_split); + } + + if (num_effective_kv_blocks <= 0) { + // no need computation + continue; + } + + auto dcQ = const_cast(p.Q + offset_q); + auto dcK = const_cast(p.K); + auto dcV = const_cast(p.V); + auto ptrO = p.Oaccum + offset_o; + auto ptrExp_sums = p.exp_sums + offset_exp_sums; + auto ptrMax_logits = p.max_logits + offset_max_logits; + + auto layout_q = make_ordered_layout(shape_Q, Step<_1, _0, _2, _3>{}); + auto layout_k = make_ordered_layout(shape_K, Step<_2, _0, _1, _3>{}); + auto layout_v = make_ordered_layout(shape_V, Step<_0, _2, _1, _3>{}); + + // auto layout_k = make_layout(shape_K, make_stride(get<0>(p.dK), _1{}, get<2>(p.dK), get<3>(p.dK))); + // auto layout_v = make_layout(shape_V, make_stride(_1{}, get<1>(p.dV), get<2>(p.dV), get<3>(p.dV))); + + auto layout_o = make_ordered_layout(shape_O, Step<_1, _0, _2, _3, _4>{}); + auto layout_exp_sums = make_ordered_layout(shape_exp_sums, Step<_1, _0, _2, _3>{}); + auto layout_max_logits = make_ordered_layout(shape_max_logits, Step<_1, _0, _2, _3>{}); + auto layout_sink = make_ordered_layout(shape_sink, Step<_1, _0>{}); + + Tensor Q = make_tensor(make_gmem_ptr(dcQ), layout_q); + Tensor K = make_tensor(make_gmem_ptr(dcK), layout_k); + Tensor V = make_tensor(make_gmem_ptr(dcV), layout_v); + Tensor O = make_tensor(make_gmem_ptr(ptrO), layout_o); + Tensor exp_sums = make_tensor(make_gmem_ptr(ptrExp_sums), layout_exp_sums); + Tensor max_logits = make_tensor(make_gmem_ptr(ptrMax_logits), layout_max_logits); + Tensor sinks = make_tensor(make_gmem_ptr(const_cast(p.sm_sink)), layout_sink); + + // O accumulator types + FragA tArA; + FragARow tA_max, tA_sum; + + // Main loop + int l_coord = is_var_len ? 0 : idx_b; + + int start_blk = kv_split_offset; + int end_blk = kv_split_offset + num_effective_kv_blocks; + + CollectiveMainloop mainloop(params.mainloop, shared_storage.mainloop); + + mainloop( + Q(_, _, head, l_coord), + K(_, _, head, l_coord), + V(_, _, head, l_coord), + tArA, + tA_max, + tA_sum, + blk_qv, + idx_b, + start_blk, + end_blk, + k_blocks, + thr_id, + seq_len, + full_tile_offset, + discard_seq_coord); + + if constexpr (!is_empty_v && !is_empty_v) { + sycl::group_barrier(get_work_group<3>()); + } + + // Epilogue + CollectiveEpilogue epilogue{params.epilogue, shared_storage.epilogue}; + if constexpr (Sink) { + auto sinks_per_kv = sinks(head, _); + epilogue( + O(_, _, head, idx_kv_split, l_coord), + tArA, + tA_max, + tA_sum, + blk_qv, + thr_id, + exp_sums(_, _, head, l_coord), + max_logits(_, _, head, l_coord), + idx_kv_split, + head_group_q, + sinks_per_kv, + num_kv_splits, + is_single_split); + } else { + epilogue( + O(_, _, head, idx_kv_split, l_coord), + tArA, + tA_max, + tA_sum, + blk_qv, + thr_id, + exp_sums(_, _, head, l_coord), + max_logits(_, _, head, l_coord), + idx_kv_split, + head_group_q, + sinks, + num_kv_splits, + is_single_split); + } + } + } +}; + } // namespace cutlass::fmha::kernel diff --git a/src/sycl/kernels/flash_attention_v2/kernel/xe_reduce_split_k.hpp b/src/sycl/kernels/flash_attention_v2/kernel/xe_reduce_split_k.hpp new file mode 100644 index 00000000..91a8b44d --- /dev/null +++ b/src/sycl/kernels/flash_attention_v2/kernel/xe_reduce_split_k.hpp @@ -0,0 +1,308 @@ +/*************************************************************************************************** + * Copyright (C) 2025-2026 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Kernel performing a reduction over densely packed tensors in global memory +*/ + +#pragma once + +#include "cute/util/type_traits.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "sycl/kernels/flash_attention_v2/collective/fmha_fusion.hpp" +#include "sycl/kernels/flash_attention_v2/collective/xe_fmha_fwd_epilogue.hpp" +#include "sycl/kernels/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp" +#include "sycl/kernels/flash_attention_v2/kernel/xe_tile_scheduler.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reduction { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class ReduceSplitK { + public: + using ProblemShape = ProblemShape_; + using VariableLength = cutlass::fmha::collective::VariableLength; + static constexpr bool is_var_len = cutlass::fmha::collective::is_variable_length_v; + using TileScheduler = TileScheduler_; + static_assert( + is_same_v, + "ReduceSplitK kernel requires XeReduceSplitKTileScheduler"); + using TileSchedulerParams = typename TileScheduler::Params; + + using ElementO = typename FMHAKernel_::ElementO; + using StrideO = typename FMHAKernel_::StrideO; + using TileShapeO = typename FMHAKernel_::TileShapeO; + using TileShapeQK = typename FMHAKernel_::TileShapeQK; + + using ElementLSE = typename FMHAKernel_::ElementLSE; + + using SGPerWG = typename FMHAKernel_::SGPerWG; + + // num values (head_dim) processed by each thread + constexpr static int num_vals_per_thread = int(get<1>(TileShapeO{}) / (SGPerWG::value * intel::sg_size)); + + // + // Types + // + + struct KernelArguments { + ProblemShape shape; + // outputs: + ElementO* O; + StrideO dO; + // below are inputs + // TODO: whether same dtype as output or accum? + const ElementO* Oaccum; + StrideO dOaccum; + const ElementLSE* exp_sums; + StrideO dExp_sums; + const ElementLSE* max_logits; + StrideO dMax_logits; + int window_size_left = -1; + }; + using KernelParams = KernelArguments; + + struct Arguments { + KernelArguments kernel{}; + KernelHardwareInfo hw_info{}; + int num_kv_splits = -1; // no split by default + }; + + /// Params structure + struct Params { + KernelParams kernel; + TileSchedulerParams scheduler; + }; + + struct SharedStorage { + cutlass::Array max_logits_slm_array; + cutlass::Array exp_sums_slm_array; + }; + + static constexpr int SharedStorageSize = is_empty_v ? size_t(0) : sizeof(SharedStorage); + + public: + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return { + args.kernel, + TileScheduler::to_underlying_arguments(args.kernel.shape, args.hw_info, TileShapeO{}, args.num_kv_splits)}; + } + + static bool can_implement(Arguments const& args) { + // only support decode + if (!is_var_len && args.kernel.shape.seq_len_qo > 1) { + return false; + } + + if (args.num_kv_splits > FMHAKernel_::max_num_kv_splits) { + return false; + } + return true; + } + + static int get_workspace_size(Arguments const& args) { + return 0; + } + + static cutlass::Status initialize_workspace( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + return Status::kSuccess; + } + + static dim3 get_grid_shape(Params const& params) { + return TileScheduler::template get_grid_shape(params.scheduler); + } + + static dim3 get_block_shape() { + return dim3(SGPerWG::value * intel::sg_size, 1, 1); + } + + CUTLASS_DEVICE + Shape get_sequence_length_shape(ProblemShape const& problem_shape, int const& batch) { + if constexpr (is_var_len) { + auto q_len = + cutlass::fmha::collective::apply_variable_length(Shape{problem_shape.seq_len_qo}, batch); + return Shape{get<0>(q_len), problem_shape.seq_len_kv.cumulative_length[batch]}; + } else { + return Shape{problem_shape.seq_len_qo, problem_shape.seq_len_kv}; + } + } + + /// Perform a reduction + CUTLASS_DEVICE + void operator()(Params const& params, char* smem_buf) { + using namespace sycl::ext::oneapi::this_work_item; + + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + auto& p = params.kernel; + ProblemShape const& s = p.shape; + + int thr_id = int(ThreadIdxX()); + int sub_group_id = thr_id / intel::sg_size; + int tid_in_sg = thr_id % intel::sg_size; + + TileScheduler tile_scheduler{params.scheduler}; + auto num_kv_splits = params.scheduler.num_kv_splits; + + auto batch_dim = is_var_len ? 1 : s.batch; + auto num_heads_q = s.num_heads_q; + auto head_size_vo = s.head_size_vo; + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto [seq_idx, head_q, idx_b] = tile_scheduler.get_block_coord(); + + auto sequence_length_shape = get_sequence_length_shape(s, idx_b); + auto [seq_len_qo, seq_len_kv] = sequence_length_shape; + + // when varlen enabled, use largest seq_len_qo to decide work group num + if (seq_idx >= seq_len_qo) continue; + + const int k_blocks = cute::ceil_div(seq_len_kv, get<1>(TileShapeQK{})); + // Sliding window: skip blocks before the window + constexpr bool LocalMask = FMHAKernel_::CollectiveMainloop::LocalMask; + const int k_block0 = LocalMask ? cute::max(seq_len_kv - 1 - p.window_size_left, 0) / get<1>(TileShapeQK{}) : 0; + const int windowed_k_blocks = k_blocks - k_block0; + int num_blocks_per_split = cute::ceil_div(windowed_k_blocks, num_kv_splits); + + int offset_o = 0, offset_o_accum = 0; + int offset_exp_sums = 0, offset_max_logits = 0; + + if constexpr (is_var_len) { + auto qo_cumulative = s.seq_len_qo.cumulative_length; + + offset_o_accum = s.num_heads_q * s.head_size_vo * num_kv_splits * qo_cumulative[idx_b]; + offset_exp_sums = s.num_heads_q * num_kv_splits * qo_cumulative[idx_b]; + offset_max_logits = s.num_heads_q * num_kv_splits * qo_cumulative[idx_b]; + + offset_o = s.num_heads_q * s.head_size_vo * qo_cumulative[idx_b]; + } + + auto shape_O = make_shape(seq_len_qo, head_size_vo, num_heads_q, batch_dim); + auto shape_Oaccum = is_var_len ? make_shape(seq_len_qo, head_size_vo, num_heads_q * num_kv_splits, batch_dim) + : make_shape(seq_len_qo, head_size_vo, num_heads_q * num_kv_splits, batch_dim); + + auto shape_exp_sums = make_shape(seq_len_qo, num_kv_splits, num_heads_q, batch_dim); + auto shape_max_logits = make_shape(seq_len_qo, num_kv_splits, num_heads_q, batch_dim); + + auto dcOaccum = const_cast(p.Oaccum + offset_o_accum); + auto ptrExp_sums = const_cast(p.exp_sums + offset_exp_sums); + auto ptrMax_logits = const_cast(p.max_logits + offset_max_logits); + auto ptrO = p.O + offset_o; + + auto stride_o = is_var_len ? cutlass::make_cute_packed_stride(StrideO{}, shape_O) : p.dO; + auto stride_o_accum = is_var_len ? cutlass::make_cute_packed_stride(StrideO{}, shape_Oaccum) : p.dOaccum; + auto stride_exp_sums = is_var_len ? cutlass::make_cute_packed_stride(StrideO{}, shape_exp_sums) : p.dExp_sums; + auto stride_max_logits = + is_var_len ? cutlass::make_cute_packed_stride(StrideO{}, shape_max_logits) : p.dMax_logits; + + Tensor Oaccum = make_tensor(make_gmem_ptr(dcOaccum), make_layout(shape_Oaccum, stride_o_accum)); + Tensor O = make_tensor(make_gmem_ptr(ptrO), make_layout(shape_O, stride_o)); + + Tensor exp_sums = make_tensor(make_gmem_ptr(ptrExp_sums), make_layout(shape_exp_sums, stride_exp_sums)); + Tensor max_logits = make_tensor(make_gmem_ptr(ptrMax_logits), make_layout(shape_max_logits, stride_max_logits)); + + int l_coord = is_var_len ? 0 : idx_b; + + // Step 1: reduce max logits across different partitions + // store into SLM for later use + + ElementLSE global_max_logits{cutlass::platform::numeric_limits::lowest()}; + ElementLSE global_exp_sums{0}; + // only first subgroup participates + if (thr_id < num_kv_splits && thr_id * num_blocks_per_split < windowed_k_blocks) { + ElementLSE cur_max_logit = max_logits(seq_idx, thr_id, head_q, l_coord); + global_max_logits = sycl::max(global_max_logits, cur_max_logit); + shared_storage.max_logits_slm_array[thr_id] = cur_max_logit; + + ElementLSE cur_exp_sum = exp_sums(seq_idx, thr_id, head_q, l_coord); + shared_storage.exp_sums_slm_array[thr_id] = cur_exp_sum; + } + + // barrier for SLM writes finished + sycl::group_barrier(get_work_group<3>()); + + // reduce across wg + global_max_logits = reduce_over_group(get_work_group<1>(), global_max_logits, sycl::maximum<>()); + + // broadcast to all other threads + global_max_logits = sycl::group_broadcast(get_work_group<1>(), global_max_logits, 0); + + for (int idx = thr_id; idx < s.head_size_vo; idx += SGPerWG::value * intel::sg_size) { + ElementLSE acc = 0; + global_exp_sums = 0; + for (int i = 0; i < num_kv_splits; ++i) { + if (i * num_blocks_per_split >= windowed_k_blocks) { + break; + } + ElementLSE local_max_logit = shared_storage.max_logits_slm_array[i]; + ElementLSE local_exp_sum = shared_storage.exp_sums_slm_array[i]; + + // Skip splits with no valid data (short sequences treated as + // single-split have exp_sums=0 / max_logits=-inf for unused splits). + if (local_exp_sum <= ElementLSE(0)) continue; + + ElementLSE rescale = sycl::native::exp2(local_max_logit - global_max_logits); + + // Partial outputs are unnormalized (not divided by exp_sum in the + // epilogue), so combine them directly with the rescale factor. + ElementLSE o_accum_val = static_cast(Oaccum(seq_idx, idx, i * num_heads_q + head_q, l_coord)); + acc += o_accum_val * rescale; + + // update global exp sum + global_exp_sums += local_exp_sum * rescale; + } + + ElementLSE inv_global_exp_sums = 1. / global_exp_sums; + + acc *= inv_global_exp_sums; + O(seq_idx, idx, head_q, l_coord) = static_cast(acc); + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace reduction +} // namespace cutlass diff --git a/src/sycl/kernels/flash_attention_v2/kernel/xe_tile_scheduler.hpp b/src/sycl/kernels/flash_attention_v2/kernel/xe_tile_scheduler.hpp index edef3695..6f6daa47 100644 --- a/src/sycl/kernels/flash_attention_v2/kernel/xe_tile_scheduler.hpp +++ b/src/sycl/kernels/flash_attention_v2/kernel/xe_tile_scheduler.hpp @@ -42,6 +42,8 @@ struct XeFHMAIndividualTileScheduler { struct Params { dim3 grid; FastDivmod divmod_num_heads; + FastDivmod divmod_batch; + int num_kv_splits_ = -1; }; bool valid_ = true; @@ -51,15 +53,25 @@ struct XeFHMAIndividualTileScheduler { XeFHMAIndividualTileScheduler(Params const& params) : params(params) {} template - static Params - to_underlying_arguments(ProblemShape const& shape, KernelHardwareInfo hw_info, TileShape const& tile_shape) { + static Params to_underlying_arguments( + ProblemShape const& shape, + KernelHardwareInfo hw_info, + TileShape const& tile_shape, + const int& num_kv_splits = -1) { using namespace cute; dim3 grid( - size(ceil_div(shape.head_size_vo, get<1>(tile_shape))), // V - size(ceil_div((int)shape.seq_len_qo, get<0>(tile_shape))), // Q - size(shape.batch * shape.num_heads_q)); // (h,b) -- split later - return Params{grid, {shape.num_heads_q}}; + size(ceil_div(shape.head_size_vo, get<1>(tile_shape))), // V + size(ceil_div(shape.seq_len_qo, get<0>(tile_shape))), // Q + size(shape.batch * shape.num_heads_q)); // (h,b) -- split later + int num_head = shape.num_heads_q; + if (num_kv_splits >= 1) { + // for splitKV, each wg handles group query heads + grid.z = size(shape.batch * shape.num_heads_kv); + grid.z *= num_kv_splits; + num_head = shape.num_heads_kv; + } + return Params{grid, {num_head}, {shape.batch * num_head}, num_kv_splits}; } template @@ -75,10 +87,18 @@ struct XeFHMAIndividualTileScheduler { CUTLASS_DEVICE auto get_block_coord() { using namespace cute; - int idx_b = BlockIdxZ(); - int head; + int idx_kv_split = BlockIdxZ(); + int head, idx_b; + + if (params.num_kv_splits_ >= 1) { + params.divmod_batch(idx_kv_split, idx_b, idx_kv_split); + params.divmod_num_heads(idx_b, head, idx_b); + return make_coord(BlockIdxY(), BlockIdxX(), head, idx_b, idx_kv_split); + } + + idx_b = idx_kv_split; params.divmod_num_heads(idx_b, head, idx_b); - return make_coord(BlockIdxY(), BlockIdxX(), head, idx_b); + return make_coord(BlockIdxY(), BlockIdxX(), head, idx_b, (int)-1); } CUTLASS_DEVICE @@ -157,4 +177,52 @@ struct XeFHMAIndividualPersistentTileScheduler { } }; +struct XeReduceSplitKTileScheduler { + struct Params { + dim3 grid; + FastDivmod divmod_num_heads; + int num_kv_splits = -1; + }; + + bool valid_ = true; + Params params; + + CUTLASS_DEVICE + XeReduceSplitKTileScheduler(Params const& params) : params(params) {} + + template + static Params to_underlying_arguments( + ProblemShape const& shape, + KernelHardwareInfo hw_info, + TileShape const& tile_shape, + const int& num_kv_splits = -1) { + using namespace cute; + + dim3 grid(shape.seq_len_qo, shape.num_heads_q, shape.batch); + return Params{grid, {shape.num_heads_q}, num_kv_splits}; + } + + template + static dim3 get_grid_shape(Params const& params) { + return params.grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return valid_; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + + return make_coord(BlockIdxX(), BlockIdxY(), BlockIdxZ()); + } + + CUTLASS_DEVICE + XeReduceSplitKTileScheduler& operator++() { + valid_ = false; + return *this; + } +}; } // namespace cutlass::fmha::kernel diff --git a/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_dispatch.hpp b/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_dispatch.hpp index eb27f217..58756aad 100644 --- a/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_dispatch.hpp +++ b/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_dispatch.hpp @@ -31,13 +31,14 @@ **************************************************************************************************/ #pragma once -namespace decode { +#include "xe_fmha_fwd_decode_runner.hpp" -struct Arguments; +namespace decode { -// Declarations for generated FMHA decode kernel launch functions. -// Each function is defined in a separate generated .cpp file from -// xe_fmha_fwd_decode_kernel.cpp.in, compiled as its own library. +// Struct functor declarations for FMHA decode kernel launchers. +// Each template specialization is explicitly instantiated in a separate +// generated .cpp file (from xe_fmha_fwd_decode_kernel.cpp.in / +// xe_fmha_fwd_split_decode_kernel.cpp.in). // // Naming: launch_fmha_decode___ // Parameters: @@ -45,31 +46,54 @@ struct Arguments; // HEAD_DIM in {64, 96, 128, 192, 256, 512} // PAGE_SIZE in {32, 64, 128} (with NUM_SG = PAGE_SIZE / 16) -#define DECLARE_LAUNCH_FMHA_DECODE(QG, HD, PS) \ - void launch_fmha_decode_##QG##_##HD##_##PS(bool use_sink, const Arguments& params); +// Explicit instantiation declarations — tell the compiler these are compiled +// in separate translation units (generated from the .cpp.in templates). + +#define EXTERN_FMHA_DECODE_RUNNER(QG, HD, PS) extern template struct FmhaDecodeRunner; + +#define EXTERN_FMHA_SPLIT_DECODE_RUNNER(QG, HD, PS) extern template struct FmhaSplitDecodeRunner; + +#define EXTERN_FMHA_DECODE_RUNNER_ALL_PAGE_SIZES(QG, HD) \ + EXTERN_FMHA_DECODE_RUNNER(QG, HD, 64) \ + EXTERN_FMHA_DECODE_RUNNER(QG, HD, 128) + +#define EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_PAGE_SIZES(QG, HD) \ + EXTERN_FMHA_SPLIT_DECODE_RUNNER(QG, HD, 64) \ + EXTERN_FMHA_SPLIT_DECODE_RUNNER(QG, HD, 128) + +#define EXTERN_FMHA_DECODE_RUNNER_ALL_QG(HD) \ + EXTERN_FMHA_DECODE_RUNNER_ALL_PAGE_SIZES(1, HD) \ + EXTERN_FMHA_DECODE_RUNNER_ALL_PAGE_SIZES(2, HD) \ + EXTERN_FMHA_DECODE_RUNNER_ALL_PAGE_SIZES(4, HD) \ + EXTERN_FMHA_DECODE_RUNNER_ALL_PAGE_SIZES(8, HD) \ + EXTERN_FMHA_DECODE_RUNNER_ALL_PAGE_SIZES(16, HD) -#define DECLARE_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES(QG, HD) \ - DECLARE_LAUNCH_FMHA_DECODE(QG, HD, 32) \ - DECLARE_LAUNCH_FMHA_DECODE(QG, HD, 64) \ - DECLARE_LAUNCH_FMHA_DECODE(QG, HD, 128) +#define EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_QG(HD) \ + EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_PAGE_SIZES(1, HD) \ + EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_PAGE_SIZES(2, HD) \ + EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_PAGE_SIZES(4, HD) \ + EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_PAGE_SIZES(8, HD) \ + EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_PAGE_SIZES(16, HD) -#define DECLARE_LAUNCH_FMHA_DECODE_ALL_QG(HD) \ - DECLARE_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES(1, HD) \ - DECLARE_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES(2, HD) \ - DECLARE_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES(4, HD) \ - DECLARE_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES(8, HD) \ - DECLARE_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES(16, HD) \ - DECLARE_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES(32, HD) +EXTERN_FMHA_DECODE_RUNNER_ALL_QG(64) +EXTERN_FMHA_DECODE_RUNNER_ALL_QG(96) +EXTERN_FMHA_DECODE_RUNNER_ALL_QG(128) +EXTERN_FMHA_DECODE_RUNNER_ALL_QG(192) +EXTERN_FMHA_DECODE_RUNNER_ALL_QG(256) +EXTERN_FMHA_DECODE_RUNNER_ALL_QG(512) -DECLARE_LAUNCH_FMHA_DECODE_ALL_QG(64) -DECLARE_LAUNCH_FMHA_DECODE_ALL_QG(96) -DECLARE_LAUNCH_FMHA_DECODE_ALL_QG(128) -DECLARE_LAUNCH_FMHA_DECODE_ALL_QG(192) -DECLARE_LAUNCH_FMHA_DECODE_ALL_QG(256) -DECLARE_LAUNCH_FMHA_DECODE_ALL_QG(512) +EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_QG(64) +EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_QG(96) +EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_QG(128) +EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_QG(192) +EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_QG(256) +EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_QG(512) -#undef DECLARE_LAUNCH_FMHA_DECODE -#undef DECLARE_LAUNCH_FMHA_DECODE_ALL_PAGE_SIZES -#undef DECLARE_LAUNCH_FMHA_DECODE_ALL_QG +#undef EXTERN_FMHA_DECODE_RUNNER +#undef EXTERN_FMHA_SPLIT_DECODE_RUNNER +#undef EXTERN_FMHA_DECODE_RUNNER_ALL_PAGE_SIZES +#undef EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_PAGE_SIZES +#undef EXTERN_FMHA_DECODE_RUNNER_ALL_QG +#undef EXTERN_FMHA_SPLIT_DECODE_RUNNER_ALL_QG } // namespace decode diff --git a/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp b/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp index 21307b42..77046ac6 100644 --- a/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp +++ b/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp @@ -44,8 +44,8 @@ #include "sycl/comm/common.h" #include "sycl/kernels/flash_attention_v2/collective/fmha_fusion.hpp" #include "sycl/kernels/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp" +#include "sycl/kernels/flash_attention_v2/kernel/xe_reduce_split_k.hpp" #include "sycl/kernels/flash_attention_v2/kernel/xe_tile_scheduler.hpp" - using namespace cute; namespace decode { struct Arguments { @@ -54,6 +54,12 @@ struct Arguments { void* __restrict__ k_ptr; void* __restrict__ v_ptr; + void* __restrict__ k_scale_ptr = nullptr; + void* __restrict__ v_scale_ptr = nullptr; + + void* __restrict__ temp_out_ptr = nullptr; + void* __restrict__ exp_sums_ptr = nullptr; + void* __restrict__ max_logits_ptr = nullptr; // The stride between rows of the Q, K and V matrices. int64_t q_batch_stride; int64_t k_batch_stride; @@ -66,9 +72,21 @@ struct Arguments { int64_t v_head_stride; int64_t v_dim_stride; + int64_t k_stride_page = 0; + int64_t k_stride_seq = 0; + int64_t k_stride_heads = 0; + int64_t v_stride_page = 0; + int64_t v_stride_seq = 0; + int64_t v_stride_heads = 0; + // The number of heads. int h, h_k; int q_group_size = 1; + int num_kv_splits = -1; // For split-KV version + bool use_split_kv = false; + bool use_sink = false; + bool is_causal = false; + bool is_local = false; // The O matrix (output). void* __restrict__ o_ptr; @@ -141,7 +159,7 @@ struct Arguments { // The indices to index into the KV cache. int* __restrict__ kv_batch_idx; - // Paged KV cache + // PagedKV KV cache int* __restrict__ page_table; int max_num_pages_per_seq; int64_t page_table_batch_stride; @@ -156,8 +174,9 @@ struct Arguments { // Scale factor of 1 / (1 - p_dropout). float rp_dropout; - // Local window size - int window_size_left, window_size_right; + // LocalMask window size + int window_size_left = -1; + int window_size_right = -1; // Pointer to the RNG seed (idx 0) and offset (idx 1). uint64_t* rng_state; @@ -165,8 +184,6 @@ struct Arguments { bool is_bf16; bool is_fp32; bool is_e4m3; - bool is_causal; - bool is_local; bool is_rotary_interleaved; @@ -217,16 +234,21 @@ struct DecodeRunner { auto initialize_varlen(const Arguments& params, const ProblemShape& problem_size) { ProblemShape problem_size_for_init = problem_size; get<0>(problem_size_for_init) = 1; // concentrated batch - get<1>(problem_size_for_init) = params.h / params.q_group_size; - get<3>(problem_size_for_init) = params.total_q * params.q_group_size; + get<1>(problem_size_for_init) = params.use_split_kv ? params.h : params.h_k; + get<3>(problem_size_for_init) = params.use_split_kv ? params.total_q : params.total_q * params.q_group_size; get<4>(problem_size_for_init) = params.total_knew; get<5>(problem_size_for_init) = params.total_k; ProblemShapeType problem_size_for_launch{ .batch = get<0>(problem_size), - .num_heads_q = get<1>(problem_size) / params.q_group_size, + .num_heads_q = params.use_split_kv ? get<1>(problem_size) : get<2>(problem_size), .num_heads_kv = get<2>(problem_size), - .seq_len_qo = {params.seqlen_q, params.total_q * params.q_group_size, nullptr, params.q_group_size}, + .seq_len_qo = + {params.use_split_kv ? params.seqlen_q * params.q_group_size : params.seqlen_q, + params.use_split_kv ? params.total_q : params.total_q * params.q_group_size, + nullptr, + params.use_split_kv ? 1 : params.q_group_size}, + .seq_len_kv = {params.seqlen_knew, params.total_knew}, .seq_len_kv_cache = {params.seqlen_k, params.total_k}, .head_size_qk = get<6>(problem_size), @@ -320,6 +342,204 @@ struct DecodeRunner { return cutlass::Status::kSuccess; } }; + +template +struct SplitDecodeKernelRunner { + using StrideQ = typename FMHAKernel::StrideQ; + using StrideK = typename FMHAKernel::StrideK; + using StrideV = typename FMHAKernel::StrideV; + using StrideO = typename FMHAKernel::StrideO; + + using ElementQ = typename FMHAKernel::ElementQ; + using ElementK = typename FMHAKernel::ElementK; + using ElementV = typename FMHAKernel::ElementV; + using ElementO = typename FMHAKernel::ElementO; + using ElementLSE = typename FMHAKernel::ElementLSE; + + using CollectiveMainloop = typename FMHAKernel::CollectiveMainloop; + using ElementS = typename CollectiveMainloop::ElementS; + + using ProblemShapeType = cutlass::fmha::kernel::FMHAProblemShape; + using ProblemShapeTypeInit = cutlass::fmha::kernel::FMHAProblemShape; + + /// Initialization + StrideQ stride_Q; + StrideK stride_K; + StrideV stride_V; + StrideO stride_O; + StrideO stride_Oaccum; + StrideO stride_exp_sums; + StrideO stride_max_logits; + + int num_kv_splits; + + ProblemShapeType initialize(const Arguments& params) { + ProblemShapeType shape; + ProblemShapeTypeInit shape_init; + auto batch = shape.batch = shape_init.batch = params.b; + auto num_heads_q = shape.num_heads_q = shape_init.num_heads_q = params.h; + auto num_heads_kv = shape.num_heads_kv = shape_init.num_heads_kv = params.h_k; + auto head_size_qk = shape.head_size_qk = shape_init.head_size_qk = params.d; + auto head_size_vo = shape.head_size_vo = shape_init.head_size_vo = params.d; + + if constexpr (isVarLen) { + batch = shape_init.batch = 1; + shape_init.seq_len_qo = params.total_q; + shape_init.seq_len_kv = params.total_k; + + shape.seq_len_qo = cutlass::fmha::collective::VariableLength{params.seqlen_q}; + shape.seq_len_qo.cumulative_length = reinterpret_cast(params.cu_seqlens_q); + shape.seq_len_kv = cutlass::fmha::collective::VariableLength{params.seqlen_k}; + shape.seq_len_kv.cumulative_length = reinterpret_cast(params.cu_seqlens_k); + } else { + shape.seq_len_qo = shape_init.seq_len_qo = params.seqlen_q; + shape.seq_len_kv = shape_init.seq_len_kv = params.seqlen_k; + } + + auto seq_len_qo = shape_init.seq_len_qo; + auto seq_len_kv = shape_init.seq_len_kv; + + num_kv_splits = params.num_kv_splits; + + stride_Q = + cutlass::make_cute_packed_stride(StrideQ{}, cute::make_shape(seq_len_qo, head_size_qk, num_heads_q, batch)); + if (params.k_stride_seq > 0) { + // Use actual strides from KV cache tensors (supports non-contiguous + // layouts such as MLA combined KV cache) + constexpr int64_t kIntMax = static_cast(std::numeric_limits::max()); + TORCH_CHECK( + params.k_stride_seq <= kIntMax && params.k_stride_heads <= kIntMax && params.k_stride_page <= kIntMax && + params.v_stride_seq <= kIntMax && params.v_stride_heads <= kIntMax && params.v_stride_page <= kIntMax, + "KV cache stride exceeds int32 max (", + kIntMax, + "): k_stride_seq=", + params.k_stride_seq, + " k_stride_heads=", + params.k_stride_heads, + " k_stride_page=", + params.k_stride_page, + " v_stride_seq=", + params.v_stride_seq, + " v_stride_heads=", + params.v_stride_heads, + " v_stride_page=", + params.v_stride_page); + stride_K = StrideK{ + static_cast(params.k_stride_seq), + _1{}, + static_cast(params.k_stride_heads), + static_cast(params.k_stride_page)}; + stride_V = StrideV{ + _1{}, + static_cast(params.v_stride_seq), + static_cast(params.v_stride_heads), + static_cast(params.v_stride_page)}; + } else { + stride_K = + cutlass::make_cute_packed_stride(StrideK{}, cute::make_shape(seq_len_kv, head_size_qk, num_heads_kv, batch)); + stride_V = + cutlass::make_cute_packed_stride(StrideV{}, cute::make_shape(head_size_vo, seq_len_kv, num_heads_kv, batch)); + } + stride_O = + cutlass::make_cute_packed_stride(StrideO{}, cute::make_shape(seq_len_qo, head_size_vo, num_heads_q, batch)); + stride_Oaccum = cutlass::make_cute_packed_stride( + StrideO{}, cute::make_shape(seq_len_qo, head_size_vo, num_heads_q * num_kv_splits, batch)); + + stride_exp_sums = + cutlass::make_cute_packed_stride(StrideO{}, cute::make_shape(seq_len_qo, num_kv_splits, num_heads_q, batch)); + + stride_max_logits = + cutlass::make_cute_packed_stride(StrideO{}, cute::make_shape(seq_len_qo, num_kv_splits, num_heads_q, batch)); + + return shape; + } + + cutlass::Status run(const Arguments& params, const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType shape = initialize(params); + + typename FMHAKernel::Arguments arguments{ + { + shape, + reinterpret_cast(params.q_ptr), + stride_Q, + reinterpret_cast(params.k_ptr), + stride_K, + reinterpret_cast(params.v_ptr), + stride_V, + reinterpret_cast(params.temp_out_ptr), + stride_Oaccum, + reinterpret_cast(params.exp_sums_ptr), + stride_exp_sums, + reinterpret_cast(params.max_logits_ptr), + stride_max_logits, + reinterpret_cast(params.softmax_sink_ptr), + }, + {params.softmax_scale, + params.k_scale_ptr, + params.v_scale_ptr, + static_cast(params.page_table), + params.page_size, + params.max_num_pages_per_seq, + params.total_k, + params.window_size_left, + params.window_size_right}, + {}, + hw_info, + params.num_kv_splits}; + + typename ReductionSplitKernel::Arguments reduce_arg{ + {shape, + reinterpret_cast(params.o_ptr), + stride_O, + reinterpret_cast(params.temp_out_ptr), + stride_Oaccum, + reinterpret_cast(params.exp_sums_ptr), + stride_exp_sums, + reinterpret_cast(params.max_logits_ptr), + stride_max_logits, + params.window_size_left}, + hw_info, + params.num_kv_splits}; + + // Define device-global scratch memory + size_t workspace_size = FMHAKernel::get_workspace_size(arguments); + size_t reduce_workspace_size = ReductionSplitKernel::get_workspace_size(reduce_arg); + torch::Tensor workspace = torch::empty( + {static_cast(workspace_size + reduce_workspace_size)}, torch::device(torch::kXPU).dtype(torch::kByte)); + uint8_t* workspace_ptr = static_cast(workspace.data_ptr()); + + if (!FMHAKernel::can_implement(arguments)) { + // std::cout << "Invalid Problem Size: " << params.batch_size << 'x' + // << params.num_heads_q << 'x' << params.max_queries << 'x' + // << params.max_keys << 'x' << params.head_size << 'x' + // << params.head_size << std::endl; + return cutlass::Status::kErrorInvalidProblem; + } + + // Initialize the workspace + FMHAKernel::initialize_workspace(arguments, workspace_ptr); + + // Convert host-side arguments to device-side arguments to be passed to the + // kernel + auto kernel_params = FMHAKernel::to_underlying_arguments(arguments, workspace_ptr); + auto reduce_params = ReductionSplitKernel::to_underlying_arguments(reduce_arg, workspace_ptr + workspace_size); + + ReductionSplitKernel::initialize_workspace(reduce_arg, workspace_ptr + workspace_size); + run(kernel_params, reduce_params, params.num_kv_splits > 1); + + return cutlass::Status::kSuccess; + } + + static void + run(typename FMHAKernel::Params params, typename ReductionSplitKernel::Params reduce_params, bool need_reduce) { + launch(params); + + if (need_reduce) { + launch(reduce_params); + } + } +}; + template < bool Causal, bool LocalMask, @@ -344,7 +564,7 @@ template < typename GmemTiledCopyK = void, typename GmemTiledCopyV = void, typename GmemTiledCopyO = void> -struct FMHAConfig { +struct DecodeConfig { static constexpr int SGTileQ = get<0>(shape_div(TileShapeQK{}, shape(SubgroupLayoutQK{})))(); using MMAOperation = cute::conditional_t< is_void_v, @@ -439,4 +659,118 @@ struct FMHAConfig { return run(params); } }; + +template < + bool Causal, + bool LocalMask, + bool Sink, + typename TileShapeQK, + typename TileShapePV, + typename TileShapeOutput, + typename SubgroupLayoutQK, + typename SubgroupLayoutPV_ = void /* void -> default */, + int PipelineStages = 1, + typename ElementQ = bfloat16_t, + typename ElementK = bfloat16_t, + typename ElementV = bfloat16_t, + typename ElementO = bfloat16_t, + typename MMAOperation_ = void, /* void -> default */ + typename StrideQ = Stride, + typename StrideK = Stride, + typename StrideV = Stride<_1, int, int, int>, + typename StrideO = Stride, + typename StrideOaccum = Stride, + typename GmemTiledCopyQ = void, /* void -> default block 2D */ + typename GmemTiledCopyK = void, + typename GmemTiledCopyV = void, + typename GmemTiledCopyO = void> +struct SplitDecodeConfig { + static constexpr int SGTileQ = get<0>(shape_div(TileShapeQK{}, shape(SubgroupLayoutQK{})))(); + using MMAOperation = + cute::conditional_t, XE_DPAS_TT, MMAOperation_>; + using SubgroupLayoutPV = cute::conditional_t< + is_void_v, + decltype(cutlass::fmha::collective::get_sg_layout_pv(SubgroupLayoutQK{})), + SubgroupLayoutPV_>; + + template + static void run(const Arguments& params) { + // constexpr bool isVarLen = true; + // constexpr bool PagedKV = true; + cutlass::KernelHardwareInfo hw_info; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + using ProblemShapeType = cutlass::fmha::kernel::FMHAProblemShape; + + using TiledMMAQK = typename TiledMMAHelper, Layout, SubgroupLayoutQK>::TiledMMA; + using TiledMMAPV = typename TiledMMAHelper, Layout, SubgroupLayoutPV>::TiledMMA; + + static_assert( + get<0>(TileShapeOutput{}) == get<0>(TileShapePV{}), + "Output tile and P*V tile have different sizes in Q dimension"); + constexpr int VTiles = get<1>(TileShapeOutput{}) / get<1>(TileShapePV{}); + + auto make_dummy_tensor = [&](auto val, auto stride) { + return make_tensor(make_gmem_ptr(&val), make_layout(repeat>(1), stride)); + }; + + using TensorQ = decltype(make_dummy_tensor(ElementQ{}, StrideQ{})); + using TensorK = decltype(make_dummy_tensor(ElementK{}, StrideK{})); + using TensorV = decltype(make_dummy_tensor(ElementV{}, StrideV{})); + using TensorO = decltype(make_dummy_tensor(ElementO{}, StrideOaccum{})); + using TensorLSE = decltype(make_dummy_tensor(float{}, StrideO{})); + + // Mainloop + using MainloopDispatchPolicy = cutlass::fmha::XeDefault; + using CollectiveMainloop = cutlass::fmha::collective::DecodeFwdMainloop< + MainloopDispatchPolicy, + PagedKV, + Causal, + TiledMMAQK, + TiledMMAPV, + VTiles, + TensorQ, + TensorK, + TensorV, + GmemTiledCopyQ, + GmemTiledCopyK, + GmemTiledCopyV, + LocalMask>; + + // Epilogue + using CollectiveEpilogue = cutlass::fmha::collective:: + DecodeFwdEpilogue; + + using FMHAKernel = cutlass::fmha::kernel:: + XeFMHAFwdSplitKVKernel; + + using ReduceSplitKernel = cutlass::reduction::kernel:: + ReduceSplitK; + + SplitDecodeKernelRunner launcher; + + launcher.run(params, hw_info); + } + + static void run(const Arguments& params) { + return run(params); + } +}; + +// Struct functors for decode kernel dispatch. +// operator() is declared here; each specialization's body is defined in a +// generated .cpp file (from xe_fmha_fwd_decode_kernel.cpp.in / +// xe_fmha_fwd_split_decode_kernel.cpp.in) so the compiler only emits code +// for the combinations that are actually needed. + +template +struct FmhaDecodeRunner { + void operator()(const Arguments& params) const; +}; + +template +struct FmhaSplitDecodeRunner { + void operator()(const Arguments& params) const; +}; + } // namespace decode diff --git a/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_prefill_dispatch.hpp b/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_prefill_dispatch.hpp new file mode 100644 index 00000000..77c7b818 --- /dev/null +++ b/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_prefill_dispatch.hpp @@ -0,0 +1,54 @@ +/*************************************************************************************************** + * Copyright (C) 2026 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "xe_fmha_fwd_prefill_runner.hpp" + +namespace prefill { + +// Explicit instantiation declarations — tell the compiler these are compiled +// in separate translation units (generated from the .cpp.in templates). +// +// Parameters: +// HEAD_DIM in {64, 96, 128, 192, 256, 512} + +#define EXTERN_FMHA_PREFILL_RUNNER(HD) extern template struct FmhaPrefillRunner; + +EXTERN_FMHA_PREFILL_RUNNER(64) +EXTERN_FMHA_PREFILL_RUNNER(96) +EXTERN_FMHA_PREFILL_RUNNER(128) +EXTERN_FMHA_PREFILL_RUNNER(192) +EXTERN_FMHA_PREFILL_RUNNER(256) +EXTERN_FMHA_PREFILL_RUNNER(512) + +#undef EXTERN_FMHA_PREFILL_RUNNER + +} // namespace prefill diff --git a/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_prefill_runner.hpp b/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_prefill_runner.hpp index 813b1f9c..e8a70854 100644 --- a/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_prefill_runner.hpp +++ b/src/sycl/kernels/flash_attention_v2/xe_fmha_fwd_prefill_runner.hpp @@ -439,285 +439,15 @@ struct FMHAConfig { return run(params); } }; -std::vector mha_fwd( - at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, - // h_k, d) if there is page_table. - const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, - // page_size, h_k, dv) if there is page_table. - std::optional& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q - const at::Tensor& cu_seqlens_q, // b+1 - const at::Tensor& cu_seqlens_k, // b+1 - int max_seqlen_q, - int max_seqlen_k, - std::optional& page_table, // (b_k, max_num_pages_per_seq) - std::optional& kv_batch_idx_, // b. indices to index into the KV cache - std::optional& leftpad_k_, // b - std::optional& rotary_cos_, // seqlen_ro x (rotary_dim / 2) - std::optional& rotary_sin_, // seqlen_ro x (rotary_dim / 2) - std::optional& seqlens_rotary_, // b - std::optional& q_descale_, // (b, h_k), not (b, h) - std::optional& k_descale_, // (b, h_k) - std::optional& v_descale_, // (b, h_k) - const float softmax_scale_, - std::optional& sinks_, - bool is_causal, - int window_size_left, - int window_size_right, - float const softcap, - bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 - std::optional& scheduler_metadata_, // (b + 1) - int num_splits, - std::optional pack_gqa_, - int const sm_margin) { - auto q_type = q.scalar_type(); - TORCH_CHECK( - q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, - "mha_fwd only supports Half and BFloat16, got", - q_type); - - TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype"); - TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype"); - CHECK_LAST_DIM_CONTIGUOUS_INPUT(q); - CHECK_LAST_DIM_CONTIGUOUS_INPUT(k); - CHECK_LAST_DIM_CONTIGUOUS_INPUT(v); - - TORCH_CHECK(page_table.value().dtype() == torch::kInt32, "page_table must have dtype torch.int32"); - TORCH_CHECK(page_table.value().stride(-1) == 1, "page_table must have contiguous last dimension"); - - TORCH_CHECK(q.dim() == 3, "query must be in ragged format"); - CHECK_INPUT(cu_seqlens_q); - TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32"); - - CHECK_INPUT(cu_seqlens_k); - TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32"); - - auto const sizes = q.sizes(); - const int batch_size = cu_seqlens_q.size(0) - 1; - int seqlen_q = max_seqlen_q; - int total_q = q.size(0); - int num_heads = q.size(-2); - int const head_size = q.size(-1); - int const head_size_v = v.size(-1); - int const max_num_pages_per_seq = page_table.value().size(1); - int const num_pages = k.size(0); - int const page_size = k.size(1); - int const seqlen_k = max_num_pages_per_seq * page_size; - int const total_k = num_pages * page_size; - int const num_heads_k = k.size(-2); - - int const batch_size_k = page_table.value().size(0); - float softmax_scale = softmax_scale_; - - if (!kv_batch_idx_.has_value()) { - TORCH_CHECK(batch_size == batch_size_k, "batch_size must be equal to batch_size_k"); - } - - // Currently only support head dims <= 256 - static constexpr int max_headdim = 256; - TORCH_CHECK(head_size <= max_headdim, "FlashAttention forward only supports head dimension at most ", max_headdim); - TORCH_CHECK(num_heads == num_heads_k, "Only support number of heads in key/value equals to number of heads in query"); - - // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM - // TODO: check this - - if (window_size_left >= seqlen_k - 1) { - window_size_left = -1; - } - window_size_right = min(window_size_right, seqlen_q); - // causal=true is the same as causal=false in this case - if (is_causal) { - window_size_right = 0; - } - - CHECK_SHAPE(k, num_pages, page_size, num_heads_k, head_size); - CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size_v); - CHECK_SHAPE(page_table.value(), batch_size_k, max_num_pages_per_seq); - - if (leftpad_k_.has_value()) { - auto leftpad_k = leftpad_k_.value(); - TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); - CHECK_INPUT(leftpad_k); - CHECK_SHAPE(leftpad_k, batch_size); - } - static constexpr int alignment = 8; - TORCH_CHECK(head_size % alignment == 0, "head_size should be a multiple of " + std::to_string(alignment)); - TORCH_CHECK(head_size_v % alignment == 0, "head_size_v should be a multiple of " + std::to_string(alignment)); - - auto opts = q.options(); - at::Tensor out; - out = torch::empty({total_q, num_heads, head_size_v}, opts); - - int const head_size_rounded = round_up_headdim(head_size); - int const head_size_v_rounded = head_size_v == head_size ? head_size_rounded : round_up_headdim(head_size_v); - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - c10::DeviceGuard device_guard(q.device()); - - at::Tensor softmax_lse; - softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); - - // align with FA3 - Arguments params; - params.is_bf16 = q.dtype() == torch::kBFloat16; - - // Set the pointers and strides. - params.q_ptr = q.data_ptr(); - params.k_ptr = k.data_ptr(); - params.v_ptr = v.data_ptr(); - // All stride are in elements, not bytes. - params.q_row_stride = q.stride(-3); - params.k_row_stride = k.stride(-3); - params.v_row_stride = v.stride(-3); - params.q_head_stride = q.stride(-2); - params.k_head_stride = k.stride(-2); - params.v_head_stride = v.stride(-2); - params.v_dim_stride = v.stride(-1); - params.o_ptr = out.data_ptr(); - params.o_row_stride = out.stride(-3); - params.o_head_stride = out.stride(-2); - - params.cu_seqlens_q = cu_seqlens_q.data_ptr(); - params.cu_seqlens_k = cu_seqlens_k.data_ptr(); - - // Softmax sum - params.softmax_lse_ptr = softmax_lse.data_ptr(); - - // Set the dimensions. - params.b = batch_size; - params.h = num_heads; - params.h_k = num_heads_k; - params.q_group_size = 1; - params.seqlen_q = seqlen_q; - params.seqlen_k = seqlen_k; - params.d = head_size; - params.d_rounded = head_size_rounded; - - // Set the different scale values. - params.softmax_scale = softmax_scale; - bool use_sink = sinks_.has_value(); - params.softmax_sink_ptr = use_sink ? sinks_.value().data_ptr() : nullptr; - - params.softcap = softcap; - - // Set this to probability of keeping an element to simplify things. - params.p_dropout = 1.f; - - // Causal is the special case where window_size_right == 0 and window_size_left < 0. - // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. - params.is_causal = window_size_left < 0 && window_size_right == 0; - params.is_local = (window_size_left >= 0 || window_size_right >= 0) && !params.is_causal; - - // TODO: check this - if (window_size_left < 0) { - window_size_left = seqlen_k - 1; - } - if (window_size_right < 0) { - window_size_right = seqlen_q - 1; - } - params.window_size_left = window_size_left; - params.window_size_right = window_size_right; - params.total_q = total_q; - params.total_k = total_k; - params.b_k = batch_size_k; - params.dv = head_size_v; - params.page_table = page_table.value().data_ptr(); - params.page_table_batch_stride = page_table.value().stride(0); - params.max_num_pages_per_seq = max_num_pages_per_seq; - params.page_size = page_size; - params.num_pages = num_pages; - - if (q_v_.has_value()) { - TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64"); - TORCH_CHECK( - q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, - "q_v is only supported for fp16 and bf16 data type"); - TORCH_CHECK(false, "q_v is not supported yet"); - at::Tensor q_v = q_v_.value(); - TORCH_CHECK(q_v.dtype() == q_type, "q_v must have the same dtype as query"); - TORCH_CHECK(q_v.stride(-1) == 1, "q_v tensor must have contiguous last dimension"); - CHECK_SHAPE(q_v, total_q, num_heads, head_size_v); - params.qv_ptr = q_v.data_ptr(); - // All stride are in elements, not bytes. - params.qv_row_stride = q_v.stride(-3); - params.qv_head_stride = q_v.stride(-2); - } +// Struct functor for prefill kernel dispatch. +// operator() is declared here; each specialization's body is defined in a +// generated .cpp file (from xe_fmha_fwd_prefill_kernel.cpp.in) so the compiler +// only emits code for the combinations that are actually needed. - if (rotary_cos_.has_value()) { - auto rotary_cos = rotary_cos_.value(); - CHECK_INPUT(rotary_cos); - params.rotary_dim = rotary_cos.size(1) * 2; - TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim"); - TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported"); - const int seqlen_ro = rotary_cos.size(0); - TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache"); - CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2); - TORCH_CHECK(rotary_cos.scalar_type() == q_type, "rotary_cos must have the same dtype as query"); - - TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided"); - auto rotary_sin = rotary_sin_.value(); - CHECK_INPUT(rotary_sin); - CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2); - TORCH_CHECK(rotary_sin.scalar_type() == q_type, "rotary_cos must have the same dtype as query"); - params.rotary_cos_ptr = rotary_cos.data_ptr(); - params.rotary_sin_ptr = rotary_sin.data_ptr(); - params.is_rotary_interleaved = is_rotary_interleaved; - if (seqlens_rotary_.has_value()) { - at::Tensor seqlens_rotary = seqlens_rotary_.value(); - CHECK_INPUT(seqlens_rotary); - TORCH_CHECK(seqlens_rotary.dtype() == torch::kInt32, "seqlens_rotary must have dtype torch.int32"); - CHECK_SHAPE(seqlens_rotary, batch_size); - params.seqlens_rotary = seqlens_rotary.data_ptr(); - } - } else { - params.rotary_dim = 0; - } - - if (kv_batch_idx_.has_value()) { - auto kv_batch_idx = kv_batch_idx_.value(); - CHECK_INPUT(kv_batch_idx); - TORCH_CHECK(kv_batch_idx.scalar_type() == torch::kInt32, "kv_batch_idx must have dtype int32"); - params.kv_batch_idx = reinterpret_cast(kv_batch_idx.data_ptr()); - } +template +struct FmhaPrefillRunner { + void operator()(const Arguments& params) const; +}; - params.tensor_opts = torch::TensorOptions().dtype(torch::kUInt8).device(q.device()); - - at::Tensor out_accum, softmax_lse_accum; - auto outaccum_type = at::ScalarType::Float; - - auto launch_kernel = [&](auto _TILED_Q, auto _TILED_KV, auto _HEAD_DIM, auto _NUM_SG) { - using TileShapeQK = cute::Shape; - using TileShapePV = cute::Shape; - using TileShapeOutput = cute::Shape; - using SubgroupLayoutQK = cute::Layout, cute::Stride<_1, _1, _1>>; - - AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { - AT_DISPATCH_BOOL_NO_RETURN(params.is_local, LocalMask, { - AT_DISPATCH_BOOL_NO_RETURN(params.is_causal, Causal, { - FMHAConfig::run(params); - }); - }); - }); - }; - - switch (params.d) { - case 64: - launch_kernel(_128{}, _64{}, _64{}, _8{}); - break; - case 96: - launch_kernel(_128{}, _64{}, _96{}, _8{}); - break; - case 128: - launch_kernel(_256{}, _32{}, _128{}, _16{}); - break; - case 192: - launch_kernel(_256{}, _64{}, _192{}, _32{}); - break; - default: - TORCH_CHECK(false, "Unsupported head size for prefill attention: ", params.d); - } - return {out, softmax_lse, out_accum, softmax_lse_accum}; -} } // namespace prefill diff --git a/src/sycl/xe_fmha_fwd_decode_kernel.cpp.in b/src/sycl/xe_fmha_fwd_decode_kernel.cpp.in index c2ac6ba2..7d36473d 100644 --- a/src/sycl/xe_fmha_fwd_decode_kernel.cpp.in +++ b/src/sycl/xe_fmha_fwd_decode_kernel.cpp.in @@ -30,27 +30,27 @@ * **************************************************************************************************/ // Auto-generated from xe_fmha_fwd_decode_kernel.cpp.in -// Template parameters: QG_SZ=@QG_SZ@, HEAD_DIM=@HEAD_DIM@, PAGE_SIZE=@PAGE_SIZE@, NUM_SG=@NUM_SG@ +// Template parameters: QG_SZ=@QG_SZ@, HEAD_DIM=@HEAD_DIM@, PAGE_SIZE=@PAGE_SIZE@ #define SYCL_INTEL_TARGET 20 #include "sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp" namespace decode { -void launch_fmha_decode_@QG_SZ@_@HEAD_DIM@_@PAGE_SIZE@(bool use_sink, const Arguments& params) { - using namespace cute; - - constexpr bool Causal = false; +template <> +void FmhaDecodeRunner<@QG_SZ@, @HEAD_DIM@, @PAGE_SIZE@>::operator()(const Arguments& params) const { using TileShapeQK = cute::Shape, cute::Int<@PAGE_SIZE@>, cute::_64>; using TileShapePV = cute::Shape, cute::_32, cute::Int<@PAGE_SIZE@>>; using TileShapeOutput = cute::Shape, cute::Int<@HEAD_DIM@>>; - using SubgroupLayoutQK = cute::Layout, cute::_1>>; + using SubgroupLayoutQK = cute::Layout, cute::_1>>; - AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { - AT_DISPATCH_BOOL_NO_RETURN(params.is_local, LocalMask, { - FMHAConfig::run(params); - }); + AT_DISPATCH_BOOL_NO_RETURN(params.use_sink, Sink, { + AT_DISPATCH_BOOL_NO_RETURN(params.is_local, LocalMask, { + DecodeConfig::run(params); + }); }); } +template struct FmhaDecodeRunner<@QG_SZ@, @HEAD_DIM@, @PAGE_SIZE@>; + } // namespace decode diff --git a/src/sycl/xe_fmha_fwd_prefill_kernel.cpp.in b/src/sycl/xe_fmha_fwd_prefill_kernel.cpp.in new file mode 100644 index 00000000..e71b42e1 --- /dev/null +++ b/src/sycl/xe_fmha_fwd_prefill_kernel.cpp.in @@ -0,0 +1,59 @@ +/*************************************************************************************************** + * Copyright (C) 2026 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +// Auto-generated from xe_fmha_fwd_prefill_kernel.cpp.in +// Template parameters: HEAD_DIM=@HEAD_DIM@, TILED_Q=@TILED_Q@, TILED_KV=@TILED_KV@, NUM_SG=@NUM_SG@ +#define SYCL_INTEL_TARGET 20 + +#include "sycl/kernels/flash_attention_v2/xe_fmha_fwd_prefill_runner.hpp" + +namespace prefill { + +template <> +void FmhaPrefillRunner<@HEAD_DIM@>::operator()(const Arguments& params) const { + using TileShapeQK = cute::Shape, cute::Int<@TILED_KV@>, cute::_32>; + using TileShapePV = cute::Shape, cute::_32, cute::Int<@TILED_KV@>>; + using TileShapeOutput = cute::Shape, cute::Int<@HEAD_DIM@>>; + using SubgroupLayoutQK = + cute::Layout, cute::_1, cute::_1>, cute::Stride>; + + bool use_sink = (params.softmax_sink_ptr != nullptr); + AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { + AT_DISPATCH_BOOL_NO_RETURN(params.is_local, LocalMask, { + AT_DISPATCH_BOOL_NO_RETURN(params.is_causal, Causal, { + FMHAConfig::run(params); + }); + }); + }); +} + +template struct FmhaPrefillRunner<@HEAD_DIM@>; + +} // namespace prefill diff --git a/src/sycl/xe_fmha_fwd_split_decode_kernel.cpp.in b/src/sycl/xe_fmha_fwd_split_decode_kernel.cpp.in new file mode 100644 index 00000000..e8ccdd14 --- /dev/null +++ b/src/sycl/xe_fmha_fwd_split_decode_kernel.cpp.in @@ -0,0 +1,57 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +// Auto-generated from xe_fmha_fwd_split_decode_kernel.cpp.in +// Template parameters: QG_SZ=@QG_SZ@, HEAD_DIM=@HEAD_DIM@, PAGE_SIZE=@PAGE_SIZE@ +#define SYCL_INTEL_TARGET 20 + +#include "sycl/kernels/flash_attention_v2/xe_fmha_fwd_decode_runner.hpp" + +namespace decode { + +template <> +void FmhaSplitDecodeRunner<@QG_SZ@, @HEAD_DIM@, @PAGE_SIZE@>::operator()(const Arguments& params) const { + using TileShapeQK = cute::Shape, cute::Int<@PAGE_SIZE@>, cute::_64>; + using TileShapePV = cute::Shape, cute::_32, cute::Int<@PAGE_SIZE@>>; + using TileShapeOutput = cute::Shape, cute::Int<@HEAD_DIM@>>; + using SubgroupLayoutQK = cute::Layout, cute::_1>>; + + AT_DISPATCH_BOOL_NO_RETURN(params.use_sink, Sink, { + AT_DISPATCH_BOOL_NO_RETURN(params.is_local, LocalMask, { + SplitDecodeConfig::run( + params); + }); + }); +} + +template struct FmhaSplitDecodeRunner<@QG_SZ@, @HEAD_DIM@, @PAGE_SIZE@>; + +} // namespace decode diff --git a/src/torch_extension_sycl.cc b/src/torch_extension_sycl.cc index 853b380c..e6c8b2ee 100644 --- a/src/torch_extension_sycl.cc +++ b/src/torch_extension_sycl.cc @@ -94,7 +94,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { * From cutlass attention */ m.def( - "fwd(Tensor! q," + "fwd(Tensor q," " Tensor k," " Tensor v," " Tensor? q_v," @@ -119,7 +119,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { " float softcap," " bool is_rotary_interleaved," " Tensor? scheduler_metadata," - " int num_splits," + " int num_kv_splits," " bool? pack_gqa," " int sm_margin) -> Tensor[]"); m.impl("fwd", torch::kXPU, make_pytorch_shim(&mha_fwd)); diff --git a/tests/test_flash_attention.py b/tests/test_flash_attention.py index 22e7eb80..ace0150d 100644 --- a/tests/test_flash_attention.py +++ b/tests/test_flash_attention.py @@ -498,7 +498,7 @@ def generate_qkv( @pytest.mark.parametrize("has_leftpad", [False]) @pytest.mark.parametrize("has_batch_idx", [False]) @pytest.mark.parametrize("varlen_q", [True]) -@pytest.mark.parametrize("d", [512]) +@pytest.mark.parametrize("d", [64, 128, 256, 512]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ @@ -985,7 +985,7 @@ def test_flash_attn_kvcache( @pytest.mark.parametrize("mha_type", ["mha"]) @pytest.mark.parametrize("new_kv", [False]) @pytest.mark.parametrize("causal", [False]) -@pytest.mark.parametrize("local", [False]) +@pytest.mark.parametrize("local", [True, False]) @pytest.mark.parametrize("use_sinks", [False]) @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) @pytest.mark.parametrize("has_rotary_seqlens", [False]) @@ -1006,21 +1006,18 @@ def test_flash_attn_kvcache( @pytest.mark.parametrize("varlen_q", [True]) @pytest.mark.parametrize("d", [64, 128, 256, 512]) @pytest.mark.parametrize("seqlen_q", [1]) +@pytest.mark.parametrize("batch_size", [1, 4]) @pytest.mark.parametrize( "seqlen_k", [ 128, - 256, - 339, 1024, - 800, - 256, - 799, - 2048, - 20000, + 4096, + 8192, ], ) def test_flash_attn_decode_kvcache( + batch_size, seqlen_q, seqlen_k, d, @@ -1051,10 +1048,9 @@ def test_flash_attn_decode_kvcache( pytest.skip() # set seed torch.random.manual_seed(0) - batch_size = 16 batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 nheads = 16 - nheads_k = 4 # nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) + nheads_k = 16 # nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) assert nheads % nheads_k == 0 if seqlen_k <= seqlen_q: @@ -1191,7 +1187,7 @@ def test_flash_attn_decode_kvcache( cache_seqlens = torch.randint( seqlen_q, # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough - seqlen_k, + seqlen_k + 1, (batch_size,), dtype=torch.int32, device=device, @@ -1471,6 +1467,7 @@ def test_flash_attn_decode_kvcache( assert (out - out_ref).abs().mean().item() <= mult_mean * ( out_pt - out_ref ).abs().mean().item() + torch.xpu.empty_cache() def _generate_block_kvcache(