Skip to content

Commit 1b08292

Browse files
committed
Per-graph buffer manager for WebGPU multi-graph capture
Replace the single shared graph_buffer_mgr_ with per-annotation-ID buffer managers so that multiple generators with different prompts each get isolated buffer caches. This prevents cross-generator buffer corruption when graph capture is enabled. Changes: - Add per_graph_buffer_mgrs_ map keyed by graph annotation ID - Route BufferManager() to the correct per-graph instance during capture/replay - Change GpuBufferAllocator from const BufferManager& to pointer with SetBufferManager() for dynamic routing - Add ReleaseGraph API through the full stack (EP base class, C API, InferenceSession, plugin EP, EP adapter) so callers can free captured graph resources when a generator is destroyed - WebGPU EP ReleaseGraph cleans up captured commands, buffer manager, and tracking state for the given annotation ID
1 parent 0992717 commit 1b08292

17 files changed

Lines changed: 209 additions & 33 deletions

include/onnxruntime/core/framework/execution_provider.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,14 @@ class IExecutionProvider {
292292
return Status::OK();
293293
}
294294

295+
/**
296+
Release a previously captured graph and its associated resources.
297+
Called when the caller no longer needs the captured graph for the given annotation ID.
298+
*/
299+
virtual common::Status ReleaseGraph(int /*graph_annotation_id*/) {
300+
return Status::OK();
301+
}
302+
295303
/**
296304
Get the node assignment validation policy for graph capture.
297305
When graph capture is enabled, ORT validates that nodes are assigned to EPs

include/onnxruntime/core/session/onnxruntime_c_api.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7435,6 +7435,21 @@ struct OrtApi {
74357435
ORT_API2_STATUS(SetPerSessionThreadPoolCallbacks, _Inout_ OrtEnv* env,
74367436
_In_ const OrtThreadPoolCallbacksConfig* config);
74377437

7438+
/** \brief Release a previously captured graph and its associated GPU resources.
7439+
*
7440+
* When graph capture is enabled, the EP records GPU commands during initial runs and replays them
7441+
* on subsequent runs. This function releases the captured commands and associated GPU buffers
7442+
* for a specific graph annotation ID, freeing GPU memory.
7443+
*
7444+
* \param[in] session The OrtSession instance.
7445+
* \param[in] graph_annotation_id The annotation ID of the captured graph to release.
7446+
*
7447+
* \snippet{doc} snippets.dox OrtStatus Return Value
7448+
*
7449+
* \since Version 1.26.
7450+
*/
7451+
ORT_API2_STATUS(SessionReleaseGraph, _In_ OrtSession* session, _In_ int graph_annotation_id);
7452+
74387453
/// @}
74397454
};
74407455

