Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
94a2ca4
Per-graph buffer manager for WebGPU multi-graph capture
qjia7 Apr 29, 2026
bc5cfdb
Fix lintrunner trailing whitespace in factory.cc
qjia7 May 7, 2026
d24af03
Replace SetBufferManager with std::function getter and cached pointer…
qjia7 May 8, 2026
96ce10b
Make buffer manager getter nullable for static allocators
qjia7 May 8, 2026
93e3db6
Remove trailing blank line in constructor body
qjia7 May 8, 2026
cd24850
Address PR #28260 review comments
qjia7 May 15, 2026
9d65cd2
Verify output correctness and cleanup in TestReleaseCapturedGraph
qjia7 May 15, 2026
7417a7a
Improve TestReleaseCapturedGraph with two-op model and input variation
qjia7 May 15, 2026
ff5a405
Use three-op model (MatMul+Relu+MatMul) in TestReleaseCapturedGraph
qjia7 May 15, 2026
92f9fd7
Fix alignment of ReleaseCapturedGraphImpl parameter in ep.h
qjia7 May 15, 2026
2acc028
Fix lintrunner alignment in ep.h
qjia7 May 15, 2026
d9288df
Merge branch 'main' into per-graph-buffer-manager-webgpu
qjia7 May 15, 2026
09b39c9
Address PR review: fix API ordering, test isolation, and annotation I…
qjia7 May 19, 2026
f171fc8
Revert unrelated changes to io_binding_test.cc and ort_version_check.h
qjia7 May 19, 2026
09900dd
Restore Doxygen group closing marker in onnxruntime_c_api.h
qjia7 May 19, 2026
d8226af
Remove extra Doxygen group closing marker before SessionReleaseCaptur…
qjia7 May 19, 2026
f6db1b7
Fix version number in ReleaseCapturedGraph comment (26 -> 27)
qjia7 May 19, 2026
60f1ed1
Address edgchen1 PR review comments
qjia7 May 20, 2026
ea0e3a6
Address edgchen1 round 3 review: simplify allocator, use public APIs …
qjia7 May 21, 2026
086811b
Fix clang-format lint issues in factory.cc and webgpu_execution_provi…
qjia7 May 21, 2026
c805958
Add missing #include <numeric> for std::iota in graph_capture_test
qjia7 May 21, 2026
f9e5ffc
Combine graph annotation guards in OnRunStart and reuse iterator to a…
qjia7 May 21, 2026
b293976
Wrap BufferManager in lambda for plugin EP shared allocator in Create…
qjia7 May 21, 2026
0f61e15
Exclude graph capture test from plugin EP builds
qjia7 May 21, 2026
aa63475
Support graph capture test in both built-in and plugin WebGPU EP builds
qjia7 May 21, 2026
f5992e9
Fix undeclared ORT_UNUSED_PARAMETER in graph capture test
qjia7 May 21, 2026
d0369a5
Fix WebGPU EP MHA to ignore past key/value when no present outputs re…
Copilot May 22, 2026
e7cfaf0
Use generic wording in SessionReleaseCapturedGraph API doc
qjia7 May 22, 2026
bb9ebdd
Address PR review: add C++ API, use generic doc wording, fix line len…
qjia7 May 22, 2026
f0f3b07
Fix WebGPU plugin config key normalization in test provider helper
Copilot May 22, 2026
3d9df67
Use global ort_env in graph capture test to avoid DLL unload
qjia7 May 22, 2026
673ed1f
Address review comments: mutex for ReleaseCapturedGraph, code style f…
qjia7 May 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,14 @@ class IExecutionProvider {
return Status::OK();
}

/**
Release a previously captured graph and its associated resources.
Called when the caller no longer needs the captured graph for the given annotation ID.
*/
virtual common::Status ReleaseCapturedGraph(int /*graph_annotation_id*/) {
return Status::OK();
}

/**
Get the node assignment validation policy for graph capture.
When graph capture is enabled, ORT validates that nodes are assigned to EPs
Expand Down
15 changes: 15 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -7471,6 +7471,21 @@ struct OrtApi {
* \see OrtApi::SetSessionExecutionMode
*/
ORT_API2_STATUS(GetSessionExecutionMode, _In_ const OrtSessionOptions* options, _Out_ ExecutionMode* out);

