Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
94a2ca4
Per-graph buffer manager for WebGPU multi-graph capture
qjia7 Apr 29, 2026
bc5cfdb
Fix lintrunner trailing whitespace in factory.cc
qjia7 May 7, 2026
d24af03
Replace SetBufferManager with std::function getter and cached pointer…
qjia7 May 8, 2026
96ce10b
Make buffer manager getter nullable for static allocators
qjia7 May 8, 2026
93e3db6
Remove trailing blank line in constructor body
qjia7 May 8, 2026
cd24850
Address PR #28260 review comments
qjia7 May 15, 2026
9d65cd2
Verify output correctness and cleanup in TestReleaseCapturedGraph
qjia7 May 15, 2026
7417a7a
Improve TestReleaseCapturedGraph with two-op model and input variation
qjia7 May 15, 2026
ff5a405
Use three-op model (MatMul+Relu+MatMul) in TestReleaseCapturedGraph
qjia7 May 15, 2026
92f9fd7
Fix alignment of ReleaseCapturedGraphImpl parameter in ep.h
qjia7 May 15, 2026
2acc028
Fix lintrunner alignment in ep.h
qjia7 May 15, 2026
d9288df
Merge branch 'main' into per-graph-buffer-manager-webgpu
qjia7 May 15, 2026
09b39c9
Address PR review: fix API ordering, test isolation, and annotation I…
qjia7 May 19, 2026
f171fc8
Revert unrelated changes to io_binding_test.cc and ort_version_check.h
qjia7 May 19, 2026
09900dd
Restore Doxygen group closing marker in onnxruntime_c_api.h
qjia7 May 19, 2026
d8226af
Remove extra Doxygen group closing marker before SessionReleaseCaptur…
qjia7 May 19, 2026
f6db1b7
Fix version number in ReleaseCapturedGraph comment (26 -> 27)
qjia7 May 19, 2026
60f1ed1
Address edgchen1 PR review comments
qjia7 May 20, 2026
ea0e3a6
Address edgchen1 round 3 review: simplify allocator, use public APIs …
qjia7 May 21, 2026
086811b
Fix clang-format lint issues in factory.cc and webgpu_execution_provi…
qjia7 May 21, 2026
c805958
Add missing #include <numeric> for std::iota in graph_capture_test
qjia7 May 21, 2026
f9e5ffc
Combine graph annotation guards in OnRunStart and reuse iterator to a…
qjia7 May 21, 2026
b293976
Wrap BufferManager in lambda for plugin EP shared allocator in Create…
qjia7 May 21, 2026
0f61e15
Exclude graph capture test from plugin EP builds
qjia7 May 21, 2026
aa63475
Support graph capture test in both built-in and plugin WebGPU EP builds
qjia7 May 21, 2026
f5992e9
Fix undeclared ORT_UNUSED_PARAMETER in graph capture test
qjia7 May 21, 2026
d0369a5
Fix WebGPU EP MHA to ignore past key/value when no present outputs re…
Copilot May 22, 2026
e7cfaf0
Use generic wording in SessionReleaseCapturedGraph API doc
qjia7 May 22, 2026
bb9ebdd
Address PR review: add C++ API, use generic doc wording, fix line len…
qjia7 May 22, 2026
f0f3b07
Fix WebGPU plugin config key normalization in test provider helper
Copilot May 22, 2026
3d9df67
Use global ort_env in graph capture test to avoid DLL unload
qjia7 May 22, 2026
673ed1f
Address review comments: mutex for ReleaseCapturedGraph, code style f…
qjia7 May 27, 2026
991c777
Make ReleaseCapturedGraph locking conditional to match Run() pattern
qjia7 May 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,14 @@ class IExecutionProvider {
return Status::OK();
}

/**
Release a previously captured graph and its associated resources.
Called when the caller no longer needs the captured graph for the given annotation ID.
*/
virtual common::Status ReleaseCapturedGraph(int /*graph_annotation_id*/) {
return Status::OK();
}

/**
Get the node assignment validation policy for graph capture.
When graph capture is enabled, ORT validates that nodes are assigned to EPs
Expand Down
15 changes: 15 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -7471,6 +7471,21 @@ struct OrtApi {
* \see OrtApi::SetSessionExecutionMode
*/
ORT_API2_STATUS(GetSessionExecutionMode, _In_ const OrtSessionOptions* options, _Out_ ExecutionMode* out);

