Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions backends/webgpu/runtime/WebGPUUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,16 @@ make_uniform(WGPUDevice device, const void* data, size_t size) {
return buf;
}

// Clamp a 1D workgroup count to the device limit, for grid-stride kernels that
// loop over any excess work (vs compute_1d_workgroup_count, which throws).
inline uint32_t clamp_workgroup_count(WGPUDevice device, uint32_t desired) {
WGPULimits limits = {};
uint32_t max_count =
wgpuDeviceGetLimits(device, &limits) == WGPUStatus_Success &&
limits.maxComputeWorkgroupsPerDimension > 0
? limits.maxComputeWorkgroupsPerDimension
: 65535u; // WebGPU spec-default floor
return std::min(desired, max_count);
}

} // namespace executorch::backends::webgpu::utils
49 changes: 34 additions & 15 deletions backends/webgpu/runtime/ops/quantized_linear/QuantizedLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <executorch/backends/webgpu/runtime/WebGPUGraph.h>
#include <executorch/backends/webgpu/runtime/WebGPUUtils.h>
#include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h>
#include <executorch/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_coop4_wgsl.h>
#include <executorch/backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_wgsl.h>

#include <webgpu/webgpu.h>
Expand Down Expand Up @@ -89,18 +90,6 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {
"WebGPU linear_q4gsw: N*K_packed must be a multiple of 4 (u32-packed)");
}

// Register-tiled GEMM: one thread per TM x TN tile; validate before alloc.
const uint32_t wg_size =
utils::clamp_workgroup_size(device, kQ4gswLinearWorkgroupSizeX);
const int64_t total_tiles = utils::div_up<int64_t>(M, kQ4gswTileM) *
utils::div_up<int64_t>(N, kQ4gswTileN);
if (total_tiles > static_cast<int64_t>(UINT32_MAX)) {
throw std::runtime_error(
"WebGPU linear_q4gsw: tile count exceeds the 1D dispatch limit");
}
const uint32_t workgroup_count = utils::compute_1d_workgroup_count(
device, static_cast<uint32_t>(total_tiles), wg_size, "linear_q4gsw");

// fp32-only byte-size guards (no runtime dtype); fp16 scales -> bail.
const uint64_t scales_numel =
static_cast<uint64_t>(num_groups) * static_cast<uint64_t>(padded_N);
Expand Down Expand Up @@ -128,6 +117,35 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {
"WebGPU linear_q4gsw: scales dims too small for K/N");
}

// M==1 decode -> coop4 GEMV (needs K%8==0 && gs%8==0); else tiled GEMM.
const uint32_t wg_size =
utils::clamp_workgroup_size(device, kQ4gswLinearWorkgroupSizeX);
const bool use_gemv = (M == 1u && K % 8u == 0u && gs % 8u == 0u);
const char* shader_src = use_gemv ? kQ4gswLinearCoop4WGSL : kQ4gswLinearWGSL;
uint32_t workgroup_count;
if (use_gemv) {
// coop4: fixed 64 lanes, 1 workgroup per output, grid-strided over M*N.
const uint64_t outputs =
static_cast<uint64_t>(M) * static_cast<uint64_t>(N);
if (outputs == 0u || outputs > UINT32_MAX) {
throw std::runtime_error("WebGPU linear_q4gsw: M*N out of range");
}
workgroup_count =
utils::clamp_workgroup_count(device, static_cast<uint32_t>(outputs));
if (workgroup_count == 0u) {
throw std::runtime_error("WebGPU linear_q4gsw: zero GEMV dispatch");
}
} else {
const int64_t total_tiles = utils::div_up<int64_t>(M, kQ4gswTileM) *
utils::div_up<int64_t>(N, kQ4gswTileN);
if (total_tiles > static_cast<int64_t>(UINT32_MAX)) {
throw std::runtime_error(
"WebGPU linear_q4gsw: tile count exceeds the 1D dispatch limit");
}
workgroup_count = utils::compute_1d_workgroup_count(
device, static_cast<uint32_t>(total_tiles), wg_size, "linear_q4gsw");
}