include/onnxruntime/core/session/onnxruntime_ep_c_api.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2514,6 +2514,20 @@ struct OrtEp {
25142514
*/
25152515
ORT_API2_STATUS(ReplayGraph, _In_ OrtEp* this_ptr, _In_ int graph_annotation_id);
25162516

2517+
/** \brief Release a previously captured graph and its associated resources.
2518+
*
2519+
* Called when the caller no longer needs the captured graph for the given annotation ID.
2520+
* This allows the EP to free GPU buffers and other resources tied to this graph.
2521+
*
2522+
* \param[in] this_ptr The EP instance.
2523+
* \param[in] graph_annotation_id The annotation ID of the graph to release.
2524+
*
2525+
* \snippet{doc} snippets.dox OrtStatus Return Value
2526+
*
2527+
* \since Version 1.26.
2528+
*/
2529+
ORT_API2_STATUS(ReleaseGraph, _In_ OrtEp* this_ptr, _In_ int graph_annotation_id);
2530+
25172531
/** \brief Get the node assignment validation policy for graph capture.
25182532
*
25192533
* When graph capture is enabled, ORT validates that nodes are assigned to EPs in a way that is

onnxruntime/core/providers/webgpu/allocator.cc

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ GpuBufferAllocator::GpuBufferAllocator(const BufferManager& buffer_manager, bool
1515
: OrtAllocatorType::OrtDeviceAllocator,
1616
WebGpuDevice,
1717
OrtMemTypeDefault)),
18-
buffer_manager_{buffer_manager},
18+
buffer_manager_{&buffer_manager},
1919
mapped_at_creation_{is_read_only_allocator && buffer_manager.SupportsUMA()} {
2020
}
2121

@@ -29,12 +29,12 @@ void* GpuBufferAllocator::Alloc(size_t size) {
2929
wgpu::BufferUsage usage = mapped_at_creation_ ? wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapWrite
3030
: wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Indirect;
3131

32-
return buffer_manager_.Create(size, usage);
32+
return buffer_manager_->Create(size, usage);
3333
}
3434

3535
void GpuBufferAllocator::Free(void* p) {
3636
if (p != nullptr) {
37-
buffer_manager_.Release(static_cast<WGPUBuffer>(p));
37+
buffer_manager_->Release(static_cast<WGPUBuffer>(p));
3838
stats_.num_allocs--;
3939
}
4040
}
@@ -43,5 +43,9 @@ void GpuBufferAllocator::GetStats(AllocatorStats* stats) {
4343
*stats = stats_;
4444
}
4545

46+
void GpuBufferAllocator::SetBufferManager(const BufferManager& buffer_manager) {
47+
buffer_manager_ = &buffer_manager;
48+
}
49+
4650
} // namespace webgpu
4751
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/allocator.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,12 @@ class GpuBufferAllocator : public IAllocator {
2424
virtual void Free(void* p) override;
2525
void GetStats(AllocatorStats* stats) override;
2626

27+
// Update the buffer manager used for allocations (for per-graph buffer isolation)
28+
void SetBufferManager(const BufferManager& buffer_manager);
29+
2730
private:
2831
AllocatorStats stats_;
29-
const BufferManager& buffer_manager_;
32+
const BufferManager* buffer_manager_;
3033
bool mapped_at_creation_;
3134
};
3235

onnxruntime/core/providers/webgpu/ep/ep.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ Ep::Ep(std::unique_ptr<IExecutionProvider> impl, Factory& factory, const OrtLogg
4343
IsGraphCaptureEnabled = IsGraphCaptureEnabledImpl;
4444
IsGraphCaptured = IsGraphCapturedImpl;
4545
ReplayGraph = ReplayGraphImpl;
46+
ReleaseGraph = ReleaseGraphImpl;
4647
GetGraphCaptureNodeAssignmentPolicy = GetGraphCaptureNodeAssignmentPolicyImpl;
4748
}
4849

@@ -279,6 +280,18 @@ OrtStatus* ORT_API_CALL Ep::ReplayGraphImpl(_In_ OrtEp* this_ptr, _In_ int graph
279280
EXCEPTION_TO_RETURNED_STATUS_END
280281
}
281282

283+
OrtStatus* ORT_API_CALL Ep::ReleaseGraphImpl(_In_ OrtEp* this_ptr, _In_ int graph_annotation_id) noexcept {
284+
EXCEPTION_TO_RETURNED_STATUS_BEGIN
285+
auto* ep = static_cast<Ep*>(this_ptr);
286+
auto status = ep->EpImpl()->ReleaseGraph(graph_annotation_id);
287+
if (!status.IsOK()) {
288+
return Api().ort.CreateStatus(static_cast<OrtErrorCode>(status.Code()),
289+
status.ErrorMessage().c_str());
290+
}
291+
return nullptr;
292+
EXCEPTION_TO_RETURNED_STATUS_END
293+
}
294+
282295
OrtGraphCaptureNodeAssignmentPolicy ORT_API_CALL Ep::GetGraphCaptureNodeAssignmentPolicyImpl(
283296
_In_ const OrtEp* this_ptr) noexcept {
284297
auto* ep = static_cast<const Ep*>(this_ptr);

onnxruntime/core/providers/webgpu/ep/ep.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ class Ep : public onnxruntime::ep::adapter::Ep {
7575
static OrtStatus* ORT_API_CALL ReplayGraphImpl(_In_ OrtEp* this_ptr,
7676
_In_ int graph_annotation_id) noexcept;
7777

78+
static OrtStatus* ORT_API_CALL ReleaseGraphImpl(_In_ OrtEp* this_ptr,
79+
_In_ int graph_annotation_id) noexcept;
80+
7881
static OrtGraphCaptureNodeAssignmentPolicy ORT_API_CALL GetGraphCaptureNodeAssignmentPolicyImpl(
7982
_In_ const OrtEp* this_ptr) noexcept;
8083

onnxruntime/core/providers/webgpu/ep/factory.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "core/framework/execution_provider.h"
1111
#include "core/framework/config_options.h"
1212
#include "core/providers/webgpu/webgpu_provider_factory_creator.h"
13+
#include "core/providers/webgpu/webgpu_execution_provider.h"
1314
#include "core/providers/webgpu/webgpu_context.h"
1415
#include "core/providers/webgpu/allocator.h"
1516

@@ -134,10 +135,13 @@ OrtStatus* ORT_API_CALL Factory::CreateEpImpl(
134135
static_cast<WebGpuExecutionProvider*>(webgpu_ep.get())->SetEpLogger(logger);
135136
auto factory = static_cast<Factory*>(this_ptr);
136137
const int context_id = webgpu_ep->GetDeviceId();
138+
auto* webgpu_ep_ptr = static_cast<WebGpuExecutionProvider*>(webgpu_ep.get());
139+
auto device_alloc = std::make_shared<webgpu::GpuBufferAllocator>(webgpu_ep_ptr->BufferManager(), false);
140+
webgpu_ep_ptr->SetDeviceAllocator(device_alloc.get());
137141
Ep::Config webgpu_ep_config{
138142
CPUAllocator::DefaultInstance(), // CPU allocator
139-
std::make_shared<webgpu::GpuBufferAllocator>(WebGpuContextFactory::GetContext(context_id).BufferManager(), false), // default device allocator
140-
std::make_shared<webgpu::GpuBufferAllocator>(WebGpuContextFactory::GetContext(context_id).InitializerBufferManager(), true), // initializer device allocator
143+
device_alloc, // default device allocator
144+
std::make_shared<webgpu::GpuBufferAllocator>(WebGpuContextFactory::GetContext(context_id).InitializerBufferManager(), true), // initializer device allocator
141145
};
142146
*ep = new Ep(std::move(webgpu_ep), *factory, *logger, webgpu_ep_config);
143147
return nullptr;

onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc

Lines changed: 75 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,7 @@ WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id,
578578
// If graph capture is enabled, create a dedicated buffer manager for graph mode
579579
if (enable_graph_capture_) {
580580
// Create buffer manager for graph capture mode with appropriate cache modes
581-
graph_buffer_mgr_ = webgpu::BufferManagerFactory::Create(
581+
graph_default_buffer_mgr_ = webgpu::BufferManagerFactory::Create(
582582
context_,
583583
webgpu::BufferCacheMode::Graph,
584584
webgpu::BufferCacheMode::GraphSimple,
@@ -596,11 +596,13 @@ WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id,
596596
}
597597

598598
std::vector<AllocatorPtr> WebGpuExecutionProvider::CreatePreferredAllocators() {
599+
auto default_alloc = std::make_unique<webgpu::GpuBufferAllocator>(BufferManager(), false);
600+
default_gpu_allocator_ = default_alloc.get();
599601
return {
600602
// allocator for initializers
601603
std::make_unique<webgpu::GpuBufferAllocator>(context_.InitializerBufferManager(), true),
602604
// default allocator
603-
std::make_unique<webgpu::GpuBufferAllocator>(BufferManager(), false),
605+
std::move(default_alloc),
604606
};
605607
}
606608

@@ -725,11 +727,17 @@ std::optional<bool> WebGpuExecutionProvider::ShouldConvertDataLayoutForOp(std::s
725727
}
726728

727729
WebGpuExecutionProvider::~WebGpuExecutionProvider() {
728-
// Release all resources associated with the captured graph
729-
if (!captured_commands_.empty()) {
730-
context_.ReleaseGraphResources(captured_commands_);
730+
// Release all captured graphs and their associated resources
731+
std::vector<int> graph_ids;
732+
graph_ids.reserve(captured_graph_ids_.size());
733+
for (int id : captured_graph_ids_) {
734+
graph_ids.push_back(id);
731735
}
732-
// The graph_buffer_mgr_ will be automatically cleaned up by unique_ptr
736+
for (int id : graph_ids) {
737+
(void)ReleaseGraph(id);
738+
}
739+
// Also release any per-graph buffer managers for graphs that were never fully captured
740+
per_graph_buffer_mgrs_.clear();
733741

734742
WebGpuContextFactory::ReleaseContext(context_id_);
735743
}
@@ -764,8 +772,24 @@ Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_op
764772
*graph_annotation_str);
765773
}
766774

775+
// Create a per-graph buffer manager on first encounter of each annotation ID
776+
if (graph_annotation_id != -1) {
777+
if (per_graph_buffer_mgrs_.find(graph_annotation_id) == per_graph_buffer_mgrs_.end()) {
778+
per_graph_buffer_mgrs_[graph_annotation_id] = webgpu::BufferManagerFactory::Create(
779+
context_,
780+
webgpu::BufferCacheMode::Graph,
781+
webgpu::BufferCacheMode::GraphSimple,
782+
webgpu::BufferCacheMode::Disabled);
783+
}
784+
// Route allocator to this graph's buffer manager
785+
if (default_gpu_allocator_) {
786+
default_gpu_allocator_->SetBufferManager(*per_graph_buffer_mgrs_[graph_annotation_id]);
787+
}
788+
}
789+
767790
if (graph_annotation_id != -1 && IsGraphCaptureAllowed() && !IsGraphCaptured(graph_annotation_id)) {
768-
context_.CaptureBegin(&captured_commands_, *graph_buffer_mgr_);
791+
auto& commands = captured_graphs_[graph_annotation_id];
792+
context_.CaptureBegin(&commands, *per_graph_buffer_mgrs_[graph_annotation_id]);
769793
}
770794
m_current_graph_annotation_id = graph_annotation_id;
771795
}
@@ -779,7 +803,7 @@ Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxrunti
779803
if (IsGraphCaptureEnabled() && !IsGraphCaptured(m_current_graph_annotation_id)) {
780804
if (m_current_graph_annotation_id != -1 && IsGraphCaptureAllowed()) {
781805
context_.CaptureEnd();
782-
is_graph_captured_ = true;
806+
captured_graph_ids_.insert(m_current_graph_annotation_id);
783807
ORT_RETURN_IF_ERROR(ReplayGraph(m_current_graph_annotation_id));
784808
} else {
785809
IncrementRegularRunCountBeforeGraphCapture();
@@ -800,6 +824,11 @@ Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxrunti
800824
}
801825
#endif // ENABLE_PIX_FOR_WEBGPU_EP
802826

827+
// Reset allocator to default buffer manager after run completes
828+
if (default_gpu_allocator_ && graph_default_buffer_mgr_) {
829+
default_gpu_allocator_->SetBufferManager(*graph_default_buffer_mgr_);
830+
}
831+
803832
if (context_.ValidationMode() >= ValidationMode::Basic) {
804833
return context_.PopErrorScope();
805834
} else {
@@ -812,7 +841,7 @@ bool WebGpuExecutionProvider::IsGraphCaptureEnabled() const {
812841
}
813842

814843
bool WebGpuExecutionProvider::IsGraphCaptured(int graph_annotation_id) const {
815-
return is_graph_captured_ && graph_annotation_id != -1;
844+
return graph_annotation_id != -1 && captured_graph_ids_.count(graph_annotation_id) > 0;
816845
}
817846

818847
Status WebGpuExecutionProvider::ReplayGraph(int graph_annotation_id) {
@@ -821,27 +850,59 @@ Status WebGpuExecutionProvider::ReplayGraph(int graph_annotation_id) {
821850
if (session_profiler_ && session_profiler_->Enabled()) {
822851
context_.StartProfiling();
823852
}
824-
context_.Replay(captured_commands_, *graph_buffer_mgr_);
853+
context_.Replay(captured_graphs_.at(graph_annotation_id), *per_graph_buffer_mgrs_.at(graph_annotation_id));
825854
if (session_profiler_ && session_profiler_->Enabled()) {
826855
// Session-level profiling: collect into profiler's own events storage.
827856
context_.CollectProfilingData(session_profiler_->GpuEvents());
828857
}
829858
return Status::OK();
830859
}
831860

861+
Status WebGpuExecutionProvider::ReleaseGraph(int graph_annotation_id) {
862+
// Release captured commands
863+
auto cmd_it = captured_graphs_.find(graph_annotation_id);
864+
if (cmd_it != captured_graphs_.end()) {
865+
if (!cmd_it->second.empty()) {
866+
context_.ReleaseGraphResources(cmd_it->second);
867+
}
868+
captured_graphs_.erase(cmd_it);
869+
}
870+
871+
// Remove from captured set
872+
captured_graph_ids_.erase(graph_annotation_id);
873+
874+
// Release per-graph buffer manager (destroys cached buffers)
875+
per_graph_buffer_mgrs_.erase(graph_annotation_id);
876+
877+
// Clean up run count tracking
878+
graph_id_to_run_count_.erase(graph_annotation_id);
879+
880+
return Status::OK();
881+
}
882+
832883
webgpu::BufferManager& WebGpuExecutionProvider::BufferManager() const {
833-
if (graph_buffer_mgr_) {
834-
return *graph_buffer_mgr_;
884+
// Use per-graph buffer manager if one exists for the current annotation ID
885+
if (m_current_graph_annotation_id != 0 && m_current_graph_annotation_id != -1) {
886+
auto it = per_graph_buffer_mgrs_.find(m_current_graph_annotation_id);
887+
if (it != per_graph_buffer_mgrs_.end()) {
888+
return *it->second;
889+
}
890+
}
891+
// Fall back to default graph buffer manager (warmup runs) or context buffer manager
892+
if (graph_default_buffer_mgr_) {
893+
return *graph_default_buffer_mgr_;
835894
} else {
836895
return context_.BufferManager();
837896
}
838897
}
839898

840899
bool WebGpuExecutionProvider::IsGraphCaptureAllowed() const {
841-
return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_;
900+
auto it = graph_id_to_run_count_.find(m_current_graph_annotation_id);
901+
int run_count = (it != graph_id_to_run_count_.end()) ? it->second : 0;
902+
return run_count >= min_num_runs_before_cuda_graph_capture_;
842903
}
843904

844905
void WebGpuExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() {
845-
++regular_run_count_before_graph_capture_;
906+
++graph_id_to_run_count_[m_current_graph_annotation_id];
846907
}
847908
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/webgpu_execution_provider.h

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
#include <span>
88
#include <string>
99
#include <memory>
10+
#include <unordered_map>
11+
#include <unordered_set>
1012
#include <vector>
1113

1214
#include "core/framework/execution_provider.h"
@@ -97,11 +99,14 @@ class WebGpuExecutionProvider : public IExecutionProvider {
9799
bool IsGraphCaptureEnabled() const override;
98100
bool IsGraphCaptured(int graph_annotation_id) const override;
99101
Status ReplayGraph(int graph_annotation_id) override;
102+
Status ReleaseGraph(int graph_annotation_id) override;
100103
OrtGraphCaptureNodeAssignmentPolicy GetGraphCaptureNodeAssignmentPolicy() const override {
101104
return OrtGraphCaptureNodeAssignmentPolicy_ALLOW_CPU_FOR_SHAPES;
102105
}
103106
webgpu::BufferManager& BufferManager() const;
104107
AllocatorPtr PrepackAllocator() const { return prepack_allocator_; }
108+
// Set the device allocator pointer so we can call SetBufferManager on it during OnRunStart/OnRunEnd
109+
void SetDeviceAllocator(webgpu::GpuBufferAllocator* allocator) { default_gpu_allocator_ = allocator; }
105110
std::span<const std::string> GetForceCpuNodeNames() const { return force_cpu_node_names_; }
106111
uint32_t MultiRotaryCacheConcatOffset() const { return multi_rotary_cache_concat_offset_; }
107112

@@ -126,24 +131,35 @@ class WebGpuExecutionProvider : public IExecutionProvider {
126131
bool enable_graph_capture_ = false;
127132
bool enable_int64_ = false;
128133
uint32_t multi_rotary_cache_concat_offset_ = 0;
129-
bool is_graph_captured_ = false;
130-
int regular_run_count_before_graph_capture_ = 0;
134+
std::unordered_map<int, int> graph_id_to_run_count_;
131135
const int min_num_runs_before_cuda_graph_capture_ = 1; // Required regular runs before graph capture for any necessary allocations.
132136
int m_current_graph_annotation_id = 0;
133137

134138
#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
135139
std::unique_ptr<WebGpuPIXFrameGenerator> pix_frame_generator_ = nullptr;
136140
#endif // ENABLE_PIX_FOR_WEBGPU_EP
137141

138-
// Buffer manager specifically for graph capture mode
139-
std::unique_ptr<webgpu::BufferManager> graph_buffer_mgr_ = nullptr;
142+
// Default buffer manager for graph capture mode (used during warmup runs
143+
// and as the stable reference target for GpuBufferAllocator)
144+
std::unique_ptr<webgpu::BufferManager> graph_default_buffer_mgr_ = nullptr;
140145

141-
// Store captured commands directly in the EP instead of in WebGpuContext
142-
std::vector<webgpu::CapturedCommandInfo> captured_commands_;
146+
// Per-graph buffer managers keyed by annotation ID.
147+
// Each captured graph gets its own buffer manager so that buffer caches
148+
// are isolated between different generators.
149+
std::unordered_map<int, std::unique_ptr<webgpu::BufferManager>> per_graph_buffer_mgrs_;
150+
151+
// Store captured commands per graph annotation ID
152+
std::unordered_map<int, std::vector<webgpu::CapturedCommandInfo>> captured_graphs_;
153+
// Track which graph annotation IDs have completed capture
154+
std::unordered_set<int> captured_graph_ids_;
143155

144156
// Allocator for prepacked weights (uses buffers without mapping)
145157
AllocatorPtr prepack_allocator_;
146158

159+
// Raw pointer to the default GPU allocator (owned by the framework via CreatePreferredAllocators)
160+
// Used to swap the buffer manager for per-graph isolation
161+
webgpu::GpuBufferAllocator* default_gpu_allocator_ = nullptr;
162+
147163
#if defined(ORT_USE_EP_API_ADAPTERS)
148164
std::unique_ptr<onnxruntime::ep::adapter::Logger> ep_logger_;
149165
#endif

0 commit comments

Comments
 (0)