Skip to content

Commit a1afd65

Browse files
Update (base update)
[ghstack-poisoned]
1 parent 441d9ba commit a1afd65

9 files changed

Lines changed: 109 additions & 366 deletions

File tree

backends/webgpu/runtime/WebGPUUtils.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@
1818

1919
namespace executorch::backends::webgpu::utils {
2020

21+
// Ceiling division for non-negative integers (mirrors Vulkan's utils::div_up).
22+
template <typename T>
23+
inline T div_up(T a, T b) {
24+
return (a + b - 1) / b;
25+
}
26+
2127
// Clamp workgroup size to device limit (SwiftShader caps at 128).
2228
inline uint32_t clamp_workgroup_size(WGPUDevice device, uint32_t desired) {
2329
WGPULimits limits = {};
@@ -34,7 +40,7 @@ inline uint32_t compute_1d_workgroup_count(
3440
uint32_t num_threads,
3541
uint32_t workgroup_size,
3642
const char* op_name) {
37-
uint32_t count = (num_threads + workgroup_size - 1) / workgroup_size;
43+
uint32_t count = div_up(num_threads, workgroup_size);
3844
WGPULimits limits = {};
3945
uint32_t max_count =
4046
wgpuDeviceGetLimits(device, &limits) == WGPUStatus_Success &&

backends/webgpu/runtime/ops/quantized_linear/QuantizedLinear.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,6 @@ static_assert(sizeof(Q4gswParams) == 32, "Q4gswParams must be 32 bytes");
3737
// Register-tile dims; MUST match TM/TN in q4gsw_linear.wgsl.
3838
constexpr int64_t kQ4gswTileM = 4;
3939
constexpr int64_t kQ4gswTileN = 4;
40-
// ceil(a/b) for positive int64 (WebGPUUtils has no ceil-div helper).
41-
inline int64_t q4gsw_ceil_div(int64_t a, int64_t b) {
42-
return (a + b - 1) / b;
43-
}
4440

4541
// et_vk.linear_q4gsw args: [in, weight, scales, group_size, bias, out].
4642
void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {
@@ -96,8 +92,8 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {
9692
// Register-tiled GEMM: one thread per TM x TN tile; validate before alloc.
9793
const uint32_t wg_size =
9894
utils::clamp_workgroup_size(device, kQ4gswLinearWorkgroupSizeX);
99-
const int64_t total_tiles =
100-
q4gsw_ceil_div(M, kQ4gswTileM) * q4gsw_ceil_div(N, kQ4gswTileN);
95+
const int64_t total_tiles = utils::div_up<int64_t>(M, kQ4gswTileM) *
96+
utils::div_up<int64_t>(N, kQ4gswTileN);
10197
if (total_tiles > static_cast<int64_t>(UINT32_MAX)) {
10298
throw std::runtime_error(
10399
"WebGPU linear_q4gsw: tile count exceeds the 1D dispatch limit");

backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear.wgsl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,15 @@ override wg_size: u32 = 64u;
2121
// Register-tiled GEMM: dequant weight once per (n,k), reused across TM rows.
2222
const TM: u32 = 4u;
2323
const TN: u32 = 4u;
24+
const TILE_ELEMS: u32 = TM * TN; // accumulator size; keeps acc in sync with TM/TN
2425

2526
@compute @workgroup_size(wg_size, 1, 1)
2627
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
2728
let nrt = (params.M + TM - 1u) / TM;
2829
let nct = (params.N + TN - 1u) / TN;
2930
let tiles = nrt * nct;
31+
// M==0 or N==0 -> tiles==0 -> every thread returns here, so the M-1u/N-1u
32+
// clamps below never underflow (the host also rejects M==0/N==0).
3033
if (gid.x >= tiles) {
3134
return;
3235
}
@@ -35,8 +38,8 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
3538
let m0 = row_tile * TM;
3639
let n0 = col_tile * TN;
3740

38-
var acc: array<f32, 16>; // TM * TN
39-
for (var i: u32 = 0u; i < TM * TN; i = i + 1u) {
41+
var acc: array<f32, TILE_ELEMS>;
42+
for (var i: u32 = 0u; i < TILE_ELEMS; i = i + 1u) {
4043
acc[i] = 0.0;
4144
}
4245

@@ -45,6 +48,12 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
4548
if (k >= params.K) {
4649
break;
4750
}
51+
// Load the TM input values for column k once; reused across all TN columns.
52+
var in_reg: array<f32, TM>;
53+
for (var ml: u32 = 0u; ml < TM; ml = ml + 1u) {
54+
let m_eff = min(m0 + ml, params.M - 1u);
55+
in_reg[ml] = t_input[m_eff * params.K + k];
56+
}
4857
for (var nl: u32 = 0u; nl < TN; nl = nl + 1u) {
4958
// Clamp to last valid column; overhang result is never stored.
5059
let n_eff = min(n0 + nl, params.N - 1u);
@@ -60,8 +69,7 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
6069
let q = f32(i32(nib) - 8); // +8-shifted on pack; recover signed [-8,7]
6170
let dq = q * t_scales[(k / params.group_size) * params.padded_N + n_eff];
6271
for (var ml: u32 = 0u; ml < TM; ml = ml + 1u) {
63-
let m_eff = min(m0 + ml, params.M - 1u);
64-
acc[ml * TN + nl] = acc[ml * TN + nl] + t_input[m_eff * params.K + k] * dq;
72+
acc[ml * TN + nl] = acc[ml * TN + nl] + in_reg[ml] * dq;
6573
}
6674
}
6775
k = k + 1u;

backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_wgsl.h

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
namespace executorch::backends::webgpu {
1414

1515
// @generated from q4gsw_linear.wgsl - DO NOT EDIT.
16-
// wgsl-sha256: f0fd0371418fdacd3387645888689caf86a387a623ed08f8337610e30f844ede
16+
// wgsl-sha256: dc6a55014ae4543bd80e5e22c3fb52896aca96e0589f700803327d8121ada489
1717
inline constexpr const char* kQ4gswLinearWGSL = R"(
1818
@group(0) @binding(0) var<storage, read_write> t_out: array<f32>;
1919
@group(0) @binding(1) var<storage, read> t_input: array<f32>;
@@ -38,12 +38,15 @@ override wg_size: u32 = 64u;
3838
// Register-tiled GEMM: dequant weight once per (n,k), reused across TM rows.
3939
const TM: u32 = 4u;
4040
const TN: u32 = 4u;
41+
const TILE_ELEMS: u32 = TM * TN; // accumulator size; keeps acc in sync with TM/TN
4142
4243
@compute @workgroup_size(wg_size, 1, 1)
4344
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
4445
let nrt = (params.M + TM - 1u) / TM;
4546
let nct = (params.N + TN - 1u) / TN;
4647
let tiles = nrt * nct;
48+
// M==0 or N==0 -> tiles==0 -> every thread returns here, so the M-1u/N-1u
49+
// clamps below never underflow (the host also rejects M==0/N==0).
4750
if (gid.x >= tiles) {
4851
return;
4952
}
@@ -52,8 +55,8 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
5255
let m0 = row_tile * TM;
5356
let n0 = col_tile * TN;
5457
55-
var acc: array<f32, 16>; // TM * TN
56-
for (var i: u32 = 0u; i < TM * TN; i = i + 1u) {
58+
var acc: array<f32, TILE_ELEMS>;
59+
for (var i: u32 = 0u; i < TILE_ELEMS; i = i + 1u) {
5760
acc[i] = 0.0;
5861
}
5962
@@ -62,6 +65,12 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
6265
if (k >= params.K) {
6366
break;
6467
}
68+
// Load the TM input values for column k once; reused across all TN columns.
69+
var in_reg: array<f32, TM>;
70+
for (var ml: u32 = 0u; ml < TM; ml = ml + 1u) {
71+
let m_eff = min(m0 + ml, params.M - 1u);
72+
in_reg[ml] = t_input[m_eff * params.K + k];
73+
}
6574
for (var nl: u32 = 0u; nl < TN; nl = nl + 1u) {
6675
// Clamp to last valid column; overhang result is never stored.
6776
let n_eff = min(n0 + nl, params.N - 1u);
@@ -77,8 +86,7 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
7786
let q = f32(i32(nib) - 8); // +8-shifted on pack; recover signed [-8,7]
7887
let dq = q * t_scales[(k / params.group_size) * params.padded_N + n_eff];
7988
for (var ml: u32 = 0u; ml < TM; ml = ml + 1u) {
80-
let m_eff = min(m0 + ml, params.M - 1u);
81-
acc[ml * TN + nl] = acc[ml * TN + nl] + t_input[m_eff * params.K + k] * dq;
89+
acc[ml * TN + nl] = acc[ml * TN + nl] + in_reg[ml] * dq;
8290
}
8391
}
8492
k = k + 1u;

backends/webgpu/runtime/ops/sdpa/Sdpa.cpp

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,6 @@ namespace executorch::backends::webgpu {
2626

2727
namespace {
2828

29-
// Register-tile dims; MUST match TM/TN in the reg WGSL kernels.
30-
constexpr int64_t kSdpaTileM = 4;
31-
constexpr int64_t kSdpaTileN = 4;
32-
inline int64_t sdpa_ceil_div(int64_t a, int64_t b) {
33-
return (a + b - 1) / b;
34-
}
35-
3629
// Uniform param structs (all 16-byte aligned, matching the WGSL Params).
3730
struct UpdateCacheParams {
3831
uint32_t numel;
@@ -471,16 +464,14 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
471464
dynamic_pos,
472465
"update_cache(V)");
473466

474-
// --- Dispatch 3: QK -> attn_weights. One thread per TM x TN tile.
467+
// --- Dispatch 3: QK -> attn_weights. One thread per (h,s,c) element.
475468
{
476469
if (aw_floats > UINT32_MAX) {
477470
throw std::runtime_error(
478471
"WebGPU sdpa: Hq*S*context_len exceeds uint32 max");
479472
}
480-
const int64_t qk_tiles = Hq * sdpa_ceil_div(S, kSdpaTileM) *
481-
sdpa_ceil_div(context_len, kSdpaTileN);
482473
const uint32_t wgc = utils::compute_1d_workgroup_count(
483-
device, static_cast<uint32_t>(qk_tiles), qk_wg, "QK");
474+
device, static_cast<uint32_t>(aw_floats), qk_wg, "QK");
484475
AttnWeightsParams p = make_attn_weights_params(
485476
S, Hq, Hkv, D, context_len, input_pos, g, scale);
486477
WGPUBuffer ubuf = make_uniform_buffer(graph, &p, sizeof(p));
@@ -524,12 +515,12 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
524515
softmax_buf = ubuf;
525516
}
526517

527-
// --- Dispatch 5: AV -> out. One thread per TM x TN tile.
518+
// --- Dispatch 5: AV -> out. One thread per (s,h,d) output element.
528519
{
529-
const int64_t av_tiles =
530-
Hq * sdpa_ceil_div(S, kSdpaTileM) * sdpa_ceil_div(D, kSdpaTileN);
520+
const uint64_t out_floats = static_cast<uint64_t>(S) *
521+
static_cast<uint64_t>(Hq) * static_cast<uint64_t>(D);
531522
const uint32_t wgc = utils::compute_1d_workgroup_count(
532-
device, static_cast<uint32_t>(av_tiles), av_wg, "AV");
523+
device, static_cast<uint32_t>(out_floats), av_wg, "AV");
533524
ComputeOutParams p = make_compute_out_params(S, Hq, Hkv, D, context_len, g);
534525
WGPUBuffer ubuf = make_uniform_buffer(graph, &p, sizeof(p));
535526
BufferBinding bindings[3] = {
@@ -600,11 +591,9 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
600591
AttnWeightsParams qp =
601592
make_attn_weights_params(S, Hq, Hkv, D, ctx, pos, g, scale);
602593
wgpuQueueWriteBuffer(gr.queue(), qk_buf, 0, &qp, sizeof(qp));
603-
const int64_t qk_tiles = Hq * sdpa_ceil_div(S, kSdpaTileM) *
604-
sdpa_ceil_div(ctx, kSdpaTileN);
605594
const uint32_t qk_wgc = utils::compute_1d_workgroup_count(
606595
gr.device(),
607-
static_cast<uint32_t>(qk_tiles),
596+
static_cast<uint32_t>(aw_floats),
608597
qk_wg,
609598
"QK(resize)");
610599
gr.dispatch_at(qk_idx).workgroup_count_x = qk_wgc;

backends/webgpu/runtime/ops/sdpa/sdpa_compute_attn_weights.wgsl

Lines changed: 19 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -19,102 +19,37 @@ const NEG_INF: f32 = -1.0e30;
1919

2020
override wg_size: u32 = 64;
2121

22-
const TM: u32 = 4u;
23-
const TN: u32 = 4u;
24-
25-
fn load_q_vec4(s: u32, h: u32, d4: u32) -> vec4<f32> {
26-
var r = vec4<f32>(0.0, 0.0, 0.0, 0.0);
27-
if (s >= params.S) {
28-
return r;
29-
}
30-
let base = s * params.Hq * params.D + h * params.D;
31-
if (d4 + 0u < params.D) { r.x = t_q[base + d4 + 0u]; }
32-
if (d4 + 1u < params.D) { r.y = t_q[base + d4 + 1u]; }
33-
if (d4 + 2u < params.D) { r.z = t_q[base + d4 + 2u]; }
34-
if (d4 + 3u < params.D) { r.w = t_q[base + d4 + 3u]; }
35-
return r;
36-
}
37-
38-
fn load_k_vec4(c: u32, kvh: u32, d4: u32) -> vec4<f32> {
39-
var r = vec4<f32>(0.0, 0.0, 0.0, 0.0);
40-
if (c >= params.context_len) {
41-
return r;
42-
}
43-
let base = c * params.Hkv * params.D + kvh * params.D;
44-
if (d4 + 0u < params.D) { r.x = t_k_cache[base + d4 + 0u]; }
45-
if (d4 + 1u < params.D) { r.y = t_k_cache[base + d4 + 1u]; }
46-
if (d4 + 2u < params.D) { r.z = t_k_cache[base + d4 + 2u]; }
47-
if (d4 + 3u < params.D) { r.w = t_k_cache[base + d4 + 3u]; }
48-
return r;
49-
}
50-
51-
fn store_qk(s: u32, c: u32, h: u32, raw: f32) {
52-
if (s >= params.S || c >= params.context_len) {
53-
return;
54-
}
55-
var val = raw * params.scale;
56-
// Causal mask: position c may not attend beyond s + input_pos.
57-
if (c > s + params.input_pos) {
58-
val = NEG_INF;
59-
}
60-
let idx = h * params.S * params.context_len + s * params.context_len + c;
61-
t_attn_weights[idx] = val;
62-
}
63-
6422
@compute @workgroup_size(wg_size, 1, 1)
6523
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
66-
let nrt = (params.S + TM - 1u) / TM;
67-
let nct = (params.context_len + TN - 1u) / TN;
68-
let tiles = nrt * nct;
69-
let total = tiles * params.Hq;
70-
if (gid.x >= total) {
24+
let total = params.Hq * params.S * params.context_len;
25+
let idx = gid.x;
26+
if (idx >= total) {
7127
return;
7228
}
29+
let c = idx % params.context_len;
30+
let s = (idx / params.context_len) % params.S;
31+
let h = idx / (params.context_len * params.S);
7332

74-
let h = gid.x / tiles;
75-
let rem = gid.x % tiles;
76-
let row_tile = rem / nct;
77-
let col_tile = rem % nct;
7833
let kvh = h / params.g;
79-
let s0 = row_tile * TM;
80-
let c0 = col_tile * TN;
8134

82-
var acc: array<vec4<f32>, 4>;
83-
acc[0] = vec4<f32>(0.0, 0.0, 0.0, 0.0);
84-
acc[1] = vec4<f32>(0.0, 0.0, 0.0, 0.0);
85-
acc[2] = vec4<f32>(0.0, 0.0, 0.0, 0.0);
86-
acc[3] = vec4<f32>(0.0, 0.0, 0.0, 0.0);
35+
let q_base = s * params.Hq * params.D + h * params.D;
36+
let k_base = c * params.Hkv * params.D + kvh * params.D;
8737

88-
var d4: u32 = 0u;
38+
var acc: f32 = 0.0;
39+
var d: u32 = 0u;
8940
loop {
90-
if (d4 >= params.D) {
41+
if (d >= params.D) {
9142
break;
9243
}
93-
let q0 = load_q_vec4(s0 + 0u, h, d4);
94-
let q1 = load_q_vec4(s0 + 1u, h, d4);
95-
let q2 = load_q_vec4(s0 + 2u, h, d4);
96-
let q3 = load_q_vec4(s0 + 3u, h, d4);
97-
let k0 = load_k_vec4(c0 + 0u, kvh, d4);
98-
let k1 = load_k_vec4(c0 + 1u, kvh, d4);
99-
let k2 = load_k_vec4(c0 + 2u, kvh, d4);
100-
let k3 = load_k_vec4(c0 + 3u, kvh, d4);
101-
acc[0] += vec4<f32>(dot(q0, k0), dot(q0, k1), dot(q0, k2), dot(q0, k3));
102-
acc[1] += vec4<f32>(dot(q1, k0), dot(q1, k1), dot(q1, k2), dot(q1, k3));
103-
acc[2] += vec4<f32>(dot(q2, k0), dot(q2, k1), dot(q2, k2), dot(q2, k3));
104-
acc[3] += vec4<f32>(dot(q3, k0), dot(q3, k1), dot(q3, k2), dot(q3, k3));
105-
d4 = d4 + 4u;
44+
acc = acc + t_q[q_base + d] * t_k_cache[k_base + d];
45+
d = d + 1u;
10646
}
47+
acc = acc * params.scale;
10748

108-
var m: u32 = 0u;
109-
loop {
110-
if (m >= TM) {
111-
break;
112-
}
113-
let av = acc[m];
114-
store_qk(s0 + m, c0 + 0u, h, av.x);
115-
store_qk(s0 + m, c0 + 1u, h, av.y);
116-
store_qk(s0 + m, c0 + 2u, h, av.z);
117-
store_qk(s0 + m, c0 + 3u, h, av.w);
118-
m = m + 1u;
49+
// Causal mask: position c may not attend beyond s + input_pos.
50+
if (c > s + params.input_pos) {
51+
acc = NEG_INF;
11952
}
53+
54+
t_attn_weights[idx] = acc;
12055
}

0 commit comments

Comments
 (0)