Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/webgpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ set(WEBGPU_SRCS
runtime/ops/select_as_symint/SelectAsSymint.cpp
runtime/ops/quantized_linear/QuantizedLinear.cpp
runtime/ops/mul/BinaryOp.cpp
runtime/ops/sigmoid/Sigmoid.cpp
runtime/ops/embedding_q4gsw/EmbeddingQ4gsw.cpp
runtime/ops/rope/RotaryEmbedding.cpp
runtime/ops/prepack/Prepack.cpp
Expand Down
20 changes: 8 additions & 12 deletions backends/webgpu/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

Run ExecuTorch models on the GPU via [WebGPU](https://www.w3.org/TR/webgpu/). The backend compiles delegated subgraphs into WGSL compute shaders executed natively through [Dawn](https://dawn.googlesource.com/dawn), whose Tint compiler is the reference WGSL implementation (Metal on macOS, Vulkan on Linux/Windows).

> **Status: Prototype, under active development.** The backend runs the core of transformer inference today — `add`, `rms_norm`, fused scaled-dot-product attention with KV cache, and 4-bit weight-only quantized linear — plus quantized embedding, rotary embedding, and constant prepacking. See [Progress](#progress) for shipped milestones.
> **Status: Prototype, under active development.** The backend runs the core of transformer inference today — `add`, `mul`, `sigmoid`, `rms_norm`, fused scaled-dot-product attention with KV cache, and 4-bit weight-only quantized linear — plus quantized embedding, rotary embedding, and constant prepacking. See [Progress](#progress) for shipped milestones.

## Progress

Expand All @@ -20,14 +20,7 @@ Milestones landed on `main`:
| 2026-06 | Added the attention core of transformer inference — fused scaled-dot-product attention (`sdpa_with_kv_cache`) with an `update_cache` operator for autoregressive decode | [#20086](https://github.com/pytorch/executorch/pull/20086), [#20087](https://github.com/pytorch/executorch/pull/20087) |
| 2026-06 | Added on-GPU kernel timing via WebGPU timestamp queries, for true GPU-side profiling | [#20201](https://github.com/pytorch/executorch/pull/20201) |
| 2026-06 | Added the dominant compute in quantized LLMs — 4-bit weight-only quantized linear (`linear_q4gsw`), a dequantize-and-matmul kernel | [#20226](https://github.com/pytorch/executorch/pull/20226), [#20227](https://github.com/pytorch/executorch/pull/20227) |

In review:

| Milestone | Pull Request |
|---|---|
| Adds 4-bit quantized embedding (`embedding_q4gsw`) | [#20263](https://github.com/pytorch/executorch/pull/20263) |
| Adds rotary position embedding / RoPE (`apply_rotary_emb`) | [#20264](https://github.com/pytorch/executorch/pull/20264) |
| Adds constant prepacking (`prepack`) for end-to-end model weight handling | [#20265](https://github.com/pytorch/executorch/pull/20265) |
| 2026-06 | Added token embedding, rotary position embedding, and constant prepacking for end-to-end model weight handling | [#20414](https://github.com/pytorch/executorch/pull/20414) |

## Architecture

Expand Down Expand Up @@ -61,14 +54,17 @@ Key design choices:
| Operator | WGSL Shader | Notes |
|---|---|---|
| `aten.add.Tensor` | `binary_add.wgsl` | Element-wise with alpha: `out = in1 + alpha * in2` |
| `aten.mul.Tensor` | `binary_mul.wgsl` | Element-wise multiply with broadcasting |
| `aten.sigmoid.default` | `sigmoid.wgsl` | Element-wise sigmoid activation |
| `et_vk.rms_norm.default` | `rms_norm.wgsl` | Root-mean-square normalization |
| `sdpa_with_kv_cache.default` | `sdpa_compute_attn_weights.wgsl`, `sdpa_softmax.wgsl`, `sdpa_compute_out.wgsl` | Fused scaled-dot-product attention (QK / softmax / AV) with KV cache |
| `llama.update_cache.default` | `update_cache.wgsl` | In-place KV cache update for autoregressive decode |
| `et_vk.linear_q4gsw.default` | `q4gsw_linear.wgsl` | 4-bit weight-only quantized linear (dequantize + matmul) |
| `et_vk.embedding_q4gsw.default` | `embedding_q4gsw.wgsl` | 4-bit groupwise-symmetric quantized embedding |
| `et_vk.apply_rotary_emb.default` | `rotary_embedding.wgsl` | Interleaved rotary positional embedding |
| `et_vk.prepack.default` | N/A | Constant materialization into GPU buffers |

**In review:** quantized embedding (`embedding_q4gsw`), rotary embedding (`apply_rotary_emb`), and constant prepacking (`prepack`).

**Planned:** `mul`, `sigmoid`, shape ops (`view`, `permute`, `slice`, `select`, `cat`, `squeeze`/`unsqueeze`), and `index` — the remaining ops needed for end-to-end Llama 3.2 1B.
**Planned:** shape ops (`view`, `permute`, `slice`, `select`, `cat`, `squeeze`/`unsqueeze`) and `index` — the remaining ops needed for end-to-end Llama 3.2 1B.

## Quick Start

Expand Down
7 changes: 5 additions & 2 deletions backends/webgpu/TODO.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# WebGPU Backend — TODO

## Current State (Prototype)
- Single op: `aten.add.Tensor` (fp32, buffer storage)
- Runtime support for transformer-oriented fp32 and LLM custom ops, including
`aten.add.Tensor`, `aten.mul.Tensor`, `aten.sigmoid.default`,
`et_vk.rms_norm.default`,
fused SDPA with KV cache, 4-bit quantized linear/embedding, RoPE, and prepack.
- No Python AOT code — directly consumes Vulkan delegate (.pte exported via VulkanPartitioner)
- Reuses Vulkan FlatBuffer format (VH00 header + VK00 payload)
- Registers as `"VulkanBackend"` at runtime — mutually exclusive with Vulkan backend at link time
Expand Down Expand Up @@ -30,7 +33,7 @@ element-wise ops (add→relu→mul→clamp) at compile time. Embed via the exist
`shaders: [VkBytes]` field in schema.fbs.

## Next Steps
1. **More ops**: sub, mul, relu, linear (matmul), softmax, layer_norm
1. **More ops**: sub, relu, linear (matmul), softmax, layer_norm, shape ops
2. **fp16 support**: Feature-detect `shader-f16`, fallback to fp32
3. **Buffer pooling**: Reuse GPU buffers to avoid OOM at scale
4. **Pipeline caching**: Cache compiled pipelines across runs
Expand Down
137 changes: 137 additions & 0 deletions backends/webgpu/runtime/ops/sigmoid/Sigmoid.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/webgpu/runtime/WebGPUGraph.h>
#include <executorch/backends/webgpu/runtime/WebGPUUtils.h>
#include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h>
#include <executorch/backends/webgpu/runtime/ops/TensorMeta.h>
#include <executorch/backends/webgpu/runtime/ops/sigmoid/sigmoid_wgsl.h>

#include <webgpu/webgpu.h>

#include <stdexcept>
#include <vector>

namespace executorch::backends::webgpu {

namespace {

void sigmoid_impl(WebGPUGraph& graph, const std::vector<int>& args) {
// aten.sigmoid.default args: [in, out]
const int in_id = args.at(0);
const int out_id = args.at(1);

WGPUDevice device = graph.device();

const auto& in_tensor = graph.get_tensor(in_id);
const auto& out_tensor = graph.get_tensor(out_id);

if (in_tensor.dims != out_tensor.dims) {
throw std::runtime_error("sigmoid: input and output shapes must match");
}

TensorMeta out_meta;
fill_tensor_meta(out_tensor, &out_meta);

if (out_tensor.nbytes !=
static_cast<size_t>(out_meta.numel) * sizeof(float) ||
in_tensor.nbytes != static_cast<size_t>(out_meta.numel) * sizeof(float)) {
throw std::runtime_error("sigmoid: non-fp32 operand (nbytes != numel * 4)");
}

uint32_t wg_size =
utils::clamp_workgroup_size(device, kSigmoidWorkgroupSizeX);
uint32_t workgroup_count = utils::compute_1d_workgroup_count(
device, out_meta.numel, wg_size, "sigmoid");

WGPUConstantEntry wg_size_constant = {};
wg_size_constant.key = {"wg_size", WGPU_STRLEN};
wg_size_constant.value = static_cast<double>(wg_size);

WGPUBuffer out_meta_buf =
utils::make_uniform(device, &out_meta, sizeof(TensorMeta));
graph.add_uniform_buffer_bytes(sizeof(TensorMeta));

WGPUShaderSourceWGSL wgsl_desc = {};
wgsl_desc.chain.sType = WGPUSType_ShaderSourceWGSL;
wgsl_desc.code = {kSigmoidWGSL, WGPU_STRLEN};

WGPUShaderModuleDescriptor shader_desc = {};
shader_desc.nextInChain = &wgsl_desc.chain;
WGPUShaderModule shader = wgpuDeviceCreateShaderModule(device, &shader_desc);

WGPUBindGroupLayoutEntry entries[3] = {};

entries[0].binding = 0;
entries[0].visibility = WGPUShaderStage_Compute;
entries[0].buffer.type = WGPUBufferBindingType_ReadOnlyStorage;

entries[1].binding = 1;
entries[1].visibility = WGPUShaderStage_Compute;
entries[1].buffer.type = WGPUBufferBindingType_Storage;

entries[2].binding = 2;
entries[2].visibility = WGPUShaderStage_Compute;
entries[2].buffer.type = WGPUBufferBindingType_Uniform;

WGPUBindGroupLayoutDescriptor bgl_desc = {};
bgl_desc.entryCount = 3;
bgl_desc.entries = entries;
WGPUBindGroupLayout bgl = wgpuDeviceCreateBindGroupLayout(device, &bgl_desc);

WGPUPipelineLayoutDescriptor pl_desc = {};
pl_desc.bindGroupLayoutCount = 1;
pl_desc.bindGroupLayouts = &bgl;
WGPUPipelineLayout pipeline_layout =
wgpuDeviceCreatePipelineLayout(device, &pl_desc);

WGPUComputePipelineDescriptor pipeline_desc = {};
pipeline_desc.layout = pipeline_layout;
pipeline_desc.compute.module = shader;
pipeline_desc.compute.entryPoint = {"main", WGPU_STRLEN};
pipeline_desc.compute.constantCount = 1;
pipeline_desc.compute.constants = &wg_size_constant;
WGPUComputePipeline pipeline =
wgpuDeviceCreateComputePipeline(device, &pipeline_desc);

WGPUBindGroupEntry bg_entries[3] = {};

bg_entries[0].binding = 0;
bg_entries[0].buffer = in_tensor.buffer;
bg_entries[0].size = in_tensor.nbytes;

bg_entries[1].binding = 1;
bg_entries[1].buffer = out_tensor.buffer;
bg_entries[1].size = out_tensor.nbytes;

bg_entries[2].binding = 2;
bg_entries[2].buffer = out_meta_buf;
bg_entries[2].size = sizeof(TensorMeta);

WGPUBindGroupDescriptor bg_desc = {};
bg_desc.layout = bgl;
bg_desc.entryCount = 3;
bg_desc.entries = bg_entries;
WGPUBindGroup bind_group = wgpuDeviceCreateBindGroup(device, &bg_desc);

graph.add_dispatch({pipeline, bind_group, workgroup_count});

wgpuShaderModuleRelease(shader);
wgpuBindGroupLayoutRelease(bgl);
wgpuPipelineLayoutRelease(pipeline_layout);
// Drop our ref; the bind group keeps the uniform alive until release.
wgpuBufferRelease(out_meta_buf);
}

} // namespace

WEBGPU_REGISTER_OPERATORS {
WEBGPU_REGISTER_OP(aten.sigmoid.default, sigmoid_impl);
}

} // namespace executorch::backends::webgpu
21 changes: 21 additions & 0 deletions backends/webgpu/runtime/ops/sigmoid/sigmoid.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;

struct TensorMeta {
ndim: u32,
numel: u32,
sizes: vec4<u32>,
strides: vec4<u32>,
}
@group(0) @binding(2) var<uniform> out_meta: TensorMeta;

override wg_size: u32 = 64u;

@compute @workgroup_size(wg_size, 1, 1)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let idx = gid.x;
if (idx >= out_meta.numel) {
return;
}
output[idx] = 1.0 / (1.0 + exp(-input[idx]));
}
45 changes: 45 additions & 0 deletions backends/webgpu/runtime/ops/sigmoid/sigmoid_wgsl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <cstdint>

namespace executorch::backends::webgpu {

// @generated from sigmoid.wgsl - DO NOT EDIT.
// wgsl-sha256: 73a26ddce78d1cbd6cbb0c586791b338153cea9af13790dc1400516128a4c278
inline constexpr const char* kSigmoidWGSL = R"(
@group(0) @binding(0) var<storage, read> input: array<f32>;
@group(0) @binding(1) var<storage, read_write> output: array<f32>;

struct TensorMeta {
ndim: u32,
numel: u32,
sizes: vec4<u32>,
strides: vec4<u32>,
}
@group(0) @binding(2) var<uniform> out_meta: TensorMeta;

override wg_size: u32 = 64u;

@compute @workgroup_size(wg_size, 1, 1)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
let idx = gid.x;
if (idx >= out_meta.numel) {
return;
}
output[idx] = 1.0 / (1.0 + exp(-input[idx]));
}
)";

inline constexpr uint32_t kSigmoidWorkgroupSizeX = 64;
inline constexpr uint32_t kSigmoidWorkgroupSizeY = 1;
inline constexpr uint32_t kSigmoidWorkgroupSizeZ = 1;

} // namespace executorch::backends::webgpu
13 changes: 13 additions & 0 deletions backends/webgpu/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,19 @@ python_unittest(
],
)

python_unittest(
name = "test_sigmoid",
srcs = [
"ops/sigmoid/test_sigmoid.py",
],
deps = [
"//caffe2:torch",
"//executorch/backends/vulkan/partitioner:vulkan_partitioner",
"//executorch/backends/vulkan:vulkan_preprocess",
"//executorch/exir:lib",
],
)

runtime.python_library(
name = "tester",
srcs = ["tester.py"],
Expand Down
48 changes: 48 additions & 0 deletions backends/webgpu/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from executorch.backends.webgpu.test.op_tests.test_suite import (
Case,
InputSpec,
M,
M1,
M2,
register_op_test,
Expand All @@ -36,6 +37,12 @@
_ramp,
RmsNormModule,
)
from executorch.backends.webgpu.test.ops.sigmoid.test_sigmoid import (
_sigmoid_range,
_sigmoid_wide_range,
SigmoidChainedModule,
SigmoidModule,
)

# rms_norm coverage is exactly the 14 cases the native test covered.
RMS_NORM_CASES = _CASES
Expand All @@ -49,6 +56,13 @@ def _add_factory(variant: str = "regular") -> torch.nn.Module:
}[variant]()


def _sigmoid_factory(variant: str = "regular") -> torch.nn.Module:
return {
"regular": SigmoidModule,
"chained": SigmoidChainedModule,
}[variant]()


@register_op_test("add")
def _add_suite() -> WebGPUTestSuite:
# Same-shape numeric coverage only: broadcast adds stay export-smoke in
Expand Down Expand Up @@ -83,6 +97,40 @@ def _add_suite() -> WebGPUTestSuite:
)


@register_op_test("sigmoid")
def _sigmoid_suite() -> WebGPUTestSuite:
return WebGPUTestSuite(
module_factory=_sigmoid_factory,
cases=[
Case(
name="regular_1d",
construct={"variant": "regular"},
inputs=(InputSpec(shape=(M,), gen=_sigmoid_range),),
),
Case(
name="regular_2d",
construct={"variant": "regular"},
inputs=(InputSpec(shape=(M1, M2), gen=_sigmoid_range),),
),
Case(
name="regular_4d",
construct={"variant": "regular"},
inputs=(InputSpec(shape=(XS, S, S1, S2), gen=_sigmoid_range),),
),
Case(
name="wide_range",
construct={"variant": "regular"},
inputs=(InputSpec(shape=(M1, M2), gen=_sigmoid_wide_range),),
),
Case(
name="chained",
construct={"variant": "chained"},
inputs=(InputSpec(shape=(M1, M2), gen=_sigmoid_range),),
),
],
)


def _rms_norm_factory(hidden: int, eps: float, weight_fn) -> torch.nn.Module:
model = RmsNormModule(hidden, eps=eps)
with torch.no_grad():
Expand Down
2 changes: 1 addition & 1 deletion backends/webgpu/test/op_tests/generate_op_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

Per case: export the module to `<id>.pte`, write its inputs + torch golden as raw
little-endian fp32, and emit `manifest.json` for the C++ gtest driver to consume.
Run: `python -m ...generate_op_tests --output <dir> [--ops add,rms_norm]`.
Run: `python -m ...generate_op_tests --output <dir> [--ops add,sigmoid,rms_norm]`.
"""

from __future__ import annotations
Expand Down
2 changes: 1 addition & 1 deletion backends/webgpu/test/op_tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_generate_manifest(tmp_path):
def test_every_case_delegates():
# Contract: every registered case must lower to a VulkanBackend delegate. An op that
# silently CPU-falls-back would otherwise produce a misleading golden-equals-golden pass.
for op in ("add", "rms_norm"):
for op in ("add", "sigmoid", "rms_norm"):
suite = op_test_registry[op]
for case in suite.cases:
_module, _inputs, prog = g.export_case(suite, case)
Expand Down
Loading
Loading