Skip to content

Commit c036150

Browse files
[ExecuTorch][WebGPU] Add rms_norm op
Pull Request resolved: pytorch#19963 Adds the `et_vk.rms_norm.default` operator to the WebGPU backend: a WGSL compute shader using a cooperative tree reduction, one workgroup per row. The shader mirrors the Vulkan implementation (`backends/vulkan/runtime/graph/ops/impl/RmsNorm.cpp`, `backends/vulkan/runtime/graph/ops/glsl/rms_norm_buffer.glsl`); indexing assumes contiguous fp32 inputs. The handler fails loud (throws, mirroring Vulkan's `VK_CHECK_COND`) on invalid shape/dtype/dispatch-limit conditions, and defaults `eps` to the float32 machine epsilon. The weight constant is uploaded via the named-data path added in the parent diff. ghstack-source-id: 389206169 @exported-using-ghexport Differential Revision: [D106887028](https://our.internmc.facebook.com/intern/diff/D106887028/)
1 parent 5dd66ad commit c036150

9 files changed

Lines changed: 796 additions & 7 deletions

File tree

backends/webgpu/CMakeLists.txt

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,13 @@ if(NOT TARGET vulkan_schema)
2626
endif()
2727

2828
set(WEBGPU_SRCS
29-
runtime/WebGPUBackend.cpp runtime/WebGPUGraph.cpp
30-
runtime/WebGPUDelegateHeader.cpp runtime/WebGPUDevice.cpp
31-
runtime/ops/OperatorRegistry.cpp runtime/ops/add/BinaryOp.cpp
29+
runtime/WebGPUBackend.cpp
30+
runtime/WebGPUGraph.cpp
31+
runtime/WebGPUDelegateHeader.cpp
32+
runtime/WebGPUDevice.cpp
33+
runtime/ops/OperatorRegistry.cpp
34+
runtime/ops/add/BinaryOp.cpp
35+
runtime/ops/rms_norm/RmsNorm.cpp
3236
)
3337

3438
add_library(webgpu_backend ${WEBGPU_SRCS})
@@ -116,4 +120,35 @@ if(EXECUTORCH_BUILD_WEBGPU_TEST)
116120

