Skip to content

Commit 5b6b1ae

Browse files
[ExecuTorch][WebGPU] q4gsw: route M==1 decode to a cooperative GEMV
Pull Request resolved: #20457 Add optimized GEMV kernel for M==1 decode path in q4gsw quantized-linear. **Problem**: The register-tiled GEMM (from D109250327) wastes 75% of each 4×N tile when M=1, as only 1 of 4 rows is used. **Solution**: Add a cooperative GEMV kernel that routes M==1 decode to a more efficient path: - **GEMV**: 64 lanes per workgroup cooperate over K-dimension, each lane loads u32 words (8 K-values), reduces via shared memory - **GEMM**: M>1 prefill continues using the tiled GEMM **Routing Logic** (build-time selection, M is static per graph): - Use GEMV when: M==1 && K%8==0 && group_size%8==0 - Otherwise: Fall back to tiled GEMM **Constraints**: - K%8==0: Kernel loads 8 K-values per u32 word - group_size%8==0: Ensures no quantization-group boundary splits a word (validated via CPU cross-check) - Llama models (group_size=32/64) satisfy both constraints **Implementation**: - New kernel: q4gsw_linear_coop4.wgsl (fixed 64-lane workgroup) - New utility: clamp_workgroup_count() for grid-stride dispatch (vs compute_1d_workgroup_count which throws) - Shared infrastructure: Same bind layout, Params, weight format **Performance**: Keeps decode at measured bandwidth plateau, avoids M=1 tile waste. GEMV uses different reduction order (agrees to fp-rounding, not bit-exact). ghstack-source-id: 396619622 @exported-using-ghexport Differential Revision: [D109250570](https://our.internmc.facebook.com/intern/diff/D109250570/)
1 parent 6b1271d commit 5b6b1ae

4 files changed

Lines changed: 244 additions & 15 deletions

File tree

backends/webgpu/runtime/WebGPUUtils.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,4 +76,16 @@ make_uniform(WGPUDevice device, const void* data, size_t size) {
7676
return buf;
7777
}
7878

79+
// Clamp a 1D workgroup count to the device limit, for grid-stride kernels that
80+
// loop over any excess work (vs compute_1d_workgroup_count, which throws).
81+
inline uint32_t clamp_workgroup_count(WGPUDevice device, uint32_t desired) {
82+
WGPULimits limits = {};
83+
uint32_t max_count =
84+
wgpuDeviceGetLimits(device, &limits) == WGPUStatus_Success &&
85+
limits.maxComputeWorkgroupsPerDimension > 0
86+
? limits.maxComputeWorkgroupsPerDimension
87+
: 65535u; // WebGPU spec-default floor
88+
return std::min(desired, max_count);
89+
}
90+
7991
} // namespace executorch::backends::webgpu::utils

backends/webgpu/runtime/ops/quantized_linear/QuantizedLinear.cpp

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <executorch/backends/webgpu/runtime/WebGPUGraph.h>
1010
#include <executorch/backends/webgpu/runtime/WebGPUUtils.h>
1111
#include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h>
12+
#include <executorch/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_coop4_wgsl.h>
1213
#include <executorch/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_wgsl.h>
1314

1415
#include <webgpu/webgpu.h>
@@ -89,18 +90,6 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {
8990
"WebGPU linear_q4gsw: N*K_packed must be a multiple of 4 (u32-packed)");
9091
}
9192

