Skip to content

Commit af92b60

Browse files
[ExecuTorch][WebGPU] Graph-owned scratch buffers for fused-op intermediates
Pull Request resolved: pytorch#20073 Add `WebGPUGraph::create_scratch_buffer` for fused-op intermediates (SDPA's `attn_weights`/`attn_weights_softmax`) that are not model tensors and live only between dispatches. Graph-owned, released in the destructor. Vulkan models these as graph tensors; we use raw buffers (buffer-only backend). Consumed by the fused SDPA op above. ghstack-source-id: 391378805 @exported-using-ghexport Differential Revision: [D107543259](https://our.internmc.facebook.com/intern/diff/D107543259/)
1 parent 6fa26d4 commit af92b60

2 files changed

Lines changed: 22 additions & 0 deletions

File tree

backends/webgpu/runtime/WebGPUGraph.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,17 @@ size_t vk_datatype_size(vkgraph::VkDataType dtype) {
4848

4949
WebGPUGraph::WebGPUGraph() = default;
5050

51+
WGPUBuffer WebGPUGraph::create_scratch_buffer(size_t nbytes) {
52+
WGPUBufferDescriptor buf_desc = {};
53+
buf_desc.size = nbytes > 0 ? nbytes : 4;
54+
buf_desc.usage = WGPUBufferUsage_Storage | WGPUBufferUsage_CopyDst |
55+
WGPUBufferUsage_CopySrc;
56+
buf_desc.mappedAtCreation = false;
57+
WGPUBuffer buffer = wgpuDeviceCreateBuffer(device_, &buf_desc);
58+
scratch_buffers_.push_back(buffer);
59+
return buffer;
60+
}
61+
5162
WebGPUGraph::~WebGPUGraph() {
5263
for (size_t i = 0; i < tensors_.size(); i++) {
5364
if (tensors_[i].buffer &&
@@ -60,6 +71,11 @@ WebGPUGraph::~WebGPUGraph() {
6071
wgpuBufferRelease(buf);
6172
}
6273
}
74+
for (auto& buf : scratch_buffers_) {
75+
if (buf) {
76+
wgpuBufferRelease(buf);
77+
}
78+
}
6379
for (auto& buf : output_staging_buffers_) {
6480
if (buf) {
6581
wgpuBufferRelease(buf);

backends/webgpu/runtime/WebGPUGraph.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,9 @@ class WebGPUGraph {
119119
uniform_buffer_bytes_ += bytes;
120120
}
121121

122+
// Graph-owned scratch storage buffer for fused-op intermediates (e.g. SDPA).
123+
WGPUBuffer create_scratch_buffer(size_t nbytes);
124+
122125
WGPUShaderModule get_or_create_shader(
123126
const std::string& key,
124127
const char* wgsl_source);
@@ -173,6 +176,9 @@ class WebGPUGraph {
173176
std::vector<WGPUBuffer> shared_buffers_;
174177
std::vector<size_t> shared_buffer_sizes_;
175178

179+
// Long-lived scratch storage buffers for fused ops (e.g. SDPA temporaries).
180+
std::vector<WGPUBuffer> scratch_buffers_;
181+
176182
// Staging buffers for reading back outputs (MapRead | CopyDst).
177183
std::vector<WGPUBuffer> output_staging_buffers_;
178184

0 commit comments

Comments
 (0)