Skip to content

Commit 134050b

Browse files
committed
Add WebGPU sigmoid operator
1 parent e3d5de2 commit 134050b

14 files changed

Lines changed: 360 additions & 19 deletions

File tree

backends/webgpu/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ set(WEBGPU_SRCS
3939
runtime/ops/select_as_symint/SelectAsSymint.cpp
4040
runtime/ops/quantized_linear/QuantizedLinear.cpp
4141
runtime/ops/mul/BinaryOp.cpp
42+
runtime/ops/sigmoid/Sigmoid.cpp
4243
runtime/ops/embedding_q4gsw/EmbeddingQ4gsw.cpp
4344
runtime/ops/rope/RotaryEmbedding.cpp
4445
runtime/ops/prepack/Prepack.cpp

backends/webgpu/README.md

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
Run ExecuTorch models on the GPU via [WebGPU](https://www.w3.org/TR/webgpu/). The backend compiles delegated subgraphs into WGSL compute shaders executed natively through [Dawn](https://dawn.googlesource.com/dawn), whose Tint compiler is the reference WGSL implementation (Metal on macOS, Vulkan on Linux/Windows).
44

5-
> **Status: Prototype, under active development.** The backend runs the core of transformer inference today — `add`, `rms_norm`, fused scaled-dot-product attention with KV cache, and 4-bit weight-only quantized linear — plus quantized embedding, rotary embedding, and constant prepacking. See [Progress](#progress) for shipped milestones.
5+
> **Status: Prototype, under active development.** The backend runs the core of transformer inference today — `add`, `mul`, `sigmoid`, `rms_norm`, fused scaled-dot-product attention with KV cache, and 4-bit weight-only quantized linear — plus quantized embedding, rotary embedding, and constant prepacking. See [Progress](#progress) for shipped milestones.
66
77
## Progress
88

@@ -20,14 +20,7 @@ Milestones landed on `main`:
2020
| 2026-06 | Added the attention core of transformer inference — fused scaled-dot-product attention (`sdpa_with_kv_cache`) with an `update_cache` operator for autoregressive decode | [#20086](https://github.com/pytorch/executorch/pull/20086), [#20087](https://github.com/pytorch/executorch/pull/20087) |
2121
| 2026-06 | Added on-GPU kernel timing via WebGPU timestamp queries, for true GPU-side profiling | [#20201](https://github.com/pytorch/executorch/pull/20201) |
2222
| 2026-06 | Added the dominant compute in quantized LLMs — 4-bit weight-only quantized linear (`linear_q4gsw`), a dequantize-and-matmul kernel | [#20226](https://github.com/pytorch/executorch/pull/20226), [#20227](https://github.com/pytorch/executorch/pull/20227) |
23-
24-
In review:
25-
26-
| Milestone | Pull Request |
27-
|---|---|
28-
| Adds 4-bit quantized embedding (`embedding_q4gsw`) | [#20263](https://github.com/pytorch/executorch/pull/20263) |
29-
| Adds rotary position embedding / RoPE (`apply_rotary_emb`) | [#20264](https://github.com/pytorch/executorch/pull/20264) |
30-
| Adds constant prepacking (`prepack`) for end-to-end model weight handling | [#20265](https://github.com/pytorch/executorch/pull/20265) |
23+
| 2026-06 | Added token embedding, rotary position embedding, and constant prepacking for end-to-end model weight handling | [#20414](https://github.com/pytorch/executorch/pull/20414) |
3124

3225
## Architecture
3326

@@ -61,14 +54,17 @@ Key design choices:
6154
| Operator | WGSL Shader | Notes |
6255
|---|---|---|
6356
| `aten.add.Tensor` | `binary_add.wgsl` | Element-wise with alpha: `out = in1 + alpha * in2` |
57+
| `aten.mul.Tensor` | `binary_mul.wgsl` | Element-wise multiply with broadcasting |
58+
| `aten.sigmoid.default` | `sigmoid.wgsl` | Element-wise sigmoid activation |
6459
| `et_vk.rms_norm.default` | `rms_norm.wgsl` | Root-mean-square normalization |
6560
| `sdpa_with_kv_cache.default` | `sdpa_compute_attn_weights.wgsl`, `sdpa_softmax.wgsl`, `sdpa_compute_out.wgsl` | Fused scaled-dot-product attention (QK / softmax / AV) with KV cache |
6661
| `llama.update_cache.default` | `update_cache.wgsl` | In-place KV cache update for autoregressive decode |
6762
| `et_vk.linear_q4gsw.default` | `q4gsw_linear.wgsl` | 4-bit weight-only quantized linear (dequantize + matmul) |
63+
| `et_vk.embedding_q4gsw.default` | `embedding_q4gsw.wgsl` | 4-bit groupwise-symmetric quantized embedding |
64+
| `et_vk.apply_rotary_emb.default` | `rotary_embedding.wgsl` | Interleaved rotary positional embedding |
65+
| `et_vk.prepack.default` | N/A | Constant materialization into GPU buffers |
6866

69-
**In review:** quantized embedding (`embedding_q4gsw`), rotary embedding (`apply_rotary_emb`), and constant prepacking (`prepack`).
70-
71-
**Planned:** `mul`, `sigmoid`, shape ops (`view`, `permute`, `slice`, `select`, `cat`, `squeeze`/`unsqueeze`), and `index` — the remaining ops needed for end-to-end Llama 3.2 1B.
67+
**Planned:** shape ops (`view`, `permute`, `slice`, `select`, `cat`, `squeeze`/`unsqueeze`) and `index` — the remaining ops needed for end-to-end Llama 3.2 1B.
7268

7369
## Quick Start
7470

backends/webgpu/TODO.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
# WebGPU Backend — TODO
22

33
## Current State (Prototype)
4-
- Single op: `aten.add.Tensor` (fp32, buffer storage)
4+
- Runtime support for transformer-oriented fp32 and LLM custom ops, including
5+
`aten.add.Tensor`, `aten.mul.Tensor`, `aten.sigmoid.default`,
6+
`et_vk.rms_norm.default`,
7+
fused SDPA with KV cache, 4-bit quantized linear/embedding, RoPE, and prepack.
58
- No Python AOT code — directly consumes Vulkan delegate (.pte exported via VulkanPartitioner)
69
- Reuses Vulkan FlatBuffer format (VH00 header + VK00 payload)
710
- Registers as `"VulkanBackend"` at runtime — mutually exclusive with Vulkan backend at link time
@@ -30,7 +33,7 @@ element-wise ops (add→relu→mul→clamp) at compile time. Embed via the exist
3033
`shaders: [VkBytes]` field in schema.fbs.
3134

3235
## Next Steps
33-
1. **More ops**: sub, mul, relu, linear (matmul), softmax, layer_norm
36+
1. **More ops**: sub, relu, linear (matmul), softmax, layer_norm, shape ops
3437
2. **fp16 support**: Feature-detect `shader-f16`, fallback to fp32
3538
3. **Buffer pooling**: Reuse GPU buffers to avoid OOM at scale
3639
4. **Pipeline caching**: Cache compiled pipelines across runs
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/webgpu/runtime/WebGPUGraph.h>
10+
#include <executorch/backends/webgpu/runtime/WebGPUUtils.h>
11+
#include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h>
12+
#include <executorch/backends/webgpu/runtime/ops/TensorMeta.h>
13+
#include <executorch/backends/webgpu/runtime/ops/sigmoid/sigmoid_wgsl.h>
14+
15+
#include <webgpu/webgpu.h>
16+
17+
#include <stdexcept>
18+
#include <vector>
19+
20+
namespace executorch::backends::webgpu {
21+
22+
namespace {
23+
24+
void sigmoid_impl(WebGPUGraph& graph, const std::vector<int>& args) {
25+
// aten.sigmoid.default args: [in, out]
26+
const int in_id = args.at(0);
27+
const int out_id = args.at(1);
28+
29+
WGPUDevice device = graph.device();
30+
31+
const auto& in_tensor = graph.get_tensor(in_id);
32+
const auto& out_tensor = graph.get_tensor(out_id);
33+
34+
if (in_tensor.dims != out_tensor.dims) {
35+
throw std::runtime_error("sigmoid: input and output shapes must match");
36+
}
37+
38+
TensorMeta out_meta;
39+
fill_tensor_meta(out_tensor, &out_meta);
40+
41+
if (out_tensor.nbytes !=
42+
static_cast<size_t>(out_meta.numel) * sizeof(float) ||
43+
in_tensor.nbytes != static_cast<size_t>(out_meta.numel) * sizeof(float)) {
44+
throw std::runtime_error("sigmoid: non-fp32 operand (nbytes != numel * 4)");
45+
}
46+
47+
uint32_t wg_size =
48+
utils::clamp_workgroup_size(device, kSigmoidWorkgroupSizeX);
49+
uint32_t workgroup_count = utils::compute_1d_workgroup_count(
50+
device, out_meta.numel, wg_size, "sigmoid");
51+
52+
WGPUConstantEntry wg_size_constant = {};
53+
wg_size_constant.key = {"wg_size", WGPU_STRLEN};
54+
wg_size_constant.value = static_cast<double>(wg_size);
55+
56+
WGPUBuffer out_meta_buf =
57+
utils::make_uniform(device, &out_meta, sizeof(TensorMeta));
58+
graph.add_uniform_buffer_bytes(sizeof(TensorMeta));
59+
60+
WGPUShaderSourceWGSL wgsl_desc = {};
61+
wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL;
62+
wgsl_desc.code = {kSigmoidWGSL, WGPU_STRLEN};
63+
64+
WGPUShaderModuleDescriptor shader_desc = {};
65+
shader_desc.nextInChain = &wgsl_desc.chain;
66+
WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device, &shader_desc);
67+
68+
WGPUBindGroupLayoutEntry entries[3] = {};
69+
70+
entries[0].binding = 0;
71+
entries[0].visibility = WGPUShaderStage_Compute;
72+
entries[0].buffer.type = WGPUBufferBindingType_ReadOnlyStorage;
73+
74+
entries[1].binding = 1;
75+
entries[1].visibility = WGPUShaderStage_Compute;
76+
entries[1].buffer.type = WGPUBufferBindingType_Storage;
77+
78+
entries[2].binding = 2;
79+
entries[2].visibility = WGPUShaderStage_Compute;
80+
entries[2].buffer.type = WGPUBufferBindingType_Uniform;
81+
82+
WGPUBindGroupLayoutDescriptor bgl_desc = {};
83+
bgl_desc.entryCount = 3;
84+
bgl_desc.entries = entries;
85+
WGPUBindGroupLayout bgl = wgpuDeviceCreateBindGroupLayout(device, &bgl_desc);
86+
87+
WGPUPipelineLayoutDescriptor pl_desc = {};
88+
pl_desc.bindGroupLayoutCount = 1;
89+
pl_desc.bindGroupLayouts = &bgl;
90+
WGPUPipelineLayout pipeline_layout =
91+
wgpuDeviceCreatePipelineLayout(device, &pl_desc);
92+
93+
WGPUComputePipelineDescriptor pipeline_desc = {};
94+
pipeline_desc.layout = pipeline_layout;
95+
pipeline_desc.compute.module = shader;
96+
pipeline_desc.compute.entryPoint = {"main", WGPU_STRLEN};
97+
pipeline_desc.compute.constantCount = 1;
98+
pipeline_desc.compute.constants = &wg_size_constant;
99+
WGPUComputePipeline pipeline =
100+
wgpuDeviceCreateComputePipeline(device, &pipeline_desc);
101+
102+
WGPUBindGroupEntry bg_entries[3] = {};
103+
104+
bg_entries[0].binding = 0;
105+
bg_entries[0].buffer = in_tensor.buffer;
106+
bg_entries[0].size = in_tensor.nbytes;
107+
108+
bg_entries[1].binding = 1;
109+
bg_entries[1].buffer = out_tensor.buffer;
110+
bg_entries[1].size = out_tensor.nbytes;
111+
112+
bg_entries[2].binding = 2;
113+
bg_entries[2].buffer = out_meta_buf;
114+
bg_entries[2].size = sizeof(TensorMeta);
115+
116+
WGPUBindGroupDescriptor bg_desc = {};
117+
bg_desc.layout = bgl;
118+
bg_desc.entryCount = 3;
119+
bg_desc.entries = bg_entries;
120+
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);
121+
122+
graph.add_dispatch({pipeline, bind_group, workgroup_count});
123+
124+
wgpuShaderModuleRelease(shader);
125+
wgpuBindGroupLayoutRelease(bgl);
126+
wgpuPipelineLayoutRelease(pipeline_layout);
127+
// Drop our ref; the bind group keeps the uniform alive until release.
128+
wgpuBufferRelease(out_meta_buf);
129+
}
130+
131+
} // namespace
132+
133+
WEBGPU_REGISTER_OPERATORS {
134+
WEBGPU_REGISTER_OP(aten.sigmoid.default, sigmoid_impl);
135+
}
136+
137+
} // namespace executorch::backends::webgpu
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
@group(0) @binding(0) var<storage, read> input: array<f32>;
2+
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
3+
4+
struct TensorMeta {
5+
ndim: u32,
6+
numel: u32,
7+
sizes: vec4<u32>,
8+
strides: vec4<u32>,
9+
}
10+
@group(0) @binding(2) var<uniform> out_meta: TensorMeta;
11+
12+
override wg_size: u32 = 64u;
13+
14+
@compute @workgroup_size(wg_size, 1, 1)
15+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
16+
let idx = gid.x;
17+
if (idx >= out_meta.numel) {
18+
return;
19+
}
20+
output[idx] = 1.0 / (1.0 + exp(-input[idx]));
21+
}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <cstdint>
12+
13+
namespace executorch::backends::webgpu {
14+
15+
// @generated from sigmoid.wgsl - DO NOT EDIT.
16+
// wgsl-sha256: 73a26ddce78d1cbd6cbb0c586791b338153cea9af13790dc1400516128a4c278
17+
inline constexpr const char* kSigmoidWGSL = R"(
18+
@group(0) @binding(0) var<storage, read> input: array<f32>;
19+
@group(0) @binding(1) var<storage, read_write> output: array<f32>;
20+
21+
struct TensorMeta {
22+
ndim: u32,
23+
numel: u32,
24+
sizes: vec4<u32>,
25+
strides: vec4<u32>,
26+
}
27+
@group(0) @binding(2) var<uniform> out_meta: TensorMeta;
28+
29+
override wg_size: u32 = 64u;
30+
31+
@compute @workgroup_size(wg_size, 1, 1)
32+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
33+
let idx = gid.x;
34+
if (idx >= out_meta.numel) {
35+
return;
36+
}
37+
output[idx] = 1.0 / (1.0 + exp(-input[idx]));
38+
}
39+
)";
40+
41+
inline constexpr uint32_t kSigmoidWorkgroupSizeX = 64;
42+
inline constexpr uint32_t kSigmoidWorkgroupSizeY = 1;
43+
inline constexpr uint32_t kSigmoidWorkgroupSizeZ = 1;
44+
45+
} // namespace executorch::backends::webgpu

backends/webgpu/test/TARGETS

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,19 @@ python_unittest(
1717
],
1818
)
1919

20+
python_unittest(
21+
name = "test_sigmoid",
22+
srcs = [
23+
"ops/sigmoid/test_sigmoid.py",
24+
],
25+
deps = [
26+
"//caffe2:torch",
27+
"//executorch/backends/vulkan/partitioner:vulkan_partitioner",
28+
"//executorch/backends/vulkan:vulkan_preprocess",
29+
"//executorch/exir:lib",
30+
],
31+
)
32+
2033
runtime.python_library(
2134
name = "tester",
2235
srcs = ["tester.py"],

backends/webgpu/test/op_tests/cases.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@
3636
_ramp,
3737
RmsNormModule,
3838
)
39+
from executorch.backends.webgpu.test.ops.sigmoid.test_sigmoid import (
40+
_sigmoid_range,
41+
_sigmoid_wide_range,
42+
SigmoidChainedModule,
43+
SigmoidModule,
44+
)
3945

4046
# rms_norm coverage is exactly the 14 cases the native test covered.
4147
RMS_NORM_CASES = _CASES
@@ -49,6 +55,13 @@ def _add_factory(variant: str = "regular") -> torch.nn.Module:
4955
}[variant]()
5056

5157

58+
def _sigmoid_factory(variant: str = "regular") -> torch.nn.Module:
59+
return {
60+
"regular": SigmoidModule,
61+
"chained": SigmoidChainedModule,
62+
}[variant]()
63+
64+
5265
@register_op_test("add")
5366
def _add_suite() -> WebGPUTestSuite:
5467
# Same-shape numeric coverage only: broadcast adds stay export-smoke in
@@ -83,6 +96,40 @@ def _add_suite() -> WebGPUTestSuite:
8396
)
8497

8598

99+
@register_op_test("sigmoid")
100+
def _sigmoid_suite() -> WebGPUTestSuite:
101+
return WebGPUTestSuite(
102+
module_factory=_sigmoid_factory,
103+
cases=[
104+
Case(
105+
name="regular_1d",
106+
construct={"variant": "regular"},
107+
inputs=(InputSpec(shape=(M,), gen=_sigmoid_range),),
108+
),
109+
Case(
110+
name="regular_2d",
111+
construct={"variant": "regular"},
112+
inputs=(InputSpec(shape=(M1, M2), gen=_sigmoid_range),),
113+
),
114+
Case(
115+
name="regular_4d",
116+
construct={"variant": "regular"},
117+
inputs=(InputSpec(shape=(XS, S, S1, S2), gen=_sigmoid_range),),
118+
),
119+
Case(
120+
name="wide_range",
121+
construct={"variant": "regular"},
122+
inputs=(InputSpec(shape=(M1, M2), gen=_sigmoid_wide_range),),
123+
),
124+
Case(
125+
name="chained",
126+
construct={"variant": "chained"},
127+
inputs=(InputSpec(shape=(M1, M2), gen=_sigmoid_range),),
128+
),
129+
],
130+
)
131+
132+
86133
def _rms_norm_factory(hidden: int, eps: float, weight_fn) -> torch.nn.Module:
87134
model = RmsNormModule(hidden, eps=eps)
88135
with torch.no_grad():

backends/webgpu/test/op_tests/generate_op_tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
99
Per case: export the module to `<id>.pte`, write its inputs + torch golden as raw
1010
little-endian fp32, and emit `manifest.json` for the C++ gtest driver to consume.
11-
Run: `python -m ...generate_op_tests --output <dir> [--ops add,rms_norm]`.
11+
Run: `python -m ...generate_op_tests --output <dir> [--ops add,sigmoid,rms_norm]`.
1212
"""
1313

1414
from __future__ import annotations

backends/webgpu/test/op_tests/test_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def test_generate_manifest(tmp_path):
7777
def test_every_case_delegates():
7878
# Contract: every registered case must lower to a VulkanBackend delegate. An op that
7979
# silently CPU-falls-back would otherwise produce a misleading golden-equals-golden pass.
80-
for op in ("add", "rms_norm"):
80+
for op in ("add", "sigmoid", "rms_norm"):
8181
suite = op_test_registry[op]
8282
for case in suite.cases:
8383
_module, _inputs, prog = g.export_case(suite, case)

0 commit comments

Comments
 (0)