Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion backends/webgpu/runtime/ops/rms_norm/RmsNorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <executorch/backends/webgpu/runtime/WebGPUGraph.h>
#include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h>
#include <executorch/backends/webgpu/runtime/ops/rms_norm/rms_norm_vec4_wgsl.h>
#include <executorch/backends/webgpu/runtime/ops/rms_norm/rms_norm_wgsl.h>

#include <webgpu/webgpu.h>
Expand Down Expand Up @@ -92,10 +93,17 @@ void rms_norm_impl(WebGPUGraph& graph, const std::vector<int>& args) {

graph.add_uniform_buffer_bytes(sizeof(RmsNormParams));

// Select the vec4 kernel when the row width is a multiple of 4 (every Llama
// hidden size qualifies); fall back to the scalar kernel otherwise. The two
// kernels are equivalent up to floating-point reassociation (the vec4
// reduction reorders the sum, so not bit-identical) and share the same bind
// group + dispatch.
const bool use_vec4 = (row_width % 4u == 0u);

// Create shader module from built-in WGSL source
WGPUShaderSourceWGSL wgsl_desc = {};
wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL;
wgsl_desc.code = {kRmsNormWGSL, WGPU_STRLEN};
wgsl_desc.code = {use_vec4 ? kRmsNormVec4WGSL : kRmsNormWGSL, WGPU_STRLEN};

WGPUShaderModuleDescriptor shader_desc = {};
shader_desc.nextInChain = &wgsl_desc.chain;
Expand Down Expand Up @@ -176,6 +184,9 @@ void rms_norm_impl(WebGPUGraph& graph, const std::vector<int>& args) {
static_assert(
kRmsNormWorkgroupSizeX == 64,
"must match @workgroup_size and WG_SIZE in rms_norm.wgsl");
static_assert(
kRmsNormVec4WorkgroupSizeX == 64,
"must match @workgroup_size and WG_SIZE in rms_norm_vec4.wgsl");
graph.add_dispatch({pipeline, bind_group, num_rows});

// Release intermediate objects (pipeline and bind_group are kept by dispatch)
Expand Down
74 changes: 74 additions & 0 deletions backends/webgpu/runtime/ops/rms_norm/rms_norm_vec4.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
@group(0) @binding(0) var<storage, read_write> t_out: array<vec4<f32>>;
@group(0) @binding(1) var<storage, read> t_in: array<vec4<f32>>;
@group(0) @binding(2) var<storage, read> t_weight: array<vec4<f32>>;

struct Params {
num_rows: u32,
row_width: u32,
epsilon: f32,
_pad: u32,
}
@group(0) @binding(3) var<uniform> params: Params;

const WG_SIZE: u32 = 64u;

var<workgroup> shared_sum: array<f32, WG_SIZE>;

fn reduce_shared(worker_id: u32) {
workgroupBarrier();
var stride: u32 = WG_SIZE / 2u;
loop {
if (stride == 0u) {
break;
}
if (worker_id < stride) {
shared_sum[worker_id] = shared_sum[worker_id] + shared_sum[worker_id + stride];
}
workgroupBarrier();
stride = stride >> 1u;
}
}

// vec4 variant of rms_norm: each lane strides by WG_SIZE over rw4 = row_width/4
// texels and accumulates dot(v, v). row_width is the ELEMENT count, so mean_sq
// divides by it (not rw4). The host selects this only when row_width % 4 == 0.
@compute @workgroup_size(64, 1, 1)
fn main(
@builtin(workgroup_id) wid: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>) {
let row_idx = wid.x;
let worker_id = lid.x;

if (row_idx >= params.num_rows) {
return;
}

let rw4 = params.row_width / 4u;
let base4 = row_idx * rw4;

var local_sq_sum: f32 = 0.0;
var x4: u32 = worker_id;
loop {
if (x4 >= rw4) {
break;
}
let v = t_in[base4 + x4];
local_sq_sum = local_sq_sum + dot(v, v);
x4 = x4 + WG_SIZE;
}

shared_sum[worker_id] = local_sq_sum;
reduce_shared(worker_id);

let mean_sq = shared_sum[0] / f32(params.row_width);
let rstd = inverseSqrt(mean_sq + params.epsilon);

x4 = worker_id;
loop {
if (x4 >= rw4) {
break;
}
t_out[base4 + x4] = t_in[base4 + x4] * rstd * t_weight[x4];
x4 = x4 + WG_SIZE;
}
}
98 changes: 98 additions & 0 deletions backends/webgpu/runtime/ops/rms_norm/rms_norm_vec4_wgsl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <cstdint>

namespace executorch::backends::webgpu {

// @generated from rms_norm_vec4.wgsl - DO NOT EDIT.
// wgsl-sha256: 4c0ba56708bf125a7ec6ea3c51d1288e05ac00a8e2cfa10e38e9a208e230b8df
inline constexpr const char* kRmsNormVec4WGSL = R"(
@group(0) @binding(0) var<storage, read_write> t_out: array<vec4<f32>>;
@group(0) @binding(1) var<storage, read> t_in: array<vec4<f32>>;
@group(0) @binding(2) var<storage, read> t_weight: array<vec4<f32>>;

struct Params {
num_rows: u32,
row_width: u32,
epsilon: f32,
_pad: u32,
}
@group(0) @binding(3) var<uniform> params: Params;

const WG_SIZE: u32 = 64u;

var<workgroup> shared_sum: array<f32, WG_SIZE>;

fn reduce_shared(worker_id: u32) {
workgroupBarrier();
var stride: u32 = WG_SIZE / 2u;
loop {
if (stride == 0u) {
break;
}
if (worker_id < stride) {
shared_sum[worker_id] = shared_sum[worker_id] + shared_sum[worker_id + stride];
}
workgroupBarrier();
stride = stride >> 1u;
}
}

// vec4 variant of rms_norm: each lane strides by WG_SIZE over rw4 = row_width/4
// texels and accumulates dot(v, v). row_width is the ELEMENT count, so mean_sq
// divides by it (not rw4). The host selects this only when row_width % 4 == 0.
@compute @workgroup_size(64, 1, 1)
fn main(
@builtin(workgroup_id) wid: vec3<u32>,
@builtin(local_invocation_id) lid: vec3<u32>) {
let row_idx = wid.x;
let worker_id = lid.x;

if (row_idx >= params.num_rows) {
return;
}

let rw4 = params.row_width / 4u;
let base4 = row_idx * rw4;

var local_sq_sum: f32 = 0.0;
var x4: u32 = worker_id;
loop {
if (x4 >= rw4) {
break;
}
let v = t_in[base4 + x4];
local_sq_sum = local_sq_sum + dot(v, v);
x4 = x4 + WG_SIZE;
}

shared_sum[worker_id] = local_sq_sum;
reduce_shared(worker_id);

let mean_sq = shared_sum[0] / f32(params.row_width);
let rstd = inverseSqrt(mean_sq + params.epsilon);

x4 = worker_id;
loop {
if (x4 >= rw4) {
break;
}
t_out[base4 + x4] = t_in[base4 + x4] * rstd * t_weight[x4];
x4 = x4 + WG_SIZE;
}
}
)";

inline constexpr uint32_t kRmsNormVec4WorkgroupSizeX = 64;
inline constexpr uint32_t kRmsNormVec4WorkgroupSizeY = 1;
inline constexpr uint32_t kRmsNormVec4WorkgroupSizeZ = 1;

} // namespace executorch::backends::webgpu
2 changes: 1 addition & 1 deletion backends/webgpu/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
RmsNormModule,
)

