@@ -578,7 +578,7 @@ WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id,
578578 // If graph capture is enabled, create a dedicated buffer manager for graph mode
579579 if (enable_graph_capture_) {
580580 // Create buffer manager for graph capture mode with appropriate cache modes
581- graph_buffer_mgr_ = webgpu::BufferManagerFactory::Create (
581+ graph_default_buffer_mgr_ = webgpu::BufferManagerFactory::Create (
582582 context_,
583583 webgpu::BufferCacheMode::Graph,
584584 webgpu::BufferCacheMode::GraphSimple,
@@ -596,11 +596,13 @@ WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id,
596596}
597597
598598std::vector<AllocatorPtr> WebGpuExecutionProvider::CreatePreferredAllocators () {
599+ auto default_alloc = std::make_unique<webgpu::GpuBufferAllocator>(BufferManager (), false );
600+ default_gpu_allocator_ = default_alloc.get ();
599601 return {
600602 // allocator for initializers
601603 std::make_unique<webgpu::GpuBufferAllocator>(context_.InitializerBufferManager (), true ),
602604 // default allocator
603- std::make_unique<webgpu::GpuBufferAllocator>( BufferManager (), false ),
605+ std::move (default_alloc ),
604606 };
605607}
606608
@@ -725,11 +727,17 @@ std::optional<bool> WebGpuExecutionProvider::ShouldConvertDataLayoutForOp(std::s
725727}
726728
727729WebGpuExecutionProvider::~WebGpuExecutionProvider () {
728- // Release all resources associated with the captured graph
729- if (!captured_commands_.empty ()) {
730- context_.ReleaseGraphResources (captured_commands_);
730+ // Release all captured graphs and their associated resources
731+ std::vector<int > graph_ids;
732+ graph_ids.reserve (captured_graph_ids_.size ());
733+ for (int id : captured_graph_ids_) {
734+ graph_ids.push_back (id);
731735 }
732- // The graph_buffer_mgr_ will be automatically cleaned up by unique_ptr
736+ for (int id : graph_ids) {
737+ (void )ReleaseGraph (id);
738+ }
739+ // Also release any per-graph buffer managers for graphs that were never fully captured
740+ per_graph_buffer_mgrs_.clear ();
733741
734742 WebGpuContextFactory::ReleaseContext (context_id_);
735743}
@@ -764,8 +772,24 @@ Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_op
764772 *graph_annotation_str);
765773 }
766774
775+ // Create a per-graph buffer manager on first encounter of each annotation ID
776+ if (graph_annotation_id != -1 ) {
777+ if (per_graph_buffer_mgrs_.find (graph_annotation_id) == per_graph_buffer_mgrs_.end ()) {
778+ per_graph_buffer_mgrs_[graph_annotation_id] = webgpu::BufferManagerFactory::Create (
779+ context_,
780+ webgpu::BufferCacheMode::Graph,
781+ webgpu::BufferCacheMode::GraphSimple,
782+ webgpu::BufferCacheMode::Disabled);
783+ }
784+ // Route allocator to this graph's buffer manager
785+ if (default_gpu_allocator_) {
786+ default_gpu_allocator_->SetBufferManager (*per_graph_buffer_mgrs_[graph_annotation_id]);
787+ }
788+ }
789+
767790 if (graph_annotation_id != -1 && IsGraphCaptureAllowed () && !IsGraphCaptured (graph_annotation_id)) {
768- context_.CaptureBegin (&captured_commands_, *graph_buffer_mgr_);
791+ auto & commands = captured_graphs_[graph_annotation_id];
792+ context_.CaptureBegin (&commands, *per_graph_buffer_mgrs_[graph_annotation_id]);
769793 }
770794 m_current_graph_annotation_id = graph_annotation_id;
771795 }
@@ -779,7 +803,7 @@ Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxrunti
779803 if (IsGraphCaptureEnabled () && !IsGraphCaptured (m_current_graph_annotation_id)) {
780804 if (m_current_graph_annotation_id != -1 && IsGraphCaptureAllowed ()) {
781805 context_.CaptureEnd ();
782- is_graph_captured_ = true ;
806+ captured_graph_ids_. insert (m_current_graph_annotation_id) ;
783807 ORT_RETURN_IF_ERROR (ReplayGraph (m_current_graph_annotation_id));
784808 } else {
785809 IncrementRegularRunCountBeforeGraphCapture ();
@@ -800,6 +824,11 @@ Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxrunti
800824 }
801825#endif // ENABLE_PIX_FOR_WEBGPU_EP
802826
827+ // Reset allocator to default buffer manager after run completes
828+ if (default_gpu_allocator_ && graph_default_buffer_mgr_) {
829+ default_gpu_allocator_->SetBufferManager (*graph_default_buffer_mgr_);
830+ }
831+
803832 if (context_.ValidationMode () >= ValidationMode::Basic) {
804833 return context_.PopErrorScope ();
805834 } else {
@@ -812,7 +841,7 @@ bool WebGpuExecutionProvider::IsGraphCaptureEnabled() const {
812841}
813842
814843bool WebGpuExecutionProvider::IsGraphCaptured (int graph_annotation_id) const {
815- return is_graph_captured_ && graph_annotation_id != -1 ;
844+ return graph_annotation_id != -1 && captured_graph_ids_. count (graph_annotation_id) > 0 ;
816845}
817846
818847Status WebGpuExecutionProvider::ReplayGraph (int graph_annotation_id) {
@@ -821,27 +850,59 @@ Status WebGpuExecutionProvider::ReplayGraph(int graph_annotation_id) {
821850 if (session_profiler_ && session_profiler_->Enabled ()) {
822851 context_.StartProfiling ();
823852 }
824- context_.Replay (captured_commands_ , *graph_buffer_mgr_ );
853+ context_.Replay (captured_graphs_. at (graph_annotation_id) , *per_graph_buffer_mgrs_. at (graph_annotation_id) );
825854 if (session_profiler_ && session_profiler_->Enabled ()) {
826855 // Session-level profiling: collect into profiler's own events storage.
827856 context_.CollectProfilingData (session_profiler_->GpuEvents ());
828857 }
829858 return Status::OK ();
830859}
831860
861+ Status WebGpuExecutionProvider::ReleaseGraph (int graph_annotation_id) {
862+ // Release captured commands
863+ auto cmd_it = captured_graphs_.find (graph_annotation_id);
864+ if (cmd_it != captured_graphs_.end ()) {
865+ if (!cmd_it->second .empty ()) {
866+ context_.ReleaseGraphResources (cmd_it->second );
867+ }
868+ captured_graphs_.erase (cmd_it);
869+ }
870+
871+ // Remove from captured set
872+ captured_graph_ids_.erase (graph_annotation_id);
873+
874+ // Release per-graph buffer manager (destroys cached buffers)
875+ per_graph_buffer_mgrs_.erase (graph_annotation_id);
876+
877+ // Clean up run count tracking
878+ graph_id_to_run_count_.erase (graph_annotation_id);
879+
880+ return Status::OK ();
881+ }
882+
832883webgpu::BufferManager& WebGpuExecutionProvider::BufferManager () const {
833- if (graph_buffer_mgr_) {
834- return *graph_buffer_mgr_;
884+ // Use per-graph buffer manager if one exists for the current annotation ID
885+ if (m_current_graph_annotation_id != 0 && m_current_graph_annotation_id != -1 ) {
886+ auto it = per_graph_buffer_mgrs_.find (m_current_graph_annotation_id);
887+ if (it != per_graph_buffer_mgrs_.end ()) {
888+ return *it->second ;
889+ }
890+ }
891+ // Fall back to default graph buffer manager (warmup runs) or context buffer manager
892+ if (graph_default_buffer_mgr_) {
893+ return *graph_default_buffer_mgr_;
835894 } else {
836895 return context_.BufferManager ();
837896 }
838897}
839898
840899bool WebGpuExecutionProvider::IsGraphCaptureAllowed () const {
841- return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_;
900+ auto it = graph_id_to_run_count_.find (m_current_graph_annotation_id);
901+ int run_count = (it != graph_id_to_run_count_.end ()) ? it->second : 0 ;
902+ return run_count >= min_num_runs_before_cuda_graph_capture_;
842903}
843904
844905void WebGpuExecutionProvider::IncrementRegularRunCountBeforeGraphCapture () {
845- ++regular_run_count_before_graph_capture_ ;
906+ ++graph_id_to_run_count_[m_current_graph_annotation_id] ;
846907}
847908} // namespace onnxruntime
0 commit comments