Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions backends/webgpu/runtime/ops/quantized_linear/QuantizedLinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ struct Q4gswParams {
};
static_assert(sizeof(Q4gswParams) == 32, "Q4gswParams must be 32 bytes");

// Register-tile dims; MUST match TM/TN in q4gsw_linear.wgsl.
constexpr int64_t kQ4gswTileM = 4;
constexpr int64_t kQ4gswTileN = 4;
// ceil(a/b) for positive int64 (WebGPUUtils has no ceil-div helper).
inline int64_t q4gsw_ceil_div(int64_t a, int64_t b) {
return (a + b - 1) / b;
}

// et_vk.linear_q4gsw args: [in, weight, scales, group_size, bias, out].
void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {
const int in_id = args.at(0);
Expand Down Expand Up @@ -85,9 +93,17 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {
"WebGPU linear_q4gsw: N*K_packed must be a multiple of 4 (u32-packed)");
}

// One workgroup per output row (M); validate dispatch before any alloc.
const uint32_t workgroup_count =
utils::compute_1d_workgroup_count(device, M, 1, "linear_q4gsw");
// Register-tiled GEMM: one thread per TM x TN tile; validate before alloc.
const uint32_t wg_size =
utils::clamp_workgroup_size(device, kQ4gswLinearWorkgroupSizeX);
const int64_t total_tiles =
q4gsw_ceil_div(M, kQ4gswTileM) * q4gsw_ceil_div(N, kQ4gswTileN);
if (total_tiles > static_cast<int64_t>(UINT32_MAX)) {
throw std::runtime_error(
"WebGPU linear_q4gsw: tile count exceeds the 1D dispatch limit");
}
const uint32_t workgroup_count = utils::compute_1d_workgroup_count(
device, static_cast<uint32_t>(total_tiles), wg_size, "linear_q4gsw");