# rms_norm coverage is exactly the 14 cases the native test covered.
# rms_norm coverage is exactly the 15 cases the native test covered.
RMS_NORM_CASES = _CASES


Expand Down
2 changes: 1 addition & 1 deletion backends/webgpu/test/op_tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_add_rms_norm_registered():

assert {"add", "rms_norm"} <= set(op_test_registry)
assert len(op_test_registry["add"].cases) >= 3 # regular/self/scalar/chained
# Exact parity, no hardcoded literal (real _CASES == 14; import so it can't drift):
# Exact parity, no hardcoded literal (real _CASES == 15; import so it can't drift):
assert len(op_test_registry["rms_norm"].cases) == len(cases.RMS_NORM_CASES)
# weight is a construction param, NOT a forward input:
rms0 = op_test_registry["rms_norm"].cases[0]
Expand Down
1 change: 1 addition & 0 deletions backends/webgpu/test/ops/rms_norm/test_rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def _weight_zeros_neg(hidden: int) -> torch.Tensor:
{"name": "distinct_rows", "shape": (1, 1, 5, 256), "input_fn": _distinct_rows},
{"name": "single_row", "shape": (1, 1, 1, 896)},
{"name": "mixed_sign", "shape": (1, 1, 4, 128), "input_fn": _mixed_sign},
{"name": "llama_hidden_2048", "shape": (1, 1, 1, 2048)},
{"name": "large_4096", "shape": (1, 1, 1, 4096)},
{"name": "large_8192", "shape": (1, 1, 1, 8192)},
{
Expand Down
Loading