Skip to content

Commit 5dd66ad

Browse files
[ExecuTorch][WebGPU] Upload named-data constants in WebGPUGraph
The Vulkan serializer that the WebGPU backend reuses stores every non-empty constant in the PTE's named-data map with `offset == UINT64_MAX` and a `named_key`, rather than inline in the VK00 blob. `WebGPUGraph::build` previously handled only inline constants, so a delegated op's constant weights were never uploaded and the op produced all zeros. `build` now also fetches named-data constants via `NamedDataMap::get_data`, mirroring the path `VulkanBackend` already uses. `aten.add` was unaffected since it has no constant tensors; the first consumer is the `rms_norm` op in the child diff. Differential Revision: [D107288998](https://our.internmc.facebook.com/intern/diff/D107288998/) ghstack-source-id: 389182397 Pull-Request: pytorch#19962
1 parent 89aed7b commit 5dd66ad

3 files changed

Lines changed: 35 additions & 3 deletions

File tree

backends/webgpu/runtime/WebGPUBackend.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ Result<DelegateHandle*> WebGPUBackend::init(
7676
}
7777

7878
try {
79-
graph->build(flatbuffer_data, constant_data);
79+
graph->build(flatbuffer_data, constant_data, context.get_named_data_map());
8080
} catch (const std::exception& e) {
8181
ET_LOG(Error, "WebGPU graph build failed: %s", e.what());
8282
graph->~WebGPUGraph();

backends/webgpu/runtime/WebGPUGraph.cpp

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <executorch/backends/webgpu/runtime/ops/OperatorRegistry.h>
1111

1212
#include <executorch/backends/vulkan/serialization/schema_generated.h>
13+
#include <executorch/runtime/core/named_data_map.h>
1314

1415
#include <executorch/backends/webgpu/runtime/WebGPUDevice.h>
1516
#include <webgpu/wgpu.h>
@@ -93,7 +94,8 @@ WebGPUGraph::~WebGPUGraph() {
9394

9495
void WebGPUGraph::build(
9596
const void* flatbuffer_data,
96-
const uint8_t* constant_data) {
97+
const uint8_t* constant_data,
98+
const executorch::runtime::NamedDataMap* named_data_map) {
9799
if (!device_) {
98100
auto* ctx = get_default_webgpu_context();
99101
if (ctx) {
@@ -165,6 +167,31 @@ void WebGPUGraph::build(
165167
const uint8_t* src = constant_data + vk_bytes->offset();
166168
wgpuQueueWriteBuffer(
167169
queue_, tensor.buffer, 0, src, tensor.nbytes);
170+
} else if (
171+
vk_bytes->named_key() != nullptr &&
172+
named_data_map != nullptr) {
173+
// Constant stored in the PTE named-data map.
174+
auto buf =
175+
named_data_map->get_data(vk_bytes->named_key()->c_str());
176+
if (!buf.ok()) {
177+
throw std::runtime_error(
178+
std::string("WebGPU: named constant '") +
179+
vk_bytes->named_key()->c_str() +
180+
"' not found in NamedDataMap");
181+
}
182+
if (buf->size() < tensor.nbytes) {
183+
throw std::runtime_error(
184+
std::string("WebGPU: named constant '") +
185+
vk_bytes->named_key()->c_str() + "' undersized: have " +
186+
std::to_string(buf->size()) + " bytes, need " +
187+
std::to_string(tensor.nbytes));
188+
}
189+
wgpuQueueWriteBuffer(
190+
queue_, tensor.buffer, 0, buf->data(), tensor.nbytes);
191+
buf->Free();
192+
} else {
193+
throw std::runtime_error(
194+
"WebGPU: constant has no inline offset and no named-data key");
168195
}
169196
}
170197
}

backends/webgpu/runtime/WebGPUGraph.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
#include <unordered_map>
1616
#include <vector>
1717

18+
#include <executorch/runtime/core/named_data_map.h>
19+
1820
namespace executorch {
1921
namespace backends {
2022
namespace webgpu {
@@ -66,7 +68,10 @@ class WebGPUGraph {
6668

6769
// Build the graph from a deserialized VkGraph flatbuffer and constant data.
6870
// The flatbuffer_data pointer must remain valid during build().
69-
void build(const void* flatbuffer_data, const uint8_t* constant_data);
71+
void build(
72+
const void* flatbuffer_data,
73+
const uint8_t* constant_data,
74+
const executorch::runtime::NamedDataMap* named_data_map = nullptr);
7075

7176
// Copy input tensor data from host pointers into GPU buffers.
7277
void copy_inputs(const std::vector<std::pair<const void*, size_t>>& inputs);

0 commit comments

Comments
 (0)