// Optional bias: real buffer if present, else a dummy for the fixed layout.
uint32_t has_bias = 0;
WGPUBuffer bias_buffer = nullptr;
Expand Down Expand Up @@ -168,7 +186,7 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {

WGPUShaderSourceWGSL wgsl_desc = {};
wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL;
wgsl_desc.code = {kQ4gswLinearWGSL, WGPU_STRLEN};
wgsl_desc.code = {shader_src, WGPU_STRLEN};
WGPUShaderModuleDescriptor shader_desc = {};
shader_desc.nextInChain = &wgsl_desc.chain;
WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device, &shader_desc);
Expand Down Expand Up @@ -206,8 +224,9 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {
pipeline_desc.layout = pipeline_layout;
pipeline_desc.compute.module = shader;
pipeline_desc.compute.entryPoint = {"main", WGPU_STRLEN};
pipeline_desc.compute.constantCount = 1;
pipeline_desc.compute.constants = &wg_size_constant;
// coop4 GEMV uses fixed @workgroup_size(64); only the GEMM has an override.
pipeline_desc.compute.constantCount = use_gemv ? 0u : 1u;
pipeline_desc.compute.constants = use_gemv ? nullptr : &wg_size_constant;
WGPUComputePipeline pipeline =
wgpuDeviceCreateComputePipeline(device, &pipeline_desc);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
@group(0) @binding(0) var<storage, read_write> t_out: array<f32>;
@group(0) @binding(1) var<storage, read> t_input: array<f32>;
@group(0) @binding(2) var<storage, read> t_weight: array<u32>;
@group(0) @binding(3) var<storage, read> t_scales: array<f32>;
@group(0) @binding(4) var<storage, read> t_bias: array<f32>;

struct Params {
M: u32,
N: u32,
K: u32,
K_packed: u32,
group_size: u32,
padded_N: u32,
has_bias: u32,
_pad: u32,
}
@group(0) @binding(5) var<uniform> params: Params;

// Cooperative-over-K GEMV with u32-batched coalesced weight loads (64 lanes).
const WG: u32 = 64u;
var<workgroup> partial: array<f32, WG>;

@compute @workgroup_size(WG, 1, 1)
fn main(
@builtin(workgroup_id) wid: vec3<u32>,
@builtin(num_workgroups) ngrp: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>) {
let total = params.M * params.N;
let stride = ngrp.x;
let num_words = params.K >> 3u; // K / 8 words per row
let row_words = params.K_packed >> 2u; // u32s per weight row (= K/8)
var idx = wid.x;
loop {
if (idx >= total) {
break;
}
let m = idx / params.N;
let n = idx % params.N;
let in_base = m * params.K;
let wbase = n * row_words;

var acc: f32 = 0.0;
var w: u32 = lid.x;
loop {
if (w >= num_words) {
break;
}
let word = t_weight[wbase + w];
let k0 = w << 3u; // first K of this word
let scale = t_scales[(k0 / params.group_size) * params.padded_N + n];
let ib = in_base + k0;
// 4 bytes, low+high nibble each -> 8 consecutive K.
for (var bi: u32 = 0u; bi < 4u; bi = bi + 1u) {
let byte = (word >> (bi * 8u)) & 0xFFu;
let lo = f32(i32(byte & 0x0Fu) - 8);
let hi = f32(i32((byte >> 4u) & 0x0Fu) - 8);
let kk = bi << 1u;
acc = acc + t_input[ib + kk] * lo * scale;
acc = acc + t_input[ib + kk + 1u] * hi * scale;
}
w = w + WG;
}

partial[lid.x] = acc;
workgroupBarrier();
var s: u32 = WG >> 1u;
loop {
if (s == 0u) {
break;
}
if (lid.x < s) {
partial[lid.x] = partial[lid.x] + partial[lid.x + s];
}
workgroupBarrier();
s = s >> 1u;
}
if (lid.x == 0u) {
var o = partial[0];
if (params.has_bias != 0u) {
o = o + t_bias[n];
}
t_out[idx] = o;
}
workgroupBarrier();
idx = idx + stride;
}
}
111 changes: 111 additions & 0 deletions backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_coop4_wgsl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <cstdint>

namespace executorch::backends::webgpu {

// @generated from q4gsw_linear_coop4.wgsl - DO NOT EDIT.
// wgsl-sha256: 3031886e68c375e617dfb263da39c492c6de4d8c1fb4073d70b18823a3e6a4fe
inline constexpr const char* kQ4gswLinearCoop4WGSL = R"(
@group(0) @binding(0) var<storage, read_write> t_out: array<f32>;
@group(0) @binding(1) var<storage, read> t_input: array<f32>;
@group(0) @binding(2) var<storage, read> t_weight: array<u32>;
@group(0) @binding(3) var<storage, read> t_scales: array<f32>;
@group(0) @binding(4) var<storage, read> t_bias: array<f32>;

struct Params {
M: u32,
N: u32,
K: u32,
K_packed: u32,
group_size: u32,
padded_N: u32,
has_bias: u32,
_pad: u32,
}
@group(0) @binding(5) var<uniform> params: Params;

// Cooperative-over-K GEMV with u32-batched coalesced weight loads (64 lanes).
const WG: u32 = 64u;
var<workgroup> partial: array<f32, WG>;

@compute @workgroup_size(WG, 1, 1)
fn main(
@builtin(workgroup_id) wid: vec3<u32>,
@builtin(num_workgroups) ngrp: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>) {
let total = params.M * params.N;
let stride = ngrp.x;
let num_words = params.K >> 3u; // K / 8 words per row
let row_words = params.K_packed >> 2u; // u32s per weight row (= K/8)
var idx = wid.x;
loop {
if (idx >= total) {
break;
}
let m = idx / params.N;
let n = idx % params.N;
let in_base = m * params.K;
let wbase = n * row_words;

var acc: f32 = 0.0;
var w: u32 = lid.x;
loop {
if (w >= num_words) {
break;
}
let word = t_weight[wbase + w];
let k0 = w << 3u; // first K of this word
let scale = t_scales[(k0 / params.group_size) * params.padded_N + n];
let ib = in_base + k0;
// 4 bytes, low+high nibble each -> 8 consecutive K.
for (var bi: u32 = 0u; bi < 4u; bi = bi + 1u) {
let byte = (word >> (bi * 8u)) & 0xFFu;
let lo = f32(i32(byte & 0x0Fu) - 8);
let hi = f32(i32((byte >> 4u) & 0x0Fu) - 8);
let kk = bi << 1u;
acc = acc + t_input[ib + kk] * lo * scale;
acc = acc + t_input[ib + kk + 1u] * hi * scale;
}
w = w + WG;
}

partial[lid.x] = acc;
workgroupBarrier();
var s: u32 = WG >> 1u;
loop {
if (s == 0u) {
break;
}
if (lid.x < s) {
partial[lid.x] = partial[lid.x] + partial[lid.x + s];
}
workgroupBarrier();
s = s >> 1u;
}
if (lid.x == 0u) {
var o = partial[0];
if (params.has_bias != 0u) {
o = o + t_bias[n];
}
t_out[idx] = o;
}
workgroupBarrier();
idx = idx + stride;
}
}
)";

inline constexpr uint32_t kQ4gswLinearCoop4WorkgroupSizeX = 64;
inline constexpr uint32_t kQ4gswLinearCoop4WorkgroupSizeY = 1;
inline constexpr uint32_t kQ4gswLinearCoop4WorkgroupSizeZ = 1;

} // namespace executorch::backends::webgpu
Loading