Skip to content

Commit 5526971

Browse files
[ExecuTorch][WebGPU] SymInt live-scalar mechanism + et_vk.select_as_symint
Pull Request resolved: pytorch#20085 Adds the dynamic-scalar (SymInt) mechanism to the WebGPU graph as a standalone enabler, ahead of the SDPA op that consumes it. Mirrors the Vulkan delegate's SymInt = live uniform-buffer design: a `ValueType::SymInt` backed by a 16-byte `Uniform|CopyDst` buffer, `set_symint`/`read_symint`/`symint_buffer` accessors with dirty-tracking, a `SymIntSource` + `add_symint_source`/`update_symints_from_inputs` host-read path, and an `add_resize_hook`/`propagate_resize`/`dispatch_at` recompute plumbing. `WebGPUBackend::execute` calls `propagate_resize` after refreshing the SymInts from the runtime inputs. The `et_vk.select_as_symint` op handler records `out SymInt = x[index]` along a dim at build time. This diff has no in-graph consumer yet — the SDPA op (stacked above) reads the SymInt value via `read_symint()` for dynamic `input_pos`. Building it as its own diff keeps the enabler separate from the op, matching the update_cache → mechanism → SDPA layering. Authored with assistance from Claude. ghstack-source-id: 391979584 @exported-using-ghexport Differential Revision: [D107584280](https://our.internmc.facebook.com/intern/diff/D107584280/)
1 parent 14168a3 commit 5526971

5 files changed

Lines changed: 241 additions & 1 deletion

File tree

backends/webgpu/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ set(WEBGPU_SRCS
3434
runtime/ops/add/BinaryOp.cpp
3535
runtime/ops/rms_norm/RmsNorm.cpp
3636
runtime/ops/update_cache/UpdateCache.cpp
37+
runtime/ops/select_as_symint/SelectAsSymint.cpp
3738
)
3839

3940
add_library(webgpu_backend ${WEBGPU_SRCS})

backends/webgpu/runtime/WebGPUBackend.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,15 @@ Error WebGPUBackend::execute(
106106
}
107107
graph->copy_inputs(inputs);
108108

109+
// Fail loud as a runtime Error so a throw never crosses the backend boundary.
110+
try {
111+
graph->update_symints_from_inputs(inputs);
112+
graph->propagate_resize();
113+
} catch (const std::exception& e) {
114+
ET_LOG(Error, "WebGPU symint refresh/resize failed: %s", e.what());
115+
return Error::Internal;
116+
}
117+
109118
// Execute the compute graph
110119
graph->execute();
111120

backends/webgpu/runtime/WebGPUGraph.cpp

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,86 @@ WGPUBuffer WebGPUGraph::create_scratch_buffer(size_t nbytes) {
5959
return buffer;
6060
}
6161