/** \brief Release a previously captured graph and its associated GPU resources.
*
* When graph capture is enabled, the EP records GPU commands during initial runs and replays them
Comment thread
qjia7 marked this conversation as resolved.
Outdated
* on subsequent runs. This function releases the captured commands and associated GPU buffers
* for a specific graph annotation ID, freeing GPU memory.
Comment thread
edgchen1 marked this conversation as resolved.
Outdated
*
* \param[in] session The OrtSession instance.
* \param[in] graph_annotation_id The annotation ID of the captured graph to release.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
Comment thread
qjia7 marked this conversation as resolved.
Outdated
*
* \since Version 1.27.
*/
ORT_API2_STATUS(SessionReleaseCapturedGraph, _In_ OrtSession* session, _In_ int graph_annotation_id);
Comment thread
qjia7 marked this conversation as resolved.
};

/*
Expand Down
14 changes: 14 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_ep_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2567,6 +2567,20 @@ struct OrtEp {
* \since Version 1.27.
*/
ORT_API2_STATUS(OnSessionInitializationEnd, _In_ OrtEp* this_ptr);

/** \brief Release a previously captured graph and its associated resources.
*
* Called when the caller no longer needs the captured graph for the given annotation ID.
* This allows the EP to free GPU buffers and other resources tied to this graph.
*
* \param[in] this_ptr The EP instance.
* \param[in] graph_annotation_id The annotation ID of the graph to release.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
Comment thread
qjia7 marked this conversation as resolved.
*
* \since Version 1.27.
*/
ORT_API2_STATUS(ReleaseCapturedGraph, _In_ OrtEp* this_ptr, _In_ int graph_annotation_id);
};

