[ExecuTorch][WebGPU] q4gsw: route M==1 decode to a cooperative GEMV#20457
[ExecuTorch][WebGPU] q4gsw: route M==1 decode to a cooperative GEMV#20457JulianCloudNTH wants to merge 4 commits into
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20457
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ⏳ No Failures, 132 PendingAs of commit d266a3d with merge base 68bb668 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
|
@claude review if there are any bugs |
|
Claude finished @JulianCloudNTH's task in 4m 26s —— View job Bug review: q4gsw coop4 GEMV
Verdict: no functional bugs found. The kernel, routing, and dispatch math all check out under the stated constraints. Details below. Correctness checks that pass ✅1. Weight layout matches the GEMM exactly. In 2. The 3. 4. All buffer accesses are in-bounds. weight: max 5. Barriers are in uniform control flow. The grid-stride outer loop depends only on 6. Dispatch + grid-stride. Minor nits (non-blocking)
The 3 reported CI "new failures" are |
Stack from ghstack (oldest at bottom):
Add optimized GEMV kernel for M==1 decode path in q4gsw quantized-linear.
Problem: The register-tiled GEMM (from D109250327) wastes 75% of each 4×N tile when M=1, as only 1 of 4 rows is used.
Solution: Add a cooperative GEMV kernel that routes M==1 decode to a more efficient path:
Routing Logic (build-time selection, M is static per graph):
Constraints:
Implementation:
Performance: Keeps decode at measured bandwidth plateau, avoids M=1 tile waste. GEMV uses different reduction order (agrees to fp-rounding, not bit-exact).
@exported-using-ghexport
Differential Revision: D109250570
Differential Revision: D109250570