Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion backends/webgpu/runtime/WebGPUUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@

namespace executorch::backends::webgpu::utils {

// Ceiling division for non-negative integers (mirrors Vulkan's utils::div_up).
template <typename T>
inline T div_up(T a, T b) {
return (a + b - 1) / b;
}

// Clamp workgroup size to device limit (SwiftShader caps at 128).
inline uint32_t clamp_workgroup_size(WGPUDevice device, uint32_t desired) {
WGPULimits limits = {};
Expand All @@ -34,7 +40,7 @@ inline uint32_t compute_1d_workgroup_count(
uint32_t num_threads,
uint32_t workgroup_size,
const char* op_name) {
uint32_t count = (num_threads + workgroup_size - 1) / workgroup_size;
uint32_t count = div_up(num_threads, workgroup_size);
WGPULimits limits = {};
uint32_t max_count =
wgpuDeviceGetLimits(device, &limits) == WGPUStatus_Success &&
Expand Down
20 changes: 15 additions & 5 deletions backends/webgpu/runtime/ops/quantized_linear/QuantizedLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ struct Q4gswParams {
};
static_assert(sizeof(Q4gswParams) == 32, "Q4gswParams must be 32 bytes");

// Register-tile dims; MUST match TM/TN in q4gsw_linear.wgsl.
constexpr int64_t kQ4gswTileM = 4;
constexpr int64_t kQ4gswTileN = 4;

// et_vk.linear_q4gsw args: [in, weight, scales, group_size, bias, out].
void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {
const int in_id = args.at(0);
Expand Down Expand Up @@ -85,9 +89,17 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {
"WebGPU linear_q4gsw: N*K_packed must be a multiple of 4 (u32-packed)");
}

// One workgroup per output row (M); validate dispatch before any alloc.
const uint32_t workgroup_count =
utils::compute_1d_workgroup_count(device, M, 1, "linear_q4gsw");
// Register-tiled GEMM: one thread per TM x TN tile; validate before alloc.
const uint32_t wg_size =
utils::clamp_workgroup_size(device, kQ4gswLinearWorkgroupSizeX);
const int64_t total_tiles = utils::div_up<int64_t>(M, kQ4gswTileM) *
utils::div_up<int64_t>(N, kQ4gswTileN);
if (total_tiles > static_cast<int64_t>(UINT32_MAX)) {
throw std::runtime_error(
"WebGPU linear_q4gsw: tile count exceeds the 1D dispatch limit");
}
const uint32_t workgroup_count = utils::compute_1d_workgroup_count(
device, static_cast<uint32_t>(total_tiles), wg_size, "linear_q4gsw");

// fp32-only byte-size guards (no runtime dtype); fp16 scales -> bail.
const uint64_t scales_numel =
Expand Down Expand Up @@ -186,8 +198,6 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {
WGPUPipelineLayout pipeline_layout =
wgpuDeviceCreatePipelineLayout(device, &pl_desc);

const uint32_t wg_size =
utils::clamp_workgroup_size(device, kQ4gswLinearWorkgroupSizeX);
WGPUConstantEntry wg_size_constant = {};
wg_size_constant.key = {"wg_size", WGPU_STRLEN};
wg_size_constant.value = static_cast<double>(wg_size);
Expand Down
77 changes: 52 additions & 25 deletions backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -18,47 +18,74 @@ struct Params {

override wg_size: u32 = 64u;

// One workgroup per row m, threads stride N; loop logical K only (in-bounds).
// Register-tiled GEMM: dequant weight once per (n,k), reused across TM rows.
const TM: u32 = 4u;
const TN: u32 = 4u;
const TILE_ELEMS: u32 = TM * TN; // accumulator size; keeps acc in sync with TM/TN

@compute @workgroup_size(wg_size, 1, 1)
fn main(
@builtin(workgroup_id) wid: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>) {
let m = wid.x;
if (m >= params.M) {
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let nrt = (params.M + TM - 1u) / TM;
let nct = (params.N + TN - 1u) / TN;
let tiles = nrt * nct;
// M==0 or N==0 -> tiles==0 -> every thread returns here, so the M-1u/N-1u
// clamps below never underflow (the host also rejects M==0/N==0).
if (gid.x >= tiles) {
return;
}
let in_base = m * params.K;
let row_tile = gid.x / nct;
let col_tile = gid.x % nct;
let m0 = row_tile * TM;
let n0 = col_tile * TN;

var acc: array<f32, TILE_ELEMS>;
for (var i: u32 = 0u; i < TILE_ELEMS; i = i + 1u) {
acc[i] = 0.0;
}

var n: u32 = lid.x;
var k: u32 = 0u;
loop {
if (n >= params.N) {
if (k >= params.K) {
break;
}
var acc: f32 = 0.0;
var k: u32 = 0u;
loop {
if (k >= params.K) {
break;
}
// Packed weight byte for (n, k): row stride K_packed bytes, byte k/2.
let byte_idx = n * params.K_packed + (k >> 1u);
// Load the TM input values for column k once; reused across all TN columns.
var in_reg: array<f32, TM>;
for (var ml: u32 = 0u; ml < TM; ml = ml + 1u) {
let m_eff = min(m0 + ml, params.M - 1u);
in_reg[ml] = t_input[m_eff * params.K + k];
}
for (var nl: u32 = 0u; nl < TN; nl = nl + 1u) {
// Clamp to last valid column; overhang result is never stored.
let n_eff = min(n0 + nl, params.N - 1u);
let byte_idx = n_eff * params.K_packed + (k >> 1u);
let word = t_weight[byte_idx >> 2u];
let b = (word >> ((byte_idx & 3u) * 8u)) & 0xFFu;
var nib: u32;
if ((k & 1u) == 0u) {
nib = b & 0x0Fu; // even k -> low nibble
nib = b & 0x0Fu; // even k -> low nibble
} else {
nib = (b >> 4u) & 0x0Fu; // odd k -> high nibble
}
let q = f32(i32(nib) - 8); // +8-shifted on pack; recover signed [-8,7]
let scale = t_scales[(k / params.group_size) * params.padded_N + n];
acc = acc + t_input[in_base + k] * q * scale;
k = k + 1u;
let dq = q * t_scales[(k / params.group_size) * params.padded_N + n_eff];
for (var ml: u32 = 0u; ml < TM; ml = ml + 1u) {
acc[ml * TN + nl] = acc[ml * TN + nl] + in_reg[ml] * dq;
}
}
if (params.has_bias != 0u) {
acc = acc + t_bias[n];
k = k + 1u;
}

for (var ml: u32 = 0u; ml < TM; ml = ml + 1u) {
let m = m0 + ml;
for (var nl: u32 = 0u; nl < TN; nl = nl + 1u) {
let n = n0 + nl;
if (m < params.M && n < params.N) {
var v = acc[ml * TN + nl];
if (params.has_bias != 0u) {
v = v + t_bias[n];
}
t_out[m * params.N + n] = v;
}
}
t_out[m * params.N + n] = acc;
n = n + wg_size;
}
}
79 changes: 53 additions & 26 deletions backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_wgsl.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
namespace executorch::backends::webgpu {

// @generated from q4gsw_linear.wgsl - DO NOT EDIT.
// wgsl-sha256: 966cec5d4102eb7c8f6504d2a335a1bd2f235424933fe83b4d0f8f274d894f39
// wgsl-sha256: dc6a55014ae4543bd80e5e22c3fb52896aca96e0589f700803327d8121ada489
inline constexpr const char* kQ4gswLinearWGSL = R"(
@group(0) @binding(0) var<storage, read_write> t_out: array<f32>;
@group(0) @binding(1) var<storage, read> t_input: array<f32>;
Expand All @@ -35,48 +35,75 @@ struct Params {

override wg_size: u32 = 64u;

// One workgroup per row m, threads stride N; loop logical K only (in-bounds).
// Register-tiled GEMM: dequant weight once per (n,k), reused across TM rows.
const TM: u32 = 4u;
const TN: u32 = 4u;
const TILE_ELEMS: u32 = TM * TN; // accumulator size; keeps acc in sync with TM/TN

@compute @workgroup_size(wg_size, 1, 1)
fn main(
@builtin(workgroup_id) wid: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>) {
let m = wid.x;
if (m >= params.M) {
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let nrt = (params.M + TM - 1u) / TM;
let nct = (params.N + TN - 1u) / TN;
let tiles = nrt * nct;
// M==0 or N==0 -> tiles==0 -> every thread returns here, so the M-1u/N-1u
// clamps below never underflow (the host also rejects M==0/N==0).
if (gid.x >= tiles) {
return;
}
let in_base = m * params.K;
let row_tile = gid.x / nct;
let col_tile = gid.x % nct;
let m0 = row_tile * TM;
let n0 = col_tile * TN;

var acc: array<f32, TILE_ELEMS>;
for (var i: u32 = 0u; i < TILE_ELEMS; i = i + 1u) {
acc[i] = 0.0;
}

var n: u32 = lid.x;
var k: u32 = 0u;
loop {
if (n >= params.N) {
if (k >= params.K) {
break;
}
var acc: f32 = 0.0;
var k: u32 = 0u;
loop {
if (k >= params.K) {
break;
}
// Packed weight byte for (n, k): row stride K_packed bytes, byte k/2.
let byte_idx = n * params.K_packed + (k >> 1u);
// Load the TM input values for column k once; reused across all TN columns.
var in_reg: array<f32, TM>;
for (var ml: u32 = 0u; ml < TM; ml = ml + 1u) {
let m_eff = min(m0 + ml, params.M - 1u);
in_reg[ml] = t_input[m_eff * params.K + k];
}
for (var nl: u32 = 0u; nl < TN; nl = nl + 1u) {
// Clamp to last valid column; overhang result is never stored.
let n_eff = min(n0 + nl, params.N - 1u);
let byte_idx = n_eff * params.K_packed + (k >> 1u);
let word = t_weight[byte_idx >> 2u];
let b = (word >> ((byte_idx & 3u) * 8u)) & 0xFFu;
var nib: u32;
if ((k & 1u) == 0u) {
nib = b & 0x0Fu; // even k -> low nibble
nib = b & 0x0Fu; // even k -> low nibble
} else {
nib = (b >> 4u) & 0x0Fu; // odd k -> high nibble
}
let q = f32(i32(nib) - 8); // +8-shifted on pack; recover signed [-8,7]
let scale = t_scales[(k / params.group_size) * params.padded_N + n];
acc = acc + t_input[in_base + k] * q * scale;
k = k + 1u;
let dq = q * t_scales[(k / params.group_size) * params.padded_N + n_eff];
for (var ml: u32 = 0u; ml < TM; ml = ml + 1u) {
acc[ml * TN + nl] = acc[ml * TN + nl] + in_reg[ml] * dq;
}
}
if (params.has_bias != 0u) {
acc = acc + t_bias[n];
k = k + 1u;
}

for (var ml: u32 = 0u; ml < TM; ml = ml + 1u) {
let m = m0 + ml;
for (var nl: u32 = 0u; nl < TN; nl = nl + 1u) {
let n = n0 + nl;
if (m < params.M && n < params.N) {
var v = acc[ml * TN + nl];
if (params.has_bias != 0u) {
v = v + t_bias[n];
}
t_out[m * params.N + n] = v;
}
}
t_out[m * params.N + n] = acc;
n = n + wg_size;
}
}
)";
Expand Down
Loading