Commit babc76f
committed
[ExecuTorch][WebGPU] Register-tile the q4gsw quantized-linear kernel
Pull Request resolved: #20456
**Register-tile the `et_vk.linear_q4gsw` GEMM — up to 3.4x faster prefill (M4 Pro, M=128).**
**Problem:** `et_vk.linear_q4gsw` (4-bit weight-only, W4A16) computes `out[m,n] = bias[n] + sum_k input[m,k] * (nibble(weight,n,k)-8) * scale[k/group_size, n]` in a single dispatch over a raw `[N, K/2]` 4-bit weight (2 nibbles/byte, +8-shifted symmetric, groupwise scales). The shipped kernel was naive: one workgroup per output row `m`, threads striding `N`, a scalar K-loop per `(m,n)`. For an M-row (prefill) input it re-extracts every dequantized weight `M` times (once per row) and re-reads each input value once per output column — redundant memory traffic that dominates the prefill GEMM.
**Solution:** a register-tiled GEMM where each thread owns a `TM x TN = 4x4` output tile, so both weights and inputs are loaded once per tile instead of once per element.
- Before: weight `(n,k)` dequantized once per `(m,n)` (extracted `M`x for prefill); `input[m,k]` re-read once per output column `n`.
- After: weight `(n,k)` dequantized ONCE and reused across the `TM` rows of the tile (weight reads drop ~`TM`x); each `input[m,k]` loaded once per `k` into a register and reused across the `TN` columns (input reads drop ~`TN`x).
**Implementation:**
- New loop nest in `q4gsw_linear.wgsl`: per `k`, hoist the `TM` input values into registers, then for each of the `TN` columns dequantize the weight once and accumulate into the `4x4` register tile.
- Host dispatch changes from `M` workgroups to `ceil(M/TM)*ceil(N/TN)` tiles over `wg_size` threads; `wg_size` is computed before the count so the dispatch is still validated against device limits before any allocation.
- Tile-edge lanes (`n0+nl >= N` or `m0+ml >= M`) clamp their weight/scale/input index to the last valid element (the never-stored overhang is harmless), since WGSL out-of-bounds reads are implementation-defined. Mirrors the Vulkan tiled GEMM `q4gsw_linear_gemm__w_4x8.glsl`'s `min(..., N-1)` clamp.
- Deliberate deviations from the Vulkan kernel (recorded in DESIGN_DECISIONS): a `4x4` tile (vs Vulkan `4M x 8N`) for a conservative register budget; the RAW `[N,K/2]` layout with scalar nibble unpack and NO `W_4X8` prepack / NO wide `vec4<u32>` loads (prior on-device measurement found wide loads regress on this GPU); a 1D-flattened tile index (the backend is 1D-dispatch only).
**Constraints:** bindings, `Params`, the weight layout, and the single-dispatch structure are unchanged; the dequant index math is copied verbatim from the naive kernel, so the result is a floating-point accumulation reorder equal to the naive output to fp-rounding. The `M=1` decode GEMV path and host M-based routing are a separate follow-up.
Authored with assistance from Claude Code.
ghstack-source-id: 396677641
@exported-using-ghexport
Differential Revision: [D109250327](https://our.internmc.facebook.com/intern/diff/D109250327/)1 parent 68bb668 commit babc76f
4 files changed
Lines changed: 127 additions & 57 deletions
File tree
- backends/webgpu/runtime
- ops/quantized_linear
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
18 | 18 | | |
19 | 19 | | |
20 | 20 | | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
21 | 27 | | |
22 | 28 | | |
23 | 29 | | |
| |||
34 | 40 | | |
35 | 41 | | |
36 | 42 | | |
37 | | - | |
| 43 | + | |
38 | 44 | | |
39 | 45 | | |
40 | 46 | | |
| |||
Lines changed: 15 additions & 5 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
34 | 34 | | |
35 | 35 | | |
36 | 36 | | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
37 | 41 | | |
38 | 42 | | |
39 | 43 | | |
| |||
85 | 89 | | |
86 | 90 | | |
87 | 91 | | |
88 | | - | |
89 | | - | |
90 | | - | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
91 | 103 | | |
92 | 104 | | |
93 | 105 | | |
| |||
186 | 198 | | |
187 | 199 | | |
188 | 200 | | |
189 | | - | |
190 | | - | |
191 | 201 | | |
192 | 202 | | |
193 | 203 | | |
| |||
Lines changed: 52 additions & 25 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
18 | 18 | | |
19 | 19 | | |
20 | 20 | | |
21 | | - | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
22 | 26 | | |
23 | | - | |
24 | | - | |
25 | | - | |
26 | | - | |
27 | | - | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
28 | 34 | | |
29 | 35 | | |
30 | | - | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
31 | 45 | | |
32 | | - | |
| 46 | + | |
33 | 47 | | |
34 | | - | |
| 48 | + | |
35 | 49 | | |
36 | 50 | | |
37 | | - | |
38 | | - | |
39 | | - | |
40 | | - | |
41 | | - | |
42 | | - | |
43 | | - | |
44 | | - | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
45 | 61 | | |
46 | 62 | | |
47 | 63 | | |
48 | 64 | | |
49 | | - | |
| 65 | + | |
50 | 66 | | |
51 | 67 | | |
52 | 68 | | |
53 | 69 | | |
54 | | - | |
55 | | - | |
56 | | - | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
57 | 74 | | |
58 | | - | |
59 | | - | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
60 | 89 | | |
61 | | - | |
62 | | - | |
63 | 90 | | |
64 | 91 | | |
Lines changed: 53 additions & 26 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
13 | 13 | | |
14 | 14 | | |
15 | 15 | | |
16 | | - | |
| 16 | + | |
17 | 17 | | |
18 | 18 | | |
19 | 19 | | |
| |||
35 | 35 | | |
36 | 36 | | |
37 | 37 | | |
38 | | - | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
39 | 43 | | |
40 | | - | |
41 | | - | |
42 | | - | |
43 | | - | |
44 | | - | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
45 | 51 | | |
46 | 52 | | |
47 | | - | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
48 | 62 | | |
49 | | - | |
| 63 | + | |
50 | 64 | | |
51 | | - | |
| 65 | + | |
52 | 66 | | |
53 | 67 | | |
54 | | - | |
55 | | - | |
56 | | - | |
57 | | - | |
58 | | - | |
59 | | - | |
60 | | - | |
61 | | - | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
62 | 78 | | |
63 | 79 | | |
64 | 80 | | |
65 | 81 | | |
66 | | - | |
| 82 | + | |
67 | 83 | | |
68 | 84 | | |
69 | 85 | | |
70 | 86 | | |
71 | | - | |
72 | | - | |
73 | | - | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
74 | 91 | | |
75 | | - | |
76 | | - | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
77 | 106 | | |
78 | | - | |
79 | | - | |
80 | 107 | | |
81 | 108 | | |
82 | 109 | | |
| |||
0 commit comments