92-
// Register-tiled GEMM: one thread per TM x TN tile; validate before alloc.
93-
const uint32_t wg_size =
94-
utils::clamp_workgroup_size(device, kQ4gswLinearWorkgroupSizeX);
95-
const int64_t total_tiles = utils::div_up<int64_t>(M, kQ4gswTileM) *
96-
utils::div_up<int64_t>(N, kQ4gswTileN);
97-
if (total_tiles > static_cast<int64_t>(UINT32_MAX)) {
98-
throw std::runtime_error(
99-
"WebGPU linear_q4gsw: tile count exceeds the 1D dispatch limit");
100-
}
101-
const uint32_t workgroup_count = utils::compute_1d_workgroup_count(
102-
device, static_cast<uint32_t>(total_tiles), wg_size, "linear_q4gsw");
103-
10493
// fp32-only byte-size guards (no runtime dtype); fp16 scales -> bail.
10594
const uint64_t scales_numel =
10695
static_cast<uint64_t>(num_groups) * static_cast<uint64_t>(padded_N);
@@ -128,6 +117,35 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {
128117
"WebGPU linear_q4gsw: scales dims too small for K/N");
129118
}
130119

120+
// M==1 decode -> coop4 GEMV (needs K%8==0 && gs%8==0); else tiled GEMM.
121+
const uint32_t wg_size =
122+
utils::clamp_workgroup_size(device, kQ4gswLinearWorkgroupSizeX);
123+
const bool use_gemv = (M == 1u && K % 8u == 0u && gs % 8u == 0u);
124+
const char* shader_src = use_gemv ? kQ4gswLinearCoop4WGSL : kQ4gswLinearWGSL;
125+
uint32_t workgroup_count;
126+
if (use_gemv) {
127+
// coop4: fixed 64 lanes, 1 workgroup per output, grid-strided over M*N.
128+
const uint64_t outputs =
129+
static_cast<uint64_t>(M) * static_cast<uint64_t>(N);
130+
if (outputs == 0u || outputs > UINT32_MAX) {
131+
throw std::runtime_error("WebGPU linear_q4gsw: M*N out of range");
132+
}
133+
workgroup_count =
134+
utils::clamp_workgroup_count(device, static_cast<uint32_t>(outputs));
135+
if (workgroup_count == 0u) {
136+
throw std::runtime_error("WebGPU linear_q4gsw: zero GEMV dispatch");
137+
}
138+
} else {
139+
const int64_t total_tiles = utils::div_up<int64_t>(M, kQ4gswTileM) *
140+
utils::div_up<int64_t>(N, kQ4gswTileN);
141+
if (total_tiles > static_cast<int64_t>(UINT32_MAX)) {
142+
throw std::runtime_error(
143+
"WebGPU linear_q4gsw: tile count exceeds the 1D dispatch limit");
144+
}
145+
workgroup_count = utils::compute_1d_workgroup_count(
146+
device, static_cast<uint32_t>(total_tiles), wg_size, "linear_q4gsw");
147+
}
148+
131149
// Optional bias: real buffer if present, else a dummy for the fixed layout.
132150
uint32_t has_bias = 0;
133151
WGPUBuffer bias_buffer = nullptr;
@@ -168,7 +186,7 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {
168186

169187
WGPUShaderSourceWGSL wgsl_desc = {};
170188
wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL;
171-
wgsl_desc.code = {kQ4gswLinearWGSL, WGPU_STRLEN};
189+
wgsl_desc.code = {shader_src, WGPU_STRLEN};
172190
WGPUShaderModuleDescriptor shader_desc = {};
173191
shader_desc.nextInChain = &wgsl_desc.chain;
174192
WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device, &shader_desc);
@@ -206,8 +224,9 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {
206224
pipeline_desc.layout = pipeline_layout;
207225
pipeline_desc.compute.module = shader;
208226
pipeline_desc.compute.entryPoint = {"main", WGPU_STRLEN};
209-
pipeline_desc.compute.constantCount = 1;
210-
pipeline_desc.compute.constants = &wg_size_constant;
227+
// coop4 GEMV uses fixed @workgroup_size(64); only the GEMM has an override.
228+
pipeline_desc.compute.constantCount = use_gemv ? 0u : 1u;
229+
pipeline_desc.compute.constants = use_gemv ? nullptr : &wg_size_constant;
211230
WGPUComputePipeline pipeline =
212231
wgpuDeviceCreateComputePipeline(device, &pipeline_desc);
213232

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
@group(0) @binding(0) var<storage, read_write> t_out: array<f32>;
2+
@group(0) @binding(1) var<storage, read> t_input: array<f32>;
3+
@group(0) @binding(2) var<storage, read> t_weight: array<u32>;
4+
@group(0) @binding(3) var<storage, read> t_scales: array<f32>;
5+
@group(0) @binding(4) var<storage, read> t_bias: array<f32>;
6+
7+
struct Params {
8+
M: u32,
9+
N: u32,
10+
K: u32,
11+
K_packed: u32,
12+
group_size: u32,
13+
padded_N: u32,
14+
has_bias: u32,
15+
_pad: u32,
16+
}
17+
@group(0) @binding(5) var<uniform> params: Params;
18+
19+
// Cooperative-over-K GEMV with u32-batched coalesced weight loads (64 lanes).
20+
const WG: u32 = 64u;
21+
var<workgroup> partial: array<f32, 64>;
22+
23+
@compute @workgroup_size(64, 1, 1)
24+
fn main(
25+
@builtin(workgroup_id) wid: vec3<u32>,
26+
@builtin(num_workgroups) ngrp: vec3<u32>,
27+
@builtin(local_invocation_id) lid: vec3<u32>) {
28+
let total = params.M * params.N;
29+
let stride = ngrp.x;
30+
let num_words = params.K >> 3u; // K / 8 words per row
31+
let row_words = params.K_packed >> 2u; // u32s per weight row (= K/8)
32+
var idx = wid.x;
33+
loop {
34+
if (idx >= total) {
35+
break;
36+
}
37+
let m = idx / params.N;
38+
let n = idx % params.N;
39+
let in_base = m * params.K;
40+
let wbase = n * row_words;
41+
42+
var acc: f32 = 0.0;
43+
var w: u32 = lid.x;
44+
loop {
45+
if (w >= num_words) {
46+
break;
47+
}
48+
let word = t_weight[wbase + w];
49+
let k0 = w << 3u; // first K of this word
50+
let scale = t_scales[(k0 / params.group_size) * params.padded_N + n];
51+
let ib = in_base + k0;
52+
// 4 bytes, low+high nibble each -> 8 consecutive K.
53+
for (var bi: u32 = 0u; bi < 4u; bi = bi + 1u) {
54+
let byte = (word >> (bi * 8u)) & 0xFFu;
55+
let lo = f32(i32(byte & 0x0Fu) - 8);
56+
let hi = f32(i32((byte >> 4u) & 0x0Fu) - 8);
57+
let kk = bi << 1u;
58+
acc = acc + t_input[ib + kk] * lo * scale;
59+
acc = acc + t_input[ib + kk + 1u] * hi * scale;
60+
}
61+
w = w + WG;
62+
}
63+
64+
partial[lid.x] = acc;
65+
workgroupBarrier();
66+
var s: u32 = WG >> 1u;
67+
loop {
68+
if (s == 0u) {
69+
break;
70+
}
71+
if (lid.x < s) {
72+
partial[lid.x] = partial[lid.x] + partial[lid.x + s];
73+
}
74+
workgroupBarrier();
75+
s = s >> 1u;
76+
}
77+
if (lid.x == 0u) {
78+
var o = partial[0];
79+
if (params.has_bias != 0u) {
80+
o = o + t_bias[n];
81+
}
82+
t_out[idx] = o;
83+
}
84+
workgroupBarrier();
85+
idx = idx + stride;
86+
}
87+
}
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <cstdint>
12+
13+
namespace executorch::backends::webgpu {
14+
15+
// @generated from q4gsw_linear_coop4.wgsl - DO NOT EDIT.
16+
// wgsl-sha256: 6e296f0583118d1ff0df914dd3ac078e7f4e526d99be7d233531a47fddb93f89
17+
inline constexpr const char* kQ4gswLinearCoop4WGSL = R"(
18+
@group(0) @binding(0) var<storage, read_write> t_out: array<f32>;
19+
@group(0) @binding(1) var<storage, read> t_input: array<f32>;
20+
@group(0) @binding(2) var<storage, read> t_weight: array<u32>;
21+
@group(0) @binding(3) var<storage, read> t_scales: array<f32>;
22+
@group(0) @binding(4) var<storage, read> t_bias: array<f32>;
23+
24+
struct Params {
25+
M: u32,
26+
N: u32,
27+
K: u32,
28+
K_packed: u32,
29+
group_size: u32,
30+
padded_N: u32,
31+
has_bias: u32,
32+
_pad: u32,
33+
}
34+
@group(0) @binding(5) var<uniform> params: Params;
35+
36+
// Cooperative-over-K GEMV with u32-batched coalesced weight loads (64 lanes).
37+
const WG: u32 = 64u;
38+
var<workgroup> partial: array<f32, 64>;
39+
40+
@compute @workgroup_size(64, 1, 1)
41+
fn main(
42+
@builtin(workgroup_id) wid: vec3<u32>,
43+
@builtin(num_workgroups) ngrp: vec3<u32>,
44+
@builtin(local_invocation_id) lid: vec3<u32>) {
45+
let total = params.M * params.N;
46+
let stride = ngrp.x;
47+
let num_words = params.K >> 3u; // K / 8 words per row
48+
let row_words = params.K_packed >> 2u; // u32s per weight row (= K/8)
49+
var idx = wid.x;
50+
loop {
51+
if (idx >= total) {
52+
break;
53+
}
54+
let m = idx / params.N;
55+
let n = idx % params.N;
56+
let in_base = m * params.K;
57+
let wbase = n * row_words;
58+
59+
var acc: f32 = 0.0;
60+
var w: u32 = lid.x;
61+
loop {
62+
if (w >= num_words) {
63+
break;
64+
}
65+
let word = t_weight[wbase + w];
66+
let k0 = w << 3u; // first K of this word
67+
let scale = t_scales[(k0 / params.group_size) * params.padded_N + n];
68+
let ib = in_base + k0;
69+
// 4 bytes, low+high nibble each -> 8 consecutive K.
70+
for (var bi: u32 = 0u; bi < 4u; bi = bi + 1u) {
71+
let byte = (word >> (bi * 8u)) & 0xFFu;
72+
let lo = f32(i32(byte & 0x0Fu) - 8);
73+
let hi = f32(i32((byte >> 4u) & 0x0Fu) - 8);
74+
let kk = bi << 1u;
75+
acc = acc + t_input[ib + kk] * lo * scale;
76+
acc = acc + t_input[ib + kk + 1u] * hi * scale;
77+
}
78+
w = w + WG;
79+
}
80+
81+
partial[lid.x] = acc;
82+
workgroupBarrier();
83+
var s: u32 = WG >> 1u;
84+
loop {
85+
if (s == 0u) {
86+
break;
87+
}
88+
if (lid.x < s) {
89+
partial[lid.x] = partial[lid.x] + partial[lid.x + s];
90+
}
91+
workgroupBarrier();
92+
s = s >> 1u;
93+
}
94+
if (lid.x == 0u) {
95+
var o = partial[0];
96+
if (params.has_bias != 0u) {
97+
o = o + t_bias[n];
98+
}
99+
t_out[idx] = o;
100+
}
101+
workgroupBarrier();
102+
idx = idx + stride;
103+
}
104+
}
105+
)";
106+
107+
inline constexpr uint32_t kQ4gswLinearCoop4WorkgroupSizeX = 64;
108+
inline constexpr uint32_t kQ4gswLinearCoop4WorkgroupSizeY = 1;
109+
inline constexpr uint32_t kQ4gswLinearCoop4WorkgroupSizeZ = 1;
110+
111+
} // namespace executorch::backends::webgpu

0 commit comments

Comments
 (0)