[webgpu] Per-graph buffer manager for WebGPU multi-graph capture#28260
[webgpu] Per-graph buffer manager for WebGPU multi-graph capture#28260qjia7 wants to merge 32 commits into
Conversation
Replace the single shared graph_buffer_mgr_ with per-annotation-ID buffer managers so that multiple generators with different prompts each get isolated buffer caches. This prevents cross-generator buffer corruption when graph capture is enabled. Changes: - Add per_graph_buffer_mgrs_ map keyed by graph annotation ID - Route BufferManager() to the correct per-graph instance during capture/replay - Change GpuBufferAllocator from const BufferManager& to pointer with SetBufferManager() for dynamic routing - Add ReleaseGraph API through the full stack (EP base class, C API, InferenceSession, plugin EP, EP adapter) so callers can free captured graph resources when a generator is destroyed - WebGPU EP ReleaseGraph cleans up captured commands, buffer manager, and tracking state for the given annotation ID
1b08292 to
bc5cfdb
Compare
… pattern Use RefreshBufferManager to update the cached buffer manager pointer from a getter lambda, keeping Alloc/Free as simple pointer dereferences while routing logic stays in EP::BufferManager() as the single source of truth. Remove graph_default_buffer_mgr_ and set min_num_runs_before_graph_capture_ to 0.
Allocators constructed with a direct BufferManager reference store a plain pointer without wrapping it in a std::function. RefreshBufferManager becomes a no-op for these allocators, avoiding unnecessary lambda overhead for initializer and shared allocators that never need dynamic routing.
There was a problem hiding this comment.
Pull request overview
This PR extends WebGPU graph-capture support to safely handle multiple captured graphs (keyed by graph annotation ID) and adds a new end-to-end API to explicitly release captured graph resources (captured command lists + GPU buffers) when a generator/session is torn down. It aims to prevent cross-contamination between sequential generators/sessions that capture and replay different graphs.
Changes:
- Introduces per-annotation-ID captured command storage and per-graph buffer managers in the WebGPU EP to isolate graph-capture state.
- Adds a
ReleaseGraph(graph_annotation_id)API through the ORT stack (EP base → plugin EP adapter → InferenceSession → C API). - Updates WebGPU allocators to support dynamically routing allocations to the active per-graph buffer manager using a cached-pointer + refresh pattern.
Reviewed changes
Copilot reviewed 17 out of 17 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h | Adds ReleaseGraph override to plugin EP wrapper. |
| onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc | Implements plugin EP ReleaseGraph with version/NULL checks and fallback. |
| onnxruntime/core/session/ort_version_check.h | Adds a comment explaining a constexpr/MSVC limitation. |
| onnxruntime/core/session/ort_apis.h | Declares OrtApis::SessionReleaseGraph. |
| onnxruntime/core/session/onnxruntime_c_api.cc | Implements SessionReleaseGraph and wires it into the exported OrtApi table. |
| onnxruntime/core/session/inference_session.h | Adds InferenceSession::ReleaseGraph API surface. |
| onnxruntime/core/session/inference_session.cc | Implements InferenceSession::ReleaseGraph forwarding to cached EP. |
| onnxruntime/core/providers/webgpu/webgpu_execution_provider.h | Replaces single-graph capture state with per-graph maps/sets; adds ReleaseGraph; adds allocator routing state. |
| onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc | Implements per-graph capture/replay/release and allocator buffer-manager routing. |
| onnxruntime/core/providers/webgpu/ep/factory.cc | Switches adapter-path device allocator to use a buffer-manager getter lambda. |
| onnxruntime/core/providers/webgpu/ep/ep.h | Adds EP-adapter function pointer for ReleaseGraph. |
| onnxruntime/core/providers/webgpu/ep/ep.cc | Wires ReleaseGraph into the EP-adapter vtable. |
| onnxruntime/core/providers/webgpu/allocator.h | Adds getter-based allocator constructor and RefreshBufferManager. |
| onnxruntime/core/providers/webgpu/allocator.cc | Implements cached-pointer refresh logic and switches Alloc/Free to pointer dereferences. |
| include/onnxruntime/core/session/onnxruntime_ep_c_api.h | Extends OrtEp C API with ReleaseGraph. |
| include/onnxruntime/core/session/onnxruntime_c_api.h | Extends OrtApi C API with SessionReleaseGraph. |
| include/onnxruntime/core/framework/execution_provider.h | Adds default virtual IExecutionProvider::ReleaseGraph. |
Comments suppressed due to low confidence (2)
onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc:786
- The cached
BufferManagerrefresh happens beforem_current_graph_annotation_idis updated for this run. SinceRefreshBufferManager()consultsWebGpuExecutionProvider::BufferManager()(which usesm_current_graph_annotation_id), the allocator can cache the previous run’s per-graph buffer manager and route allocations to the wrong graph. Setm_current_graph_annotation_id(and any other state thatBufferManager()depends on) before callingRefreshBufferManager().
graph_buffer_mgr_active_ = true;
if (default_gpu_allocator_) {
default_gpu_allocator_->RefreshBufferManager();
}
onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc:793
IsGraphCaptureAllowed()is called beforem_current_graph_annotation_idis set tograph_annotation_id, butIsGraphCaptureAllowed()readsm_current_graph_annotation_idto look up the per-graph run count. This means the capture decision can be made using the previous graph’s run count instead of the current graph’s. Updatem_current_graph_annotation_idbefore this check (or changeIsGraphCaptureAllowedto takegraph_annotation_id).
if (graph_annotation_id != -1 && IsGraphCaptureAllowed() && !IsGraphCaptured(graph_annotation_id)) {
auto& commands = captured_graphs_[graph_annotation_id];
context_.CaptureBegin(&commands, *per_graph_buffer_mgrs_[graph_annotation_id]);
}
m_current_graph_annotation_id = graph_annotation_id;
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Rename SessionReleaseGraph to SessionReleaseCapturedGraph across the full API chain for clarity. Fix API versioning static_assert for version 26, correct \since tag to 1.27, fix destructor to iterate captured_graphs_ for partial capture cleanup, wire SetDefaultGpuAllocator in adapter path for RefreshBufferManager, add clarifying comment for plugin EP fallback, and add TestReleaseCapturedGraph test in io_binding_test.cc.
Add output verification (GPU→CPU copy + VerifySingleOutput) after each capture and replay run. Call ReleaseCapturedGraph at test end to clean up the re-captured graph. Fix nodiscard warning on AddConfigEntry.
Use MatMul→Relu model so the intermediate tensor exercises the per-graph buffer manager. Change input A between capture and replay runs to verify replay re-executes rather than returning stale output.
Replace two-op model with MatMul→Relu→MatMul to create two intermediate tensors that exercise the per-graph buffer manager more thoroughly.
…in test - Use generic wording in API doc comments (remove GPU-specific references) - Rename m_current_graph_annotation_id to current_graph_annotation_id_ - Use .contains() instead of .count() > 0 - Use try_emplace to avoid double lookup in OnRunStart - Remove allocator caching: single constructor with getter lambda, eliminating RefreshBufferManager, default_gpu_allocator_, and SetDefaultGpuAllocator - Rewrite graph_capture_test.cc using public APIs (Model Editor, Ort::Session, Ort::IoBinding, CopyTensors, SessionReleaseCapturedGraph)
…void redundant lookup
The plugin build defines USE_WEBGPU but cannot instantiate the WebGPU EP at runtime, causing TestReleaseCapturedGraph to fail. Guard the test with !defined(ORT_USE_EP_API_ADAPTERS) so it only compiles for built-in WebGPU EP builds. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add WebGpuPluginRegistration RAII helper and AppendWebGpuEp abstraction so TestReleaseCapturedGraph works in plugin builds (ORT_USE_EP_API_ADAPTERS) by dynamically registering the WebGPU plugin library and using the V2 session options API with EpDevice discovery.
Replace ORT_UNUSED_PARAMETER(env) with (void)env to avoid dependency on ORT internal macros not included in this test file.
…quested The C++ WebGPU EP ApplyFlashAttention does not implement CPU EP semantics for past key/value: when no present_key/present_value outputs are requested, past inputs should be ignored (matching CPU EP behavior where past_sequence_length=0 is used in this case). The non-flash ApplyAttention already handled this correctly at line 603-604. This fix adds the same check in multihead_attention.cc before dispatching to either path, ensuring consistent behavior. This fixes the failing CI test: "[webgpu]MultiHeadAttention - MultiHeadAttention Basic, one head and head-size=4 with pastKey and pastValue" which expected output [9,10,11,12] but got ~[17,18,19,20] because past values dominated attention when flash attention was used. Agent-Logs-Url: https://github.com/microsoft/onnxruntime/sessions/2e159834-89db-46b3-989b-e6effe3d08a1 Co-authored-by: qjia7 <4221210+qjia7@users.noreply.github.com>
Replace GPU-specific terms with generic descriptions per review feedback, keeping GPU commands as an example.
| * | ||
| * \since Version 1.27. | ||
| */ | ||
| ORT_API2_STATUS(SessionReleaseCapturedGraph, _In_ OrtSession* session, _In_ int graph_annotation_id); |
There was a problem hiding this comment.
do we need to account for calls that may happen concurrently with Ort::Session::Run()? in general (though not for the WebGPU EP), Run() itself may be called concurrently.
There was a problem hiding this comment.
Thread safety for concurrent ReleaseCapturedGraph + Run() is an EP implementation detail rather than a C API contract concern. For WebGPU EP, Run() itself is not concurrent (single GPU queue), so this is not a realistic scenario. For EPs that do support concurrent Run() (e.g., CUDA), the EP would need to handle synchronization internally in its ReleaseCapturedGraph implementation — same as other EP callbacks. No other session-mutating APIs (e.g., SetEpDynamicOptions) document thread-safety constraints at the C API level, so adding one here would be inconsistent.
There was a problem hiding this comment.
if an EP reports that ConcurrentRunSupported() is false, ORT will lock a mutex during the call to Ort::Session::Run() to ensure that only one actual run is happening at a time. this doesn't prevent users from calling Run() concurrently.
it may not make sense for an application to make concurrent calls to Run() + ReleaseCapturedGraph() if the WebGPU EP is the only EP to consider, but an application might support more than just the WebGPU EP.
I don't have a good answer to this now, but I think it may be worth some thought. if we provide no synchronization in ORT or EP code, I think the assumption should at least be documented somewhere.
There was a problem hiding this comment.
Done. ReleaseCapturedGraph now acquires session_mutex_ before delegating to the EP. For EPs like WebGPU where ConcurrentRunSupported() returns false, this provides mutual exclusion with Run() since Run() also acquires session_mutex_ in that case. For EPs that support concurrent runs, Run() does not acquire session_mutex_, so additional synchronization would need to be handled at the EP level if needed.
…gths - Add Session::ReleaseCapturedGraph() C++ wrapper and use it in test - Replace raw C API CopyTensors calls with Env::CopyTensor in test - Unify doc wording across all locations to use generic "resources" - Fix lines exceeding 120 chars in allocator, factory, and EP header
Agent-Logs-Url: https://github.com/microsoft/onnxruntime/sessions/c9a7307c-7dc5-4fde-bd62-0c62552730cd Co-authored-by: qjia7 <4221210+qjia7@users.noreply.github.com>
The test created its own Env and registered the WebGPU plugin DLL under a separate name. On cleanup, UnregisterExecutionProviderLibrary unloaded the DLL, leaving the global test infrastructure with dangling pointers. Subsequent tests crashed with SEH 0xc0000005 in ParseEpConfig. Use the global ort_env which already has the plugin registered by the test infrastructure, avoiding the duplicate registration and unload.
|
@skottmckay @edgchen1 All checks have passed. Please take another look when you have time, thanks. |
| if (present_key == nullptr && present_value == nullptr) { | ||
| past_key = nullptr; | ||
| past_value = nullptr; | ||
| parameters.past_sequence_length_ = 0; | ||
| parameters.total_sequence_length_ = parameters.kv_sequence_length_; | ||
| } |
There was a problem hiding this comment.
was this change meant to be a part of this PR? it seems unrelated.
There was a problem hiding this comment.
Yes, this is intentional. It fixes the CI test [webgpu]MultiHeadAttention Basic, one head and head-size=4 with pastKey and pastValue. The fix matches CPU EP semantics: when no present_key/present_value output is requested, past_key/past_value should be ignored. See commit d0369a5 for details.
| static_assert(offsetof(OrtApi, GetTensorElementTypeAndShapeDataReference) / sizeof(void*) == 414, "Size of version 24 API cannot change"); | ||
| static_assert(offsetof(OrtApi, SetPerSessionThreadPoolCallbacks) / sizeof(void*) == 418, "Size of version 25 API cannot change"); | ||
| // no additions in version 26 | ||
| static_assert(offsetof(OrtApi, SessionReleaseCapturedGraph) / sizeof(void*) == 421, "Size of version 27 API cannot change"); |
There was a problem hiding this comment.
don't need to do this now as 1.27 is not finalized yet.
There was a problem hiding this comment.
Done. Removed the static_assert.
| ASSERT_FALSE(webgpu_devices.empty()) << "No WebGPU EP device found after plugin registration"; | ||
| session_options.AppendExecutionProvider_V2(env, webgpu_devices, provider_options); | ||
| #else | ||
| (void)env; |
There was a problem hiding this comment.
| (void)env; | |
| static_cast<void>(env); |
| ORT_ENFORCE(key.length() >= prefix.length() && key.substr(0, prefix.length()) == prefix, | ||
| "Config key \"", key, "\" does not start with expected prefix \"", prefix, "\""); | ||
| return full_key + prefix.length(); | ||
| if (normalized_key.rfind(prefix, 0) == 0) { |
There was a problem hiding this comment.
we can use std::string::starts_with in C++20
There was a problem hiding this comment.
Done. Replaced rfind(prefix, 0) == 0 with starts_with(prefix) at both locations.
| if (normalized_key.rfind(prefix, 0) == 0) { | ||
| normalized_key.erase(0, prefix.length()); | ||
| } | ||
| ORT_ENFORCE(normalized_config_options.AddConfigEntry(normalized_key.c_str(), value.c_str()).IsOK()); |
There was a problem hiding this comment.
nit: ORT_ENFORCE(status.IsOK()) -> ORT_THROW_IF_ERROR(status). this will also include status error information.
There was a problem hiding this comment.
Done. Replaced all three ORT_ENFORCE(...IsOK()) calls with ORT_THROW_IF_ERROR(...).
…ixes - Acquire session_mutex_ in ReleaseCapturedGraph to prevent concurrent access with Run() - Remove premature version 27 static_assert - Use static_cast<void> instead of C-style void cast - Use std::string::starts_with (C++20) instead of rfind workaround - Use ORT_THROW_IF_ERROR instead of ORT_ENFORCE for better error info
Summary
Add per-annotation-ID buffer managers and captured command storage so multiple generators can each capture and replay their own graph independently without cross-contamination
Add ReleaseGraph API through the full ORT stack (EP base → C API → InferenceSession → plugin EP) to release captured commands and GPU buffers when a generator is destroyed
Replace the single graph_buffer_mgr_ / is_graph_captured_ bool with per_graph_buffer_mgrs_ map and captured_graph_ids_ set keyed by annotation ID
Use a std::function getter with cached pointer pattern in GpuBufferAllocator to dynamically route allocations to the active per-graph buffer manager during runs, while keeping Alloc/Free as simple pointer dereferences
Motivation
Edge's Prompt API speed benchmark creates multiple sessions/generators sequentially with graph capture enabled. With the existing single-graph design, the second generator replays the first generator's captured commands with wrong buffers, producing incorrect output and ultimately a QuotaExceededError in the browser. This PR isolates each generator's graph capture state so they don't interfere with each other.
Related PR
The GenAI side change is in microsoft/onnxruntime-genai#2106, which calls SessionReleaseGraph when a generator is destroyed to release the captured graph's GPU buffers.