Skip to content

Commit dddb501

Browse files
[ExecuTorch][WebGPU] rms_norm: add a vec4 kernel for 4-aligned row widths
Pull Request resolved: #20458 Add optimized vec4 kernel for bandwidth-bound rms_norm on Llama decode. **Problem**: Scalar kernel loads one element per lane per iteration — bandwidth-limited on Llama decode. **Solution**: Add vec4 kernel that loads/stores four contiguous elements as `vec4<f32>` and squares them with `dot(v, v)`, cutting loop iterations 4× and widening memory transactions. **Routing Logic**: - Use vec4 when: row_width % 4 == 0 - Otherwise: Fall back to scalar kernel **Constraints**: - row_width % 4 == 0: vec4 kernel has no partial-texel tail handling - Llama models (all hidden sizes 4-aligned) satisfy constraint **Implementation**: - New kernel: rms_norm_vec4.wgsl (same 64-lane workgroup) - Shared infrastructure: Same bind layout, Params, dispatch - Numerical: Float reassociation differs, not bit-identical to scalar **Performance**: ~33% faster on Apple M4 Pro / Metal across benchmark shapes (largest on decode, smallest on long prefill where already bandwidth-bound). This change was authored with assistance from Claude. ghstack-source-id: 396677654 @exported-using-ghexport Differential Revision: [D109333390](https://our.internmc.facebook.com/intern/diff/D109333390/)
1 parent 14c5c3c commit dddb501

6 files changed

Lines changed: 187 additions & 3 deletions

File tree

backends/webgpu/runtime/ops/rms_norm/RmsNorm.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <executorch/backends/webgpu/runtime/WebGPUGraph.h>
1010
#include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h>
11+
#include <executorch/backends/webgpu/runtime/ops/rms_norm/rms_norm_vec4_wgsl.h>
1112
#include <executorch/backends/webgpu/runtime/ops/rms_norm/rms_norm_wgsl.h>
1213

