diff --git a/backends/webgpu/runtime/ops/rms_norm/RmsNorm.cpp b/backends/webgpu/runtime/ops/rms_norm/RmsNorm.cpp index 7de83330810..e73c6e23a88 100644 --- a/backends/webgpu/runtime/ops/rms_norm/RmsNorm.cpp +++ b/backends/webgpu/runtime/ops/rms_norm/RmsNorm.cpp @@ -8,6 +8,7 @@ #include #include +#include #include #include @@ -92,10 +93,17 @@ void rms_norm_impl(WebGPUGraph& graph, const std::vector& args) { graph.add_uniform_buffer_bytes(sizeof(RmsNormParams)); + // Select the vec4 kernel when the row width is a multiple of 4 (every Llama + // hidden size qualifies); fall back to the scalar kernel otherwise. The two + // kernels are equivalent up to floating-point reassociation (the vec4 + // reduction reorders the sum, so not bit-identical) and share the same bind + // group + dispatch. + const bool use_vec4 = (row_width % 4u == 0u); + // Create shader module from built-in WGSL source WGPUShaderSourceWGSL wgsl_desc = {}; wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL; - wgsl_desc.code = {kRmsNormWGSL, WGPU_STRLEN}; + wgsl_desc.code = {use_vec4 ? kRmsNormVec4WGSL : kRmsNormWGSL, WGPU_STRLEN}; WGPUShaderModuleDescriptor shader_desc = {}; shader_desc.nextInChain = &wgsl_desc.chain; @@ -176,6 +184,9 @@ void rms_norm_impl(WebGPUGraph& graph, const std::vector& args) { static_assert( kRmsNormWorkgroupSizeX == 64, "must match @workgroup_size and WG_SIZE in rms_norm.wgsl"); + static_assert( + kRmsNormVec4WorkgroupSizeX == 64, + "must match @workgroup_size and WG_SIZE in rms_norm_vec4.wgsl"); graph.add_dispatch({pipeline, bind_group, num_rows}); // Release intermediate objects (pipeline and bind_group are kept by dispatch) diff --git a/backends/webgpu/runtime/ops/rms_norm/rms_norm_vec4.wgsl b/backends/webgpu/runtime/ops/rms_norm/rms_norm_vec4.wgsl new file mode 100644 index 00000000000..c2f731e5f60 --- /dev/null +++ b/backends/webgpu/runtime/ops/rms_norm/rms_norm_vec4.wgsl @@ -0,0 +1,74 @@ +@group(0) @binding(0) var t_out: array>; +@group(0) @binding(1) var t_in: array>; +@group(0) @binding(2) var t_weight: array>; + +struct Params { + num_rows: u32, + row_width: u32, + epsilon: f32, + _pad: u32, +} +@group(0) @binding(3) var params: Params; + +const WG_SIZE: u32 = 64u; + +var shared_sum: array; + +fn reduce_shared(worker_id: u32) { + workgroupBarrier(); + var stride: u32 = WG_SIZE / 2u; + loop { + if (stride == 0u) { + break; + } + if (worker_id < stride) { + shared_sum[worker_id] = shared_sum[worker_id] + shared_sum[worker_id + stride]; + } + workgroupBarrier(); + stride = stride >> 1u; + } +} + +// vec4 variant of rms_norm: each lane strides by WG_SIZE over rw4 = row_width/4 +// texels and accumulates dot(v, v). row_width is the ELEMENT count, so mean_sq +// divides by it (not rw4). The host selects this only when row_width % 4 == 0. +@compute @workgroup_size(64, 1, 1) +fn main( + @builtin(workgroup_id) wid: vec3, + @builtin(local_invocation_id) lid: vec3) { + let row_idx = wid.x; + let worker_id = lid.x; + + if (row_idx >= params.num_rows) { + return; + } + + let rw4 = params.row_width / 4u; + let base4 = row_idx * rw4; + + var local_sq_sum: f32 = 0.0; + var x4: u32 = worker_id; + loop { + if (x4 >= rw4) { + break; + } + let v = t_in[base4 + x4]; + local_sq_sum = local_sq_sum + dot(v, v); + x4 = x4 + WG_SIZE; + } + + shared_sum[worker_id] = local_sq_sum; + reduce_shared(worker_id); + + let mean_sq = shared_sum[0] / f32(params.row_width); + let rstd = inverseSqrt(mean_sq + params.epsilon); + + x4 = worker_id; + loop { + if (x4 >= rw4) { + break; + } + t_out[base4 + x4] = t_in[base4 + x4] * rstd * t_weight[x4]; + x4 = x4 + WG_SIZE; + } +} diff --git a/backends/webgpu/runtime/ops/rms_norm/rms_norm_vec4_wgsl.h b/backends/webgpu/runtime/ops/rms_norm/rms_norm_vec4_wgsl.h new file mode 100644 index 00000000000..633bf3adfc0 --- /dev/null +++ b/backends/webgpu/runtime/ops/rms_norm/rms_norm_vec4_wgsl.h @@ -0,0 +1,98 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace executorch::backends::webgpu { + +// @generated from rms_norm_vec4.wgsl - DO NOT EDIT. +// wgsl-sha256: 4c0ba56708bf125a7ec6ea3c51d1288e05ac00a8e2cfa10e38e9a208e230b8df +inline constexpr const char* kRmsNormVec4WGSL = R"( +@group(0) @binding(0) var t_out: array>; +@group(0) @binding(1) var t_in: array>; +@group(0) @binding(2) var t_weight: array>; + +struct Params { + num_rows: u32, + row_width: u32, + epsilon: f32, + _pad: u32, +} +@group(0) @binding(3) var params: Params; + +const WG_SIZE: u32 = 64u; + +var shared_sum: array; + +fn reduce_shared(worker_id: u32) { + workgroupBarrier(); + var stride: u32 = WG_SIZE / 2u; + loop { + if (stride == 0u) { + break; + } + if (worker_id < stride) { + shared_sum[worker_id] = shared_sum[worker_id] + shared_sum[worker_id + stride]; + } + workgroupBarrier(); + stride = stride >> 1u; + } +} + +// vec4 variant of rms_norm: each lane strides by WG_SIZE over rw4 = row_width/4 +// texels and accumulates dot(v, v). row_width is the ELEMENT count, so mean_sq +// divides by it (not rw4). The host selects this only when row_width % 4 == 0. +@compute @workgroup_size(64, 1, 1) +fn main( + @builtin(workgroup_id) wid: vec3, + @builtin(local_invocation_id) lid: vec3) { + let row_idx = wid.x; + let worker_id = lid.x; + + if (row_idx >= params.num_rows) { + return; + } + + let rw4 = params.row_width / 4u; + let base4 = row_idx * rw4; + + var local_sq_sum: f32 = 0.0; + var x4: u32 = worker_id; + loop { + if (x4 >= rw4) { + break; + } + let v = t_in[base4 + x4]; + local_sq_sum = local_sq_sum + dot(v, v); + x4 = x4 + WG_SIZE; + } + + shared_sum[worker_id] = local_sq_sum; + reduce_shared(worker_id); + + let mean_sq = shared_sum[0] / f32(params.row_width); + let rstd = inverseSqrt(mean_sq + params.epsilon); + + x4 = worker_id; + loop { + if (x4 >= rw4) { + break; + } + t_out[base4 + x4] = t_in[base4 + x4] * rstd * t_weight[x4]; + x4 = x4 + WG_SIZE; + } +} +)"; + +inline constexpr uint32_t kRmsNormVec4WorkgroupSizeX = 64; +inline constexpr uint32_t kRmsNormVec4WorkgroupSizeY = 1; +inline constexpr uint32_t kRmsNormVec4WorkgroupSizeZ = 1; + +} // namespace executorch::backends::webgpu diff --git a/backends/webgpu/test/op_tests/cases.py b/backends/webgpu/test/op_tests/cases.py index 428c94d3066..febdbd507a8 100644 --- a/backends/webgpu/test/op_tests/cases.py +++ b/backends/webgpu/test/op_tests/cases.py @@ -37,7 +37,7 @@ RmsNormModule, ) -# rms_norm coverage is exactly the 14 cases the native test covered. +# rms_norm coverage is exactly the 15 cases the native test covered. RMS_NORM_CASES = _CASES diff --git a/backends/webgpu/test/op_tests/test_schema.py b/backends/webgpu/test/op_tests/test_schema.py index 9e62c9558cf..bcc03a40fd9 100644 --- a/backends/webgpu/test/op_tests/test_schema.py +++ b/backends/webgpu/test/op_tests/test_schema.py @@ -41,7 +41,7 @@ def test_add_rms_norm_registered(): assert {"add", "rms_norm"} <= set(op_test_registry) assert len(op_test_registry["add"].cases) >= 3 # regular/self/scalar/chained - # Exact parity, no hardcoded literal (real _CASES == 14; import so it can't drift): + # Exact parity, no hardcoded literal (real _CASES == 15; import so it can't drift): assert len(op_test_registry["rms_norm"].cases) == len(cases.RMS_NORM_CASES) # weight is a construction param, NOT a forward input: rms0 = op_test_registry["rms_norm"].cases[0] diff --git a/backends/webgpu/test/ops/rms_norm/test_rms_norm.py b/backends/webgpu/test/ops/rms_norm/test_rms_norm.py index d4f88de672a..57679d6d097 100644 --- a/backends/webgpu/test/ops/rms_norm/test_rms_norm.py +++ b/backends/webgpu/test/ops/rms_norm/test_rms_norm.py @@ -140,6 +140,7 @@ def _weight_zeros_neg(hidden: int) -> torch.Tensor: {"name": "distinct_rows", "shape": (1, 1, 5, 256), "input_fn": _distinct_rows}, {"name": "single_row", "shape": (1, 1, 1, 896)}, {"name": "mixed_sign", "shape": (1, 1, 4, 128), "input_fn": _mixed_sign}, + {"name": "llama_hidden_2048", "shape": (1, 1, 1, 2048)}, {"name": "large_4096", "shape": (1, 1, 1, 4096)}, {"name": "large_8192", "shape": (1, 1, 1, 8192)}, {