Skip to content

[webgpu] Per-graph buffer manager for WebGPU multi-graph capture#28260

Open
qjia7 wants to merge 32 commits into
mainfrom
per-graph-buffer-manager-webgpu
Open

[webgpu] Per-graph buffer manager for WebGPU multi-graph capture#28260
qjia7 wants to merge 32 commits into
mainfrom
per-graph-buffer-manager-webgpu

Conversation

@qjia7
Copy link
Copy Markdown
Contributor

@qjia7 qjia7 commented Apr 29, 2026

Summary

Add per-annotation-ID buffer managers and captured command storage so multiple generators can each capture and replay their own graph independently without cross-contamination
Add ReleaseGraph API through the full ORT stack (EP base → C API → InferenceSession → plugin EP) to release captured commands and GPU buffers when a generator is destroyed
Replace the single graph_buffer_mgr_ / is_graph_captured_ bool with per_graph_buffer_mgrs_ map and captured_graph_ids_ set keyed by annotation ID
Use a std::function getter with cached pointer pattern in GpuBufferAllocator to dynamically route allocations to the active per-graph buffer manager during runs, while keeping Alloc/Free as simple pointer dereferences

Motivation

Edge's Prompt API speed benchmark creates multiple sessions/generators sequentially with graph capture enabled. With the existing single-graph design, the second generator replays the first generator's captured commands with wrong buffers, producing incorrect output and ultimately a QuotaExceededError in the browser. This PR isolates each generator's graph capture state so they don't interfere with each other.

Related PR

The GenAI side change is in microsoft/onnxruntime-genai#2106, which calls SessionReleaseGraph when a generator is destroyed to release the captured graph's GPU buffers.

Copy link
Copy Markdown
Contributor

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment thread onnxruntime/core/providers/webgpu/ep/factory.cc Outdated
qjia7 added 2 commits May 7, 2026 14:13
Replace the single shared graph_buffer_mgr_ with per-annotation-ID buffer
managers so that multiple generators with different prompts each get isolated
buffer caches. This prevents cross-generator buffer corruption when graph
capture is enabled.

Changes:
- Add per_graph_buffer_mgrs_ map keyed by graph annotation ID
- Route BufferManager() to the correct per-graph instance during capture/replay
- Change GpuBufferAllocator from const BufferManager& to pointer with
  SetBufferManager() for dynamic routing
- Add ReleaseGraph API through the full stack (EP base class, C API,
  InferenceSession, plugin EP, EP adapter) so callers can free captured
  graph resources when a generator is destroyed
- WebGPU EP ReleaseGraph cleans up captured commands, buffer manager,
  and tracking state for the given annotation ID
… pattern

Use RefreshBufferManager to update the cached buffer manager pointer from a
getter lambda, keeping Alloc/Free as simple pointer dereferences while routing
logic stays in EP::BufferManager() as the single source of truth. Remove
graph_default_buffer_mgr_ and set min_num_runs_before_graph_capture_ to 0.
Copy link
Copy Markdown
Contributor

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment thread onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc Outdated
qjia7 added 2 commits May 8, 2026 15:20
Allocators constructed with a direct BufferManager reference store a
plain pointer without wrapping it in a std::function. RefreshBufferManager
becomes a no-op for these allocators, avoiding unnecessary lambda overhead
for initializer and shared allocators that never need dynamic routing.
@qjia7 qjia7 marked this pull request as ready for review May 8, 2026 07:38
@qjia7 qjia7 requested review from Copilot, edgchen1 and guschmue May 8, 2026 10:05
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends WebGPU graph-capture support to safely handle multiple captured graphs (keyed by graph annotation ID) and adds a new end-to-end API to explicitly release captured graph resources (captured command lists + GPU buffers) when a generator/session is torn down. It aims to prevent cross-contamination between sequential generators/sessions that capture and replay different graphs.

Changes:

  • Introduces per-annotation-ID captured command storage and per-graph buffer managers in the WebGPU EP to isolate graph-capture state.
  • Adds a ReleaseGraph(graph_annotation_id) API through the ORT stack (EP base → plugin EP adapter → InferenceSession → C API).
  • Updates WebGPU allocators to support dynamically routing allocations to the active per-graph buffer manager using a cached-pointer + refresh pattern.

Reviewed changes