1314
#include <webgpu/webgpu.h>
@@ -92,10 +93,17 @@ void rms_norm_impl(WebGPUGraph& graph, const std::vector<int>& args) {
9293

9394
graph.add_uniform_buffer_bytes(sizeof(RmsNormParams));
9495

96+
// Select the vec4 kernel when the row width is a multiple of 4 (every Llama
97+
// hidden size qualifies); fall back to the scalar kernel otherwise. The two
98+
// kernels are equivalent up to floating-point reassociation (the vec4
99+
// reduction reorders the sum, so not bit-identical) and share the same bind
100+
// group + dispatch.
101+
const bool use_vec4 = (row_width % 4u == 0u);
102+
95103
// Create shader module from built-in WGSL source
96104
WGPUShaderSourceWGSL wgsl_desc = {};
97105
wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL;
98-
wgsl_desc.code = {kRmsNormWGSL, WGPU_STRLEN};
106+
wgsl_desc.code = {use_vec4 ? kRmsNormVec4WGSL : kRmsNormWGSL, WGPU_STRLEN};
99107

100108
WGPUShaderModuleDescriptor shader_desc = {};
101109
shader_desc.nextInChain = &wgsl_desc.chain;
@@ -176,6 +184,9 @@ void rms_norm_impl(WebGPUGraph& graph, const std::vector<int>& args) {
176184
static_assert(
177185
kRmsNormWorkgroupSizeX == 64,
178186
"must match @workgroup_size and WG_SIZE in rms_norm.wgsl");
187+
static_assert(
188+
kRmsNormVec4WorkgroupSizeX == 64,
189+
"must match @workgroup_size and WG_SIZE in rms_norm_vec4.wgsl");
179190
graph.add_dispatch({pipeline, bind_group, num_rows});
180191

181192
// Release intermediate objects (pipeline and bind_group are kept by dispatch)
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
@group(0) @binding(0) var<storage, read_write> t_out: array<vec4<f32>>;
2+
@group(0) @binding(1) var<storage, read> t_in: array<vec4<f32>>;
3+
@group(0) @binding(2) var<storage, read> t_weight: array<vec4<f32>>;
4+
5+
struct Params {
6+
num_rows: u32,
7+
row_width: u32,
8+
epsilon: f32,
9+
_pad: u32,
10+
}
11+
@group(0) @binding(3) var<uniform> params: Params;
12+
13+
const WG_SIZE: u32 = 64u;
14+
15+
var<workgroup> shared_sum: array<f32, WG_SIZE>;
16+
17+
fn reduce_shared(worker_id: u32) {
18+
workgroupBarrier();
19+
var stride: u32 = WG_SIZE / 2u;
20+
loop {
21+
if (stride == 0u) {
22+
break;
23+
}
24+
if (worker_id < stride) {
25+
shared_sum[worker_id] = shared_sum[worker_id] + shared_sum[worker_id + stride];
26+
}
27+
workgroupBarrier();
28+
stride = stride >> 1u;
29+
}
30+
}
31+
32+
// vec4 variant of rms_norm: each lane strides by WG_SIZE over rw4 = row_width/4
33+
// texels and accumulates dot(v, v). row_width is the ELEMENT count, so mean_sq
34+
// divides by it (not rw4). The host selects this only when row_width % 4 == 0.
35+
@compute @workgroup_size(64, 1, 1)
36+
fn main(
37+
@builtin(workgroup_id) wid: vec3<u32>,
38+
@builtin(local_invocation_id) lid: vec3<u32>) {
39+
let row_idx = wid.x;
40+
let worker_id = lid.x;
41+
42+
if (row_idx >= params.num_rows) {
43+
return;
44+
}
45+
46+
let rw4 = params.row_width / 4u;
47+
let base4 = row_idx * rw4;
48+
49+
var local_sq_sum: f32 = 0.0;
50+
var x4: u32 = worker_id;
51+
loop {
52+
if (x4 >= rw4) {
53+
break;
54+
}
55+
let v = t_in[base4 + x4];
56+
local_sq_sum = local_sq_sum + dot(v, v);
57+
x4 = x4 + WG_SIZE;
58+
}
59+
60+
shared_sum[worker_id] = local_sq_sum;
61+
reduce_shared(worker_id);
62+
63+
let mean_sq = shared_sum[0] / f32(params.row_width);
64+
let rstd = inverseSqrt(mean_sq + params.epsilon);
65+
66+
x4 = worker_id;
67+
loop {
68+
if (x4 >= rw4) {
69+
break;
70+
}
71+
t_out[base4 + x4] = t_in[base4 + x4] * rstd * t_weight[x4];
72+
x4 = x4 + WG_SIZE;
73+
}
74+
}
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <cstdint>
12+
13+
namespace executorch::backends::webgpu {
14+
15+
// @generated from rms_norm_vec4.wgsl - DO NOT EDIT.
16+
// wgsl-sha256: 4c0ba56708bf125a7ec6ea3c51d1288e05ac00a8e2cfa10e38e9a208e230b8df
17+
inline constexpr const char* kRmsNormVec4WGSL = R"(
18+
@group(0) @binding(0) var<storage, read_write> t_out: array<vec4<f32>>;
19+
@group(0) @binding(1) var<storage, read> t_in: array<vec4<f32>>;
20+
@group(0) @binding(2) var<storage, read> t_weight: array<vec4<f32>>;
21+
22+
struct Params {
23+
num_rows: u32,
24+
row_width: u32,
25+
epsilon: f32,
26+
_pad: u32,
27+
}
28+
@group(0) @binding(3) var<uniform> params: Params;
29+
30+
const WG_SIZE: u32 = 64u;
31+
32+
var<workgroup> shared_sum: array<f32, WG_SIZE>;
33+
34+
fn reduce_shared(worker_id: u32) {
35+
workgroupBarrier();
36+
var stride: u32 = WG_SIZE / 2u;
37+
loop {
38+
if (stride == 0u) {
39+
break;
40+
}
41+
if (worker_id < stride) {
42+
shared_sum[worker_id] = shared_sum[worker_id] + shared_sum[worker_id + stride];
43+
}
44+
workgroupBarrier();
45+
stride = stride >> 1u;
46+
}
47+
}
48+
49+
// vec4 variant of rms_norm: each lane strides by WG_SIZE over rw4 = row_width/4
50+
// texels and accumulates dot(v, v). row_width is the ELEMENT count, so mean_sq
51+
// divides by it (not rw4). The host selects this only when row_width % 4 == 0.
52+
@compute @workgroup_size(64, 1, 1)
53+
fn main(
54+
@builtin(workgroup_id) wid: vec3<u32>,
55+
@builtin(local_invocation_id) lid: vec3<u32>) {
56+
let row_idx = wid.x;
57+
let worker_id = lid.x;
58+
59+
if (row_idx >= params.num_rows) {
60+
return;
61+
}
62+
63+
let rw4 = params.row_width / 4u;
64+
let base4 = row_idx * rw4;
65+
66+
var local_sq_sum: f32 = 0.0;
67+
var x4: u32 = worker_id;
68+
loop {
69+
if (x4 >= rw4) {
70+
break;
71+
}
72+
let v = t_in[base4 + x4];
73+
local_sq_sum = local_sq_sum + dot(v, v);
74+
x4 = x4 + WG_SIZE;
75+
}
76+
77+
shared_sum[worker_id] = local_sq_sum;
78+
reduce_shared(worker_id);
79+
80+
let mean_sq = shared_sum[0] / f32(params.row_width);
81+
let rstd = inverseSqrt(mean_sq + params.epsilon);
82+
83+
x4 = worker_id;
84+
loop {
85+
if (x4 >= rw4) {
86+
break;
87+
}
88+
t_out[base4 + x4] = t_in[base4 + x4] * rstd * t_weight[x4];
89+
x4 = x4 + WG_SIZE;
90+
}
91+
}
92+
)";
93+
94+
inline constexpr uint32_t kRmsNormVec4WorkgroupSizeX = 64;
95+
inline constexpr uint32_t kRmsNormVec4WorkgroupSizeY = 1;
96+
inline constexpr uint32_t kRmsNormVec4WorkgroupSizeZ = 1;
97+
98+
} // namespace executorch::backends::webgpu

backends/webgpu/test/op_tests/cases.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
RmsNormModule,
3838
)
3939

