Skip to content

Commit 441d9ba

Browse files
Update (base update)
[ghstack-poisoned]
1 parent 0e65ba6 commit 441d9ba

8 files changed

Lines changed: 458 additions & 131 deletions

File tree

backends/webgpu/runtime/ops/quantized_linear/QuantizedLinear.cpp

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@ struct Q4gswParams {
3434
};
3535
static_assert(sizeof(Q4gswParams) == 32, "Q4gswParams must be 32 bytes");
3636

37+
// Register-tile dims; MUST match TM/TN in q4gsw_linear.wgsl.
38+
constexpr int64_t kQ4gswTileM = 4;
39+
constexpr int64_t kQ4gswTileN = 4;
40+
// ceil(a/b) for positive int64 (WebGPUUtils has no ceil-div helper).
41+
inline int64_t q4gsw_ceil_div(int64_t a, int64_t b) {
42+
return (a + b - 1) / b;
43+
}
44+
3745
// et_vk.linear_q4gsw args: [in, weight, scales, group_size, bias, out].
3846
void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {
3947
const int in_id = args.at(0);
@@ -85,9 +93,17 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {
8593
"WebGPU linear_q4gsw: N*K_packed must be a multiple of 4 (u32-packed)");
8694
}
8795

88-
// One workgroup per output row (M); validate dispatch before any alloc.
89-
const uint32_t workgroup_count =
90-
utils::compute_1d_workgroup_count(device, M, 1, "linear_q4gsw");
96+
// Register-tiled GEMM: one thread per TM x TN tile; validate before alloc.
97+
const uint32_t wg_size =
98+
utils::clamp_workgroup_size(device, kQ4gswLinearWorkgroupSizeX);
99+
const int64_t total_tiles =
100+
q4gsw_ceil_div(M, kQ4gswTileM) * q4gsw_ceil_div(N, kQ4gswTileN);
101+
if (total_tiles > static_cast<int64_t>(UINT32_MAX)) {
102+
throw std::runtime_error(
103+
"WebGPU linear_q4gsw: tile count exceeds the 1D dispatch limit");
104+
}
105+
const uint32_t workgroup_count = utils::compute_1d_workgroup_count(
106+
device, static_cast<uint32_t>(total_tiles), wg_size, "linear_q4gsw");
91107

