Skip to content

Commit d6ce898

Browse files
[ExecuTorch][WebGPU] Register-tile the q4gsw quantized-linear kernel
Differential Revision: D109250327 Pull Request resolved: #20456
1 parent 842b3cb commit d6ce898

4 files changed

Lines changed: 127 additions & 57 deletions

File tree

backends/webgpu/runtime/WebGPUUtils.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@
1818

1919
namespace executorch::backends::webgpu::utils {
2020

21+
// Ceiling division for non-negative integers (mirrors Vulkan's utils::div_up).
22+
template <typename T>
23+
inline T div_up(T a, T b) {
24+
return (a + b - 1) / b;
25+
}
26+
2127
// Clamp workgroup size to device limit (SwiftShader caps at 128).
2228
inline uint32_t clamp_workgroup_size(WGPUDevice device, uint32_t desired) {
2329
WGPULimits limits = {};
@@ -34,7 +40,7 @@ inline uint32_t compute_1d_workgroup_count(
3440
uint32_t num_threads,
3541
uint32_t workgroup_size,
3642
const char* op_name) {
37-
uint32_t count = (num_threads + workgroup_size - 1) / workgroup_size;
43+
uint32_t count = div_up(num_threads, workgroup_size);
3844
WGPULimits limits = {};
3945
uint32_t max_count =
4046
wgpuDeviceGetLimits(device, &limits) == WGPUStatus_Success &&

backends/webgpu/runtime/ops/quantized_linear/QuantizedLinear.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ struct Q4gswParams {
3434
};
3535
static_assert(sizeof(Q4gswParams) == 32, "Q4gswParams must be 32 bytes");
3636

37+
// Register-tile dims; MUST match TM/TN in q4gsw_linear.wgsl.
38+
constexpr int64_t kQ4gswTileM = 4;
39+
constexpr int64_t kQ4gswTileN = 4;
40+
3741
// et_vk.linear_q4gsw args: [in, weight, scales, group_size, bias, out].
3842
void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {
3943
const int in_id = args.at(0);
@@ -85,9 +89,17 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {
8589
"WebGPU linear_q4gsw: N*K_packed must be a multiple of 4 (u32-packed)");
8690
}
8791

88-
// One workgroup per output row (M); validate dispatch before any alloc.
89-
const uint32_t workgroup_count =
90-
utils::compute_1d_workgroup_count(device, M, 1, "linear_q4gsw");
92+
// Register-tiled GEMM: one thread per TM x TN tile; validate before alloc.
93+
const uint32_t wg_size =
94+
utils::clamp_workgroup_size(device, kQ4gswLinearWorkgroupSizeX);
95+
const int64_t total_tiles = utils::div_up<int64_t>(M, kQ4gswTileM) *
96+
utils::div_up<int64_t>(N, kQ4gswTileN);
97+
if (total_tiles > static_cast<int64_t>(UINT32_MAX)) {
98+
throw std::runtime_error(
99+
"WebGPU linear_q4gsw: tile count exceeds the 1D dispatch limit");
100+
}
101+
const uint32_t workgroup_count = utils::compute_1d_workgroup_count(
102+
device, static_cast<uint32_t>(total_tiles), wg_size, "linear_q4gsw");
91103

92104
// fp32-only byte-size guards (no runtime dtype); fp16 scales -> bail.
93105
const uint64_t scales_numel =
@@ -186,8 +198,6 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {
186198
WGPUPipelineLayout pipeline_layout =
187199
wgpuDeviceCreatePipelineLayout(device, &pl_desc);
188200

189-
const uint32_t wg_size =
190-
utils::clamp_workgroup_size(device, kQ4gswLinearWorkgroupSizeX);
191201
WGPUConstantEntry wg_size_constant = {};
192202
wg_size_constant.key = {"wg_size", WGPU_STRLEN};
193203
wg_size_constant.value = static_cast<double>(wg_size);

backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear.wgsl

Lines changed: 52 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,47 +18,74 @@ struct Params {
1818

1919
override wg_size: u32 = 64u;
2020

21-
// One workgroup per row m, threads stride N; loop logical K only (in-bounds).
21+
// Register-tiled GEMM: dequant weight once per (n,k), reused across TM rows.
22+
const TM: u32 = 4u;
23+
const TN: u32 = 4u;
24+
const TILE_ELEMS: u32 = TM * TN; // accumulator size; keeps acc in sync with TM/TN
25+
2226
@compute @workgroup_size(wg_size, 1, 1)
23-
fn main(
24-
@builtin(workgroup_id) wid: vec3<u32>,
25-
@builtin(local_invocation_id) lid: vec3<u32>) {
26-
let m = wid.x;
27-
if (m >= params.M) {
27+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
28+
let nrt = (params.M + TM - 1u) / TM;
29+
let nct = (params.N + TN - 1u) / TN;
30+
let tiles = nrt * nct;
31+
// M==0 or N==0 -> tiles==0 -> every thread returns here, so the M-1u/N-1u
32+
// clamps below never underflow (the host also rejects M==0/N==0).
33+
if (gid.x >= tiles) {
2834
return;
2935
}
30-
let in_base = m * params.K;
36+
let row_tile = gid.x / nct;
37+
let col_tile = gid.x % nct;
38+
let m0 = row_tile * TM;
39+
let n0 = col_tile * TN;
40+
41+
var acc: array<f32, TILE_ELEMS>;
42+
for (var i: u32 = 0u; i < TILE_ELEMS; i = i + 1u) {
43+
acc[i] = 0.0;
44+
}
3145

32-
var n: u32 = lid.x;
46+
var k: u32 = 0u;
3347
loop {
34-
if (n >= params.N) {
48+
if (k >= params.K) {
3549
break;
3650
}
37-
var acc: f32 = 0.0;
38-
var k: u32 = 0u;
39-
loop {
40-
if (k >= params.K) {
41-
break;
42-
}
43-
// Packed weight byte for (n, k): row stride K_packed bytes, byte k/2.
44-
let byte_idx = n * params.K_packed + (k >> 1u);
51+
// Load the TM input values for column k once; reused across all TN columns.
52+
var in_reg: array<f32, TM>;
53+
for (var ml: u32 = 0u; ml < TM; ml = ml + 1u) {
54+
let m_eff = min(m0 + ml, params.M - 1u);
55+
in_reg[ml] = t_input[m_eff * params.K + k];
56+
}
57+
for (var nl: u32 = 0u; nl < TN; nl = nl + 1u) {
58+
// Clamp to last valid column; overhang result is never stored.
59+
let n_eff = min(n0 + nl, params.N - 1u);
60+
let byte_idx = n_eff * params.K_packed + (k >> 1u);
4561
let word = t_weight[byte_idx >> 2u];
4662
let b = (word >> ((byte_idx & 3u) * 8u)) & 0xFFu;
4763
var nib: u32;
4864
if ((k & 1u) == 0u) {
49-
nib = b & 0x0Fu; // even k -> low nibble
65+
nib = b & 0x0Fu; // even k -> low nibble
5066
} else {
5167
nib = (b >> 4u) & 0x0Fu; // odd k -> high nibble
5268
}
5369
let q = f32(i32(nib) - 8); // +8-shifted on pack; recover signed [-8,7]
54-
let scale = t_scales[(k / params.group_size) * params.padded_N + n];
55-
acc = acc + t_input[in_base + k] * q * scale;
56-
k = k + 1u;
70+
let dq = q * t_scales[(k / params.group_size) * params.padded_N + n_eff];
71+
for (var ml: u32 = 0u; ml < TM; ml = ml + 1u) {
72+
acc[ml * TN + nl] = acc[ml * TN + nl] + in_reg[ml] * dq;
73+
}
5774
}
58-
if (params.has_bias != 0u) {
59-
acc = acc + t_bias[n];
75+
k = k + 1u;
76+
}
77+
78+
for (var ml: u32 = 0u; ml < TM; ml = ml + 1u) {
79+
let m = m0 + ml;
80+
for (var nl: u32 = 0u; nl < TN; nl = nl + 1u) {
81+
let n = n0 + nl;
82+
if (m < params.M && n < params.N) {
83+
var v = acc[ml * TN + nl];
84+
if (params.has_bias != 0u) {
85+
v = v + t_bias[n];
86+
}
87+
t_out[m * params.N + n] = v;
88+
}
6089
}
61-
t_out[m * params.N + n] = acc;
62-
n = n + wg_size;
6390
}
6491
}

backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_wgsl.h

Lines changed: 53 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
namespace executorch::backends::webgpu {
1414

1515
// @generated from q4gsw_linear.wgsl - DO NOT EDIT.
16-
// wgsl-sha256: 966cec5d4102eb7c8f6504d2a335a1bd2f235424933fe83b4d0f8f274d894f39
16+
// wgsl-sha256: dc6a55014ae4543bd80e5e22c3fb52896aca96e0589f700803327d8121ada489
1717
inline constexpr const char* kQ4gswLinearWGSL = R"(
1818
@group(0) @binding(0) var<storage, read_write> t_out: array<f32>;
1919
@group(0) @binding(1) var<storage, read> t_input: array<f32>;
@@ -35,48 +35,75 @@ struct Params {
3535
3636
override wg_size: u32 = 64u;
3737
38-
// One workgroup per row m, threads stride N; loop logical K only (in-bounds).
38+
// Register-tiled GEMM: dequant weight once per (n,k), reused across TM rows.
39+
const TM: u32 = 4u;
40+
const TN: u32 = 4u;
41+
const TILE_ELEMS: u32 = TM * TN; // accumulator size; keeps acc in sync with TM/TN
42+
3943
@compute @workgroup_size(wg_size, 1, 1)
40-
fn main(
41-
@builtin(workgroup_id) wid: vec3<u32>,
42-
@builtin(local_invocation_id) lid: vec3<u32>) {
43-
let m = wid.x;
44-
if (m >= params.M) {
44+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
45+
let nrt = (params.M + TM - 1u) / TM;
46+
let nct = (params.N + TN - 1u) / TN;
47+
let tiles = nrt * nct;
48+
// M==0 or N==0 -> tiles==0 -> every thread returns here, so the M-1u/N-1u
49+
// clamps below never underflow (the host also rejects M==0/N==0).
50+
if (gid.x >= tiles) {
4551
return;
4652
}
47-
let in_base = m * params.K;
53+
let row_tile = gid.x / nct;
54+
let col_tile = gid.x % nct;
55+
let m0 = row_tile * TM;
56+
let n0 = col_tile * TN;
57+
58+
var acc: array<f32, TILE_ELEMS>;
59+
for (var i: u32 = 0u; i < TILE_ELEMS; i = i + 1u) {
60+
acc[i] = 0.0;
61+
}
4862
49-
var n: u32 = lid.x;
63+
var k: u32 = 0u;
5064
loop {
51-
if (n >= params.N) {
65+
if (k >= params.K) {
5266
break;
5367
}
54-
var acc: f32 = 0.0;
55-
var k: u32 = 0u;
56-
loop {
57-
if (k >= params.K) {
58-
break;
59-
}
60-
// Packed weight byte for (n, k): row stride K_packed bytes, byte k/2.
61-
let byte_idx = n * params.K_packed + (k >> 1u);
68+
// Load the TM input values for column k once; reused across all TN columns.
69+
var in_reg: array<f32, TM>;
70+
for (var ml: u32 = 0u; ml < TM; ml = ml + 1u) {
71+
let m_eff = min(m0 + ml, params.M - 1u);
72+
in_reg[ml] = t_input[m_eff * params.K + k];
73+
}
74+
for (var nl: u32 = 0u; nl < TN; nl = nl + 1u) {
75+
// Clamp to last valid column; overhang result is never stored.
76+
let n_eff = min(n0 + nl, params.N - 1u);
77+
let byte_idx = n_eff * params.K_packed + (k >> 1u);
6278
let word = t_weight[byte_idx >> 2u];
6379
let b = (word >> ((byte_idx & 3u) * 8u)) & 0xFFu;
6480
var nib: u32;
6581
if ((k & 1u) == 0u) {
66-
nib = b & 0x0Fu; // even k -> low nibble
82+
nib = b & 0x0Fu; // even k -> low nibble
6783
} else {
6884
nib = (b >> 4u) & 0x0Fu; // odd k -> high nibble
6985
}
7086
let q = f32(i32(nib) - 8); // +8-shifted on pack; recover signed [-8,7]
71-
let scale = t_scales[(k / params.group_size) * params.padded_N + n];
72-
acc = acc + t_input[in_base + k] * q * scale;
73-
k = k + 1u;
87+
let dq = q * t_scales[(k / params.group_size) * params.padded_N + n_eff];
88+
for (var ml: u32 = 0u; ml < TM; ml = ml + 1u) {
89+
acc[ml * TN + nl] = acc[ml * TN + nl] + in_reg[ml] * dq;
90+
}
7491
}
75-
if (params.has_bias != 0u) {
76-
acc = acc + t_bias[n];
92+
k = k + 1u;
93+
}
94+
95+
for (var ml: u32 = 0u; ml < TM; ml = ml + 1u) {
96+
let m = m0 + ml;
97+
for (var nl: u32 = 0u; nl < TN; nl = nl + 1u) {
98+
let n = n0 + nl;
99+
if (m < params.M && n < params.N) {
100+
var v = acc[ml * TN + nl];
101+
if (params.has_bias != 0u) {
102+
v = v + t_bias[n];
103+
}
104+
t_out[m * params.N + n] = v;
105+
}
77106
}
78-
t_out[m * params.N + n] = acc;
79-
n = n + wg_size;
80107
}
81108
}
82109
)";

0 commit comments

Comments
 (0)