diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index 7395d823fd2be..6f9b99918c0e2 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -292,6 +292,18 @@ class IExecutionProvider { return Status::OK(); } + /** + Release a previously captured graph and its associated resources. + Called when the caller no longer needs the captured graph for the given annotation ID. + + Thread safety: For EPs where ConcurrentRunSupported() returns true, this method may be + called concurrently with Run(). The EP is responsible for its own synchronization in + that case. For non-concurrent EPs, the session serializes calls via session_mutex_. + */ + virtual common::Status ReleaseCapturedGraph(int /*graph_annotation_id*/) { + return Status::OK(); + } + /** Get the node assignment validation policy for graph capture. When graph capture is enabled, ORT validates that nodes are assigned to EPs diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index f8cb2d6ad46de..d90189496af42 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -7471,6 +7471,21 @@ struct OrtApi { * \see OrtApi::SetSessionExecutionMode */ ORT_API2_STATUS(GetSessionExecutionMode, _In_ const OrtSessionOptions* options, _Out_ ExecutionMode* out); + + /** \brief Release a previously captured graph and its associated resources. + * + * When graph capture is enabled, the EP records information during initial runs (e.g., GPU commands) + * and replays them on subsequent runs. This function releases the captured resources for a specific + * graph annotation ID, freeing memory. + * + * \param[in] session The OrtSession instance. + * \param[in] graph_annotation_id The annotation ID of the captured graph to release. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.27. + */ + ORT_API2_STATUS(SessionReleaseCapturedGraph, _In_ OrtSession* session, _In_ int graph_annotation_id); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 9bc9be34514ee..42eeac19da377 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -2005,6 +2005,14 @@ struct SessionImpl : ConstSessionImpl { void FinalizeModelEditorSession(const Model& model, const SessionOptions& options, OrtPrepackedWeightsContainer* prepacked_weights_container = nullptr); + + /** \brief Release a previously captured graph. + * + * Wraps OrtApi::SessionReleaseCapturedGraph + * + * \param[in] graph_annotation_id The annotation ID of the captured graph to release. + */ + void ReleaseCapturedGraph(int graph_annotation_id); }; } // namespace detail diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 99c606161c812..61bc31736f5b5 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -2107,6 +2107,11 @@ inline void SessionImpl::FinalizeModelEditorSession(const Model& model, const } #endif // #if !defined(ORT_MINIMAL_BUILD) +template +inline void SessionImpl::ReleaseCapturedGraph(int graph_annotation_id) { + ThrowOnError(GetApi().SessionReleaseCapturedGraph(this->p_, graph_annotation_id)); +} + } // namespace detail inline SessionOptions::SessionOptions() { diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 76fb7ce93b600..b816528f1f2ba 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -2567,6 +2567,28 @@ struct OrtEp { * \since Version 1.27. */ ORT_API2_STATUS(OnSessionInitializationEnd, _In_ OrtEp* this_ptr); + + /** \brief Release a previously captured graph and its associated resources. + * + * Called when the caller no longer needs the captured graph for the given annotation ID. + * This allows the EP to free buffers and other resources tied to this graph. + * + * \param[in] this_ptr The EP instance. + * \param[in] graph_annotation_id The annotation ID of the graph to release. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \note Implementation of this function is optional. If set to NULL, ORT assumes + * no captured graph release is needed and treats it as a no-op success. + * + * \note Thread safety: For EPs that support concurrent Run() calls, this method may be + * called concurrently with Run(). The EP is responsible for ensuring thread safety + * of its own state in that case. For non-concurrent EPs, the session serializes + * calls via its internal mutex. + * + * \since Version 1.27. + */ + ORT_API2_STATUS(ReleaseCapturedGraph, _In_ OrtEp* this_ptr, _In_ int graph_annotation_id); }; /** \brief The function signature that ORT will call to create OrtEpFactory instances. diff --git a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc index 2890afae02ab9..8645d0751b65f 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/multihead_attention.cc @@ -103,6 +103,16 @@ Status MultiHeadAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& TensorShape output_qk_shape(output_qk_dims); Tensor* output_qk = context.Output(3, output_qk_shape); + // Match CPU EP semantics: when no present_key/present_value output is requested, + // ignore past_key/past_value. The CPU EP sets past_sequence_length=0 in this case, + // effectively treating the input as if there is no KV cache. + if (present_key == nullptr && present_value == nullptr) { + past_key = nullptr; + past_value = nullptr; + parameters.past_sequence_length_ = 0; + parameters.total_sequence_length_ = parameters.kv_sequence_length_; + } + if (output_qk == nullptr && // Flash attention does not output QK scores CanApplyFlashAttention(parameters, context)) { if (bias != nullptr) { diff --git a/onnxruntime/core/providers/webgpu/allocator.cc b/onnxruntime/core/providers/webgpu/allocator.cc index 3e1b87821fe2f..af00d864b6087 100644 --- a/onnxruntime/core/providers/webgpu/allocator.cc +++ b/onnxruntime/core/providers/webgpu/allocator.cc @@ -8,15 +8,17 @@ namespace onnxruntime { namespace webgpu { -GpuBufferAllocator::GpuBufferAllocator(const BufferManager& buffer_manager, bool is_read_only_allocator) +GpuBufferAllocator::GpuBufferAllocator( + std::function buffer_manager_getter, + bool is_read_only_allocator) : IAllocator( OrtMemoryInfo(WEBGPU_BUFFER, is_read_only_allocator ? OrtAllocatorType::OrtReadOnlyAllocator : OrtAllocatorType::OrtDeviceAllocator, WebGpuDevice, OrtMemTypeDefault)), - buffer_manager_{buffer_manager}, - mapped_at_creation_{is_read_only_allocator && buffer_manager.SupportsUMA()} { + buffer_manager_getter_{std::move(buffer_manager_getter)}, + mapped_at_creation_{is_read_only_allocator && buffer_manager_getter_().SupportsUMA()} { } void* GpuBufferAllocator::Alloc(size_t size) { @@ -29,12 +31,12 @@ void* GpuBufferAllocator::Alloc(size_t size) { wgpu::BufferUsage usage = mapped_at_creation_ ? wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapWrite : wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Indirect; - return buffer_manager_.Create(size, usage); + return buffer_manager_getter_().Create(size, usage); } void GpuBufferAllocator::Free(void* p) { if (p != nullptr) { - buffer_manager_.Release(static_cast(p)); + buffer_manager_getter_().Release(static_cast(p)); stats_.num_allocs--; } } diff --git a/onnxruntime/core/providers/webgpu/allocator.h b/onnxruntime/core/providers/webgpu/allocator.h index 74b3d669fcf3b..670d8f9f1694d 100644 --- a/onnxruntime/core/providers/webgpu/allocator.h +++ b/onnxruntime/core/providers/webgpu/allocator.h @@ -3,6 +3,8 @@ #pragma once +#include + #include "core/framework/allocator.h" #include "core/framework/ortdevice.h" @@ -18,7 +20,10 @@ inline constexpr OrtDevice WebGpuDevice{OrtDevice::GPU, class GpuBufferAllocator : public IAllocator { public: - GpuBufferAllocator(const BufferManager& buffer_manager, bool is_read_only_allocator); + // Calls buffer_manager_getter on every Alloc/Free to obtain the current + // BufferManager. This allows the EP to route allocations to different + // buffer managers (e.g., per-graph) without explicit refresh calls. + GpuBufferAllocator(std::function buffer_manager_getter, bool is_read_only_allocator); virtual void* Alloc(size_t size) override; virtual void Free(void* p) override; @@ -26,7 +31,7 @@ class GpuBufferAllocator : public IAllocator { private: AllocatorStats stats_; - const BufferManager& buffer_manager_; + std::function buffer_manager_getter_; bool mapped_at_creation_; }; diff --git a/onnxruntime/core/providers/webgpu/ep/ep.cc b/onnxruntime/core/providers/webgpu/ep/ep.cc index e924caed2d086..24ff784a3a9a6 100644 --- a/onnxruntime/core/providers/webgpu/ep/ep.cc +++ b/onnxruntime/core/providers/webgpu/ep/ep.cc @@ -43,6 +43,7 @@ Ep::Ep(std::unique_ptr impl, Factory& factory, const OrtLogg IsGraphCaptureEnabled = IsGraphCaptureEnabledImpl; IsGraphCaptured = IsGraphCapturedImpl; ReplayGraph = ReplayGraphImpl; + ReleaseCapturedGraph = ReleaseCapturedGraphImpl; GetGraphCaptureNodeAssignmentPolicy = GetGraphCaptureNodeAssignmentPolicyImpl; } @@ -279,6 +280,18 @@ OrtStatus* ORT_API_CALL Ep::ReplayGraphImpl(_In_ OrtEp* this_ptr, _In_ int graph EXCEPTION_TO_RETURNED_STATUS_END } +OrtStatus* ORT_API_CALL Ep::ReleaseCapturedGraphImpl(_In_ OrtEp* this_ptr, _In_ int graph_annotation_id) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + auto* ep = static_cast(this_ptr); + auto status = ep->EpImpl()->ReleaseCapturedGraph(graph_annotation_id); + if (!status.IsOK()) { + return Api().ort.CreateStatus(static_cast(status.Code()), + status.ErrorMessage().c_str()); + } + return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END +} + OrtGraphCaptureNodeAssignmentPolicy ORT_API_CALL Ep::GetGraphCaptureNodeAssignmentPolicyImpl( _In_ const OrtEp* this_ptr) noexcept { auto* ep = static_cast(this_ptr); diff --git a/onnxruntime/core/providers/webgpu/ep/ep.h b/onnxruntime/core/providers/webgpu/ep/ep.h index 1b7c1b6057f45..29f23ed66c39f 100644 --- a/onnxruntime/core/providers/webgpu/ep/ep.h +++ b/onnxruntime/core/providers/webgpu/ep/ep.h @@ -75,6 +75,9 @@ class Ep : public onnxruntime::ep::adapter::Ep { static OrtStatus* ORT_API_CALL ReplayGraphImpl(_In_ OrtEp* this_ptr, _In_ int graph_annotation_id) noexcept; + static OrtStatus* ORT_API_CALL ReleaseCapturedGraphImpl(_In_ OrtEp* this_ptr, + _In_ int graph_annotation_id) noexcept; + static OrtGraphCaptureNodeAssignmentPolicy ORT_API_CALL GetGraphCaptureNodeAssignmentPolicyImpl( _In_ const OrtEp* this_ptr) noexcept; diff --git a/onnxruntime/core/providers/webgpu/ep/factory.cc b/onnxruntime/core/providers/webgpu/ep/factory.cc index 6d8e8724f72d9..2b812f4419a91 100644 --- a/onnxruntime/core/providers/webgpu/ep/factory.cc +++ b/onnxruntime/core/providers/webgpu/ep/factory.cc @@ -10,6 +10,7 @@ #include "core/framework/execution_provider.h" #include "core/framework/config_options.h" #include "core/providers/webgpu/webgpu_provider_factory_creator.h" +#include "core/providers/webgpu/webgpu_execution_provider.h" #include "core/providers/webgpu/webgpu_context.h" #include "core/providers/webgpu/allocator.h" @@ -134,10 +135,17 @@ OrtStatus* ORT_API_CALL Factory::CreateEpImpl( static_cast(webgpu_ep.get())->SetEpLogger(logger); auto factory = static_cast(this_ptr); const int context_id = webgpu_ep->GetDeviceId(); + auto* webgpu_ep_ptr = static_cast(webgpu_ep.get()); + auto device_alloc = std::make_shared( + [webgpu_ep_ptr]() -> const webgpu::BufferManager& { return webgpu_ep_ptr->BufferManager(); }, false); Ep::Config webgpu_ep_config{ - CPUAllocator::DefaultInstance(), // CPU allocator - std::make_shared(WebGpuContextFactory::GetContext(context_id).BufferManager(), false), // default device allocator - std::make_shared(WebGpuContextFactory::GetContext(context_id).InitializerBufferManager(), true), // initializer device allocator + CPUAllocator::DefaultInstance(), // CPU allocator + device_alloc, // default device allocator + std::make_shared( + [context_id]() -> const webgpu::BufferManager& { + return WebGpuContextFactory::GetContext(context_id).InitializerBufferManager(); + }, + true), // initializer device allocator }; *ep = new Ep(std::move(webgpu_ep), *factory, *logger, webgpu_ep_config); return nullptr; @@ -165,7 +173,12 @@ OrtStatus* ORT_API_CALL Factory::CreateAllocatorImpl( *allocator = new onnxruntime::ep::adapter::Allocator(memory_info, [](const OrtMemoryInfo&) -> AllocatorPtr { - return std::make_shared(WebGpuContextFactory::DefaultContext().BufferManager(), false); + return std::make_shared( + []() -> const webgpu::BufferManager& { + return WebGpuContextFactory::DefaultContext() + .BufferManager(); + }, + false); }); return nullptr; EXCEPTION_TO_RETURNED_STATUS_END diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index d1cde04277938..50b322c328d6f 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -577,17 +577,8 @@ WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id, // enable_int64_ is always true when enable_graph_capture_ is true enable_int64_{config.enable_graph_capture || config.enable_int64}, multi_rotary_cache_concat_offset_{config.multi_rotary_cache_concat_offset}, - prepack_allocator_{std::make_shared(context_.InitializerBufferManager(), false)} { - // If graph capture is enabled, create a dedicated buffer manager for graph mode - if (enable_graph_capture_) { - // Create buffer manager for graph capture mode with appropriate cache modes - graph_buffer_mgr_ = webgpu::BufferManagerFactory::Create( - context_, - webgpu::BufferCacheMode::Graph, - webgpu::BufferCacheMode::GraphSimple, - webgpu::BufferCacheMode::Disabled); - } - + prepack_allocator_{std::make_shared( + [this]() -> const webgpu::BufferManager& { return context_.InitializerBufferManager(); }, false)} { if (config.enable_pix_capture) { #if defined(ENABLE_PIX_FOR_WEBGPU_EP) // set pix frame generator @@ -599,11 +590,14 @@ WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id, } std::vector WebGpuExecutionProvider::CreatePreferredAllocators() { + auto device_allocator = std::make_unique( + [this]() -> const webgpu::BufferManager& { return BufferManager(); }, false); return { // allocator for initializers - std::make_unique(context_.InitializerBufferManager(), true), + std::make_unique( + [this]() -> const webgpu::BufferManager& { return context_.InitializerBufferManager(); }, true), // default allocator - std::make_unique(BufferManager(), false), + std::move(device_allocator), }; } @@ -733,11 +727,23 @@ std::optional WebGpuExecutionProvider::ShouldConvertDataLayoutForOp(std::s } WebGpuExecutionProvider::~WebGpuExecutionProvider() { - // Release all resources associated with the captured graph - if (!captured_commands_.empty()) { - context_.ReleaseGraphResources(captured_commands_); + // Release all captured graphs (both fully captured and partially captured) and their associated resources. + // Use captured_graphs_ keys to also cover partially captured graphs that have GPU command handles + // but never completed capture (i.e., CaptureBegin was called but CaptureEnd was not). + std::vector graph_ids; + graph_ids.reserve(captured_graphs_.size()); + for (const auto& [id, _] : captured_graphs_) { + graph_ids.push_back(id); } - // The graph_buffer_mgr_ will be automatically cleaned up by unique_ptr + for (int id : graph_ids) { + auto status = ReleaseCapturedGraph(id); + if (!status.IsOK()) { + LOGS(*GetLogger(), WARNING) << "Failed to release captured graph " << id << ": " << status.ErrorMessage(); + } + } + // Release any per-graph buffer managers for graphs that had buffer managers created + // but no entries in captured_graphs_ (edge case cleanup) + per_graph_buffer_mgrs_.clear(); WebGpuContextFactory::ReleaseContext(context_id_); } @@ -772,10 +778,25 @@ Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_op *graph_annotation_str); } - if (graph_annotation_id != -1 && IsGraphCaptureAllowed() && !IsGraphCaptured(graph_annotation_id)) { - context_.CaptureBegin(&captured_commands_, *graph_buffer_mgr_); + current_graph_annotation_id_ = graph_annotation_id; + + // Create a per-graph buffer manager on first encounter of each annotation ID + if (graph_annotation_id != -1) { + auto [it, inserted] = per_graph_buffer_mgrs_.try_emplace(graph_annotation_id, nullptr); + if (inserted) { + it->second = webgpu::BufferManagerFactory::Create( + context_, + webgpu::BufferCacheMode::Graph, + webgpu::BufferCacheMode::GraphSimple, + webgpu::BufferCacheMode::Disabled); + } + graph_buffer_mgr_active_ = true; + + if (IsGraphCaptureAllowed() && !IsGraphCaptured(graph_annotation_id)) { + auto& commands = captured_graphs_[graph_annotation_id]; + context_.CaptureBegin(&commands, *it->second); + } } - m_current_graph_annotation_id = graph_annotation_id; } return Status::OK(); @@ -784,11 +805,11 @@ Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_op Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxruntime::RunOptions& run_options) { context_.Flush(BufferManager()); - if (IsGraphCaptureEnabled() && !IsGraphCaptured(m_current_graph_annotation_id)) { - if (m_current_graph_annotation_id != -1 && IsGraphCaptureAllowed()) { + if (IsGraphCaptureEnabled() && !IsGraphCaptured(current_graph_annotation_id_)) { + if (current_graph_annotation_id_ != -1 && IsGraphCaptureAllowed()) { context_.CaptureEnd(); - is_graph_captured_ = true; - ORT_RETURN_IF_ERROR(ReplayGraph(m_current_graph_annotation_id)); + captured_graph_ids_.insert(current_graph_annotation_id_); + ORT_RETURN_IF_ERROR(ReplayGraph(current_graph_annotation_id_)); } else { IncrementRegularRunCountBeforeGraphCapture(); } @@ -808,6 +829,9 @@ Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxrunti } #endif // ENABLE_PIX_FOR_WEBGPU_EP + // Reset buffer manager routing after run completes + graph_buffer_mgr_active_ = false; + if (context_.ValidationMode() >= ValidationMode::Basic) { return context_.PopErrorScope(); } else { @@ -820,7 +844,7 @@ bool WebGpuExecutionProvider::IsGraphCaptureEnabled() const { } bool WebGpuExecutionProvider::IsGraphCaptured(int graph_annotation_id) const { - return is_graph_captured_ && graph_annotation_id != -1; + return graph_annotation_id != -1 && captured_graph_ids_.contains(graph_annotation_id); } Status WebGpuExecutionProvider::ReplayGraph(int graph_annotation_id) { @@ -829,7 +853,7 @@ Status WebGpuExecutionProvider::ReplayGraph(int graph_annotation_id) { if (session_profiler_ && session_profiler_->Enabled()) { context_.StartProfiling(); } - context_.Replay(captured_commands_, *graph_buffer_mgr_); + context_.Replay(captured_graphs_.at(graph_annotation_id), *per_graph_buffer_mgrs_.at(graph_annotation_id)); if (session_profiler_ && session_profiler_->Enabled()) { // Session-level profiling: collect into profiler's own events storage. context_.CollectProfilingData(session_profiler_->GpuEvents()); @@ -837,19 +861,45 @@ Status WebGpuExecutionProvider::ReplayGraph(int graph_annotation_id) { return Status::OK(); } +Status WebGpuExecutionProvider::ReleaseCapturedGraph(int graph_annotation_id) { + // Release captured commands + auto cmd_it = captured_graphs_.find(graph_annotation_id); + if (cmd_it != captured_graphs_.end()) { + if (!cmd_it->second.empty()) { + context_.ReleaseGraphResources(cmd_it->second); + } + captured_graphs_.erase(cmd_it); + } + + // Remove from captured set + captured_graph_ids_.erase(graph_annotation_id); + + // Release per-graph buffer manager (destroys cached buffers) + per_graph_buffer_mgrs_.erase(graph_annotation_id); + + // Clean up run count tracking + graph_id_to_run_count_.erase(graph_annotation_id); + + return Status::OK(); +} + webgpu::BufferManager& WebGpuExecutionProvider::BufferManager() const { - if (graph_buffer_mgr_) { - return *graph_buffer_mgr_; - } else { - return context_.BufferManager(); + if (graph_buffer_mgr_active_) { + auto it = per_graph_buffer_mgrs_.find(current_graph_annotation_id_); + if (it != per_graph_buffer_mgrs_.end()) { + return *it->second; + } } + return context_.BufferManager(); } bool WebGpuExecutionProvider::IsGraphCaptureAllowed() const { - return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_; + auto it = graph_id_to_run_count_.find(current_graph_annotation_id_); + int run_count = (it != graph_id_to_run_count_.end()) ? it->second : 0; + return run_count >= min_num_runs_before_graph_capture_; } void WebGpuExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() { - ++regular_run_count_before_graph_capture_; + ++graph_id_to_run_count_[current_graph_annotation_id_]; } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h index d1e2231dbba6f..92d1ebfd36c79 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h @@ -7,6 +7,8 @@ #include #include #include +#include +#include #include #include "core/framework/execution_provider.h" @@ -97,6 +99,7 @@ class WebGpuExecutionProvider : public IExecutionProvider { bool IsGraphCaptureEnabled() const override; bool IsGraphCaptured(int graph_annotation_id) const override; Status ReplayGraph(int graph_annotation_id) override; + Status ReleaseCapturedGraph(int graph_annotation_id) override; OrtGraphCaptureNodeAssignmentPolicy GetGraphCaptureNodeAssignmentPolicy() const override { return OrtGraphCaptureNodeAssignmentPolicy_ALLOW_CPU_FOR_SHAPES; } @@ -124,22 +127,27 @@ class WebGpuExecutionProvider : public IExecutionProvider { DataLayout preferred_data_layout_; std::vector force_cpu_node_names_; bool enable_graph_capture_ = false; + bool graph_buffer_mgr_active_ = false; bool enable_int64_ = false; uint32_t multi_rotary_cache_concat_offset_ = 0; - bool is_graph_captured_ = false; - int regular_run_count_before_graph_capture_ = 0; - const int min_num_runs_before_cuda_graph_capture_ = 1; // Required regular runs before graph capture for any necessary allocations. - int m_current_graph_annotation_id = 0; + std::unordered_map graph_id_to_run_count_; + // Required regular runs before graph capture for any necessary allocations. + const int min_num_runs_before_graph_capture_ = 0; + int current_graph_annotation_id_ = 0; #if defined(ENABLE_PIX_FOR_WEBGPU_EP) std::unique_ptr pix_frame_generator_ = nullptr; #endif // ENABLE_PIX_FOR_WEBGPU_EP - // Buffer manager specifically for graph capture mode - std::unique_ptr graph_buffer_mgr_ = nullptr; + // Per-graph buffer managers keyed by annotation ID. + // Each captured graph gets its own buffer manager so that buffer caches + // are isolated between different generators. + std::unordered_map> per_graph_buffer_mgrs_; - // Store captured commands directly in the EP instead of in WebGpuContext - std::vector captured_commands_; + // Store captured commands per graph annotation ID + std::unordered_map> captured_graphs_; + // Track which graph annotation IDs have completed capture + std::unordered_set captured_graph_ids_; // Allocator for prepacked weights (uses buffers without mapping) AllocatorPtr prepack_allocator_; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index c4b10a212450f..8f65c6fa698ef 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -3820,6 +3820,18 @@ common::Status InferenceSession::Run(IOBinding& io_binding) { return Run(run_options, io_binding); } +common::Status InferenceSession::ReleaseCapturedGraph(int graph_annotation_id) { + // Acquire session_mutex_ only when concurrent run is not supported, matching the + // locking pattern in Run(). For concurrent EPs the mutex is not held by Run(), + // so acquiring it here would not synchronize with in-flight runs; those EPs are + // responsible for their own thread safety in ReleaseCapturedGraph. + std::optional> lock; + if (!is_concurrent_run_supported_) { + lock.emplace(session_mutex_); + } + return cached_execution_provider_for_graph_replay_.ReleaseCapturedGraph(graph_annotation_id); +} + template void InferenceSession::StartProfiling(const std::basic_string& file_prefix) { std::basic_ostringstream ss; diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index d79cb1a6424f0..ffde672dec618 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -414,6 +414,12 @@ class InferenceSession { [[nodiscard]] virtual common::Status Run(const RunOptions& run_options, IOBinding& io_binding); [[nodiscard]] common::Status Run(IOBinding& io_binding); + /** + * Release a previously captured graph and its associated resources. + * @param graph_annotation_id The annotation ID of the captured graph to release. + */ + [[nodiscard]] common::Status ReleaseCapturedGraph(int graph_annotation_id); + #ifdef ENABLE_TRAINING /** * Partially run a pre-loaded and pre-intialized model. @@ -1087,6 +1093,13 @@ class InferenceSession { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Cached EP instance for graph replay is not set yet before calling ReplayGraph()"); } + Status ReleaseCapturedGraph(int graph_annotation_id) { + if (cached_execution_provider_for_graph_replay_) { + return cached_execution_provider_for_graph_replay_->ReleaseCapturedGraph(graph_annotation_id); + } + return Status::OK(); + } + const std::string& Type() const { ORT_ENFORCE(cached_execution_provider_for_graph_replay_ != nullptr, "No EP registered for graph replay yet"); diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 536c5e7ae3eda..549334564a1cf 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -4365,6 +4365,13 @@ ORT_API_STATUS_IMPL(OrtApis::SetPerSessionThreadPoolCallbacks, _Inout_ OrtEnv* o API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::SessionReleaseCapturedGraph, _In_ OrtSession* session, _In_ int graph_annotation_id) { + API_IMPL_BEGIN + auto* inference_session = reinterpret_cast<::onnxruntime::InferenceSession*>(session); + return ToOrtStatus(inference_session->ReleaseCapturedGraph(graph_annotation_id)); + API_IMPL_END +} + static constexpr OrtApiBase ort_api_base = { &OrtApis::GetApi, &OrtApis::GetVersionString}; @@ -4918,6 +4925,7 @@ static constexpr OrtApi ort_api_1_to_27 = { // End of Version 26 - DO NOT MODIFY ABOVE (see above text for more information) &OrtApis::GetMemPatternEnabled, &OrtApis::GetSessionExecutionMode, + &OrtApis::SessionReleaseCapturedGraph, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 00ec258cef91e..adccfe09bc3f7 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -431,6 +431,8 @@ ORT_API_STATUS_IMPL(SetGlobalIntraOpThreadAffinity, _Inout_ OrtThreadingOptions* ORT_API_STATUS_IMPL(SetPerSessionThreadPoolCallbacks, _Inout_ OrtEnv* ort_env, _In_ const OrtThreadPoolCallbacksConfig* config); +ORT_API_STATUS_IMPL(SessionReleaseCapturedGraph, _In_ OrtSession* session, _In_ int graph_annotation_id); + ORT_API_STATUS_IMPL(RegisterCustomOpsLibrary_V2, _Inout_ OrtSessionOptions* options, _In_ const ORTCHAR_T* library_name); ORT_API_STATUS_IMPL(RegisterCustomOpsUsingFunction, _Inout_ OrtSessionOptions* options, diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc index 0a7c05158e0b4..1e8839b33ab5b 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc @@ -1034,6 +1034,16 @@ Status PluginExecutionProvider::ReplayGraph(int graph_annotation_id) { return ToStatusAndRelease(ort_ep_->ReplayGraph(ort_ep_.get(), graph_annotation_id)); } +Status PluginExecutionProvider::ReleaseCapturedGraph(int graph_annotation_id) { + // For plugin EPs that don't implement ReleaseCapturedGraph (version < 27 or null function pointer), + // fall back to the base class no-op implementation. This is intentional: the request is silently + // ignored since the plugin EP doesn't support explicit graph resource release. + if (ort_ep_->ort_version_supported < 27 || ort_ep_->ReleaseCapturedGraph == nullptr) { + return Base::ReleaseCapturedGraph(graph_annotation_id); + } + return ToStatusAndRelease(ort_ep_->ReleaseCapturedGraph(ort_ep_.get(), graph_annotation_id)); +} + OrtGraphCaptureNodeAssignmentPolicy PluginExecutionProvider::GetGraphCaptureNodeAssignmentPolicy() const { if (ort_ep_->ort_version_supported < 26 || ort_ep_->GetGraphCaptureNodeAssignmentPolicy == nullptr) { return OrtGraphCaptureNodeAssignmentPolicy_ALL_NODES_ON_EP; diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h index 48e769ccafaa3..5c79f6500fe6f 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h @@ -149,6 +149,7 @@ class PluginExecutionProvider : public IExecutionProvider { bool IsGraphCaptureEnabled() const override; bool IsGraphCaptured(int graph_annotation_id) const override; common::Status ReplayGraph(int graph_annotation_id) override; + common::Status ReleaseCapturedGraph(int graph_annotation_id) override; OrtGraphCaptureNodeAssignmentPolicy GetGraphCaptureNodeAssignmentPolicy() const override; private: diff --git a/onnxruntime/test/framework/ep_plugin_provider_test.cc b/onnxruntime/test/framework/ep_plugin_provider_test.cc index 80b638314bad9..075128140fb9f 100644 --- a/onnxruntime/test/framework/ep_plugin_provider_test.cc +++ b/onnxruntime/test/framework/ep_plugin_provider_test.cc @@ -1380,6 +1380,49 @@ TEST(PluginExecutionProviderTest, ReplayGraph) { } } +TEST(PluginExecutionProviderTest, ReleaseCapturedGraph) { + auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(); + + { + // NULL function pointer should return OK (default behavior). + ort_ep->ReleaseCapturedGraph = nullptr; + ASSERT_STATUS_OK(ep->ReleaseCapturedGraph(0)); + } + + { + // Non-NULL implementation returning OK. + auto release_ok = [](OrtEp* /*this_ptr*/, int /*graph_annotation_id*/) noexcept -> ::OrtStatus* { + return nullptr; + }; + ort_ep->ReleaseCapturedGraph = release_ok; + ASSERT_STATUS_OK(ep->ReleaseCapturedGraph(0)); + } + + { + // Non-NULL implementation returning an error. + auto release_fail = [](OrtEp* this_ptr, int /*graph_annotation_id*/) noexcept -> ::OrtStatus* { + auto* test_ort_ep = static_cast(this_ptr); + return test_ort_ep->ort_api->CreateStatus(OrtErrorCode::ORT_FAIL, "Release captured graph failed"); + }; + ort_ep->ReleaseCapturedGraph = release_fail; + auto status = ep->ReleaseCapturedGraph(0); + ASSERT_FALSE(status.IsOK()); + ASSERT_THAT(status.ErrorMessage(), ::testing::HasSubstr("Release captured graph failed")); + } + + { + // Backward compatibility: version < 27 should return OK even if function pointer is set. + auto release_fail = [](OrtEp* this_ptr, int /*graph_annotation_id*/) noexcept -> ::OrtStatus* { + auto* test_ort_ep = static_cast(this_ptr); + return test_ort_ep->ort_api->CreateStatus(OrtErrorCode::ORT_FAIL, "Should not be called"); + }; + ort_ep->ReleaseCapturedGraph = release_fail; + ort_ep->ort_version_supported = 26; + ASSERT_STATUS_OK(ep->ReleaseCapturedGraph(0)); + ort_ep->ort_version_supported = ORT_API_VERSION; // Restore. + } +} + TEST(PluginExecutionProviderTest, GetGraphCaptureNodeAssignmentPolicy) { auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(); diff --git a/onnxruntime/test/providers/graph_capture_test.cc b/onnxruntime/test/providers/graph_capture_test.cc new file mode 100644 index 0000000000000..012ab571f3550 --- /dev/null +++ b/onnxruntime/test/providers/graph_capture_test.cc @@ -0,0 +1,245 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef USE_WEBGPU + +#include + +#include "gtest/gtest.h" + +#include "core/graph/constants.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "core/session/onnxruntime_run_options_config_keys.h" + +using namespace Ort; + +extern std::unique_ptr ort_env; + +namespace { + +// Append WebGPU EP to session options, handling both built-in and plugin builds. +void AppendWebGpuEp(Env& env, SessionOptions& session_options, + const std::unordered_map& provider_options) { +#if defined(ORT_USE_EP_API_ADAPTERS) + // Plugin build: find the WebGPU EpDevice and use V2 API + auto ep_devices = env.GetEpDevices(); + std::vector webgpu_devices; + for (const auto& device : ep_devices) { + if (std::string(device.EpName()) == onnxruntime::kWebGpuExecutionProvider) { + webgpu_devices.push_back(device); + break; + } + } + ASSERT_FALSE(webgpu_devices.empty()) << "No WebGPU EP device found after plugin registration"; + session_options.AppendExecutionProvider_V2(env, webgpu_devices, provider_options); +#else + static_cast(env); + session_options.AppendExecutionProvider("WebGPU", provider_options); +#endif +} + +// Build a model: Y = MatMul(Relu(MatMul(A, B)), C) +// All shapes are unspecified (free dimensions) to keep it simple. +static Model CreateMatMulReluMatMulModel() { + Graph graph; + + // Inputs: A[3x4], B[4x3], C[3x2] — float tensors + std::vector dims_a_shape = {3, 4}; + std::vector dims_b_shape = {4, 3}; + std::vector dims_c_shape = {3, 2}; + std::vector dims_y_shape = {3, 2}; + + TensorTypeAndShapeInfo a_info(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, dims_a_shape); + TensorTypeAndShapeInfo b_info(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, dims_b_shape); + TensorTypeAndShapeInfo c_info(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, dims_c_shape); + TensorTypeAndShapeInfo y_info(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, dims_y_shape); + + auto a_type = TypeInfo::CreateTensorInfo(a_info.GetConst()); + auto b_type = TypeInfo::CreateTensorInfo(b_info.GetConst()); + auto c_type = TypeInfo::CreateTensorInfo(c_info.GetConst()); + auto y_type = TypeInfo::CreateTensorInfo(y_info.GetConst()); + + std::vector inputs; + inputs.emplace_back("A", a_type.GetConst()); + inputs.emplace_back("B", b_type.GetConst()); + inputs.emplace_back("C", c_type.GetConst()); + + std::vector outputs; + outputs.emplace_back("Y", y_type.GetConst()); + + graph.SetInputs(inputs); + graph.SetOutputs(outputs); + + // MatMul(A, B) -> T1 + Node matmul1("MatMul", onnxruntime::kOnnxDomain, "matmul1", {"A", "B"}, {"T1"}); + graph.AddNode(matmul1); + + // Relu(T1) -> T2 + Node relu("Relu", onnxruntime::kOnnxDomain, "relu", {"T1"}, {"T2"}); + graph.AddNode(relu); + + // MatMul(T2, C) -> Y + Node matmul2("MatMul", onnxruntime::kOnnxDomain, "matmul2", {"T2", "C"}, {"Y"}); + graph.AddNode(matmul2); + + std::vector opsets{{onnxruntime::kOnnxDomain, 13}}; + Model model(opsets); + model.AddGraph(graph); + return model; +} + +TEST(GraphCaptureTests, TestReleaseCapturedGraph) { + Env& env = *ort_env; + + // Create session with WebGPU EP and graph capture enabled + SessionOptions session_options; + session_options.DisableMemPattern(); + std::unordered_map provider_options; + provider_options["enableGraphCapture"] = "1"; + AppendWebGpuEp(env, session_options, provider_options); + + auto model = CreateMatMulReluMatMulModel(); + Session session(env, model, session_options); + + // Get GPU allocator from session + MemoryInfo gpu_mem_info("WebGPU_Buffer", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemTypeDefault); + Allocator gpu_allocator(session, gpu_mem_info); + + // Model: Y = MatMul(Relu(MatMul(A[3x4], B[4x3])), C[3x2]) => Y[3x2] + std::vector dims_a = {3, 4}; + std::vector dims_b = {4, 3}; + std::vector dims_c = {3, 2}; + std::vector dims_y = {3, 2}; + + // Input set 1 + std::vector values_a1(12); + std::iota(values_a1.begin(), values_a1.end(), 0.0f); // 0..11 + std::vector values_b(12); + std::iota(values_b.begin(), values_b.end(), 0.0f); // 0..11 + std::vector values_c = {1.0f, 0.0f, 0.0f, 1.0f, 1.0f, 1.0f}; + + // Input set 2 + std::vector values_a2 = {12.0f, 11.0f, 10.0f, 9.0f, 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f}; + + MemoryInfo cpu_mem_info = MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); + + // Pre-compute expected outputs on CPU: + // T1 = MatMul(A, B), T2 = Relu(T1), Y = MatMul(T2, C) + auto compute_expected = [&](const std::vector& a) -> std::vector { + // T1 = A[3x4] * B[4x3] + std::vector t1(9, 0.0f); + for (int i = 0; i < 3; ++i) + for (int j = 0; j < 3; ++j) + for (int k = 0; k < 4; ++k) + t1[i * 3 + j] += a[i * 4 + k] * values_b[k * 3 + j]; + + // T2 = Relu(T1) + std::vector t2(9); + for (int i = 0; i < 9; ++i) + t2[i] = std::max(0.0f, t1[i]); + + // Y = T2[3x3] * C[3x2] + std::vector y(6, 0.0f); + for (int i = 0; i < 3; ++i) + for (int j = 0; j < 2; ++j) + for (int k = 0; k < 3; ++k) + y[i * 2 + j] += t2[i * 3 + k] * values_c[k * 2 + j]; + + return y; + }; + + auto expected_y1 = compute_expected(values_a1); + auto expected_y2 = compute_expected(values_a2); + + // Allocate GPU tensors + Value gpu_a = Value::CreateTensor(gpu_allocator, dims_a.data(), dims_a.size()); + Value gpu_b = Value::CreateTensor(gpu_allocator, dims_b.data(), dims_b.size()); + Value gpu_c = Value::CreateTensor(gpu_allocator, dims_c.data(), dims_c.size()); + Value gpu_y = Value::CreateTensor(gpu_allocator, dims_y.data(), dims_y.size()); + + // Helper: copy CPU tensor to GPU tensor via Env::CopyTensor + auto copy_to_gpu = [&](float* data, size_t count, const int64_t* shape, size_t shape_len, Value& gpu_tensor) { + Value cpu_tensor = Value::CreateTensor(cpu_mem_info, data, count, shape, shape_len); + auto status = env.CopyTensor(cpu_tensor, gpu_tensor, nullptr); + ASSERT_TRUE(status.IsOK()) << status.GetErrorMessage(); + }; + + // Upload initial inputs (B and C are constant across all runs) + copy_to_gpu(values_a1.data(), values_a1.size(), dims_a.data(), dims_a.size(), gpu_a); + copy_to_gpu(values_b.data(), values_b.size(), dims_b.data(), dims_b.size(), gpu_b); + copy_to_gpu(values_c.data(), values_c.size(), dims_c.data(), dims_c.size(), gpu_c); + + // Set up IoBinding + IoBinding io_binding(session); + io_binding.BindInput("A", gpu_a); + io_binding.BindInput("B", gpu_b); + io_binding.BindInput("C", gpu_c); + io_binding.BindOutput("Y", gpu_y); + io_binding.SynchronizeInputs(); + + // Helper: verify GPU output matches expected values + auto verify_output = [&](const std::vector& expected) { + std::vector result(expected.size(), 0.0f); + Value cpu_result = Value::CreateTensor(cpu_mem_info, result.data(), result.size(), + dims_y.data(), dims_y.size()); + auto status = env.CopyTensor(gpu_y, cpu_result, nullptr); + ASSERT_TRUE(status.IsOK()) << status.GetErrorMessage(); + for (size_t i = 0; i < expected.size(); ++i) { + ASSERT_FLOAT_EQ(result[i], expected[i]) << "Mismatch at index " << i; + } + }; + + // Helper: upload new A values to GPU + auto set_input_a = [&](std::vector& values) { + copy_to_gpu(values.data(), values.size(), dims_a.data(), dims_a.size(), gpu_a); + }; + + // Helper: run with a given annotation ID + auto run_with_id = [&](const char* id) { + RunOptions run_options; + run_options.AddConfigEntry(kOrtRunOptionsConfigCudaGraphAnnotation, id); + session.Run(run_options, io_binding); + }; + + // Capture graph for annotation ID 1 (input set 1) + run_with_id("1"); + verify_output(expected_y1); + + // Replay ID 1 + run_with_id("1"); + verify_output(expected_y1); + + // Capture graph for annotation ID 2 (input set 2) + set_input_a(values_a2); + run_with_id("2"); + verify_output(expected_y2); + + // Replay ID 2 + run_with_id("2"); + verify_output(expected_y2); + + // Replay ID 1 again (cross-ID isolation) + set_input_a(values_a1); + run_with_id("1"); + verify_output(expected_y1); + + // Release ID 1 using the C++ API + session.ReleaseCapturedGraph(1); + + // Replay ID 2 (unaffected by ID 1 release) + set_input_a(values_a2); + run_with_id("2"); + verify_output(expected_y2); + + // Re-capture ID 1 after release + run_with_id("1"); + verify_output(expected_y2); + + // Release both using the C++ API + session.ReleaseCapturedGraph(1); + session.ReleaseCapturedGraph(2); +} + +} // namespace + +#endif // USE_WEBGPU diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 19311c99c6f6f..1c559bb1c01ed 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -284,28 +284,26 @@ std::unique_ptr DefaultWebGpuExecutionProvider(bool is_nhwc) // Helper to strip the EP prefix from config entry keys when building as a plugin EP. // The full key is like "ep.webgpuexecutionprovider.storageBufferCacheMode", and the // config entry expects just "storageBufferCacheMode" in the EP API build. - // Returns a pointer into the original string, so the result is valid as long as the input is. - auto strip_ep_prefix = [](const char* full_key) -> const char* { + auto normalize_config_key = [](const char* key) -> std::string { #if defined(ORT_USE_EP_API_ADAPTERS) - std::string_view key{full_key}; + std::string normalized_key = key; std::string prefix = OrtSessionOptions::GetProviderOptionPrefix(kWebGpuExecutionProvider); - ORT_ENFORCE(key.length() >= prefix.length() && key.substr(0, prefix.length()) == prefix, - "Config key \"", key, "\" does not start with expected prefix \"", prefix, "\""); - return full_key + prefix.length(); + if (normalized_key.starts_with(prefix)) { + normalized_key.erase(0, prefix.length()); + } + return normalized_key; #else - return full_key; + return key; #endif }; // Disable storage buffer cache - ORT_ENFORCE(config_options.AddConfigEntry(strip_ep_prefix(webgpu::options::kStorageBufferCacheMode), - webgpu::options::kBufferCacheMode_Disabled) - .IsOK()); + ORT_THROW_IF_ERROR(config_options.AddConfigEntry(normalize_config_key(webgpu::options::kStorageBufferCacheMode).c_str(), + webgpu::options::kBufferCacheMode_Disabled)); if (!is_nhwc) { // Enable NCHW support - ORT_ENFORCE(config_options.AddConfigEntry(strip_ep_prefix(webgpu::options::kPreferredLayout), - webgpu::options::kPreferredLayout_NCHW) - .IsOK()); + ORT_THROW_IF_ERROR(config_options.AddConfigEntry(normalize_config_key(webgpu::options::kPreferredLayout).c_str(), + webgpu::options::kPreferredLayout_NCHW)); } return WebGpuExecutionProviderWithOptions(config_options); @@ -318,11 +316,21 @@ std::unique_ptr DefaultWebGpuExecutionProvider(bool is_nhwc) std::unique_ptr WebGpuExecutionProviderWithOptions(const ConfigOptions& config_options) { #if defined(USE_WEBGPU) #if defined(ORT_USE_EP_API_ADAPTERS) + ConfigOptions normalized_config_options{}; + const std::string prefix = OrtSessionOptions::GetProviderOptionPrefix(kWebGpuExecutionProvider); + for (const auto& [key, value] : config_options.GetConfigOptionsMap()) { + std::string normalized_key = key; + if (normalized_key.starts_with(prefix)) { + normalized_key.erase(0, prefix.length()); + } + ORT_THROW_IF_ERROR(normalized_config_options.AddConfigEntry(normalized_key.c_str(), value.c_str())); + } + auto ep_name = dynamic_plugin_ep_infra::GetEpName(); ORT_ENFORCE(ep_name == kWebGpuExecutionProvider, "Dynamic plugin EP is not the WebGPU EP. Expected \"", kWebGpuExecutionProvider, "\", got \"", ep_name.value_or(""), "\""); - return dynamic_plugin_ep_infra::MakeEp(nullptr, &config_options); + return dynamic_plugin_ep_infra::MakeEp(nullptr, &normalized_config_options); #else return WebGpuProviderFactoryCreator::Create(config_options)->CreateProvider(); #endif