Skip to content

Commit 6fa26d4

Browse files
[ExecuTorch][WebGPU] Per-pass compute dispatch ordering for fused multi-dispatch ops
Pull Request resolved: pytorch#20072 WebGPU has no write->read ordering between dispatches in a single compute pass, so a fused multi-dispatch op (SDPA) can read stale writes. Record one compute pass per dispatch in `execute()` (both the full and ranged paths) -- the pass boundary is WebGPU's implicit barrier (there is no `vkCmdPipelineBarrier`). Single-dispatch ops are unchanged. Also flips this file to the C++17 nested namespace. Consumed by the fused SDPA op above. ghstack-source-id: 391378799 @exported-using-ghexport Differential Revision: [D107543258](https://our.internmc.facebook.com/intern/diff/D107543258/)
1 parent 3dcb1c4 commit 6fa26d4

2 files changed

Lines changed: 15 additions & 26 deletions

File tree

backends/webgpu/runtime/WebGPUGraph.cpp

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@
1818
#include <cstring>
1919
#include <stdexcept>
2020

21-
namespace executorch {
22-
namespace backends {
23-
namespace webgpu {
21+
namespace executorch::backends::webgpu {
2422

2523
// vkgraph namespace is declared at global scope in the generated FlatBuffer
2624
// header
@@ -380,21 +378,20 @@ void WebGPUGraph::execute() {
380378
WGPUCommandEncoder encoder =
381379
wgpuDeviceCreateCommandEncoder(device_, &enc_desc);
382380

383-
WGPUComputePassDescriptor pass_desc = {};
384-
WGPUComputePassEncoder pass =
385-
wgpuCommandEncoderBeginComputePass(encoder, &pass_desc);
386-
381+
// One pass per dispatch: enforces storage RAW ordering across deps.
387382
for (const auto& dispatch : dispatches_) {
383+
WGPUComputePassDescriptor pass_desc = {};
384+
WGPUComputePassEncoder pass =
385+
wgpuCommandEncoderBeginComputePass(encoder, &pass_desc);
388386
wgpuComputePassEncoderSetPipeline(pass, dispatch.pipeline);
389387
wgpuComputePassEncoderSetBindGroup(
390388
pass, 0, dispatch.bind_group, 0, nullptr);
391389
wgpuComputePassEncoderDispatchWorkgroups(
392390
pass, dispatch.workgroup_count_x, 1, 1);
391+
wgpuComputePassEncoderEnd(pass);
392+
wgpuComputePassEncoderRelease(pass);
393393
}
394394

395-
wgpuComputePassEncoderEnd(pass);
396-
wgpuComputePassEncoderRelease(pass);
397-
398395
for (const auto& copy : output_copies_) {
399396
wgpuCommandEncoderCopyBufferToBuffer(
400397
encoder, copy.src_buffer, 0, copy.staging_buffer, 0, copy.nbytes);
@@ -423,21 +420,19 @@ void WebGPUGraph::execute() {
423420
WGPUCommandEncoder encoder =
424421
wgpuDeviceCreateCommandEncoder(device_, &enc_desc);
425422

426-
WGPUComputePassDescriptor pass_desc = {};
427-
WGPUComputePassEncoder pass =
428-
wgpuCommandEncoderBeginComputePass(encoder, &pass_desc);
429-
430423
for (size_t i = start; i < end; i++) {
424+
WGPUComputePassDescriptor pass_desc = {};
425+
WGPUComputePassEncoder pass =
426+
wgpuCommandEncoderBeginComputePass(encoder, &pass_desc);
431427
wgpuComputePassEncoderSetPipeline(pass, dispatches_[i].pipeline);
432428
wgpuComputePassEncoderSetBindGroup(
433429
pass, 0, dispatches_[i].bind_group, 0, nullptr);
434430
wgpuComputePassEncoderDispatchWorkgroups(
435431
pass, dispatches_[i].workgroup_count_x, 1, 1);
432+
wgpuComputePassEncoderEnd(pass);
433+
wgpuComputePassEncoderRelease(pass);
436434
}
437435

438-
wgpuComputePassEncoderEnd(pass);
439-
wgpuComputePassEncoderRelease(pass);
440-
441436
if (end == n) {
442437
for (const auto& copy : output_copies_) {
443438
wgpuCommandEncoderCopyBufferToBuffer(
@@ -545,6 +540,4 @@ WebGPUMemoryStats WebGPUGraph::memory_stats() const {
545540
return stats;
546541
}
547542

548-
} // namespace webgpu
549-
} // namespace backends
550-
} // namespace executorch
543+
} // namespace executorch::backends::webgpu

backends/webgpu/runtime/WebGPUGraph.h

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@
1717

1818
#include <executorch/runtime/core/named_data_map.h>
1919

20-
namespace executorch {
21-
namespace backends {
22-
namespace webgpu {
20+
namespace executorch::backends::webgpu {
2321

2422
struct WebGPUTensor {
2523
WGPUBuffer buffer = nullptr;
@@ -193,6 +191,4 @@ class WebGPUGraph {
193191
size_t uniform_buffer_bytes_ = 0;
194192
};
195193

196-
} // namespace webgpu
197-
} // namespace backends
198-
} // namespace executorch
194+
} // namespace executorch::backends::webgpu

0 commit comments

Comments
 (0)