fa2 xe2: hoist Q/KV dtype dispatch out of policy_dispatch_impl#349
Draft
jasonboukheir wants to merge 1 commit into
Draft
fa2 xe2: hoist Q/KV dtype dispatch out of policy_dispatch_impl#349jasonboukheir wants to merge 1 commit into
jasonboukheir wants to merge 1 commit into
Conversation
Each generated chunk_prefill / paged_decode TU previously instantiated exactly one (policy, bool...) specialisation of *_dispatch_impl, but the function body itself ran a runtime if-tree over six (Q dtype, KV dtype) combinations and instantiated a full FMHAConfig / PagedDecodeConfig kernel pipeline in every branch. icpx had to hold all six pipelines in the frontend AST + middle-end IR + SYCL device-image bundler within a single process; on the heavy q16_h256_p128 paged_decode TU that spiked to ~40 GB RSS and OOM'd 24-core builders even at modest --max-jobs. Move Q and KV dtypes to template parameters of *_dispatch_impl, drop the unused CutlassQKType& argument from the impl signature, and surface the runtime dispatch in the *_dispatch_func helpers in chunk_prefill_utils.hpp / paged_decode_utils.hpp. Those helpers now trampoline into one of six extern-declared specialisations per (policy, bool...) tuple. The cmake generators emit one .cpp per (policy, bool..., Q, KV); the X-macro chains in *_extern.hpp fan 6-wide on the leaf so all specialisations are declared as extern templates. Effect: ~6x more TUs but each compiles a single SYCL kernel pipeline. Per-TU peak RSS drops to ~7 GB on the worst-case q16_h256_p128 TU, which is below the 12-worker budget on a 24-core 96-GB box without serialising the heavy tail. The accompanying PCH change amortises the duplicated header parse cost once per (cmake configure) instead of once per TU. For non-Nix consumers (Bazel, distributed CI, sccache shards) the change is byte-equivalent at the .so level — only the kernel TU graph is wider and shallower. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Each generated
chunk_prefill/paged_decodeTU previously instantiated exactly one(policy, bool...)specialisation of*_dispatch_impl, but the function body itself ran a runtimeif/elsetree over six(Q dtype, KV dtype)combinations and instantiated a fullFMHAConfig/PagedDecodeConfigkernel pipeline in every branch. icpx had to hold all six pipelines in the frontend AST + middle-end IR + SYCL device-image bundler within a single process; on the heavyq16_h256_p128paged_decode TU that spiked to ~40 GB RSS and OOM'd 24-core builders even at modest--max-jobs.Change
Move Q and KV dtypes to template parameters of
*_dispatch_impl, drop the unusedCutlassQKType&argument from the impl signature, and surface the runtime dispatch in the*_dispatch_funchelpers inchunk_prefill_utils.hpp/paged_decode_utils.hpp. Those helpers now trampoline into one of six extern-declared specialisations per(policy, bool...)tuple. The cmake generators emit one.cppper(policy, bool..., Q, KV); the X-macro chains in*_extern.hppfan 6-wide on the leaf so all specialisations are declared as extern templates.Effect
~6× more TUs but each compiles a single SYCL kernel pipeline. Per-TU peak RSS drops to ~7 GB on the worst-case
q16_h256_p128TU, which is below the 12-worker budget on a 24-core 96-GB box without serialising the heavy tail.The companion PCH PR (#350) amortises the now-duplicated header parse cost once per (cmake configure) instead of once per TU, which makes the "more TUs" wall-time penalty disappear entirely.
For non-Nix consumers (Bazel, distributed CI, sccache shards) the change is byte-equivalent at the
.solevel — only the kernel TU graph is wider and shallower.Draft — willing to iterate on the X-macro fan-out style or extract the dtype matrix into a config header.