/** \brief Release a previously captured graph and its associated resources.
*
* When graph capture is enabled, the EP records information during initial runs (e.g., GPU commands)
* and replays them on subsequent runs. This function releases the captured resources for a specific
* graph annotation ID, freeing memory.
*
* \param[in] session The OrtSession instance.
* \param[in] graph_annotation_id The annotation ID of the captured graph to release.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.27.
*/
ORT_API2_STATUS(SessionReleaseCapturedGraph, _In_ OrtSession* session, _In_ int graph_annotation_id);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to account for calls that may happen concurrently with Ort::Session::Run()? in general (though not for the WebGPU EP), Run() itself may be called concurrently.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thread safety for concurrent ReleaseCapturedGraph + Run() is an EP implementation detail rather than a C API contract concern. For WebGPU EP, Run() itself is not concurrent (single GPU queue), so this is not a realistic scenario. For EPs that do support concurrent Run() (e.g., CUDA), the EP would need to handle synchronization internally in its ReleaseCapturedGraph implementation — same as other EP callbacks. No other session-mutating APIs (e.g., SetEpDynamicOptions) document thread-safety constraints at the C API level, so adding one here would be inconsistent.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if an EP reports that ConcurrentRunSupported() is false, ORT will lock a mutex during the call to Ort::Session::Run() to ensure that only one actual run is happening at a time. this doesn't prevent users from calling Run() concurrently.

it may not make sense for an application to make concurrent calls to Run() + ReleaseCapturedGraph() if the WebGPU EP is the only EP to consider, but an application might support more than just the WebGPU EP.

I don't have a good answer to this now, but I think it may be worth some thought. if we provide no synchronization in ORT or EP code, I think the assumption should at least be documented somewhere.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. ReleaseCapturedGraph now acquires session_mutex_ before delegating to the EP. For EPs like WebGPU where ConcurrentRunSupported() returns false, this provides mutual exclusion with Run() since Run() also acquires session_mutex_ in that case. For EPs that support concurrent runs, Run() does not acquire session_mutex_, so additional synchronization would need to be handled at the EP level if needed.

};

/*
Expand Down
8 changes: 8 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2005,6 +2005,14 @@ struct SessionImpl : ConstSessionImpl<T> {

void FinalizeModelEditorSession(const Model& model, const SessionOptions& options,
OrtPrepackedWeightsContainer* prepacked_weights_container = nullptr);

/** \brief Release a previously captured graph.
*
* Wraps OrtApi::SessionReleaseCapturedGraph
*
* \param[in] graph_annotation_id The annotation ID of the captured graph to release.
*/
void ReleaseCapturedGraph(int graph_annotation_id);
};

} // namespace detail
Expand Down
5 changes: 5 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -2107,6 +2107,11 @@ inline void SessionImpl<T>::FinalizeModelEditorSession(const Model& model, const
}
#endif // #if !defined(ORT_MINIMAL_BUILD)

template <typename T>
inline void SessionImpl<T>::ReleaseCapturedGraph(int graph_annotation_id) {
ThrowOnError(GetApi().SessionReleaseCapturedGraph(this->p_, graph_annotation_id));
}

} // namespace detail