// fp32-only byte-size guards (no runtime dtype); fp16 scales -> bail.
const uint64_t scales_numel =
Expand Down Expand Up @@ -186,8 +202,6 @@ void q4gsw_linear_impl(WebGPUGraph& graph, const std::vector<int>& args) {
WGPUPipelineLayout pipeline_layout =
wgpuDeviceCreatePipelineLayout(device, &pl_desc);

const uint32_t wg_size =
utils::clamp_workgroup_size(device, kQ4gswLinearWorkgroupSizeX);
WGPUConstantEntry wg_size_constant = {};
wg_size_constant.key = {"wg_size", WGPU_STRLEN};
wg_size_constant.value = static_cast<double>(wg_size);
Expand Down
74 changes: 49 additions & 25 deletions backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -18,47 +18,71 @@ struct Params {

override wg_size: u32 = 64u;

// One workgroup per row m, threads stride N; loop logical K only (in-bounds).
// Register-tiled GEMM: dequant weight once per (n,k), reused across TM rows.
const TM: u32 = 4u;
const TN: u32 = 4u;

@compute @workgroup_size(wg_size, 1, 1)
fn main(
@builtin(workgroup_id) wid: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>) {
let m = wid.x;
if (m >= params.M) {
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let nrt = (params.M + TM - 1u) / TM;
let nct = (params.N + TN - 1u) / TN;
let tiles = nrt * nct;
if (gid.x >= tiles) {
return;
}
let in_base = m * params.K;
let row_tile = gid.x / nct;
let col_tile = gid.x % nct;
let m0 = row_tile * TM;
let n0 = col_tile * TN;

var acc: array<f32, 16>; // TM * TN
for (var i: u32 = 0u; i < TM * TN; i = i + 1u) {
acc[i] = 0.0;
}

var n: u32 = lid.x;
var k: u32 = 0u;
loop {
if (n >= params.N) {
if (k >= params.K) {
break;
}
var acc: f32 = 0.0;
var k: u32 = 0u;
loop {
if (k >= params.K) {
break;
}
// Packed weight byte for (n, k): row stride K_packed bytes, byte k/2.
let byte_idx = n * params.K_packed + (k >> 1u);
// Load the TM input values for column k once; reused across all TN columns.
var in_reg: array<f32, TM>;
for (var ml: u32 = 0u; ml < TM; ml = ml + 1u) {
let m_eff = min(m0 + ml, params.M - 1u);
in_reg[ml] = t_input[m_eff * params.K + k];
}
for (var nl: u32 = 0u; nl < TN; nl = nl + 1u) {
// Clamp to last valid column; overhang result is never stored.
let n_eff = min(n0 + nl, params.N - 1u);
let byte_idx = n_eff * params.K_packed + (k >> 1u);
let word = t_weight[byte_idx >> 2u];
let b = (word >> ((byte_idx & 3u) * 8u)) & 0xFFu;
var nib: u32;
if ((k & 1u) == 0u) {
nib = b & 0x0Fu; // even k -> low nibble
nib = b & 0x0Fu; // even k -> low nibble
} else {
nib = (b >> 4u) & 0x0Fu; // odd k -> high nibble
}
let q = f32(i32(nib) - 8); // +8-shifted on pack; recover signed [-8,7]
let scale = t_scales[(k / params.group_size) * params.padded_N + n];
acc = acc + t_input[in_base + k] * q * scale;
k = k + 1u;
let dq = q * t_scales[(k / params.group_size) * params.padded_N + n_eff];
for (var ml: u32 = 0u; ml < TM; ml = ml + 1u) {
acc[ml * TN + nl] = acc[ml * TN + nl] + in_reg[ml] * dq;
}
}
if (params.has_bias != 0u) {
acc = acc + t_bias[n];
k = k + 1u;
}

for (var ml: u32 = 0u; ml < TM; ml = ml + 1u) {
let m = m0 + ml;
for (var nl: u32 = 0u; nl < TN; nl = nl + 1u) {
let n = n0 + nl;
if (m < params.M && n < params.N) {
var v = acc[ml * TN + nl];
if (params.has_bias != 0u) {
v = v + t_bias[n];
}
t_out[m * params.N + n] = v;
}
}
t_out[m * params.N + n] = acc;
n = n + wg_size;
}
}
76 changes: 50 additions & 26 deletions backends/webgpu/runtime/ops/quantized_linear/q4gsw_linear_wgsl.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
namespace executorch::backends::webgpu {

// @generated from q4gsw_linear.wgsl - DO NOT EDIT.
// wgsl-sha256: 966cec5d4102eb7c8f6504d2a335a1bd2f235424933fe83b4d0f8f274d894f39
// wgsl-sha256: f8563b572d4f0ef9662f93f6e42f71e6126dddad0ffe6330cb5dc4d4259bb831
inline constexpr const char* kQ4gswLinearWGSL = R"(
@group(0) @binding(0) var<storage, read_write> t_out: array<f32>;
@group(0) @binding(1) var<storage, read> t_input: array<f32>;
Expand All @@ -35,48 +35,72 @@ struct Params {

override wg_size: u32 = 64u;

// One workgroup per row m, threads stride N; loop logical K only (in-bounds).
// Register-tiled GEMM: dequant weight once per (n,k), reused across TM rows.
const TM: u32 = 4u;
const TN: u32 = 4u;

@compute @workgroup_size(wg_size, 1, 1)
fn main(
@builtin(workgroup_id) wid: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>) {
let m = wid.x;
if (m >= params.M) {
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let nrt = (params.M + TM - 1u) / TM;
let nct = (params.N + TN - 1u) / TN;
let tiles = nrt * nct;
if (gid.x >= tiles) {
return;
}
let in_base = m * params.K;
let row_tile = gid.x / nct;
let col_tile = gid.x % nct;
let m0 = row_tile * TM;
let n0 = col_tile * TN;

var acc: array<f32, 16>; // TM * TN
for (var i: u32 = 0u; i < TM * TN; i = i + 1u) {
acc[i] = 0.0;
}

var n: u32 = lid.x;
var k: u32 = 0u;
loop {
if (n >= params.N) {
if (k >= params.K) {
break;
}
var acc: f32 = 0.0;
var k: u32 = 0u;
loop {
if (k >= params.K) {
break;
}
// Packed weight byte for (n, k): row stride K_packed bytes, byte k/2.
let byte_idx = n * params.K_packed + (k >> 1u);
// Load the TM input values for column k once; reused across all TN columns.
var in_reg: array<f32, TM>;
for (var ml: u32 = 0u; ml < TM; ml = ml + 1u) {
let m_eff = min(m0 + ml, params.M - 1u);
in_reg[ml] = t_input[m_eff * params.K + k];
}
for (var nl: u32 = 0u; nl < TN; nl = nl + 1u) {
// Clamp to last valid column; overhang result is never stored.
let n_eff = min(n0 + nl, params.N - 1u);
let byte_idx = n_eff * params.K_packed + (k >> 1u);
let word = t_weight[byte_idx >> 2u];
let b = (word >> ((byte_idx & 3u) * 8u)) & 0xFFu;
var nib: u32;
if ((k & 1u) == 0u) {
nib = b & 0x0Fu; // even k -> low nibble
nib = b & 0x0Fu; // even k -> low nibble
} else {
nib = (b >> 4u) & 0x0Fu; // odd k -> high nibble
}
let q = f32(i32(nib) - 8); // +8-shifted on pack; recover signed [-8,7]
let scale = t_scales[(k / params.group_size) * params.padded_N + n];
acc = acc + t_input[in_base + k] * q * scale;
k = k + 1u;
let dq = q * t_scales[(k / params.group_size) * params.padded_N + n_eff];
for (var ml: u32 = 0u; ml < TM; ml = ml + 1u) {
acc[ml * TN + nl] = acc[ml * TN + nl] + in_reg[ml] * dq;
}
}
if (params.has_bias != 0u) {
acc = acc + t_bias[n];
k = k + 1u;
}

for (var ml: u32 = 0u; ml < TM; ml = ml + 1u) {
let m = m0 + ml;
for (var nl: u32 = 0u; nl < TN; nl = nl + 1u) {
let n = n0 + nl;
if (m < params.M && n < params.N) {
var v = acc[ml * TN + nl];
if (params.has_bias != 0u) {
v = v + t_bias[n];
}
t_out[m * params.N + n] = v;
}
}
t_out[m * params.N + n] = acc;
n = n + wg_size;
}
}
)";
Expand Down
Loading