Skip to content

Commit 94a2ca4

Browse files
committed
Per-graph buffer manager for WebGPU multi-graph capture
Replace the single shared graph_buffer_mgr_ with per-annotation-ID buffer managers so that multiple generators with different prompts each get isolated buffer caches. This prevents cross-generator buffer corruption when graph capture is enabled. Changes: - Add per_graph_buffer_mgrs_ map keyed by graph annotation ID - Route BufferManager() to the correct per-graph instance during capture/replay - Change GpuBufferAllocator from const BufferManager& to pointer with SetBufferManager() for dynamic routing - Add ReleaseGraph API through the full stack (EP base class, C API, InferenceSession, plugin EP, EP adapter) so callers can free captured graph resources when a generator is destroyed - WebGPU EP ReleaseGraph cleans up captured commands, buffer manager, and tracking state for the given annotation ID
1 parent e3c34da commit 94a2ca4

17 files changed

Lines changed: 203 additions & 26 deletions

include/onnxruntime/core/framework/execution_provider.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,14 @@ class IExecutionProvider {
292292
return Status::OK();
293293
}
294294

295+
/**
296+
Release a previously captured graph and its associated resources.
297+
Called when the caller no longer needs the captured graph for the given annotation ID.
298+
*/
299+
virtual common::Status ReleaseGraph(int /*graph_annotation_id*/) {
300+
return Status::OK();
301+
}
302+
295303
/**
296304
Get the node assignment validation policy for graph capture.
297305
When graph capture is enabled, ORT validates that nodes are assigned to EPs

include/onnxruntime/core/session/onnxruntime_c_api.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7435,6 +7435,21 @@ struct OrtApi {
74357435
ORT_API2_STATUS(SetPerSessionThreadPoolCallbacks, _Inout_ OrtEnv* env,
74367436
_In_ const OrtThreadPoolCallbacksConfig* config);
74377437

7438+
/** \brief Release a previously captured graph and its associated GPU resources.
7439+
*
7440+
* When graph capture is enabled, the EP records GPU commands during initial runs and replays them
7441+
* on subsequent runs. This function releases the captured commands and associated GPU buffers
7442+
* for a specific graph annotation ID, freeing GPU memory.
7443+
*
7444+
* \param[in] session The OrtSession instance.
7445+
* \param[in] graph_annotation_id The annotation ID of the captured graph to release.
7446+
*
7447+
* \snippet{doc} snippets.dox OrtStatus Return Value
7448+
*
7449+
* \since Version 1.26.
7450+
*/
7451+
ORT_API2_STATUS(SessionReleaseGraph, _In_ OrtSession* session, _In_ int graph_annotation_id);
7452+
74387453
/// @}
74397454
};
74407455