/** \brief The function signature that ORT will call to create OrtEpFactory instances.
Expand Down
25 changes: 22 additions & 3 deletions onnxruntime/core/providers/webgpu/allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,29 @@
: OrtAllocatorType::OrtDeviceAllocator,
WebGpuDevice,
OrtMemTypeDefault)),
buffer_manager_{buffer_manager},
buffer_manager_getter_{},
buffer_manager_{&buffer_manager},
mapped_at_creation_{is_read_only_allocator && buffer_manager.SupportsUMA()} {
}

GpuBufferAllocator::GpuBufferAllocator(std::function<const BufferManager&()> buffer_manager_getter, bool is_read_only_allocator)
: IAllocator(
OrtMemoryInfo(WEBGPU_BUFFER,
is_read_only_allocator ? OrtAllocatorType::OrtReadOnlyAllocator
: OrtAllocatorType::OrtDeviceAllocator,
WebGpuDevice,
OrtMemTypeDefault)),
buffer_manager_getter_{std::move(buffer_manager_getter)},

Check warning on line 30 in onnxruntime/core/providers/webgpu/allocator.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/allocator.cc:30: Add #include <utility> for move [build/include_what_you_use] [4]
buffer_manager_{&buffer_manager_getter_()},
mapped_at_creation_{is_read_only_allocator && buffer_manager_->SupportsUMA()} {
}

void GpuBufferAllocator::RefreshBufferManager() {
if (buffer_manager_getter_) {
buffer_manager_ = &buffer_manager_getter_();
}
}

void* GpuBufferAllocator::Alloc(size_t size) {
if (size == 0) {
return nullptr;
Expand All @@ -29,12 +48,12 @@
wgpu::BufferUsage usage = mapped_at_creation_ ? wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapWrite
: wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Indirect;

return buffer_manager_.Create(size, usage);
return buffer_manager_->Create(size, usage);
}

void GpuBufferAllocator::Free(void* p) {
if (p != nullptr) {
buffer_manager_.Release(static_cast<WGPUBuffer>(p));
buffer_manager_->Release(static_cast<WGPUBuffer>(p));
stats_.num_allocs--;
}
}
Expand Down
10 changes: 9 additions & 1 deletion onnxruntime/core/providers/webgpu/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

#pragma once

#include <functional>

#include "core/framework/allocator.h"
#include "core/framework/ortdevice.h"

Expand All @@ -19,14 +21,20 @@ inline constexpr OrtDevice WebGpuDevice{OrtDevice::GPU,
class GpuBufferAllocator : public IAllocator {
public:
GpuBufferAllocator(const BufferManager& buffer_manager, bool is_read_only_allocator);
GpuBufferAllocator(std::function<const BufferManager&()> buffer_manager_getter, bool is_read_only_allocator);

// Re-reads the buffer manager from the getter and caches it.
// No-op if constructed with a direct BufferManager reference (no getter).
void RefreshBufferManager();

virtual void* Alloc(size_t size) override;
virtual void Free(void* p) override;
void GetStats(AllocatorStats* stats) override;

private:
AllocatorStats stats_;
const BufferManager& buffer_manager_;
std::function<const BufferManager&()> buffer_manager_getter_; // may be empty
const BufferManager* buffer_manager_; // cached from getter, or direct pointer
Comment thread
edgchen1 marked this conversation as resolved.
Outdated
bool mapped_at_creation_;
};

Expand Down
13 changes: 13 additions & 0 deletions onnxruntime/core/providers/webgpu/ep/ep.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ Ep::Ep(std::unique_ptr<IExecutionProvider> impl, Factory& factory, const OrtLogg
IsGraphCaptureEnabled = IsGraphCaptureEnabledImpl;
IsGraphCaptured = IsGraphCapturedImpl;
ReplayGraph = ReplayGraphImpl;
ReleaseCapturedGraph = ReleaseCapturedGraphImpl;
GetGraphCaptureNodeAssignmentPolicy = GetGraphCaptureNodeAssignmentPolicyImpl;
}

Expand Down Expand Up @@ -279,6 +280,18 @@ OrtStatus* ORT_API_CALL Ep::ReplayGraphImpl(_In_ OrtEp* this_ptr, _In_ int graph
EXCEPTION_TO_RETURNED_STATUS_END
}

OrtStatus* ORT_API_CALL Ep::ReleaseCapturedGraphImpl(_In_ OrtEp* this_ptr, _In_ int graph_annotation_id) noexcept {
EXCEPTION_TO_RETURNED_STATUS_BEGIN
auto* ep = static_cast<Ep*>(this_ptr);
auto status = ep->EpImpl()->ReleaseCapturedGraph(graph_annotation_id);
if (!status.IsOK()) {
return Api().ort.CreateStatus(static_cast<OrtErrorCode>(status.Code()),
status.ErrorMessage().c_str());
}
return nullptr;
EXCEPTION_TO_RETURNED_STATUS_END
}

OrtGraphCaptureNodeAssignmentPolicy ORT_API_CALL Ep::GetGraphCaptureNodeAssignmentPolicyImpl(
_In_ const OrtEp* this_ptr) noexcept {
auto* ep = static_cast<const Ep*>(this_ptr);
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/providers/webgpu/ep/ep.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ class Ep : public onnxruntime::ep::adapter::Ep {
static OrtStatus* ORT_API_CALL ReplayGraphImpl(_In_ OrtEp* this_ptr,
_In_ int graph_annotation_id) noexcept;

static OrtStatus* ORT_API_CALL ReleaseCapturedGraphImpl(_In_ OrtEp* this_ptr,
_In_ int graph_annotation_id) noexcept;

static OrtGraphCaptureNodeAssignmentPolicy ORT_API_CALL GetGraphCaptureNodeAssignmentPolicyImpl(
_In_ const OrtEp* this_ptr) noexcept;

Expand Down
7 changes: 6 additions & 1 deletion onnxruntime/core/providers/webgpu/ep/factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "core/framework/execution_provider.h"
#include "core/framework/config_options.h"
#include "core/providers/webgpu/webgpu_provider_factory_creator.h"
#include "core/providers/webgpu/webgpu_execution_provider.h"
#include "core/providers/webgpu/webgpu_context.h"
#include "core/providers/webgpu/allocator.h"

Expand Down Expand Up @@ -134,9 +135,13 @@ OrtStatus* ORT_API_CALL Factory::CreateEpImpl(
static_cast<WebGpuExecutionProvider*>(webgpu_ep.get())->SetEpLogger(logger);
auto factory = static_cast<Factory*>(this_ptr);
const int context_id = webgpu_ep->GetDeviceId();
auto* webgpu_ep_ptr = static_cast<WebGpuExecutionProvider*>(webgpu_ep.get());
auto device_alloc = std::make_shared<webgpu::GpuBufferAllocator>(
[webgpu_ep_ptr]() -> const webgpu::BufferManager& { return webgpu_ep_ptr->BufferManager(); }, false);
webgpu_ep_ptr->SetDefaultGpuAllocator(static_cast<webgpu::GpuBufferAllocator*>(device_alloc.get()));
Ep::Config webgpu_ep_config{
CPUAllocator::DefaultInstance(), // CPU allocator
std::make_shared<webgpu::GpuBufferAllocator>(WebGpuContextFactory::GetContext(context_id).BufferManager(), false), // default device allocator
device_alloc, // default device allocator
std::make_shared<webgpu::GpuBufferAllocator>(WebGpuContextFactory::GetContext(context_id).InitializerBufferManager(), true), // initializer device allocator
};
*ep = new Ep(std::move(webgpu_ep), *factory, *logger, webgpu_ep_config);
Expand Down
103 changes: 77 additions & 26 deletions onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -578,16 +578,6 @@ WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id,
enable_int64_{config.enable_graph_capture || config.enable_int64},
multi_rotary_cache_concat_offset_{config.multi_rotary_cache_concat_offset},
prepack_allocator_{std::make_shared<webgpu::GpuBufferAllocator>(context_.InitializerBufferManager(), false)} {
// If graph capture is enabled, create a dedicated buffer manager for graph mode
if (enable_graph_capture_) {
// Create buffer manager for graph capture mode with appropriate cache modes
graph_buffer_mgr_ = webgpu::BufferManagerFactory::Create(
context_,
webgpu::BufferCacheMode::Graph,
webgpu::BufferCacheMode::GraphSimple,
webgpu::BufferCacheMode::Disabled);
}

if (config.enable_pix_capture) {
#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
// set pix frame generator
Expand All @@ -599,11 +589,14 @@ WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id,
}

std::vector<AllocatorPtr> WebGpuExecutionProvider::CreatePreferredAllocators() {
auto device_allocator = std::make_unique<webgpu::GpuBufferAllocator>(
[this]() -> const webgpu::BufferManager& { return BufferManager(); }, false);
default_gpu_allocator_ = device_allocator.get();
return {
// allocator for initializers
std::make_unique<webgpu::GpuBufferAllocator>(context_.InitializerBufferManager(), true),
// default allocator
std::make_unique<webgpu::GpuBufferAllocator>(BufferManager(), false),
std::move(device_allocator),
};
}

Expand Down Expand Up @@ -733,11 +726,20 @@ std::optional<bool> WebGpuExecutionProvider::ShouldConvertDataLayoutForOp(std::s
}

WebGpuExecutionProvider::~WebGpuExecutionProvider() {
// Release all resources associated with the captured graph
if (!captured_commands_.empty()) {
context_.ReleaseGraphResources(captured_commands_);
// Release all captured graphs (both fully captured and partially captured) and their associated resources.
// Use captured_graphs_ keys to also cover partially captured graphs that have GPU command handles
// but never completed capture (i.e., CaptureBegin was called but CaptureEnd was not).
std::vector<int> graph_ids;
graph_ids.reserve(captured_graphs_.size());
for (const auto& [id, _] : captured_graphs_) {
graph_ids.push_back(id);
}
for (int id : graph_ids) {
(void)ReleaseCapturedGraph(id);
Comment thread
edgchen1 marked this conversation as resolved.
Outdated
}
// The graph_buffer_mgr_ will be automatically cleaned up by unique_ptr
// Release any per-graph buffer managers for graphs that had buffer managers created
// but no entries in captured_graphs_ (edge case cleanup)
per_graph_buffer_mgrs_.clear();
Comment thread
qjia7 marked this conversation as resolved.

WebGpuContextFactory::ReleaseContext(context_id_);
}
Expand Down Expand Up @@ -772,10 +774,27 @@ Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_op
*graph_annotation_str);
}

m_current_graph_annotation_id = graph_annotation_id;
Comment thread
edgchen1 marked this conversation as resolved.
Outdated

// Create a per-graph buffer manager on first encounter of each annotation ID
if (graph_annotation_id != -1) {
if (per_graph_buffer_mgrs_.find(graph_annotation_id) == per_graph_buffer_mgrs_.end()) {
per_graph_buffer_mgrs_[graph_annotation_id] = webgpu::BufferManagerFactory::Create(
context_,
webgpu::BufferCacheMode::Graph,
webgpu::BufferCacheMode::GraphSimple,
webgpu::BufferCacheMode::Disabled);
}
graph_buffer_mgr_active_ = true;
if (default_gpu_allocator_) {
default_gpu_allocator_->RefreshBufferManager();
}
}

if (graph_annotation_id != -1 && IsGraphCaptureAllowed() && !IsGraphCaptured(graph_annotation_id)) {
context_.CaptureBegin(&captured_commands_, *graph_buffer_mgr_);
auto& commands = captured_graphs_[graph_annotation_id];
context_.CaptureBegin(&commands, *per_graph_buffer_mgrs_[graph_annotation_id]);
Comment thread
edgchen1 marked this conversation as resolved.
Outdated
}
m_current_graph_annotation_id = graph_annotation_id;
}

return Status::OK();
Expand All @@ -787,7 +806,7 @@ Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxrunti
if (IsGraphCaptureEnabled() && !IsGraphCaptured(m_current_graph_annotation_id)) {
if (m_current_graph_annotation_id != -1 && IsGraphCaptureAllowed()) {
context_.CaptureEnd();
is_graph_captured_ = true;
captured_graph_ids_.insert(m_current_graph_annotation_id);
ORT_RETURN_IF_ERROR(ReplayGraph(m_current_graph_annotation_id));
} else {
IncrementRegularRunCountBeforeGraphCapture();
Expand All @@ -808,6 +827,12 @@ Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxrunti
}
#endif // ENABLE_PIX_FOR_WEBGPU_EP

// Reset buffer manager routing after run completes
graph_buffer_mgr_active_ = false;
if (default_gpu_allocator_) {
default_gpu_allocator_->RefreshBufferManager();
}

if (context_.ValidationMode() >= ValidationMode::Basic) {
return context_.PopErrorScope();
} else {
Expand All @@ -820,7 +845,7 @@ bool WebGpuExecutionProvider::IsGraphCaptureEnabled() const {
}

bool WebGpuExecutionProvider::IsGraphCaptured(int graph_annotation_id) const {
return is_graph_captured_ && graph_annotation_id != -1;
return graph_annotation_id != -1 && captured_graph_ids_.count(graph_annotation_id) > 0;
Comment thread
edgchen1 marked this conversation as resolved.
Outdated
}

Status WebGpuExecutionProvider::ReplayGraph(int graph_annotation_id) {
Expand All @@ -829,27 +854,53 @@ Status WebGpuExecutionProvider::ReplayGraph(int graph_annotation_id) {
if (session_profiler_ && session_profiler_->Enabled()) {
context_.StartProfiling();
}
context_.Replay(captured_commands_, *graph_buffer_mgr_);
context_.Replay(captured_graphs_.at(graph_annotation_id), *per_graph_buffer_mgrs_.at(graph_annotation_id));
if (session_profiler_ && session_profiler_->Enabled()) {
// Session-level profiling: collect into profiler's own events storage.
context_.CollectProfilingData(session_profiler_->GpuEvents());
}
return Status::OK();
}

Status WebGpuExecutionProvider::ReleaseCapturedGraph(int graph_annotation_id) {
// Release captured commands
auto cmd_it = captured_graphs_.find(graph_annotation_id);
if (cmd_it != captured_graphs_.end()) {
if (!cmd_it->second.empty()) {
context_.ReleaseGraphResources(cmd_it->second);
}
captured_graphs_.erase(cmd_it);
}

// Remove from captured set
captured_graph_ids_.erase(graph_annotation_id);

// Release per-graph buffer manager (destroys cached buffers)
per_graph_buffer_mgrs_.erase(graph_annotation_id);

// Clean up run count tracking
graph_id_to_run_count_.erase(graph_annotation_id);

return Status::OK();
}

webgpu::BufferManager& WebGpuExecutionProvider::BufferManager() const {
if (graph_buffer_mgr_) {
return *graph_buffer_mgr_;
} else {
return context_.BufferManager();
if (graph_buffer_mgr_active_) {
auto it = per_graph_buffer_mgrs_.find(m_current_graph_annotation_id);
if (it != per_graph_buffer_mgrs_.end()) {
return *it->second;
}
}
return context_.BufferManager();
}

bool WebGpuExecutionProvider::IsGraphCaptureAllowed() const {
return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_;
auto it = graph_id_to_run_count_.find(m_current_graph_annotation_id);
int run_count = (it != graph_id_to_run_count_.end()) ? it->second : 0;
return run_count >= min_num_runs_before_graph_capture_;
}

void WebGpuExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() {
++regular_run_count_before_graph_capture_;
++graph_id_to_run_count_[m_current_graph_annotation_id];
}
} // namespace onnxruntime
Loading
Loading