|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +#pragma once |
| 10 | + |
| 11 | +#include <cstdint> |
| 12 | + |
| 13 | +namespace executorch::backends::webgpu { |
| 14 | + |
| 15 | +// @generated from rms_norm_vec4.wgsl - DO NOT EDIT. |
| 16 | +// wgsl-sha256: 4c0ba56708bf125a7ec6ea3c51d1288e05ac00a8e2cfa10e38e9a208e230b8df |
| 17 | +inline constexpr const char* kRmsNormVec4WGSL = R"( |
| 18 | +@group(0) @binding(0) var<storage, read_write> t_out: array<vec4<f32>>; |
| 19 | +@group(0) @binding(1) var<storage, read> t_in: array<vec4<f32>>; |
| 20 | +@group(0) @binding(2) var<storage, read> t_weight: array<vec4<f32>>; |
| 21 | +
|
| 22 | +struct Params { |
| 23 | + num_rows: u32, |
| 24 | + row_width: u32, |
| 25 | + epsilon: f32, |
| 26 | + _pad: u32, |
| 27 | +} |
| 28 | +@group(0) @binding(3) var<uniform> params: Params; |
| 29 | +
|
| 30 | +const WG_SIZE: u32 = 64u; |
| 31 | +
|
| 32 | +var<workgroup> shared_sum: array<f32, WG_SIZE>; |
| 33 | +
|
| 34 | +fn reduce_shared(worker_id: u32) { |
| 35 | + workgroupBarrier(); |
| 36 | + var stride: u32 = WG_SIZE / 2u; |
| 37 | + loop { |
| 38 | + if (stride == 0u) { |
| 39 | + break; |
| 40 | + } |
| 41 | + if (worker_id < stride) { |
| 42 | + shared_sum[worker_id] = shared_sum[worker_id] + shared_sum[worker_id + stride]; |
| 43 | + } |
| 44 | + workgroupBarrier(); |
| 45 | + stride = stride >> 1u; |
| 46 | + } |
| 47 | +} |
| 48 | +
|
| 49 | +// vec4 variant of rms_norm: each lane strides by WG_SIZE over rw4 = row_width/4 |
| 50 | +// texels and accumulates dot(v, v). row_width is the ELEMENT count, so mean_sq |
| 51 | +// divides by it (not rw4). The host selects this only when row_width % 4 == 0. |
| 52 | +@compute @workgroup_size(64, 1, 1) |
| 53 | +fn main( |
| 54 | + @builtin(workgroup_id) wid: vec3<u32>, |
| 55 | + @builtin(local_invocation_id) lid: vec3<u32>) { |
| 56 | + let row_idx = wid.x; |
| 57 | + let worker_id = lid.x; |
| 58 | +
|
| 59 | + if (row_idx >= params.num_rows) { |
| 60 | + return; |
| 61 | + } |
| 62 | +
|
| 63 | + let rw4 = params.row_width / 4u; |
| 64 | + let base4 = row_idx * rw4; |
| 65 | +
|
| 66 | + var local_sq_sum: f32 = 0.0; |
| 67 | + var x4: u32 = worker_id; |
| 68 | + loop { |
| 69 | + if (x4 >= rw4) { |
| 70 | + break; |
| 71 | + } |
| 72 | + let v = t_in[base4 + x4]; |
| 73 | + local_sq_sum = local_sq_sum + dot(v, v); |
| 74 | + x4 = x4 + WG_SIZE; |
| 75 | + } |
| 76 | +
|
| 77 | + shared_sum[worker_id] = local_sq_sum; |
| 78 | + reduce_shared(worker_id); |
| 79 | +
|
| 80 | + let mean_sq = shared_sum[0] / f32(params.row_width); |
| 81 | + let rstd = inverseSqrt(mean_sq + params.epsilon); |
| 82 | +
|
| 83 | + x4 = worker_id; |
| 84 | + loop { |
| 85 | + if (x4 >= rw4) { |
| 86 | + break; |
| 87 | + } |
| 88 | + t_out[base4 + x4] = t_in[base4 + x4] * rstd * t_weight[x4]; |
| 89 | + x4 = x4 + WG_SIZE; |
| 90 | + } |
| 91 | +} |
| 92 | +)"; |
| 93 | + |
| 94 | +inline constexpr uint32_t kRmsNormVec4WorkgroupSizeX = 64; |
| 95 | +inline constexpr uint32_t kRmsNormVec4WorkgroupSizeY = 1; |
| 96 | +inline constexpr uint32_t kRmsNormVec4WorkgroupSizeZ = 1; |
| 97 | + |
| 98 | +} // namespace executorch::backends::webgpu |
0 commit comments