92108
// fp32-only byte-size guards (no runtime dtype); fp16 scales -> bail.
93109
const uint64_t scales_numel =
@@ -186,8 +202,6 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {
186202
WGPUPipelineLayout pipeline_layout =
187203
wgpuDeviceCreatePipelineLayout(device, &pl_desc);
188204

189-
const uint32_t wg_size =
190-
utils::clamp_workgroup_size(device, kQ4gswLinearWorkgroupSizeX);
191205
WGPUConstantEntry wg_size_constant = {};
192206
wg_size_constant.key = {"wg_size", WGPU_STRLEN};
193207
wg_size_constant.value = static_cast<double>(wg_size);

backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear.wgsl

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,47 +18,66 @@ struct Params {
1818

1919
override wg_size: u32 = 64u;
2020

21-
// One workgroup per row m, threads stride N; loop logical K only (in-bounds).
21+
// Register-tiled GEMM: dequant weight once per (n,k), reused across TM rows.
22+
const TM: u32 = 4u;
23+
const TN: u32 = 4u;
24+
2225
@compute @workgroup_size(wg_size, 1, 1)
23-
fn main(
24-
@builtin(workgroup_id) wid: vec3<u32>,
25-
@builtin(local_invocation_id) lid: vec3<u32>) {
26-
let m = wid.x;
27-
if (m >= params.M) {
26+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
27+
let nrt = (params.M + TM - 1u) / TM;
28+
let nct = (params.N + TN - 1u) / TN;
29+
let tiles = nrt * nct;
30+
if (gid.x >= tiles) {
2831
return;
2932
}
30-
let in_base = m * params.K;
33+
let row_tile = gid.x / nct;
34+
let col_tile = gid.x % nct;
35+
let m0 = row_tile * TM;
36+
let n0 = col_tile * TN;
37+
38+
var acc: array<f32, 16>; // TM * TN
39+
for (var i: u32 = 0u; i < TM * TN; i = i + 1u) {
40+
acc[i] = 0.0;
41+
}
3142

32-
var n: u32 = lid.x;
43+
var k: u32 = 0u;
3344
loop {
34-
if (n >= params.N) {
45+
if (k >= params.K) {
3546
break;
3647
}
37-
var acc: f32 = 0.0;
38-
var k: u32 = 0u;
39-
loop {
40-
if (k >= params.K) {
41-
break;
42-
}
43-
// Packed weight byte for (n, k): row stride K_packed bytes, byte k/2.
44-
let byte_idx = n * params.K_packed + (k >> 1u);
48+
for (var nl: u32 = 0u; nl < TN; nl = nl + 1u) {
49+
// Clamp to last valid column; overhang result is never stored.
50+
let n_eff = min(n0 + nl, params.N - 1u);
51+
let byte_idx = n_eff * params.K_packed + (k >> 1u);
4552
let word = t_weight[byte_idx >> 2u];
4653
let b = (word >> ((byte_idx & 3u) * 8u)) & 0xFFu;
4754
var nib: u32;
4855
if ((k & 1u) == 0u) {
49-
nib = b & 0x0Fu; // even k -> low nibble
56+
nib = b & 0x0Fu; // even k -> low nibble
5057
} else {
5158
nib = (b >> 4u) & 0x0Fu; // odd k -> high nibble
5259
}
5360
let q = f32(i32(nib) - 8); // +8-shifted on pack; recover signed [-8,7]
54-
let scale = t_scales[(k / params.group_size) * params.padded_N + n];
55-
acc = acc + t_input[in_base + k] * q * scale;
56-
k = k + 1u;
61+
let dq = q * t_scales[(k / params.group_size) * params.padded_N + n_eff];
62+
for (var ml: u32 = 0u; ml < TM; ml = ml + 1u) {
63+
let m_eff = min(m0 + ml, params.M - 1u);
64+
acc[ml * TN + nl] = acc[ml * TN + nl] + t_input[m_eff * params.K + k] * dq;
65+
}
5766
}
58-
if (params.has_bias != 0u) {
59-
acc = acc + t_bias[n];
67+
k = k + 1u;
68+
}
69+
70+
for (var ml: u32 = 0u; ml < TM; ml = ml + 1u) {
71+
let m = m0 + ml;
72+
for (var nl: u32 = 0u; nl < TN; nl = nl + 1u) {
73+
let n = n0 + nl;
74+
if (m < params.M && n < params.N) {
75+
var v = acc[ml * TN + nl];
76+
if (params.has_bias != 0u) {
77+
v = v + t_bias[n];
78+
}
79+
t_out[m * params.N + n] = v;
80+
}
6081
}
61-
t_out[m * params.N + n] = acc;
62-
n = n + wg_size;
6382
}
6483
}

backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_wgsl.h

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
namespace executorch::backends::webgpu {
1414

1515
// @generated from q4gsw_linear.wgsl - DO NOT EDIT.
16-
// wgsl-sha256: 966cec5d4102eb7c8f6504d2a335a1bd2f235424933fe83b4d0f8f274d894f39
16+
// wgsl-sha256: f0fd0371418fdacd3387645888689caf86a387a623ed08f8337610e30f844ede
1717
inline constexpr const char* kQ4gswLinearWGSL = R"(
1818
@group(0) @binding(0) var<storage, read_write> t_out: array<f32>;
1919
@group(0) @binding(1) var<storage, read> t_input: array<f32>;
@@ -35,48 +35,67 @@ struct Params {
3535
3636
override wg_size: u32 = 64u;
3737
38-
// One workgroup per row m, threads stride N; loop logical K only (in-bounds).
38+
// Register-tiled GEMM: dequant weight once per (n,k), reused across TM rows.
39+
const TM: u32 = 4u;
40+
const TN: u32 = 4u;
41+
3942
@compute @workgroup_size(wg_size, 1, 1)
40-
fn main(
41-
@builtin(workgroup_id) wid: vec3<u32>,
42-
@builtin(local_invocation_id) lid: vec3<u32>) {
43-
let m = wid.x;
44-
if (m >= params.M) {
43+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
44+
let nrt = (params.M + TM - 1u) / TM;
45+
let nct = (params.N + TN - 1u) / TN;
46+
let tiles = nrt * nct;
47+
if (gid.x >= tiles) {
4548
return;
4649
}
47-
let in_base = m * params.K;
50+
let row_tile = gid.x / nct;
51+
let col_tile = gid.x % nct;
52+
let m0 = row_tile * TM;
53+
let n0 = col_tile * TN;
54+
55+
var acc: array<f32, 16>; // TM * TN
56+
for (var i: u32 = 0u; i < TM * TN; i = i + 1u) {
57+
acc[i] = 0.0;
58+
}
4859
49-
var n: u32 = lid.x;
60+
var k: u32 = 0u;
5061
loop {
51-
if (n >= params.N) {
62+
if (k >= params.K) {
5263
break;
5364
}
54-
var acc: f32 = 0.0;
55-
var k: u32 = 0u;
56-
loop {
57-
if (k >= params.K) {
58-
break;
59-
}
60-
// Packed weight byte for (n, k): row stride K_packed bytes, byte k/2.
61-
let byte_idx = n * params.K_packed + (k >> 1u);
65+
for (var nl: u32 = 0u; nl < TN; nl = nl + 1u) {
66+
// Clamp to last valid column; overhang result is never stored.
67+
let n_eff = min(n0 + nl, params.N - 1u);
68+
let byte_idx = n_eff * params.K_packed + (k >> 1u);
6269
let word = t_weight[byte_idx >> 2u];
6370
let b = (word >> ((byte_idx & 3u) * 8u)) & 0xFFu;
6471
var nib: u32;
6572
if ((k & 1u) == 0u) {
66-
nib = b & 0x0Fu; // even k -> low nibble
73+
nib = b & 0x0Fu; // even k -> low nibble
6774
} else {
6875
nib = (b >> 4u) & 0x0Fu; // odd k -> high nibble
6976
}
7077
let q = f32(i32(nib) - 8); // +8-shifted on pack; recover signed [-8,7]
71-
let scale = t_scales[(k / params.group_size) * params.padded_N + n];
72-
acc = acc + t_input[in_base + k] * q * scale;
73-
k = k + 1u;
78+
let dq = q * t_scales[(k / params.group_size) * params.padded_N + n_eff];
79+
for (var ml: u32 = 0u; ml < TM; ml = ml + 1u) {
80+
let m_eff = min(m0 + ml, params.M - 1u);
81+
acc[ml * TN + nl] = acc[ml * TN + nl] + t_input[m_eff * params.K + k] * dq;
82+
}
7483
}
75-
if (params.has_bias != 0u) {
76-
acc = acc + t_bias[n];
84+
k = k + 1u;
85+
}
86+
87+
for (var ml: u32 = 0u; ml < TM; ml = ml + 1u) {
88+
let m = m0 + ml;
89+
for (var nl: u32 = 0u; nl < TN; nl = nl + 1u) {
90+
let n = n0 + nl;
91+
if (m < params.M && n < params.N) {
92+
var v = acc[ml * TN + nl];
93+
if (params.has_bias != 0u) {
94+
v = v + t_bias[n];
95+
}
96+
t_out[m * params.N + n] = v;
97+
}
7798
}
78-
t_out[m * params.N + n] = acc;
79-
n = n + wg_size;
8099
}
81100
}
82101
)";

backends/webgpu/runtime/ops/sdpa/Sdpa.cpp

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@ namespace executorch::backends::webgpu {
2626

2727
namespace {
2828

29+
// Register-tile dims; MUST match TM/TN in the reg WGSL kernels.
30+
constexpr int64_t kSdpaTileM = 4;
31+
constexpr int64_t kSdpaTileN = 4;
32+
inline int64_t sdpa_ceil_div(int64_t a, int64_t b) {
33+
return (a + b - 1) / b;
34+
}
35+
2936
// Uniform param structs (all 16-byte aligned, matching the WGSL Params).
3037
struct UpdateCacheParams {
3138
uint32_t numel;
@@ -464,14 +471,16 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
464471
dynamic_pos,
465472
"update_cache(V)");
466473

467-
// --- Dispatch 3: QK -> attn_weights. One thread per (h,s,c) element.
474+
// --- Dispatch 3: QK -> attn_weights. One thread per TM x TN tile.
468475
{
469476
if (aw_floats > UINT32_MAX) {
470477
throw std::runtime_error(
471478
"WebGPU sdpa: Hq*S*context_len exceeds uint32 max");
472479
}
480+
const int64_t qk_tiles = Hq * sdpa_ceil_div(S, kSdpaTileM) *
481+
sdpa_ceil_div(context_len, kSdpaTileN);
473482
const uint32_t wgc = utils::compute_1d_workgroup_count(
474-
device, static_cast<uint32_t>(aw_floats), qk_wg, "QK");
483+
device, static_cast<uint32_t>(qk_tiles), qk_wg, "QK");
475484
AttnWeightsParams p = make_attn_weights_params(
476485
S, Hq, Hkv, D, context_len, input_pos, g, scale);
477486
WGPUBuffer ubuf = make_uniform_buffer(graph, &p, sizeof(p));
@@ -515,12 +524,12 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
515524
softmax_buf = ubuf;
516525
}
517526

518-
// --- Dispatch 5: AV -> out. One thread per (s,h,d) output element.
527+
// --- Dispatch 5: AV -> out. One thread per TM x TN tile.
519528
{
520-
const uint64_t out_floats = static_cast<uint64_t>(S) *
521-
static_cast<uint64_t>(Hq) * static_cast<uint64_t>(D);
529+
const int64_t av_tiles =
530+
Hq * sdpa_ceil_div(S, kSdpaTileM) * sdpa_ceil_div(D, kSdpaTileN);
522531
const uint32_t wgc = utils::compute_1d_workgroup_count(
523-
device, static_cast<uint32_t>(out_floats), av_wg, "AV");
532+
device, static_cast<uint32_t>(av_tiles), av_wg, "AV");
524533
ComputeOutParams p = make_compute_out_params(S, Hq, Hkv, D, context_len, g);
525534
WGPUBuffer ubuf = make_uniform_buffer(graph, &p, sizeof(p));
526535
BufferBinding bindings[3] = {
@@ -591,9 +600,11 @@ void sdpa_with_kv_cache_impl(WebGPUGraph& graph, const std::vector<int>& args) {
591600
AttnWeightsParams qp =
592601
make_attn_weights_params(S, Hq, Hkv, D, ctx, pos, g, scale);
593602
wgpuQueueWriteBuffer(gr.queue(), qk_buf, 0, &qp, sizeof(qp));
603+
const int64_t qk_tiles = Hq * sdpa_ceil_div(S, kSdpaTileM) *
604+
sdpa_ceil_div(ctx, kSdpaTileN);
594605
const uint32_t qk_wgc = utils::compute_1d_workgroup_count(
595606
gr.device(),
596-
static_cast<uint32_t>(aw_floats),
607+
static_cast<uint32_t>(qk_tiles),
597608
qk_wg,
598609
"QK(resize)");
599610
gr.dispatch_at(qk_idx).workgroup_count_x = qk_wgc;

0 commit comments

Comments
 (0)