include/onnxruntime/core/session/onnxruntime_ep_c_api.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2514,6 +2514,20 @@ struct OrtEp {
25142514
*/
25152515
ORT_API2_STATUS(ReplayGraph, _In_ OrtEp* this_ptr, _In_ int graph_annotation_id);
25162516

2517+
/** \brief Release a previously captured graph and its associated resources.
2518+
*
2519+
* Called when the caller no longer needs the captured graph for the given annotation ID.
2520+
* This allows the EP to free GPU buffers and other resources tied to this graph.
2521+
*
2522+
* \param[in] this_ptr The EP instance.
2523+
* \param[in] graph_annotation_id The annotation ID of the graph to release.
2524+
*
2525+
* \snippet{doc} snippets.dox OrtStatus Return Value
2526+
*
2527+
* \since Version 1.26.
2528+
*/
2529+
ORT_API2_STATUS(ReleaseGraph, _In_ OrtEp* this_ptr, _In_ int graph_annotation_id);
2530+
25172531
/** \brief Get the node assignment validation policy for graph capture.
25182532
*
25192533
* When graph capture is enabled, ORT validates that nodes are assigned to EPs in a way that is

onnxruntime/core/providers/webgpu/allocator.cc

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ GpuBufferAllocator::GpuBufferAllocator(const BufferManager& buffer_manager, bool
1515
: OrtAllocatorType::OrtDeviceAllocator,
1616
WebGpuDevice,
1717
OrtMemTypeDefault)),
18-
buffer_manager_{buffer_manager},
18+
buffer_manager_{&buffer_manager},
1919
mapped_at_creation_{is_read_only_allocator && buffer_manager.SupportsUMA()} {
2020
}
2121

@@ -29,12 +29,12 @@ void* GpuBufferAllocator::Alloc(size_t size) {
2929
wgpu::BufferUsage usage = mapped_at_creation_ ? wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapWrite
3030
: wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Indirect;
3131

32-
return buffer_manager_.Create(size, usage);
32+
return buffer_manager_->Create(size, usage);
3333
}
3434

3535
void GpuBufferAllocator::Free(void* p) {
3636
if (p != nullptr) {
37-
buffer_manager_.Release(static_cast<WGPUBuffer>(p));
37+
buffer_manager_->Release(static_cast<WGPUBuffer>(p));
3838
stats_.num_allocs--;
3939
}
4040
}
@@ -43,5 +43,9 @@ void GpuBufferAllocator::GetStats(AllocatorStats* stats) {
4343
*stats = stats_;
4444
}
4545

46+
void GpuBufferAllocator::SetBufferManager(const BufferManager& buffer_manager) {
47+
buffer_manager_ = &buffer_manager;
48+
}
49+
4650
} // namespace webgpu
4751
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/allocator.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,12 @@ class GpuBufferAllocator : public IAllocator {
2424
virtual void Free(void* p) override;
2525
void GetStats(AllocatorStats* stats) override;
2626

27+
// Update the buffer manager used for allocations (for per-graph buffer isolation)
28+
void SetBufferManager(const BufferManager& buffer_manager);
29+
2730
private:
2831
AllocatorStats stats_;
29-
const BufferManager& buffer_manager_;
32+
const BufferManager* buffer_manager_;
3033
bool mapped_at_creation_;
3134
};
3235

onnxruntime/core/providers/webgpu/ep/ep.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ Ep::Ep(std::unique_ptr<IExecutionProvider> impl, Factory& factory, const OrtLogg
4343
IsGraphCaptureEnabled = IsGraphCaptureEnabledImpl;
4444
IsGraphCaptured = IsGraphCapturedImpl;
4545
ReplayGraph = ReplayGraphImpl;
46+
ReleaseGraph = ReleaseGraphImpl;
4647
GetGraphCaptureNodeAssignmentPolicy = GetGraphCaptureNodeAssignmentPolicyImpl;
4748
}
4849

@@ -279,6 +280,18 @@ OrtStatus* ORT_API_CALL Ep::ReplayGraphImpl(_In_ OrtEp* this_ptr, _In_ int graph
279280
EXCEPTION_TO_RETURNED_STATUS_END
280281
}
281282

283+
OrtStatus* ORT_API_CALL Ep::ReleaseGraphImpl(_In_ OrtEp* this_ptr, _In_ int graph_annotation_id) noexcept {
284+
EXCEPTION_TO_RETURNED_STATUS_BEGIN
285+
auto* ep = static_cast<Ep*>(this_ptr);
286+
auto status = ep->EpImpl()->ReleaseGraph(graph_annotation_id);
287+
if (!status.IsOK()) {
288+
return Api().ort.CreateStatus(static_cast<OrtErrorCode>(status.Code()),
289+
status.ErrorMessage().c_str());
290+
}
291+
return nullptr;
292+
EXCEPTION_TO_RETURNED_STATUS_END
293+
}
294+
282295
OrtGraphCaptureNodeAssignmentPolicy ORT_API_CALL Ep::GetGraphCaptureNodeAssignmentPolicyImpl(
283296
_In_ const OrtEp* this_ptr) noexcept {
284297
auto* ep = static_cast<const Ep*>(this_ptr);

onnxruntime/core/providers/webgpu/ep/ep.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ class Ep : public onnxruntime::ep::adapter::Ep {
7575
static OrtStatus* ORT_API_CALL ReplayGraphImpl(_In_ OrtEp* this_ptr,
7676
_In_ int graph_annotation_id) noexcept;
7777

78+
static OrtStatus* ORT_API_CALL ReleaseGraphImpl(_In_ OrtEp* this_ptr,
79+
_In_ int graph_annotation_id) noexcept;
80+
7881
static OrtGraphCaptureNodeAssignmentPolicy ORT_API_CALL GetGraphCaptureNodeAssignmentPolicyImpl(
7982
_In_ const OrtEp* this_ptr) noexcept;
8083

onnxruntime/core/providers/webgpu/ep/factory.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "core/framework/execution_provider.h"
1111
#include "core/framework/config_options.h"
1212
#include "core/providers/webgpu/webgpu_provider_factory_creator.h"
13+
#include "core/providers/webgpu/webgpu_execution_provider.h"
1314
#include "core/providers/webgpu/webgpu_context.h"
1415
#include "core/providers/webgpu/allocator.h"
1516

@@ -134,10 +135,13 @@ OrtStatus* ORT_API_CALL Factory::CreateEpImpl(
134135
static_cast<WebGpuExecutionProvider*>(webgpu_ep.get())->SetEpLogger(logger);
135136
auto factory = static_cast<Factory*>(this_ptr);
136137
const int context_id = webgpu_ep->GetDeviceId();
138+
auto* webgpu_ep_ptr = static_cast<WebGpuExecutionProvider*>(webgpu_ep.get());
139+
auto device_alloc = std::make_shared<webgpu::GpuBufferAllocator>(webgpu_ep_ptr->BufferManager(), false);
140+
webgpu_ep_ptr->SetDeviceAllocator(device_alloc.get());
137141
Ep::Config webgpu_ep_config{
138142
CPUAllocator::DefaultInstance(), // CPU allocator
139-
std::make_shared<webgpu::GpuBufferAllocator>(WebGpuContextFactory::GetContext(context_id).BufferManager(), false), // default device allocator
140-
std::make_shared<webgpu::GpuBufferAllocator>(WebGpuContextFactory::GetContext(context_id).InitializerBufferManager(), true), // initializer device allocator
143+
device_alloc, // default device allocator
144+
std::make_shared<webgpu::GpuBufferAllocator>(WebGpuContextFactory::GetContext(context_id).InitializerBufferManager(), true), // initializer device allocator
141145
};
142146
*ep = new Ep(std::move(webgpu_ep), *factory, *logger, webgpu_ep_config);
143147
return nullptr;

onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc

Lines changed: 75 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,7 @@ WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id,
581581
// If graph capture is enabled, create a dedicated buffer manager for graph mode
582582
if (enable_graph_capture_) {
583583
// Create buffer manager for graph capture mode with appropriate cache modes
584-
graph_buffer_mgr_ = webgpu::BufferManagerFactory::Create(
584+
graph_default_buffer_mgr_ = webgpu::BufferManagerFactory::Create(
585585
context_,
586586
webgpu::BufferCacheMode::Graph,
587587
webgpu::BufferCacheMode::GraphSimple,
@@ -599,11 +599,13 @@ WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id,
599599
}
600600

601601
std::vector<AllocatorPtr> WebGpuExecutionProvider::CreatePreferredAllocators() {
602+
auto default_alloc = std::make_unique<webgpu::GpuBufferAllocator>(BufferManager(), false);
603+
default_gpu_allocator_ = default_alloc.get();
602604
return {
603605
// allocator for initializers
604606
std::make_unique<webgpu::GpuBufferAllocator>(context_.InitializerBufferManager(), true),
605607
// default allocator
606-
std::make_unique<webgpu::GpuBufferAllocator>(BufferManager(), false),
608+
std::move(default_alloc),
607609
};
608610
}
609611

@@ -733,11 +735,17 @@ std::optional<bool> WebGpuExecutionProvider::ShouldConvertDataLayoutForOp(std::s
733735
}
734736

735737
WebGpuExecutionProvider::~WebGpuExecutionProvider() {
736-
// Release all resources associated with the captured graph
737-
if (!captured_commands_.empty()) {
738-
context_.ReleaseGraphResources(captured_commands_);
738+
// Release all captured graphs and their associated resources
739+
std::vector<int> graph_ids;
740+
graph_ids.reserve(captured_graph_ids_.size());
741+
for (int id : captured_graph_ids_) {
742+
graph_ids.push_back(id);
739743
}
740-
// The graph_buffer_mgr_ will be automatically cleaned up by unique_ptr
744+
for (int id : graph_ids) {
745+
(void)ReleaseGraph(id);
746+
}
747+
// Also release any per-graph buffer managers for graphs that were never fully captured
748+
per_graph_buffer_mgrs_.clear();
741749

742750
WebGpuContextFactory::ReleaseContext(context_id_);
743751
}
@@ -772,8 +780,24 @@ Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_op
772780
*graph_annotation_str);
773781
}
774782

783+
// Create a per-graph buffer manager on first encounter of each annotation ID
784+
if (graph_annotation_id != -1) {
785+
if (per_graph_buffer_mgrs_.find(graph_annotation_id) == per_graph_buffer_mgrs_.end()) {
786+
per_graph_buffer_mgrs_[graph_annotation_id] = webgpu::BufferManagerFactory::Create(
787+
context_,
788+
webgpu::BufferCacheMode::Graph,
789+
webgpu::BufferCacheMode::GraphSimple,
790+
webgpu::BufferCacheMode::Disabled);
791+
}
792+
// Route allocator to this graph's buffer manager
793+
if (default_gpu_allocator_) {
794+
default_gpu_allocator_->SetBufferManager(*per_graph_buffer_mgrs_[graph_annotation_id]);
795+
}
796+
}
797+
775798
if (graph_annotation_id != -1 && IsGraphCaptureAllowed() && !IsGraphCaptured(graph_annotation_id)) {
776-
context_.CaptureBegin(&captured_commands_, *graph_buffer_mgr_);
799+
auto& commands = captured_graphs_[graph_annotation_id];
800+
context_.CaptureBegin(&commands, *per_graph_buffer_mgrs_[graph_annotation_id]);
777801
}
778802
m_current_graph_annotation_id = graph_annotation_id;
779803
}
@@ -787,7 +811,7 @@ Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxrunti
787811
if (IsGraphCaptureEnabled() && !IsGraphCaptured(m_current_graph_annotation_id)) {
788812
if (m_current_graph_annotation_id != -1 && IsGraphCaptureAllowed()) {
789813
context_.CaptureEnd();
790-
is_graph_captured_ = true;
814+
captured_graph_ids_.insert(m_current_graph_annotation_id);
791815
ORT_RETURN_IF_ERROR(ReplayGraph(m_current_graph_annotation_id));
792816
} else {
793817
IncrementRegularRunCountBeforeGraphCapture();
@@ -808,6 +832,11 @@ Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxrunti
808832
}
809833
#endif // ENABLE_PIX_FOR_WEBGPU_EP
810834

835+
// Reset allocator to default buffer manager after run completes
836+
if (default_gpu_allocator_ && graph_default_buffer_mgr_) {
837+
default_gpu_allocator_->SetBufferManager(*graph_default_buffer_mgr_);
838+
}
839+
811840
if (context_.ValidationMode() >= ValidationMode::Basic) {
812841
return context_.PopErrorScope();
813842
} else {
@@ -820,7 +849,7 @@ bool WebGpuExecutionProvider::IsGraphCaptureEnabled() const {
820849
}
821850

822851
bool WebGpuExecutionProvider::IsGraphCaptured(int graph_annotation_id) const {
823-
return is_graph_captured_ && graph_annotation_id != -1;
852+
return graph_annotation_id != -1 && captured_graph_ids_.count(graph_annotation_id) > 0;
824853
}
825854

826855
Status WebGpuExecutionProvider::ReplayGraph(int graph_annotation_id) {
@@ -829,27 +858,59 @@ Status WebGpuExecutionProvider::ReplayGraph(int graph_annotation_id) {
829858
if (session_profiler_ && session_profiler_->Enabled()) {
830859
context_.StartProfiling();
831860
}
832-
context_.Replay(captured_commands_, *graph_buffer_mgr_);
861+
context_.Replay(captured_graphs_.at(graph_annotation_id), *per_graph_buffer_mgrs_.at(graph_annotation_id));
833862
if (session_profiler_ && session_profiler_->Enabled()) {
834863
// Session-level profiling: collect into profiler's own events storage.
835864
context_.CollectProfilingData(session_profiler_->GpuEvents());
836865
}
837866
return Status::OK();
838867
}
839868

869+
Status WebGpuExecutionProvider::ReleaseGraph(int graph_annotation_id) {
870+
// Release captured commands
871+
auto cmd_it = captured_graphs_.find(graph_annotation_id);
872+
if (cmd_it != captured_graphs_.end()) {
873+
if (!cmd_it->second.empty()) {
874+
context_.ReleaseGraphResources(cmd_it->second);
875+
}
876+
captured_graphs_.erase(cmd_it);
877+
}
878+
879+
// Remove from captured set
880+
captured_graph_ids_.erase(graph_annotation_id);
881+
882+
// Release per-graph buffer manager (destroys cached buffers)
883+
per_graph_buffer_mgrs_.erase(graph_annotation_id);
884+
885+
// Clean up run count tracking
886+
graph_id_to_run_count_.erase(graph_annotation_id);
887+
888+
return Status::OK();
889+
}
890+
840891
webgpu::BufferManager& WebGpuExecutionProvider::BufferManager() const {
841-
if (graph_buffer_mgr_) {
842-
return *graph_buffer_mgr_;
892+
// Use per-graph buffer manager if one exists for the current annotation ID
893+
if (m_current_graph_annotation_id != 0 && m_current_graph_annotation_id != -1) {
894+
auto it = per_graph_buffer_mgrs_.find(m_current_graph_annotation_id);
895+
if (it != per_graph_buffer_mgrs_.end()) {
896+
return *it->second;
897+
}
898+
}
899+
// Fall back to default graph buffer manager (warmup runs) or context buffer manager
900+
if (graph_default_buffer_mgr_) {
901+
return *graph_default_buffer_mgr_;
843902
} else {
844903
return context_.BufferManager();
845904
}
846905
}
847906

848907
bool WebGpuExecutionProvider::IsGraphCaptureAllowed() const {
849-
return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_;
908+
auto it = graph_id_to_run_count_.find(m_current_graph_annotation_id);
909+
int run_count = (it != graph_id_to_run_count_.end()) ? it->second : 0;
910+
return run_count >= min_num_runs_before_cuda_graph_capture_;
850911
}
851912

852913
void WebGpuExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() {
853-
++regular_run_count_before_graph_capture_;
914+
++graph_id_to_run_count_[m_current_graph_annotation_id];
854915
}
855916
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/webgpu_execution_provider.h

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
#include <span>
88
#include <string>
99
#include <memory>
10+
#include <unordered_map>
11+
#include <unordered_set>
1012
#include <vector>
1113

1214
#include "core/framework/execution_provider.h"
@@ -97,11 +99,14 @@ class WebGpuExecutionProvider : public IExecutionProvider {
9799
bool IsGraphCaptureEnabled() const override;
98100
bool IsGraphCaptured(int graph_annotation_id) const override;
99101
Status ReplayGraph(int graph_annotation_id) override;
102+
Status ReleaseGraph(int graph_annotation_id) override;
100103
OrtGraphCaptureNodeAssignmentPolicy GetGraphCaptureNodeAssignmentPolicy() const override {
101104
return OrtGraphCaptureNodeAssignmentPolicy_ALLOW_CPU_FOR_SHAPES;
102105
}
103106
webgpu::BufferManager& BufferManager() const;
104107
AllocatorPtr PrepackAllocator() const { return prepack_allocator_; }
108+
// Set the device allocator pointer so we can call SetBufferManager on it during OnRunStart/OnRunEnd
109+
void SetDeviceAllocator(webgpu::GpuBufferAllocator* allocator) { default_gpu_allocator_ = allocator; }
105110
std::span<const std::string> GetForceCpuNodeNames() const { return force_cpu_node_names_; }
106111
uint32_t MultiRotaryCacheConcatOffset() const { return multi_rotary_cache_concat_offset_; }
107112

@@ -126,24 +131,35 @@ class WebGpuExecutionProvider : public IExecutionProvider {
126131
bool enable_graph_capture_ = false;
127132
bool enable_int64_ = false;
128133
uint32_t multi_rotary_cache_concat_offset_ = 0;
129-
bool is_graph_captured_ = false;
130-
int regular_run_count_before_graph_capture_ = 0;
134+
std::unordered_map<int, int> graph_id_to_run_count_;
131135
const int min_num_runs_before_cuda_graph_capture_ = 1; // Required regular runs before graph capture for any necessary allocations.
132136
int m_current_graph_annotation_id = 0;
133137

134138
#if defined(ENABLE_PIX_FOR_WEBGPU_EP)
135139
std::unique_ptr<WebGpuPIXFrameGenerator> pix_frame_generator_ = nullptr;
136140
#endif // ENABLE_PIX_FOR_WEBGPU_EP
137141

138-
// Buffer manager specifically for graph capture mode
139-
std::unique_ptr<webgpu::BufferManager> graph_buffer_mgr_ = nullptr;
142+
// Default buffer manager for graph capture mode (used during warmup runs
143+
// and as the stable reference target for GpuBufferAllocator)
144+
std::unique_ptr<webgpu::BufferManager> graph_default_buffer_mgr_ = nullptr;
140145

141-
// Store captured commands directly in the EP instead of in WebGpuContext
142-
std::vector<webgpu::CapturedCommandInfo> captured_commands_;
146+
// Per-graph buffer managers keyed by annotation ID.
147+
// Each captured graph gets its own buffer manager so that buffer caches
148+
// are isolated between different generators.
149+
std::unordered_map<int, std::unique_ptr<webgpu::BufferManager>> per_graph_buffer_mgrs_;
150+
151+
// Store captured commands per graph annotation ID
152+
std::unordered_map<int, std::vector<webgpu::CapturedCommandInfo>> captured_graphs_;
153+
// Track which graph annotation IDs have completed capture
154+
std::unordered_set<int> captured_graph_ids_;
143155

144156
// Allocator for prepacked weights (uses buffers without mapping)
145157
AllocatorPtr prepack_allocator_;
146158

159+
// Raw pointer to the default GPU allocator (owned by the framework via CreatePreferredAllocators)
160+
// Used to swap the buffer manager for per-graph isolation
161+
webgpu::GpuBufferAllocator* default_gpu_allocator_ = nullptr;
162+
147163
#if defined(ORT_USE_EP_API_ADAPTERS)
148164
std::unique_ptr<onnxruntime::ep::adapter::Logger> ep_logger_;
149165
#endif

0 commit comments

Comments
 (0)