Copilot reviewed 17 out of 17 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h Adds ReleaseGraph override to plugin EP wrapper.
onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc Implements plugin EP ReleaseGraph with version/NULL checks and fallback.
onnxruntime/core/session/ort_version_check.h Adds a comment explaining a constexpr/MSVC limitation.
onnxruntime/core/session/ort_apis.h Declares OrtApis::SessionReleaseGraph.
onnxruntime/core/session/onnxruntime_c_api.cc Implements SessionReleaseGraph and wires it into the exported OrtApi table.
onnxruntime/core/session/inference_session.h Adds InferenceSession::ReleaseGraph API surface.
onnxruntime/core/session/inference_session.cc Implements InferenceSession::ReleaseGraph forwarding to cached EP.
onnxruntime/core/providers/webgpu/webgpu_execution_provider.h Replaces single-graph capture state with per-graph maps/sets; adds ReleaseGraph; adds allocator routing state.
onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc Implements per-graph capture/replay/release and allocator buffer-manager routing.
onnxruntime/core/providers/webgpu/ep/factory.cc Switches adapter-path device allocator to use a buffer-manager getter lambda.
onnxruntime/core/providers/webgpu/ep/ep.h Adds EP-adapter function pointer for ReleaseGraph.
onnxruntime/core/providers/webgpu/ep/ep.cc Wires ReleaseGraph into the EP-adapter vtable.
onnxruntime/core/providers/webgpu/allocator.h Adds getter-based allocator constructor and RefreshBufferManager.
onnxruntime/core/providers/webgpu/allocator.cc Implements cached-pointer refresh logic and switches Alloc/Free to pointer dereferences.
include/onnxruntime/core/session/onnxruntime_ep_c_api.h Extends OrtEp C API with ReleaseGraph.
include/onnxruntime/core/session/onnxruntime_c_api.h Extends OrtApi C API with SessionReleaseGraph.
include/onnxruntime/core/framework/execution_provider.h Adds default virtual IExecutionProvider::ReleaseGraph.
Comments suppressed due to low confidence (2)

onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc:786

  • The cached BufferManager refresh happens before m_current_graph_annotation_id is updated for this run. Since RefreshBufferManager() consults WebGpuExecutionProvider::BufferManager() (which uses m_current_graph_annotation_id), the allocator can cache the previous run’s per-graph buffer manager and route allocations to the wrong graph. Set m_current_graph_annotation_id (and any other state that BufferManager() depends on) before calling RefreshBufferManager().
      graph_buffer_mgr_active_ = true;
      if (default_gpu_allocator_) {
        default_gpu_allocator_->RefreshBufferManager();
      }

onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc:793

  • IsGraphCaptureAllowed() is called before m_current_graph_annotation_id is set to graph_annotation_id, but IsGraphCaptureAllowed() reads m_current_graph_annotation_id to look up the per-graph run count. This means the capture decision can be made using the previous graph’s run count instead of the current graph’s. Update m_current_graph_annotation_id before this check (or change IsGraphCaptureAllowed to take graph_annotation_id).
    if (graph_annotation_id != -1 && IsGraphCaptureAllowed() && !IsGraphCaptured(graph_annotation_id)) {
      auto& commands = captured_graphs_[graph_annotation_id];
      context_.CaptureBegin(&commands, *per_graph_buffer_mgrs_[graph_annotation_id]);
    }
    m_current_graph_annotation_id = graph_annotation_id;

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread onnxruntime/core/session/onnxruntime_c_api.cc
Comment thread onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc
Comment thread onnxruntime/core/providers/webgpu/ep/factory.cc Outdated
Comment thread include/onnxruntime/core/session/onnxruntime_c_api.h Outdated
@guschmue guschmue added the ep:WebGPU ort-web webgpu provider label May 13, 2026
Comment thread include/onnxruntime/core/session/onnxruntime_c_api.h Outdated
Comment thread onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc Outdated
Comment thread include/onnxruntime/core/session/onnxruntime_c_api.h Outdated
@qjia7 qjia7 marked this pull request as draft May 14, 2026 10:00
Rename SessionReleaseGraph to SessionReleaseCapturedGraph across the full
API chain for clarity. Fix API versioning static_assert for version 26,
correct \since tag to 1.27, fix destructor to iterate captured_graphs_
for partial capture cleanup, wire SetDefaultGpuAllocator in adapter path
for RefreshBufferManager, add clarifying comment for plugin EP fallback,
and add TestReleaseCapturedGraph test in io_binding_test.cc.
Copy link
Copy Markdown
Contributor

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment thread onnxruntime/core/providers/webgpu/ep/ep.h Outdated
qjia7 added 4 commits May 15, 2026 09:52
Add output verification (GPU→CPU copy + VerifySingleOutput) after each
capture and replay run. Call ReleaseCapturedGraph at test end to clean up
the re-captured graph. Fix nodiscard warning on AddConfigEntry.
Use MatMul→Relu model so the intermediate tensor exercises the per-graph
buffer manager. Change input A between capture and replay runs to verify
replay re-executes rather than returning stale output.
Replace two-op model with MatMul→Relu→MatMul to create two intermediate
tensors that exercise the per-graph buffer manager more thoroughly.
Copy link
Copy Markdown
Contributor

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment thread onnxruntime/core/providers/webgpu/ep/ep.h Outdated
@qjia7 qjia7 marked this pull request as ready for review May 15, 2026 02:35
@qjia7 qjia7 requested a review from skottmckay May 16, 2026 02:27
Comment thread onnxruntime/test/providers/io_binding_test.cc Outdated
Comment thread include/onnxruntime/core/session/onnxruntime_c_api.h Outdated
Comment thread onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc Outdated
Comment thread onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc Outdated
Comment thread onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc Outdated
Comment thread onnxruntime/core/providers/webgpu/allocator.h Outdated
Comment thread onnxruntime/core/providers/webgpu/webgpu_execution_provider.h Outdated
Comment thread onnxruntime/test/providers/graph_capture_test.cc Outdated
…in test

