[ExecuTorch][WebGPU] Register-tile the q4gsw quantized-linear kernel#20456
[ExecuTorch][WebGPU] Register-tile the q4gsw quantized-linear kernel#20456JulianCloudNTH wants to merge 3 commits into
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20456
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ⏳ 117 Pending, 1 Unrelated FailureAs of commit b3965f2 with merge base 68bb668 ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
@claude review if there are any bugs |
|
Claude finished @JulianCloudNTH's task in 2m 10s —— View job Code Review: Register-tiled q4gsw kernelI reviewed the shader (
Verdict: No correctness bugs found ✅The transformation is sound. Specifically I verified:
Performance notes (not bugs)
Minor
Both perf items are optional and the second is explicitly deferred per the PR description, so nothing blocks this change from a correctness standpoint. |
Stack from ghstack (oldest at bottom):
Register-tile the
et_vk.linear_q4gswGEMM — up to 3.4x faster prefill (M4 Pro, M=128).Problem:
et_vk.linear_q4gsw(4-bit weight-only, W4A16) computesout[m,n] = bias[n] + sum_k input[m,k] * (nibble(weight,n,k)-8) * scale[k/group_size, n]in a single dispatch over a raw[N, K/2]4-bit weight (2 nibbles/byte, +8-shifted symmetric, groupwise scales). The shipped kernel was naive: one workgroup per output rowm, threads stridingN, a scalar K-loop per(m,n). For an M-row (prefill) input it re-extracts every dequantized weightMtimes (once per row) and re-reads each input value once per output column — redundant memory traffic that dominates the prefill GEMM.Solution: a register-tiled GEMM where each thread owns a
TM x TN = 4x4output tile, so both weights and inputs are loaded once per tile instead of once per element.(n,k)dequantized once per(m,n)(extractedMx for prefill);input[m,k]re-read once per output columnn.(n,k)dequantized ONCE and reused across theTMrows of the tile (weight reads drop ~TMx); eachinput[m,k]loaded once perkinto a register and reused across theTNcolumns (input reads drop ~TNx).Implementation:
q4gsw_linear.wgsl: perk, hoist theTMinput values into registers, then for each of theTNcolumns dequantize the weight once and accumulate into the4x4register tile.Mworkgroups toceil(M/TM)*ceil(N/TN)tiles overwg_sizethreads;wg_sizeis computed before the count so the dispatch is still validated against device limits before any allocation.n0+nl >= Norm0+ml >= M) clamp their weight/scale/input index to the last valid element (the never-stored overhang is harmless), since WGSL out-of-bounds reads are implementation-defined. Mirrors the Vulkan tiled GEMMq4gsw_linear_gemm__w_4x8.glsl'smin(..., N-1)clamp.4x4tile (vs Vulkan4M x 8N) for a conservative register budget; the RAW[N,K/2]layout with scalar nibble unpack and NOW_4X8prepack / NO widevec4<u32>loads (prior on-device measurement found wide loads regress on this GPU); a 1D-flattened tile index (the backend is 1D-dispatch only).Constraints: bindings,
Params, the weight layout, and the single-dispatch structure are unchanged; the dequant index math is copied verbatim from the naive kernel, so the result is a floating-point accumulation reorder equal to the naive output to fp-rounding. TheM=1decode GEMV path and host M-based routing are a separate follow-up.Authored with assistance from Claude Code.
@exported-using-ghexport
Differential Revision: D109250327
Differential Revision: D109250327