40-
# rms_norm coverage is exactly the 14 cases the native test covered.
40+
# rms_norm coverage is exactly the 15 cases the native test covered.
4141
RMS_NORM_CASES = _CASES
4242

4343

backends/webgpu/test/op_tests/test_schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def test_add_rms_norm_registered():
4141

4242
assert {"add", "rms_norm"} <= set(op_test_registry)
4343
assert len(op_test_registry["add"].cases) >= 3 # regular/self/scalar/chained
44-
# Exact parity, no hardcoded literal (real _CASES == 14; import so it can't drift):
44+
# Exact parity, no hardcoded literal (real _CASES == 15; import so it can't drift):
4545
assert len(op_test_registry["rms_norm"].cases) == len(cases.RMS_NORM_CASES)
4646
# weight is a construction param, NOT a forward input:
4747
rms0 = op_test_registry["rms_norm"].cases[0]

backends/webgpu/test/ops/rms_norm/test_rms_norm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def _weight_zeros_neg(hidden: int) -> torch.Tensor:
140140
{"name": "distinct_rows", "shape": (1, 1, 5, 256), "input_fn": _distinct_rows},
141141
{"name": "single_row", "shape": (1, 1, 1, 896)},
142142
{"name": "mixed_sign", "shape": (1, 1, 4, 128), "input_fn": _mixed_sign},
143+
{"name": "llama_hidden_2048", "shape": (1, 1, 1, 2048)},
143144
{"name": "large_4096", "shape": (1, 1, 1, 4096)},
144145
{"name": "large_8192", "shape": (1, 1, 1, 8192)},
145146
{

0 commit comments

Comments
 (0)