- Use generic wording in API doc comments (remove GPU-specific references)
- Rename m_current_graph_annotation_id to current_graph_annotation_id_
- Use .contains() instead of .count() > 0
- Use try_emplace to avoid double lookup in OnRunStart
- Remove allocator caching: single constructor with getter lambda,
  eliminating RefreshBufferManager, default_gpu_allocator_, and
  SetDefaultGpuAllocator
- Rewrite graph_capture_test.cc using public APIs (Model Editor,
  Ort::Session, Ort::IoBinding, CopyTensors, SessionReleaseCapturedGraph)
Copy link
Copy Markdown
Contributor

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

Comment thread onnxruntime/core/providers/webgpu/ep/factory.cc Outdated
Comment thread onnxruntime/core/providers/webgpu/webgpu_execution_provider.h Outdated
@qjia7 qjia7 requested a review from edgchen1 May 21, 2026 02:46
qjia7 and others added 6 commits May 21, 2026 11:01
The plugin build defines USE_WEBGPU but cannot instantiate the
WebGPU EP at runtime, causing TestReleaseCapturedGraph to fail.
Guard the test with !defined(ORT_USE_EP_API_ADAPTERS) so it only
compiles for built-in WebGPU EP builds.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add WebGpuPluginRegistration RAII helper and AppendWebGpuEp abstraction
so TestReleaseCapturedGraph works in plugin builds (ORT_USE_EP_API_ADAPTERS)
by dynamically registering the WebGPU plugin library and using the V2
session options API with EpDevice discovery.
Replace ORT_UNUSED_PARAMETER(env) with (void)env to avoid dependency
on ORT internal macros not included in this test file.
Comment thread include/onnxruntime/core/session/onnxruntime_c_api.h Outdated
…quested

The C++ WebGPU EP ApplyFlashAttention does not implement CPU EP semantics
for past key/value: when no present_key/present_value outputs are requested,
past inputs should be ignored (matching CPU EP behavior where
past_sequence_length=0 is used in this case).

The non-flash ApplyAttention already handled this correctly at line 603-604.
This fix adds the same check in multihead_attention.cc before dispatching to
either path, ensuring consistent behavior.

This fixes the failing CI test: "[webgpu]MultiHeadAttention - MultiHeadAttention
Basic, one head and head-size=4 with pastKey and pastValue" which expected
output [9,10,11,12] but got ~[17,18,19,20] because past values dominated
attention when flash attention was used.

Agent-Logs-Url: https://github.com/microsoft/onnxruntime/sessions/2e159834-89db-46b3-989b-e6effe3d08a1

Co-authored-by: qjia7 <4221210+qjia7@users.noreply.github.com>
Replace GPU-specific terms with generic descriptions per review
feedback, keeping GPU commands as an example.
@qjia7 qjia7 requested a review from edgchen1 May 22, 2026 01:41
Comment thread include/onnxruntime/core/session/onnxruntime_c_api.h Outdated
Comment thread onnxruntime/core/providers/webgpu/ep/factory.cc Outdated
Comment thread onnxruntime/test/providers/graph_capture_test.cc Outdated
*
* \since Version 1.27.
*/
ORT_API2_STATUS(SessionReleaseCapturedGraph, _In_ OrtSession* session, _In_ int graph_annotation_id);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to account for calls that may happen concurrently with Ort::Session::Run()? in general (though not for the WebGPU EP), Run() itself may be called concurrently.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thread safety for concurrent ReleaseCapturedGraph + Run() is an EP implementation detail rather than a C API contract concern. For WebGPU EP, Run() itself is not concurrent (single GPU queue), so this is not a realistic scenario. For EPs that do support concurrent Run() (e.g., CUDA), the EP would need to handle synchronization internally in its ReleaseCapturedGraph implementation — same as other EP callbacks. No other session-mutating APIs (e.g., SetEpDynamicOptions) document thread-safety constraints at the C API level, so adding one here would be inconsistent.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if an EP reports that ConcurrentRunSupported() is false, ORT will lock a mutex during the call to Ort::Session::Run() to ensure that only one actual run is happening at a time. this doesn't prevent users from calling Run() concurrently.