62+
void WebGPUGraph::update_symints_from_inputs(
63+
const std::vector<std::pair<const void*, size_t>>& inputs) {
64+
for (const auto& src : symint_sources_) {
65+
int pos = -1;
66+
for (size_t i = 0; i < input_ids_.size(); i++) {
67+
if (input_ids_[i] == src.input_tensor_id) {
68+
pos = static_cast<int>(i);
69+
break;
70+
}
71+
}
72+
if (pos < 0 || pos >= static_cast<int>(inputs.size())) {
73+
throw std::runtime_error(
74+
"select_as_symint: source tensor is not a graph input");
75+
}
76+
const auto& dims = tensors_[src.input_tensor_id].dims;
77+
int dim = src.dim < 0 ? src.dim + static_cast<int>(dims.size()) : src.dim;
78+
if (dim < 0 || dim >= static_cast<int>(dims.size())) {
79+
throw std::runtime_error("select_as_symint: dim out of range");
80+
}
81+
int index = src.index;
82+
if (index < 0) {
83+
index += static_cast<int>(dims[dim]);
84+
}
85+
if (index < 0 || index >= static_cast<int>(dims[dim])) {
86+
throw std::runtime_error("select_as_symint: index out of range");
87+
}
88+
int64_t numel = 1;
89+
for (int64_t d : dims) {
90+
numel *= d;
91+
}
92+
if (numel <= 0) {
93+
throw std::runtime_error("select_as_symint: empty input tensor");
94+
}
95+
int64_t stride = 1;
96+
for (size_t i = static_cast<size_t>(dim) + 1; i < dims.size(); i++) {
97+
stride *= dims[i];
98+
}
99+
// Reads the [0,..,index,..,0] element; symint sources are scalar-ish.
100+
const int64_t offset = static_cast<int64_t>(index) * stride;
101+
// elem_size back-derived from build-time numel (sources are static-shaped).
102+
const void* host = inputs[pos].first;
103+
const size_t elem_size = inputs[pos].second / static_cast<size_t>(numel);
104+
int32_t val;
105+
if (elem_size == sizeof(int64_t)) {
106+
val = static_cast<int32_t>(static_cast<const int64_t*>(host)[offset]);
107+
} else if (elem_size == sizeof(int32_t)) {
108+
val = static_cast<const int32_t*>(host)[offset];
109+
} else {
110+
throw std::runtime_error(
111+
"select_as_symint: unsupported input element size");
112+
}
113+
set_symint(src.symint_id, val);
114+
}
115+
}
116+
117+
void WebGPUGraph::set_symint(int id, int32_t val) {
118+
auto it = symints_.find(id);
119+
if (it == symints_.end()) {
120+
throw std::runtime_error("WebGPUGraph::set_symint: id is not a SymInt");
121+
}
122+
if (it->second.value != val) {
123+
it->second.value = val;
124+
wgpuQueueWriteBuffer(
125+
queue_, it->second.buffer, 0, &it->second.value, sizeof(int32_t));
126+
dirty_symints_.insert(id);
127+
}
128+
}
129+
130+
void WebGPUGraph::propagate_resize() {
131+
if (dirty_symints_.empty()) {
132+
return;
133+
}
134+
for (auto& hook : resize_hooks_) {
135+
if (dirty_symints_.count(hook.symint_id) != 0) {
136+
hook.fn(*this);
137+
}
138+
}
139+
dirty_symints_.clear();
140+
}
141+
62142
WebGPUGraph::~WebGPUGraph() {
63143
for (size_t i = 0; i < tensors_.size(); i++) {
64144
if (tensors_[i].buffer &&
@@ -76,6 +156,16 @@ WebGPUGraph::~WebGPUGraph() {
76156
wgpuBufferRelease(buf);
77157
}
78158
}
159+
for (auto& buf : owned_uniform_buffers_) {
160+
if (buf) {
161+
wgpuBufferRelease(buf);
162+
}
163+
}
164+
for (auto& kv : symints_) {
165+
if (kv.second.buffer) {
166+
wgpuBufferRelease(kv.second.buffer);
167+
}
168+
}
79169
for (auto& buf : output_staging_buffers_) {
80170
if (buf) {
81171
wgpuBufferRelease(buf);
@@ -236,6 +326,27 @@ void WebGPUGraph::build(
236326
bools_[i] = val->value_as_Bool()->bool_val();
237327
break;
238328
}
329+
case vkgraph::GraphTypes::SymInt: {
330+
// Live scalar: small Uniform buffer the CPU rewrites per execute.
331+
value_types_[i] = ValueType::SymInt;
332+
SymIntSlot slot;
333+
slot.value = static_cast<int32_t>(val->value_as_SymInt()->value());
334+
// 16B matches the backend uniform-struct alignment; int32 in first 4.
335+
constexpr size_t kSymIntUniformBytes = 16;
336+
WGPUBufferDescriptor d = {};
337+
d.size = kSymIntUniformBytes;
338+
d.usage = WGPUBufferUsage_Uniform | WGPUBufferUsage_CopyDst;
339+
d.mappedAtCreation = true;
340+
slot.buffer = wgpuDeviceCreateBuffer(device_, &d);
341+
void* mapped =
342+
wgpuBufferGetMappedRange(slot.buffer, 0, kSymIntUniformBytes);
343+
std::memset(mapped, 0, kSymIntUniformBytes);
344+
std::memcpy(mapped, &slot.value, sizeof(int32_t));
345+
wgpuBufferUnmap(slot.buffer);
346+
symints_[i] = slot;
347+
add_uniform_buffer_bytes(kSymIntUniformBytes);
348+
break;
349+
}
239350
default:
240351
value_types_[i] = ValueType::Null;
241352
break;

backends/webgpu/runtime/WebGPUGraph.h

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111
#include <webgpu/webgpu.h>
1212

1313
#include <cstdint>
14+
#include <functional>
1415
#include <string>
1516
#include <unordered_map>
17+
#include <unordered_set>
1618
#include <vector>
1719

1820
#include <executorch/runtime/core/named_data_map.h>
@@ -104,6 +106,52 @@ class WebGPUGraph {
104106
return ints_[id];
105107
}
106108

109+
// Live-scalar (SymInt) API; mirrors the Vulkan SymInt/ParamsBuffer UBO.
110+
// set_symint writes the buffer + marks dirty only if the value changed.
111+
void set_symint(int id, int32_t val);
112+
// read_symint throws (fail-loud) if id is not a SymInt.
113+
int32_t read_symint(int id) const {
114+
return symints_.at(id).value;
115+
}
116+
// symint_buffer throws (fail-loud) if id is not a SymInt.
117+
WGPUBuffer symint_buffer(int id) const {
118+
return symints_.at(id).buffer;
119+
}
120+
121+
// Records that a SymInt's value is read from input_tensor[index] along dim.
122+
struct SymIntSource {
123+
int symint_id;
124+
int input_tensor_id;
125+
int dim;
126+
int index;
127+
};
128+
void
129+
add_symint_source(int symint_id, int input_tensor_id, int dim, int index) {
130+
symint_sources_.push_back({symint_id, input_tensor_id, dim, index});
131+
}
132+
const std::vector<SymIntSource>& symint_sources() const {
133+
return symint_sources_;
134+
}
135+
136+
// Execute-time select_as_symint read; mirrors Vulkan select_as_symint_impl.
137+
void update_symints_from_inputs(
138+
const std::vector<std::pair<const void*, size_t>>& inputs);
139+
140+
// Per-SymInt resize hook; mirrors Vulkan DynamicDispatchNode::trigger_resize.
141+
void add_resize_hook(int symint_id, std::function<void(WebGPUGraph&)> fn) {
142+
resize_hooks_.push_back({symint_id, std::move(fn)});
143+
}
144+
// Run hooks for changed SymInts then clear; call before execute().
145+
void propagate_resize();
146+
147+
// Mutable dispatch access for resize hooks (to rewrite workgroup_count_x).
148+
WebGPUDispatch& dispatch_at(size_t i) {
149+
return dispatches_[i];
150+
}
151+
size_t num_dispatches() const {
152+
return dispatches_.size();
153+
}
154+
107155
WGPUDevice device() const {
108156
return device_;
109157
}
@@ -119,6 +167,11 @@ class WebGPUGraph {
119167
uniform_buffer_bytes_ += bytes;
120168
}
121169

170+
// Keep a uniform alive for the graph's lifetime; released in the dtor.
171+
void own_uniform_buffer(WGPUBuffer buffer) {
172+
owned_uniform_buffers_.push_back(buffer);
173+
}
174+
122175
// Graph-owned scratch storage buffer for fused-op intermediates (e.g. SDPA).
123176
WGPUBuffer create_scratch_buffer(size_t nbytes);
124177

@@ -149,7 +202,7 @@ class WebGPUGraph {
149202
return static_cast<int>(value_types_.size());
150203
}
151204

152-
enum class ValueType { Tensor, Int, Double, Bool, Null, String };
205+
enum class ValueType { Tensor, Int, Double, Bool, Null, String, SymInt };
153206

154207
ValueType get_value_type(int id) const {
155208
return value_types_[id];
@@ -168,6 +221,22 @@ class WebGPUGraph {
168221
std::vector<double> doubles_;
169222
std::vector<bool> bools_;
170223

224+
// SymInt (live scalar): id -> {live Uniform buffer, current value}, sparse.
225+
struct SymIntSlot {
226+
WGPUBuffer buffer = nullptr;
227+
int32_t value = 0;
228+
};
229+
std::unordered_map<int, SymIntSlot> symints_;
230+
std::vector<SymIntSource> symint_sources_;
231+
232+
// Resize hooks + the set of SymInts changed since the last propagate_resize.
233+
struct ResizeHook {
234+
int symint_id;
235+
std::function<void(WebGPUGraph&)> fn;
236+
};
237+
std::vector<ResizeHook> resize_hooks_;
238+
std::unordered_set<int> dirty_symints_;
239+
171240
std::vector<int> input_ids_;
172241
std::vector<int> output_ids_;
173242

@@ -179,6 +248,9 @@ class WebGPUGraph {
179248
// Long-lived scratch storage buffers for fused ops (e.g. SDPA temporaries).
180249
std::vector<WGPUBuffer> scratch_buffers_;
181250

251+
// Uniform buffers owned for the graph's lifetime; released in the dtor.
252+
std::vector<WGPUBuffer> owned_uniform_buffers_;
253+
182254
// Staging buffers for reading back outputs (MapRead | CopyDst).
183255
std::vector<WGPUBuffer> output_staging_buffers_;
184256

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/webgpu/runtime/WebGPUGraph.h>
10+
#include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h>
11+
12+
#include <algorithm>
13+
#include <stdexcept>
14+
15+
namespace executorch::backends::webgpu {
16+
17+
namespace {
18+
19+
// et_vk.select_as_symint: out SymInt = x[index] along dim; read at execute.
20+
void select_as_symint_impl(WebGPUGraph& graph, const std::vector<int>& args) {
21+
const int x_id = args.at(0);
22+
const int dim_id = args.at(1);
23+
const int index_id = args.at(2);
24+
const int out_id = args.at(3);
25+
26+
if (graph.get_value_type(out_id) != WebGPUGraph::ValueType::SymInt) {
27+
throw std::runtime_error("select_as_symint: output is not a SymInt");
28+
}
29+
const std::vector<int>& inputs = graph.input_ids();
30+
if (std::find(inputs.begin(), inputs.end(), x_id) == inputs.end()) {
31+
throw std::runtime_error(
32+
"select_as_symint: source tensor is not a graph input");
33+
}
34+
graph.add_symint_source(
35+
out_id,
36+
x_id,
37+
static_cast<int>(graph.get_int(dim_id)),
38+
static_cast<int>(graph.get_int(index_id)));
39+
}
40+
41+
} // namespace
42+
43+
WEBGPU_REGISTER_OPERATORS {
44+
WEBGPU_REGISTER_OP(et_vk.select_as_symint.default, select_as_symint_impl);
45+
}
46+
47+
} // namespace executorch::backends::webgpu

0 commit comments

Comments
 (0)