117121
target_compile_options(webgpu_native_test PRIVATE -fexceptions)
118122
set_property(TARGET webgpu_native_test PROPERTY CXX_STANDARD 17)
123+
124+
add_executable(webgpu_rms_norm_test test/native/test_rms_norm.cpp)
125+
126+
target_include_directories(
127+
webgpu_rms_norm_test PRIVATE $<BUILD_INTERFACE:${EXECUTORCH_ROOT}/..>
128+
"${WGPU_NATIVE_DIR}/include"
129+
)
130+
131+
target_link_libraries(
132+
webgpu_rms_norm_test
133+
PRIVATE webgpu_backend
134+
wgpu_native
135+
executorch_core
136+
extension_module_static
137+
extension_data_loader
138+
extension_tensor
139+
portable_kernels
140+
portable_ops_lib
141+
)
142+
143+
if(APPLE)
144+
target_link_libraries(
145+
webgpu_rms_norm_test PRIVATE "-framework Metal" "-framework QuartzCore"
146+
"-framework CoreGraphics"
147+
)
148+
else()
149+
target_link_libraries(webgpu_rms_norm_test PRIVATE dl m pthread)
150+
endif()
151+
152+
target_compile_options(webgpu_rms_norm_test PRIVATE -fexceptions)
153+
set_property(TARGET webgpu_rms_norm_test PROPERTY CXX_STANDARD 17)
119154
endif()
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/webgpu/runtime/WebGPUGraph.h>
10+
#include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h>
11+
#include <executorch/backends/webgpu/runtime/ops/rms_norm/rms_norm_wgsl.h>
12+
13+
#include <webgpu/webgpu.h>
14+
15+
#include <cstdint>
16+
#include <cstring>
17+
#include <limits>
18+
#include <stdexcept>
19+
20+
namespace executorch::backends::webgpu {
21+
22+
namespace {
23+
24+
// Uniform layout matching the WGSL Params struct (16-byte aligned).
25+
struct RmsNormParams {
26+
uint32_t num_rows;
27+
uint32_t row_width;
28+
float epsilon;
29+
uint32_t _pad;
30+
};
31+
static_assert(sizeof(RmsNormParams) == 16, "RmsNormParams must be 16 bytes");
32+
33+
void rms_norm_impl(WebGPUGraph& graph, const std::vector<int>& args) {
34+
// et_vk.rms_norm.default args: [in, weight, eps, out]
35+
const int in_id = args.at(0);
36+
const int weight_id = args.at(1);
37+
const int eps_id = args.at(2);
38+
const int out_id = args.at(3);
39+
40+
WGPUDevice device = graph.device();
41+
42+
// Get epsilon (Double from a Python float; defaults to float32 eps)
43+
float epsilon = std::numeric_limits<float>::epsilon();
44+
if (graph.get_value_type(eps_id) == WebGPUGraph::ValueType::Double) {
45+
epsilon = static_cast<float>(graph.get_double(eps_id));
46+
} else if (graph.get_value_type(eps_id) == WebGPUGraph::ValueType::Int) {
47+
epsilon = static_cast<float>(graph.get_int(eps_id));
48+
}
49+
50+
// row_width = last dim; num_rows = product of the rest (PyTorch NCHW order)
51+
const auto& in_tensor = graph.get_tensor(in_id);
52+
if (in_tensor.dims.empty() || in_tensor.nbytes == 0) {
53+
throw std::runtime_error("WebGPU rms_norm: empty input");
54+
}
55+
const uint32_t row_width = static_cast<uint32_t>(in_tensor.dims.back());
56+
if (row_width == 0) {
57+
throw std::runtime_error("WebGPU rms_norm: zero row width");
58+
}
59+
uint64_t in_numel = 1;
60+
for (int64_t d : in_tensor.dims) {
61+
in_numel *= static_cast<uint64_t>(d);
62+
}
63+
// fp32-only shader: bail if the bytes don't match an fp32 element count.
64+
if (in_tensor.nbytes != in_numel * sizeof(float)) {
65+
throw std::runtime_error("WebGPU rms_norm: fp32-only (byte-size mismatch)");
66+
}
67+
const uint32_t num_rows = static_cast<uint32_t>(in_numel / row_width);
68+
if (num_rows == 0) {
69+
throw std::runtime_error("WebGPU rms_norm: zero rows");
70+
}
71+
// Validate the 1D dispatch limit before allocating any GPU objects.
72+
if (num_rows > 65535u) {
73+
throw std::runtime_error(
74+
"WebGPU rms_norm: num_rows exceeds the 1D dispatch limit (65535)");
75+
}
76+
77+
// Create uniform buffer for params
78+
RmsNormParams params = {};
79+
params.num_rows = num_rows;
80+
params.row_width = row_width;
81+
params.epsilon = epsilon;
82+
83+
WGPUBufferDescriptor uniform_desc = {};
84+
uniform_desc.size = sizeof(RmsNormParams);
85+
uniform_desc.usage = WGPUBufferUsage_Uniform | WGPUBufferUsage_CopyDst;
86+
uniform_desc.mappedAtCreation = true;
87+
WGPUBuffer uniform_buffer = wgpuDeviceCreateBuffer(device, &uniform_desc);
88+
void* mapped =
89+
wgpuBufferGetMappedRange(uniform_buffer, 0, sizeof(RmsNormParams));
90+
std::memcpy(mapped, &params, sizeof(RmsNormParams));
91+
wgpuBufferUnmap(uniform_buffer);
92+
93+
graph.add_uniform_buffer_bytes(sizeof(RmsNormParams));
94+
95+
// Create shader module from built-in WGSL source
96+
WGPUShaderSourceWGSL wgsl_desc = {};
97+
wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL;
98+
wgsl_desc.code = {kRmsNormWGSL, WGPU_STRLEN};
99+
100+
WGPUShaderModuleDescriptor shader_desc = {};
101+
shader_desc.nextInChain = &wgsl_desc.chain;
102+
WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device, &shader_desc);
103+
104+
// Create bind group layout: out (rw) + in/weight (ro storage) + params
105+
WGPUBindGroupLayoutEntry entries[4] = {};
106+
107+
// t_out - storage buffer, read-write
108+
entries[0].binding = 0;
109+
entries[0].visibility = WGPUShaderStage_Compute;
110+
entries[0].buffer.type = WGPUBufferBindingType_Storage;
111+
112+
// t_in - storage buffer, read-only
113+
entries[1].binding = 1;
114+
entries[1].visibility = WGPUShaderStage_Compute;
115+
entries[1].buffer.type = WGPUBufferBindingType_ReadOnlyStorage;
116+
117+
// t_weight - storage buffer, read-only
118+
entries[2].binding = 2;
119+
entries[2].visibility = WGPUShaderStage_Compute;
120+
entries[2].buffer.type = WGPUBufferBindingType_ReadOnlyStorage;
121+
122+
// params - uniform buffer
123+
entries[3].binding = 3;
124+
entries[3].visibility = WGPUShaderStage_Compute;
125+
entries[3].buffer.type = WGPUBufferBindingType_Uniform;
126+
127+
WGPUBindGroupLayoutDescriptor bgl_desc = {};
128+
bgl_desc.entryCount = 4;
129+
bgl_desc.entries = entries;
130+
WGPUBindGroupLayout bgl = wgpuDeviceCreateBindGroupLayout(device, &bgl_desc);
131+
132+
// Create pipeline layout
133+
WGPUPipelineLayoutDescriptor pl_desc = {};
134+
pl_desc.bindGroupLayoutCount = 1;
135+
pl_desc.bindGroupLayouts = &bgl;
136+
WGPUPipelineLayout pipeline_layout =
137+
wgpuDeviceCreatePipelineLayout(device, &pl_desc);
138+
139+
// Create compute pipeline
140+
WGPUComputePipelineDescriptor pipeline_desc = {};
141+
pipeline_desc.layout = pipeline_layout;
142+
pipeline_desc.compute.module = shader;
143+
pipeline_desc.compute.entryPoint = {"main", WGPU_STRLEN};
144+
WGPUComputePipeline pipeline =
145+
wgpuDeviceCreateComputePipeline(device, &pipeline_desc);
146+
147+
// Create bind group with actual buffers
148+
const auto& out_tensor = graph.get_tensor(out_id);
149+
const auto& weight_tensor = graph.get_tensor(weight_id);
150+
151+
WGPUBindGroupEntry bg_entries[4] = {};
152+
153+
bg_entries[0].binding = 0;
154+
bg_entries[0].buffer = out_tensor.buffer;
155+
bg_entries[0].size = out_tensor.nbytes;
156+
157+
bg_entries[1].binding = 1;
158+
bg_entries[1].buffer = in_tensor.buffer;
159+
bg_entries[1].size = in_tensor.nbytes;
160+
161+
bg_entries[2].binding = 2;
162+
bg_entries[2].buffer = weight_tensor.buffer;
163+
bg_entries[2].size = weight_tensor.nbytes;
164+
165+
bg_entries[3].binding = 3;
166+
bg_entries[3].buffer = uniform_buffer;
167+
bg_entries[3].size = sizeof(RmsNormParams);
168+
169+
WGPUBindGroupDescriptor bg_desc = {};
170+
bg_desc.layout = bgl;
171+
bg_desc.entryCount = 4;
172+
bg_desc.entries = bg_entries;
173+
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);
174+
175+
// One workgroup per row (kRmsNormWorkgroupSize threads cooperate per row)
176+
static_assert(
177+
kRmsNormWorkgroupSize == 64,
178+
"must match @workgroup_size and WG_SIZE in rms_norm.wgsl");
179+
graph.add_dispatch({pipeline, bind_group, num_rows});
180+
181+
// Release intermediate objects (pipeline and bind_group are kept by dispatch)
182+
wgpuShaderModuleRelease(shader);
183+
wgpuBindGroupLayoutRelease(bgl);
184+
wgpuPipelineLayoutRelease(pipeline_layout);
185+
// Drop our ref; the bind group keeps the uniform buffer alive until release.
186+
wgpuBufferRelease(uniform_buffer);
187+
}
188+
189+
} // namespace
190+
191+
WEBGPU_REGISTER_OPERATORS {
192+
WEBGPU_REGISTER_OP(et_vk.rms_norm.default, rms_norm_impl);
193+
}
194+
195+
} // namespace executorch::backends::webgpu
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
// NOTE: This file is for editor/tooling support only. The runtime consumes the
2+
// inline copy of this shader in `rms_norm_wgsl.h` (kRmsNormWGSL). Keep the two
3+
// in sync by hand — any edit here must be mirrored there.
4+
@group(0) @binding(0) var<storage, read_write> t_out: array<f32>;
5+
@group(0) @binding(1) var<storage, read> t_in: array<f32>;
6+
@group(0) @binding(2) var<storage, read> t_weight: array<f32>;
7+
8+
struct Params {
9+
num_rows: u32,
10+
row_width: u32,
11+
epsilon: f32,
12+
_pad: u32,
13+
}
14+
@group(0) @binding(3) var<uniform> params: Params;
15+
16+
const WG_SIZE: u32 = 64u;
17+
18+
var<workgroup> shared_sum: array<f32, WG_SIZE>;
19+
20+
fn reduce_shared(worker_id: u32) {
21+
workgroupBarrier();
22+
var stride: u32 = WG_SIZE / 2u;
23+
loop {
24+
if (stride == 0u) {
25+
break;
26+
}
27+
if (worker_id < stride) {
28+
shared_sum[worker_id] = shared_sum[worker_id] + shared_sum[worker_id + stride];
29+
}
30+
workgroupBarrier();
31+
stride = stride >> 1u;
32+
}
33+
}
34+
35+
@compute @workgroup_size(64, 1, 1)
36+
fn main(
37+
@builtin(workgroup_id) wid: vec3<u32>,
38+
@builtin(local_invocation_id) lid: vec3<u32>) {
39+
let row_idx = wid.x;
40+
let worker_id = lid.x;
41+
42+
if (row_idx >= params.num_rows) {
43+
return;
44+
}
45+
46+
let base = row_idx * params.row_width;
47+
48+
var local_sq_sum: f32 = 0.0;
49+
var x: u32 = worker_id;
50+
loop {
51+
if (x >= params.row_width) {
52+
break;
53+
}
54+
let v = t_in[base + x];
55+
local_sq_sum = local_sq_sum + v * v;
56+
x = x + WG_SIZE;
57+
}
58+
59+
shared_sum[worker_id] = local_sq_sum;
60+
reduce_shared(worker_id);
61+
62+
let mean_sq = shared_sum[0] / f32(params.row_width);
63+
let rstd = inverseSqrt(mean_sq + params.epsilon);
64+
65+
x = worker_id;
66+
loop {
67+
if (x >= params.row_width) {
68+
break;
69+
}
70+
let v = t_in[base + x];
71+
let w = t_weight[x];
72+
t_out[base + x] = v * rstd * w;
73+
x = x + WG_SIZE;
74+
}
75+
}

0 commit comments

Comments
 (0)