it may not make sense for an application to make concurrent calls to Run() + ReleaseCapturedGraph() if the WebGPU EP is the only EP to consider, but an application might support more than just the WebGPU EP.

I don't have a good answer to this now, but I think it may be worth some thought. if we provide no synchronization in ORT or EP code, I think the assumption should at least be documented somewhere.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. ReleaseCapturedGraph now acquires session_mutex_ before delegating to the EP. For EPs like WebGPU where ConcurrentRunSupported() returns false, this provides mutual exclusion with Run() since Run() also acquires session_mutex_ in that case. For EPs that support concurrent runs, Run() does not acquire session_mutex_, so additional synchronization would need to be handled at the EP level if needed.

…gths

- Add Session::ReleaseCapturedGraph() C++ wrapper and use it in test
- Replace raw C API CopyTensors calls with Env::CopyTensor in test
- Unify doc wording across all locations to use generic "resources"
- Fix lines exceeding 120 chars in allocator, factory, and EP header
Copilot finished work on behalf of qjia7 May 22, 2026 04:55
The test created its own Env and registered the WebGPU plugin DLL under
a separate name. On cleanup, UnregisterExecutionProviderLibrary unloaded
the DLL, leaving the global test infrastructure with dangling pointers.
Subsequent tests crashed with SEH 0xc0000005 in ParseEpConfig.

Use the global ort_env which already has the plugin registered by the
test infrastructure, avoiding the duplicate registration and unload.
@qjia7
Copy link
Copy Markdown
Contributor Author

qjia7 commented May 23, 2026

@skottmckay @edgchen1 All checks have passed. Please take another look when you have time, thanks.

Comment on lines +109 to +114
if (present_key == nullptr && present_value == nullptr) {
past_key = nullptr;
past_value = nullptr;
parameters.past_sequence_length_ = 0;
parameters.total_sequence_length_ = parameters.kv_sequence_length_;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was this change meant to be a part of this PR? it seems unrelated.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is intentional. It fixes the CI test [webgpu]MultiHeadAttention Basic, one head and head-size=4 with pastKey and pastValue. The fix matches CPU EP semantics: when no present_key/present_value output is requested, past_key/past_value should be ignored. See commit d0369a5 for details.

static_assert(offsetof(OrtApi, GetTensorElementTypeAndShapeDataReference) / sizeof(void*) == 414, "Size of version 24 API cannot change");
static_assert(offsetof(OrtApi, SetPerSessionThreadPoolCallbacks) / sizeof(void*) == 418, "Size of version 25 API cannot change");
// no additions in version 26
static_assert(offsetof(OrtApi, SessionReleaseCapturedGraph) / sizeof(void*) == 421, "Size of version 27 API cannot change");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't need to do this now as 1.27 is not finalized yet.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Removed the static_assert.

ASSERT_FALSE(webgpu_devices.empty()) << "No WebGPU EP device found after plugin registration";
session_options.AppendExecutionProvider_V2(env, webgpu_devices, provider_options);
#else
(void)env;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
(void)env;
static_cast<void>(env);

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

ORT_ENFORCE(key.length() >= prefix.length() && key.substr(0, prefix.length()) == prefix,
"Config key \"", key, "\" does not start with expected prefix \"", prefix, "\"");
return full_key + prefix.length();
if (normalized_key.rfind(prefix, 0) == 0) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can use std::string::starts_with in C++20

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Replaced rfind(prefix, 0) == 0 with starts_with(prefix) at both locations.

if (normalized_key.rfind(prefix, 0) == 0) {
normalized_key.erase(0, prefix.length());
}
ORT_ENFORCE(normalized_config_options.AddConfigEntry(normalized_key.c_str(), value.c_str()).IsOK());
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: ORT_ENFORCE(status.IsOK()) -> ORT_THROW_IF_ERROR(status). this will also include status error information.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Replaced all three ORT_ENFORCE(...IsOK()) calls with ORT_THROW_IF_ERROR(...).

…ixes

- Acquire session_mutex_ in ReleaseCapturedGraph to prevent concurrent
  access with Run()
- Remove premature version 27 static_assert
- Use static_cast<void> instead of C-style void cast
- Use std::string::starts_with (C++20) instead of rfind workaround
- Use ORT_THROW_IF_ERROR instead of ORT_ENFORCE for better error info
@qjia7 qjia7 requested a review from edgchen1 May 27, 2026 03:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ep:WebGPU ort-web webgpu provider

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants