diff --git a/backends/webgpu/runtime/WebGPUUtils.h b/backends/webgpu/runtime/WebGPUUtils.h index 39eb3caa28b..293ad495be2 100644 --- a/backends/webgpu/runtime/WebGPUUtils.h +++ b/backends/webgpu/runtime/WebGPUUtils.h @@ -18,6 +18,12 @@ namespace executorch::backends::webgpu::utils { +// Ceiling division for non-negative integers (mirrors Vulkan's utils::div_up). +template +inline T div_up(T a, T b) { + return (a + b - 1) / b; +} + // Clamp workgroup size to device limit (SwiftShader caps at 128). inline uint32_t clamp_workgroup_size(WGPUDevice device, uint32_t desired) { WGPULimits limits = {}; @@ -34,7 +40,7 @@ inline uint32_t compute_1d_workgroup_count( uint32_t num_threads, uint32_t workgroup_size, const char* op_name) { - uint32_t count = (num_threads + workgroup_size - 1) / workgroup_size; + uint32_t count = div_up(num_threads, workgroup_size); WGPULimits limits = {}; uint32_t max_count = wgpuDeviceGetLimits(device, &limits) == WGPUStatus_Success && diff --git a/backends/webgpu/runtime/ops/quantized_linear/QuantizedLinear.cpp b/backends/webgpu/runtime/ops/quantized_linear/QuantizedLinear.cpp index 2597aea10d4..89e722cdb9a 100644 --- a/backends/webgpu/runtime/ops/quantized_linear/QuantizedLinear.cpp +++ b/backends/webgpu/runtime/ops/quantized_linear/QuantizedLinear.cpp @@ -34,6 +34,10 @@ struct Q4gswParams { }; static_assert(sizeof(Q4gswParams) == 32, "Q4gswParams must be 32 bytes"); +// Register-tile dims; MUST match TM/TN in q4gsw_linear.wgsl. +constexpr int64_t kQ4gswTileM = 4; +constexpr int64_t kQ4gswTileN = 4; + // et_vk.linear_q4gsw args: [in, weight, scales, group_size, bias, out]. void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector& args) { const int in_id = args.at(0); @@ -85,9 +89,17 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector& args) { "WebGPU linear_q4gsw: N*K_packed must be a multiple of 4 (u32-packed)"); } - // One workgroup per output row (M); validate dispatch before any alloc. - const uint32_t workgroup_count = - utils::compute_1d_workgroup_count(device, M, 1, "linear_q4gsw"); + // Register-tiled GEMM: one thread per TM x TN tile; validate before alloc. + const uint32_t wg_size = + utils::clamp_workgroup_size(device, kQ4gswLinearWorkgroupSizeX); + const int64_t total_tiles = utils::div_up(M, kQ4gswTileM) * + utils::div_up(N, kQ4gswTileN); + if (total_tiles > static_cast(UINT32_MAX)) { + throw std::runtime_error( + "WebGPU linear_q4gsw: tile count exceeds the 1D dispatch limit"); + } + const uint32_t workgroup_count = utils::compute_1d_workgroup_count( + device, static_cast(total_tiles), wg_size, "linear_q4gsw"); // fp32-only byte-size guards (no runtime dtype); fp16 scales -> bail. const uint64_t scales_numel = @@ -186,8 +198,6 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector& args) { WGPUPipelineLayout pipeline_layout = wgpuDeviceCreatePipelineLayout(device, &pl_desc); - const uint32_t wg_size = - utils::clamp_workgroup_size(device, kQ4gswLinearWorkgroupSizeX); WGPUConstantEntry wg_size_constant = {}; wg_size_constant.key = {"wg_size", WGPU_STRLEN}; wg_size_constant.value = static_cast(wg_size); diff --git a/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear.wgsl b/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear.wgsl index d0d6e155987..8cea61d331c 100644 --- a/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear.wgsl +++ b/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear.wgsl @@ -18,47 +18,74 @@ struct Params { override wg_size: u32 = 64u; -// One workgroup per row m, threads stride N; loop logical K only (in-bounds). +// Register-tiled GEMM: dequant weight once per (n,k), reused across TM rows. +const TM: u32 = 4u; +const TN: u32 = 4u; +const TILE_ELEMS: u32 = TM * TN; // accumulator size; keeps acc in sync with TM/TN + @compute @workgroup_size(wg_size, 1, 1) -fn main( - @builtin(workgroup_id) wid: vec3, - @builtin(local_invocation_id) lid: vec3) { - let m = wid.x; - if (m >= params.M) { +fn main(@builtin(global_invocation_id) gid: vec3) { + let nrt = (params.M + TM - 1u) / TM; + let nct = (params.N + TN - 1u) / TN; + let tiles = nrt * nct; + // M==0 or N==0 -> tiles==0 -> every thread returns here, so the M-1u/N-1u + // clamps below never underflow (the host also rejects M==0/N==0). + if (gid.x >= tiles) { return; } - let in_base = m * params.K; + let row_tile = gid.x / nct; + let col_tile = gid.x % nct; + let m0 = row_tile * TM; + let n0 = col_tile * TN; + + var acc: array; + for (var i: u32 = 0u; i < TILE_ELEMS; i = i + 1u) { + acc[i] = 0.0; + } - var n: u32 = lid.x; + var k: u32 = 0u; loop { - if (n >= params.N) { + if (k >= params.K) { break; } - var acc: f32 = 0.0; - var k: u32 = 0u; - loop { - if (k >= params.K) { - break; - } - // Packed weight byte for (n, k): row stride K_packed bytes, byte k/2. - let byte_idx = n * params.K_packed + (k >> 1u); + // Load the TM input values for column k once; reused across all TN columns. + var in_reg: array; + for (var ml: u32 = 0u; ml < TM; ml = ml + 1u) { + let m_eff = min(m0 + ml, params.M - 1u); + in_reg[ml] = t_input[m_eff * params.K + k]; + } + for (var nl: u32 = 0u; nl < TN; nl = nl + 1u) { + // Clamp to last valid column; overhang result is never stored. + let n_eff = min(n0 + nl, params.N - 1u); + let byte_idx = n_eff * params.K_packed + (k >> 1u); let word = t_weight[byte_idx >> 2u]; let b = (word >> ((byte_idx & 3u) * 8u)) & 0xFFu; var nib: u32; if ((k & 1u) == 0u) { - nib = b & 0x0Fu; // even k -> low nibble + nib = b & 0x0Fu; // even k -> low nibble } else { nib = (b >> 4u) & 0x0Fu; // odd k -> high nibble } let q = f32(i32(nib) - 8); // +8-shifted on pack; recover signed [-8,7] - let scale = t_scales[(k / params.group_size) * params.padded_N + n]; - acc = acc + t_input[in_base + k] * q * scale; - k = k + 1u; + let dq = q * t_scales[(k / params.group_size) * params.padded_N + n_eff]; + for (var ml: u32 = 0u; ml < TM; ml = ml + 1u) { + acc[ml * TN + nl] = acc[ml * TN + nl] + in_reg[ml] * dq; + } } - if (params.has_bias != 0u) { - acc = acc + t_bias[n]; + k = k + 1u; + } + + for (var ml: u32 = 0u; ml < TM; ml = ml + 1u) { + let m = m0 + ml; + for (var nl: u32 = 0u; nl < TN; nl = nl + 1u) { + let n = n0 + nl; + if (m < params.M && n < params.N) { + var v = acc[ml * TN + nl]; + if (params.has_bias != 0u) { + v = v + t_bias[n]; + } + t_out[m * params.N + n] = v; + } } - t_out[m * params.N + n] = acc; - n = n + wg_size; } } diff --git a/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_wgsl.h b/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_wgsl.h index d176a01d27f..69494bbc947 100644 --- a/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_wgsl.h +++ b/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_wgsl.h @@ -13,7 +13,7 @@ namespace executorch::backends::webgpu { // @generated from q4gsw_linear.wgsl - DO NOT EDIT. -// wgsl-sha256: 966cec5d4102eb7c8f6504d2a335a1bd2f235424933fe83b4d0f8f274d894f39 +// wgsl-sha256: dc6a55014ae4543bd80e5e22c3fb52896aca96e0589f700803327d8121ada489 inline constexpr const char* kQ4gswLinearWGSL = R"( @group(0) @binding(0) var t_out: array; @group(0) @binding(1) var t_input: array; @@ -35,48 +35,75 @@ struct Params { override wg_size: u32 = 64u; -// One workgroup per row m, threads stride N; loop logical K only (in-bounds). +// Register-tiled GEMM: dequant weight once per (n,k), reused across TM rows. +const TM: u32 = 4u; +const TN: u32 = 4u; +const TILE_ELEMS: u32 = TM * TN; // accumulator size; keeps acc in sync with TM/TN + @compute @workgroup_size(wg_size, 1, 1) -fn main( - @builtin(workgroup_id) wid: vec3, - @builtin(local_invocation_id) lid: vec3) { - let m = wid.x; - if (m >= params.M) { +fn main(@builtin(global_invocation_id) gid: vec3) { + let nrt = (params.M + TM - 1u) / TM; + let nct = (params.N + TN - 1u) / TN; + let tiles = nrt * nct; + // M==0 or N==0 -> tiles==0 -> every thread returns here, so the M-1u/N-1u + // clamps below never underflow (the host also rejects M==0/N==0). + if (gid.x >= tiles) { return; } - let in_base = m * params.K; + let row_tile = gid.x / nct; + let col_tile = gid.x % nct; + let m0 = row_tile * TM; + let n0 = col_tile * TN; + + var acc: array; + for (var i: u32 = 0u; i < TILE_ELEMS; i = i + 1u) { + acc[i] = 0.0; + } - var n: u32 = lid.x; + var k: u32 = 0u; loop { - if (n >= params.N) { + if (k >= params.K) { break; } - var acc: f32 = 0.0; - var k: u32 = 0u; - loop { - if (k >= params.K) { - break; - } - // Packed weight byte for (n, k): row stride K_packed bytes, byte k/2. - let byte_idx = n * params.K_packed + (k >> 1u); + // Load the TM input values for column k once; reused across all TN columns. + var in_reg: array; + for (var ml: u32 = 0u; ml < TM; ml = ml + 1u) { + let m_eff = min(m0 + ml, params.M - 1u); + in_reg[ml] = t_input[m_eff * params.K + k]; + } + for (var nl: u32 = 0u; nl < TN; nl = nl + 1u) { + // Clamp to last valid column; overhang result is never stored. + let n_eff = min(n0 + nl, params.N - 1u); + let byte_idx = n_eff * params.K_packed + (k >> 1u); let word = t_weight[byte_idx >> 2u]; let b = (word >> ((byte_idx & 3u) * 8u)) & 0xFFu; var nib: u32; if ((k & 1u) == 0u) { - nib = b & 0x0Fu; // even k -> low nibble + nib = b & 0x0Fu; // even k -> low nibble } else { nib = (b >> 4u) & 0x0Fu; // odd k -> high nibble } let q = f32(i32(nib) - 8); // +8-shifted on pack; recover signed [-8,7] - let scale = t_scales[(k / params.group_size) * params.padded_N + n]; - acc = acc + t_input[in_base + k] * q * scale; - k = k + 1u; + let dq = q * t_scales[(k / params.group_size) * params.padded_N + n_eff]; + for (var ml: u32 = 0u; ml < TM; ml = ml + 1u) { + acc[ml * TN + nl] = acc[ml * TN + nl] + in_reg[ml] * dq; + } } - if (params.has_bias != 0u) { - acc = acc + t_bias[n]; + k = k + 1u; + } + + for (var ml: u32 = 0u; ml < TM; ml = ml + 1u) { + let m = m0 + ml; + for (var nl: u32 = 0u; nl < TN; nl = nl + 1u) { + let n = n0 + nl; + if (m < params.M && n < params.N) { + var v = acc[ml * TN + nl]; + if (params.has_bias != 0u) { + v = v + t_bias[n]; + } + t_out[m * params.N + n] = v; + } } - t_out[m * params.N + n] = acc; - n = n + wg_size; } } )";