inline SessionOptions::SessionOptions() {
Expand Down
17 changes: 17 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_ep_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2567,6 +2567,23 @@ struct OrtEp {
* \since Version 1.27.
*/
ORT_API2_STATUS(OnSessionInitializationEnd, _In_ OrtEp* this_ptr);

/** \brief Release a previously captured graph and its associated resources.
*
* Called when the caller no longer needs the captured graph for the given annotation ID.
* This allows the EP to free buffers and other resources tied to this graph.
*
* \param[in] this_ptr The EP instance.
* \param[in] graph_annotation_id The annotation ID of the graph to release.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
Comment thread
qjia7 marked this conversation as resolved.
*
* \note Implementation of this function is optional. If set to NULL, ORT assumes
* no captured graph release is needed and treats it as a no-op success.
*
* \since Version 1.27.
*/
ORT_API2_STATUS(ReleaseCapturedGraph, _In_ OrtEp* this_ptr, _In_ int graph_annotation_id);
};

/** \brief The function signature that ORT will call to create OrtEpFactory instances.
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,16 @@ Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext&
TensorShape output_qk_shape(output_qk_dims);
Tensor* output_qk = context.Output(3, output_qk_shape);

// Match CPU EP semantics: when no present_key/present_value output is requested,
// ignore past_key/past_value. The CPU EP sets past_sequence_length=0 in this case,
// effectively treating the input as if there is no KV cache.
if (present_key == nullptr && present_value == nullptr) {
past_key = nullptr;
past_value = nullptr;
parameters.past_sequence_length_ = 0;
parameters.total_sequence_length_ = parameters.kv_sequence_length_;
}
Comment on lines +109 to +114
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was this change meant to be a part of this PR? it seems unrelated.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is intentional. It fixes the CI test [webgpu]MultiHeadAttention Basic, one head and head-size=4 with pastKey and pastValue. The fix matches CPU EP semantics: when no present_key/present_value output is requested, past_key/past_value should be ignored. See commit d0369a5 for details.


if (output_qk == nullptr && // Flash attention does not output QK scores
CanApplyFlashAttention(parameters, context)) {
if (bias != nullptr) {
Expand Down
12 changes: 7 additions & 5 deletions onnxruntime/core/providers/webgpu/allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,17 @@
namespace onnxruntime {
namespace webgpu {

GpuBufferAllocator::GpuBufferAllocator(const BufferManager& buffer_manager, bool is_read_only_allocator)
GpuBufferAllocator::GpuBufferAllocator(
std::function<const BufferManager&()> buffer_manager_getter,
bool is_read_only_allocator)
: IAllocator(
OrtMemoryInfo(WEBGPU_BUFFER,
is_read_only_allocator ? OrtAllocatorType::OrtReadOnlyAllocator
: OrtAllocatorType::OrtDeviceAllocator,
WebGpuDevice,
OrtMemTypeDefault)),
buffer_manager_{buffer_manager},
mapped_at_creation_{is_read_only_allocator && buffer_manager.SupportsUMA()} {
buffer_manager_getter_{std::move(buffer_manager_getter)},

Check warning on line 20 in onnxruntime/core/providers/webgpu/allocator.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/allocator.cc:20: Add #include <utility> for move [build/include_what_you_use] [4]
mapped_at_creation_{is_read_only_allocator && buffer_manager_getter_().SupportsUMA()} {
}

void* GpuBufferAllocator::Alloc(size_t size) {
Expand All @@ -29,12 +31,12 @@
wgpu::BufferUsage usage = mapped_at_creation_ ? wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapWrite
: wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Indirect;

return buffer_manager_.Create(size, usage);
return buffer_manager_getter_().Create(size, usage);
}

void GpuBufferAllocator::Free(void* p) {
if (p != nullptr) {
buffer_manager_.Release(static_cast<WGPUBuffer>(p));
buffer_manager_getter_().Release(static_cast<WGPUBuffer>(p));
stats_.num_allocs--;
}
}
Expand Down
9 changes: 7 additions & 2 deletions onnxruntime/core/providers/webgpu/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

#pragma once

#include <functional>

#include "core/framework/allocator.h"
#include "core/framework/ortdevice.h"

Expand All @@ -18,15 +20,18 @@ inline constexpr OrtDevice WebGpuDevice{OrtDevice::GPU,

class GpuBufferAllocator : public IAllocator {
public:
GpuBufferAllocator(const BufferManager& buffer_manager, bool is_read_only_allocator);
// Calls buffer_manager_getter on every Alloc/Free to obtain the current
// BufferManager. This allows the EP to route allocations to different
// buffer managers (e.g., per-graph) without explicit refresh calls.
GpuBufferAllocator(std::function<const BufferManager&()> buffer_manager_getter, bool is_read_only_allocator);

virtual void* Alloc(size_t size) override;
virtual void Free(void* p) override;
void GetStats(AllocatorStats* stats) override;

private:
AllocatorStats stats_;
const BufferManager& buffer_manager_;
std::function<const BufferManager&()> buffer_manager_getter_;
bool mapped_at_creation_;
};

Expand Down
13 changes: 13 additions & 0 deletions onnxruntime/core/providers/webgpu/ep/ep.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ Ep::Ep(std::unique_ptr<IExecutionProvider> impl, Factory& factory, const OrtLogg
IsGraphCaptureEnabled = IsGraphCaptureEnabledImpl;
IsGraphCaptured = IsGraphCapturedImpl;
ReplayGraph = ReplayGraphImpl;
ReleaseCapturedGraph = ReleaseCapturedGraphImpl;
GetGraphCaptureNodeAssignmentPolicy = GetGraphCaptureNodeAssignmentPolicyImpl;
}

Expand Down Expand Up @@ -279,6 +280,18 @@ OrtStatus* ORT_API_CALL Ep::ReplayGraphImpl(_In_ OrtEp* this_ptr, _In_ int graph
EXCEPTION_TO_RETURNED_STATUS_END
}

OrtStatus* ORT_API_CALL Ep::ReleaseCapturedGraphImpl(_In_ OrtEp* this_ptr, _In_ int graph_annotation_id) noexcept {
EXCEPTION_TO_RETURNED_STATUS_BEGIN
auto* ep = static_cast<Ep*>(this_ptr);
auto status = ep->EpImpl()->ReleaseCapturedGraph(graph_annotation_id);
if (!status.IsOK()) {
return Api().ort.CreateStatus(static_cast<OrtErrorCode>(status.Code()),
status.ErrorMessage().c_str());
}
return nullptr;
EXCEPTION_TO_RETURNED_STATUS_END
}

OrtGraphCaptureNodeAssignmentPolicy ORT_API_CALL Ep::GetGraphCaptureNodeAssignmentPolicyImpl(
_In_ const OrtEp* this_ptr) noexcept {
auto* ep = static_cast<const Ep*>(this_ptr);
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/providers/webgpu/ep/ep.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ class Ep : public onnxruntime::ep::adapter::Ep {
static OrtStatus* ORT_API_CALL ReplayGraphImpl(_In_ OrtEp* this_ptr,
_In_ int graph_annotation_id) noexcept;

static OrtStatus* ORT_API_CALL ReleaseCapturedGraphImpl(_In_ OrtEp* this_ptr,
_In_ int graph_annotation_id) noexcept;

static OrtGraphCaptureNodeAssignmentPolicy ORT_API_CALL GetGraphCaptureNodeAssignmentPolicyImpl(
_In_ const OrtEp* this_ptr) noexcept;

Expand Down
21 changes: 17 additions & 4 deletions onnxruntime/core/providers/webgpu/ep/factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "core/framework/execution_provider.h"
#include "core/framework/config_options.h"
#include "core/providers/webgpu/webgpu_provider_factory_creator.h"
#include "core/providers/webgpu/webgpu_execution_provider.h"
#include "core/providers/webgpu/webgpu_context.h"
#include "core/providers/webgpu/allocator.h"

Expand Down Expand Up @@ -134,10 +135,17 @@
static_cast<WebGpuExecutionProvider*>(webgpu_ep.get())->SetEpLogger(logger);
auto factory = static_cast<Factory*>(this_ptr);
const int context_id = webgpu_ep->GetDeviceId();
auto* webgpu_ep_ptr = static_cast<WebGpuExecutionProvider*>(webgpu_ep.get());
auto device_alloc = std::make_shared<webgpu::GpuBufferAllocator>(
[webgpu_ep_ptr]() -> const webgpu::BufferManager& { return webgpu_ep_ptr->BufferManager(); }, false);
Ep::Config webgpu_ep_config{
CPUAllocator::DefaultInstance(), // CPU allocator
std::make_shared<webgpu::GpuBufferAllocator>(WebGpuContextFactory::GetContext(context_id).BufferManager(), false), // default device allocator
std::make_shared<webgpu::GpuBufferAllocator>(WebGpuContextFactory::GetContext(context_id).InitializerBufferManager(), true), // initializer device allocator
CPUAllocator::DefaultInstance(), // CPU allocator
device_alloc, // default device allocator
std::make_shared<webgpu::GpuBufferAllocator>(
[context_id]() -> const webgpu::BufferManager& {
return WebGpuContextFactory::GetContext(context_id).InitializerBufferManager();
},
true), // initializer device allocator
};
*ep = new Ep(std::move(webgpu_ep), *factory, *logger, webgpu_ep_config);
return nullptr;
Expand Down Expand Up @@ -165,7 +173,12 @@

*allocator = new onnxruntime::ep::adapter::Allocator(memory_info,
[](const OrtMemoryInfo&) -> AllocatorPtr {
return std::make_shared<webgpu::GpuBufferAllocator>(WebGpuContextFactory::DefaultContext().BufferManager(), false);
return std::make_shared<webgpu::GpuBufferAllocator>(

Check warning on line 176 in onnxruntime/core/providers/webgpu/ep/factory.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for make_shared<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/ep/factory.cc:176: Add #include <memory> for make_shared<> [build/include_what_you_use] [4]
[]() -> const webgpu::BufferManager& {
return WebGpuContextFactory::DefaultContext()
.BufferManager();
},
false);
});
return nullptr;
EXCEPTION_TO_RETURNED_STATUS_END
Expand Down
Loading
Loading