@@ -581,7 +581,7 @@ WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id,
581581 // If graph capture is enabled, create a dedicated buffer manager for graph mode
582582 if (enable_graph_capture_) {
583583 // Create buffer manager for graph capture mode with appropriate cache modes
584- graph_buffer_mgr_ = webgpu::BufferManagerFactory::Create (
584+ graph_default_buffer_mgr_ = webgpu::BufferManagerFactory::Create (
585585 context_,
586586 webgpu::BufferCacheMode::Graph,
587587 webgpu::BufferCacheMode::GraphSimple,
@@ -599,11 +599,13 @@ WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id,
599599}
600600
601601std::vector<AllocatorPtr> WebGpuExecutionProvider::CreatePreferredAllocators () {
602+ auto default_alloc = std::make_unique<webgpu::GpuBufferAllocator>(BufferManager (), false );
603+ default_gpu_allocator_ = default_alloc.get ();
602604 return {
603605 // allocator for initializers
604606 std::make_unique<webgpu::GpuBufferAllocator>(context_.InitializerBufferManager (), true ),
605607 // default allocator
606- std::make_unique<webgpu::GpuBufferAllocator>( BufferManager (), false ),
608+ std::move (default_alloc ),
607609 };
608610}
609611
@@ -733,11 +735,17 @@ std::optional<bool> WebGpuExecutionProvider::ShouldConvertDataLayoutForOp(std::s
733735}
734736
735737WebGpuExecutionProvider::~WebGpuExecutionProvider () {
736- // Release all resources associated with the captured graph
737- if (!captured_commands_.empty ()) {
738- context_.ReleaseGraphResources (captured_commands_);
738+ // Release all captured graphs and their associated resources
739+ std::vector<int > graph_ids;
740+ graph_ids.reserve (captured_graph_ids_.size ());
741+ for (int id : captured_graph_ids_) {
742+ graph_ids.push_back (id);
739743 }
740- // The graph_buffer_mgr_ will be automatically cleaned up by unique_ptr
744+ for (int id : graph_ids) {
745+ (void )ReleaseGraph (id);
746+ }
747+ // Also release any per-graph buffer managers for graphs that were never fully captured
748+ per_graph_buffer_mgrs_.clear ();
741749
742750 WebGpuContextFactory::ReleaseContext (context_id_);
743751}
@@ -772,8 +780,24 @@ Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_op
772780 *graph_annotation_str);
773781 }
774782
783+ // Create a per-graph buffer manager on first encounter of each annotation ID
784+ if (graph_annotation_id != -1 ) {
785+ if (per_graph_buffer_mgrs_.find (graph_annotation_id) == per_graph_buffer_mgrs_.end ()) {
786+ per_graph_buffer_mgrs_[graph_annotation_id] = webgpu::BufferManagerFactory::Create (
787+ context_,
788+ webgpu::BufferCacheMode::Graph,
789+ webgpu::BufferCacheMode::GraphSimple,
790+ webgpu::BufferCacheMode::Disabled);
791+ }
792+ // Route allocator to this graph's buffer manager
793+ if (default_gpu_allocator_) {
794+ default_gpu_allocator_->SetBufferManager (*per_graph_buffer_mgrs_[graph_annotation_id]);
795+ }
796+ }
797+
775798 if (graph_annotation_id != -1 && IsGraphCaptureAllowed () && !IsGraphCaptured (graph_annotation_id)) {
776- context_.CaptureBegin (&captured_commands_, *graph_buffer_mgr_);
799+ auto & commands = captured_graphs_[graph_annotation_id];
800+ context_.CaptureBegin (&commands, *per_graph_buffer_mgrs_[graph_annotation_id]);
777801 }
778802 m_current_graph_annotation_id = graph_annotation_id;
779803 }
@@ -787,7 +811,7 @@ Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxrunti
787811 if (IsGraphCaptureEnabled () && !IsGraphCaptured (m_current_graph_annotation_id)) {
788812 if (m_current_graph_annotation_id != -1 && IsGraphCaptureAllowed ()) {
789813 context_.CaptureEnd ();
790- is_graph_captured_ = true ;
814+ captured_graph_ids_. insert (m_current_graph_annotation_id) ;
791815 ORT_RETURN_IF_ERROR (ReplayGraph (m_current_graph_annotation_id));
792816 } else {
793817 IncrementRegularRunCountBeforeGraphCapture ();
@@ -808,6 +832,11 @@ Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxrunti
808832 }
809833#endif // ENABLE_PIX_FOR_WEBGPU_EP
810834
835+ // Reset allocator to default buffer manager after run completes
836+ if (default_gpu_allocator_ && graph_default_buffer_mgr_) {
837+ default_gpu_allocator_->SetBufferManager (*graph_default_buffer_mgr_);
838+ }
839+
811840 if (context_.ValidationMode () >= ValidationMode::Basic) {
812841 return context_.PopErrorScope ();
813842 } else {
@@ -820,7 +849,7 @@ bool WebGpuExecutionProvider::IsGraphCaptureEnabled() const {
820849}
821850
822851bool WebGpuExecutionProvider::IsGraphCaptured (int graph_annotation_id) const {
823- return is_graph_captured_ && graph_annotation_id != -1 ;
852+ return graph_annotation_id != -1 && captured_graph_ids_. count (graph_annotation_id) > 0 ;
824853}
825854
826855Status WebGpuExecutionProvider::ReplayGraph (int graph_annotation_id) {
@@ -829,27 +858,59 @@ Status WebGpuExecutionProvider::ReplayGraph(int graph_annotation_id) {
829858 if (session_profiler_ && session_profiler_->Enabled ()) {
830859 context_.StartProfiling ();
831860 }
832- context_.Replay (captured_commands_ , *graph_buffer_mgr_ );
861+ context_.Replay (captured_graphs_. at (graph_annotation_id) , *per_graph_buffer_mgrs_. at (graph_annotation_id) );
833862 if (session_profiler_ && session_profiler_->Enabled ()) {
834863 // Session-level profiling: collect into profiler's own events storage.
835864 context_.CollectProfilingData (session_profiler_->GpuEvents ());
836865 }
837866 return Status::OK ();
838867}
839868
869+ Status WebGpuExecutionProvider::ReleaseGraph (int graph_annotation_id) {
870+ // Release captured commands
871+ auto cmd_it = captured_graphs_.find (graph_annotation_id);
872+ if (cmd_it != captured_graphs_.end ()) {
873+ if (!cmd_it->second .empty ()) {
874+ context_.ReleaseGraphResources (cmd_it->second );
875+ }
876+ captured_graphs_.erase (cmd_it);
877+ }
878+
879+ // Remove from captured set
880+ captured_graph_ids_.erase (graph_annotation_id);
881+
882+ // Release per-graph buffer manager (destroys cached buffers)
883+ per_graph_buffer_mgrs_.erase (graph_annotation_id);
884+
885+ // Clean up run count tracking
886+ graph_id_to_run_count_.erase (graph_annotation_id);
887+
888+ return Status::OK ();
889+ }
890+
840891webgpu::BufferManager& WebGpuExecutionProvider::BufferManager () const {
841- if (graph_buffer_mgr_) {
842- return *graph_buffer_mgr_;
892+ // Use per-graph buffer manager if one exists for the current annotation ID
893+ if (m_current_graph_annotation_id != 0 && m_current_graph_annotation_id != -1 ) {
894+ auto it = per_graph_buffer_mgrs_.find (m_current_graph_annotation_id);
895+ if (it != per_graph_buffer_mgrs_.end ()) {
896+ return *it->second ;
897+ }
898+ }
899+ // Fall back to default graph buffer manager (warmup runs) or context buffer manager
900+ if (graph_default_buffer_mgr_) {
901+ return *graph_default_buffer_mgr_;
843902 } else {
844903 return context_.BufferManager ();
845904 }
846905}
847906
848907bool WebGpuExecutionProvider::IsGraphCaptureAllowed () const {
849- return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_;
908+ auto it = graph_id_to_run_count_.find (m_current_graph_annotation_id);
909+ int run_count = (it != graph_id_to_run_count_.end ()) ? it->second : 0 ;
910+ return run_count >= min_num_runs_before_cuda_graph_capture_;
850911}
851912
852913void WebGpuExecutionProvider::IncrementRegularRunCountBeforeGraphCapture () {
853- ++regular_run_count_before_graph_capture_ ;
914+ ++graph_id_to_run_count_[m_current_graph_annotation_id] ;
854915}
855916} // namespace onnxruntime
0 commit comments