From 9b784809a7e6bdca3c56843265e700f76c20b827 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Thu, 9 Apr 2026 12:26:58 -0700 Subject: [PATCH 01/14] CUDA EP ResourceAcountant integration --- .../core/session/onnxruntime_cxx_api.h | 53 +++ .../core/session/onnxruntime_cxx_inline.h | 39 ++ .../core/session/onnxruntime_ep_c_api.h | 172 ++++++++ .../core/providers/cuda/plugin/cuda_ep.cc | 69 +++- onnxruntime/core/session/abi_ep_types.h | 9 + onnxruntime/core/session/plugin_ep/ep_api.cc | 148 +++++++ onnxruntime/core/session/plugin_ep/ep_api.h | 15 + .../ep_plugin_provider_interfaces.cc | 40 +- .../framework/resource_accountant_test.cc | 346 +++++++++------- .../plugin/cuda_resource_partitioning_test.cc | 373 ++++++++++++++++++ 10 files changed, 1121 insertions(+), 143 deletions(-) create mode 100644 onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 9ae0814fb9dc1..3764dc85f683a 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -3725,5 +3725,58 @@ using UnownedSharedPrePackedWeightCache = ///< Wraps OrtEpApi::GetEnvConfigEntries() Ort::KeyValuePairs GetEnvConfigEntries(); + +/// \brief Non-owning C++ wrapper for resource budget queries on OrtEpGraphSupportInfo. +/// +/// Constructed from the OrtEpGraphSupportInfo* passed to OrtEp::GetCapability. +/// Provides convenient methods for resource-constrained node selection. +/// All costs and budgets use OrtResourceCount, the ABI-stable tagged union. +/// +/// Example use in a plugin EP's GetCapability implementation: +/// \code +/// OrtStatus* GetCapabilityImpl(OrtEp*, const OrtGraph* graph, +/// OrtEpGraphSupportInfo* info) noexcept { +/// Ort::ResourceBudget budget(info); +/// if (budget.HasBudget()) { +/// OrtResourceCount remaining = budget.GetBudget(); +/// OrtResourceCount consumed = budget.GetConsumedResources(); +/// for (const OrtNode* node : candidates) { +/// OrtResourceCount cost = budget.ComputeNodeCost(Ort::ConstNode{node}); +/// if (consumed.AsTotalBytes() + cost.AsTotalBytes() > remaining.AsTotalBytes()) { +/// budget.SignalStopAssignment(); +/// break; +/// } +/// budget.ReportAcceptedNodeCost(Ort::ConstNode{node}, cost); +/// } +/// } +/// } +/// \endcode +struct ResourceBudget { + explicit ResourceBudget(OrtEpGraphSupportInfo* info) : info_(info) {} + + /// Returns true if a resource budget is configured for this EP. + bool HasBudget() const; + + /// Returns the total resource budget. Only valid if HasBudget() is true. + OrtResourceCount GetBudget() const; + + /// Returns the amount of resources already consumed. + OrtResourceCount GetConsumedResources() const; + + /// Computes the estimated resource cost of the given node. + OrtResourceCount ComputeNodeCost(ConstNode node) const; + + /// Reports that the plugin accepted a node at the given cost. + void ReportAcceptedNodeCost(ConstNode node, OrtResourceCount cost); + + /// Returns true if stop has been signaled (by this or another EP). + bool IsStopIssued() const; + + /// Signals that this EP wants to stop receiving nodes. + void SignalStopAssignment(); + + private: + OrtEpGraphSupportInfo* info_; +}; } // namespace Ort #include "onnxruntime_cxx_inline.h" diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 152f548673729..35ba44effa86f 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -4169,4 +4169,43 @@ inline OpSchema GetOpSchema(const char* name, int max_inclusive_version, const c ThrowOnError(GetEpApi().GetOpSchema(name, max_inclusive_version, domain, &schema)); return OpSchema{schema}; } + +// ResourceBudget implementation +inline bool ResourceBudget::HasBudget() const { + bool has_budget = false; + ThrowOnError(GetEpApi().EpGraphSupportInfo_HasResourceBudget(info_, &has_budget)); + return has_budget; +} + +inline OrtResourceCount ResourceBudget::GetBudget() const { + OrtResourceCount budget = OrtResourceCount::None(); + ThrowOnError(GetEpApi().EpGraphSupportInfo_GetResourceBudget(info_, &budget)); + return budget; +} + +inline OrtResourceCount ResourceBudget::GetConsumedResources() const { + OrtResourceCount consumed = OrtResourceCount::None(); + ThrowOnError(GetEpApi().EpGraphSupportInfo_GetConsumedResources(info_, &consumed)); + return consumed; +} + +inline OrtResourceCount ResourceBudget::ComputeNodeCost(ConstNode node) const { + OrtResourceCount cost = OrtResourceCount::None(); + ThrowOnError(GetEpApi().EpGraphSupportInfo_ComputeNodeResourceCost(info_, node, &cost)); + return cost; +} + +inline void ResourceBudget::ReportAcceptedNodeCost(ConstNode node, OrtResourceCount cost) { + ThrowOnError(GetEpApi().EpGraphSupportInfo_ReportAcceptedNodeCost(info_, node, cost)); +} + +inline bool ResourceBudget::IsStopIssued() const { + bool stop = false; + ThrowOnError(GetEpApi().EpGraphSupportInfo_IsStopIssued(info_, &stop)); + return stop; +} + +inline void ResourceBudget::SignalStopAssignment() { + ThrowOnError(GetEpApi().EpGraphSupportInfo_SignalStopAssignment(info_)); +} } // namespace Ort diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 07d9ca19e766d..a68d250e1157b 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -946,6 +946,62 @@ struct OrtScanKernelHelper { _In_ const OrtValue* input, _In_opt_ OrtSyncStream* stream, _Inout_ OrtValue* output); }; +/** + * \brief Discriminator for the resource count type stored in an OrtResourceCount. + * + * New resource accounting types can be added by appending new enum values. + * The OrtResourceCount union storage is large enough to hold all current and future types. + * + * \since Version 1.26. + */ +typedef enum OrtResourceCountKind { + OrtResourceCountKind_None = 0, ///< Unset / zero-cost sentinel. + OrtResourceCountKind_TotalBytes = 1, ///< Single size_t: total estimated bytes. +} OrtResourceCountKind; + +/** + * \brief ABI-stable tagged union representing a resource cost or budget. + * + * This struct is a C-safe variant that can be passed by value across the plugin DLL boundary. + * The `kind` field selects which union member is active. The `_storage` member reserves space + * for future resource types without changing the struct layout. + * + * Adding new resource types requires only: (a) a new OrtResourceCountKind enum value, + * (b) a new union member. No new C API functions are needed. + * + * \since Version 1.26. + */ +typedef struct OrtResourceCount { + OrtResourceCountKind kind; + uint32_t reserved_; ///< Alignment padding + future flags. Must be zero. + + union { + size_t total_bytes; ///< Active when kind == OrtResourceCountKind_TotalBytes. + uint8_t _storage[48]; ///< ABI reserve: all types must fit within 48 bytes. + } value; + +#ifdef __cplusplus + /// Create a zero-cost (None) resource count. + static OrtResourceCount None() { + OrtResourceCount rc{}; + return rc; + } + + /// Create a TotalBytes resource count. + static OrtResourceCount FromTotalBytes(size_t bytes) { + OrtResourceCount rc{}; + rc.kind = OrtResourceCountKind_TotalBytes; + rc.value.total_bytes = bytes; + return rc; + } + + /// Extract total_bytes. Caller must check kind == OrtResourceCountKind_TotalBytes first. + size_t AsTotalBytes() const { + return value.total_bytes; + } +#endif +} OrtResourceCount; + /** * \brief The OrtEpApi struct provides functions that are relevant to the implementation of an execution provider. * @@ -2010,6 +2066,122 @@ struct OrtEpApi { ORT_API2_STATUS(ProfilingEventsContainer_AddEvents, _In_ OrtProfilingEventsContainer* events_container, _In_reads_(num_events) const OrtProfilingEvent* const* events, _In_ size_t num_events); + + /** \brief Query whether resource accounting is active for this GetCapability call. + * + * Returns true if a resource accountant is attached to the given OrtEpGraphSupportInfo instance, + * meaning the session was configured with resource-constrained partitioning settings. + * + * \param[in] graph_support_info The OrtEpGraphSupportInfo instance from OrtEp::GetCapability(). + * \param[out] has_budget Output parameter set to true if a resource budget is active. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.26. + */ + ORT_API2_STATUS(EpGraphSupportInfo_HasResourceBudget, _In_ const OrtEpGraphSupportInfo* graph_support_info, + _Out_ bool* has_budget); + + /** \brief Get the total resource budget. + * + * Only valid if EpGraphSupportInfo_HasResourceBudget returns true. + * If the accountant has no explicit threshold (e.g. auto-detection mode), + * the returned OrtResourceCount will have kind == OrtResourceCountKind_TotalBytes with + * value.total_bytes set to SIZE_MAX. + * + * \param[in] graph_support_info The OrtEpGraphSupportInfo instance from OrtEp::GetCapability(). + * \param[out] budget Output parameter set to the total resource budget. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.26. + */ + ORT_API2_STATUS(EpGraphSupportInfo_GetResourceBudget, _In_ const OrtEpGraphSupportInfo* graph_support_info, + _Out_ OrtResourceCount* budget); + + /** \brief Get the amount of resources already consumed from prior partitioning passes or previously assigned nodes. + * + * Only valid if EpGraphSupportInfo_HasResourceBudget returns true. + * + * \param[in] graph_support_info The OrtEpGraphSupportInfo instance from OrtEp::GetCapability(). + * \param[out] consumed Output parameter set to the consumed resource amount. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.26. + */ + ORT_API2_STATUS(EpGraphSupportInfo_GetConsumedResources, _In_ const OrtEpGraphSupportInfo* graph_support_info, + _Out_ OrtResourceCount* consumed); + + /** \brief Compute the estimated resource cost for a node. + * + * Uses pre-recorded memory statistics if available, otherwise estimates from initializer sizes + * and static output shapes with a safety multiplier. + * + * Only valid if EpGraphSupportInfo_HasResourceBudget returns true. + * + * \param[in] graph_support_info The OrtEpGraphSupportInfo instance from OrtEp::GetCapability(). + * \param[in] node The OrtNode for which to compute the resource cost. + * \param[out] cost Output parameter set to the estimated resource cost. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.26. + */ + ORT_API2_STATUS(EpGraphSupportInfo_ComputeNodeResourceCost, _In_ OrtEpGraphSupportInfo* graph_support_info, + _In_ const OrtNode* node, _Out_ OrtResourceCount* cost); + + /** \brief Report that a node was accepted and its cost should be tracked. + * + * The cost is stored internally so the host can attach it to the IndexedSubGraph after + * GetCapability returns. This does NOT commit the cost to the accountant's consumed amount — + * that happens later during node assignment by the graph partitioner. + * + * Only valid if EpGraphSupportInfo_HasResourceBudget returns true. + * + * \param[in] graph_support_info The OrtEpGraphSupportInfo instance from OrtEp::GetCapability(). + * \param[in] node The OrtNode whose cost is being reported. + * \param[in] cost The cost (as returned by EpGraphSupportInfo_ComputeNodeResourceCost). + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.26. + */ + ORT_API2_STATUS(EpGraphSupportInfo_ReportAcceptedNodeCost, _In_ OrtEpGraphSupportInfo* graph_support_info, + _In_ const OrtNode* node, _In_ OrtResourceCount cost); + + /** \brief Query whether a previous GetCapability pass already signaled stop. + * + * Returns true if EpGraphSupportInfo_SignalStopAssignment was called in a prior pass + * (or by another mechanism). The plugin can use this to early-exit from GetCapability + * without re-evaluating nodes. + * + * Only valid if EpGraphSupportInfo_HasResourceBudget returns true. + * + * \param[in] graph_support_info The OrtEpGraphSupportInfo instance from OrtEp::GetCapability(). + * \param[out] is_stopped Output parameter set to true if stop was previously signaled. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.26. + */ + ORT_API2_STATUS(EpGraphSupportInfo_IsStopIssued, _In_ const OrtEpGraphSupportInfo* graph_support_info, + _Out_ bool* is_stopped); + + /** \brief Signal that the EP wants to stop accepting further nodes due to budget exhaustion. + * + * After this call, the accountant's stop flag is set. Subsequent GetCapability calls for this EP + * will see EpGraphSupportInfo_IsStopIssued() returning true and can return early. + * + * Only valid if EpGraphSupportInfo_HasResourceBudget returns true. + * + * \param[in] graph_support_info The OrtEpGraphSupportInfo instance from OrtEp::GetCapability(). + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.26. + */ + ORT_API2_STATUS(EpGraphSupportInfo_SignalStopAssignment, _In_ OrtEpGraphSupportInfo* graph_support_info); }; /** diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc index 76a676116dc8e..87d9de95bed49 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc @@ -10,6 +10,7 @@ #include "ep/get_capability_utils.h" #include +#include #include #include #include @@ -98,6 +99,18 @@ OrtStatus* ORT_API_CALL CudaEp::GetCapabilityImpl( auto* ep = static_cast(this_ptr); const OrtEpApi& ep_api = ep->factory_.GetEpApi(); + // Early exit if a previous GetCapability pass already signaled stop. + // This mirrors the in-tree CUDA EP's check at the top of GetCapability(). + Ort::ResourceBudget resource_budget(graph_support_info); + bool has_budget = resource_budget.HasBudget(); + if (has_budget && resource_budget.IsStopIssued()) { + Ort::Status log_status(Ort::GetApi().Logger_LogMessage( + &ep->logger_, ORT_LOGGING_LEVEL_WARNING, + "CUDA Plugin EP returning due to Stop Set", + ORT_FILE, __LINE__, __FUNCTION__)); + return nullptr; + } + Ort::ConstGraph graph{ort_graph}; std::vector all_nodes = graph.GetNodes(); @@ -144,13 +157,59 @@ OrtStatus* ORT_API_CALL CudaEp::GetCapabilityImpl( gsl::span(tentative_nodes.data(), tentative_nodes.size()), cpu_preferred_nodes)); - // Phase 3: Add final supported nodes (tentative minus CPU-preferred). + // Phase 3: Add final supported nodes (tentative minus CPU-preferred), + // respecting the optional resource budget. + // resource_budget and has_budget were computed at the top of this function. + size_t budget_bytes = std::numeric_limits::max(); + size_t consumed_bytes = 0; + if (has_budget) { + budget_bytes = resource_budget.GetBudget().AsTotalBytes(); + consumed_bytes = resource_budget.GetConsumedResources().AsTotalBytes(); + } + for (const OrtNode* ort_node : candidate_nodes) { - if (cpu_preferred_nodes.count(ort_node) == 0) { - Ort::ConstNode node{ort_node}; - RETURN_IF_ERROR(ep_api.EpGraphSupportInfo_AddSingleNode( - graph_support_info, node)); + if (cpu_preferred_nodes.count(ort_node) != 0) { + continue; } + + // Previously assigned nodes (ep_name matched) are already accounted for. + Ort::ConstNode node{ort_node}; + bool previously_assigned = !node.GetEpName().empty(); + + if (has_budget && !previously_assigned) { + OrtResourceCount cost = resource_budget.ComputeNodeCost(node); + size_t cost_bytes = cost.AsTotalBytes(); + size_t would_be_consumed = consumed_bytes + cost_bytes; + + { + // Log per-node cost information (mirrors in-tree CUDA EP logging) + std::string msg = "CUDA Plugin EP Node: " + node.GetName() + + " Memory usage: " + std::to_string(cost_bytes) + + " would be consumed: " + std::to_string(would_be_consumed) + + " threshold: " + std::to_string(budget_bytes); + Ort::Status log_status(Ort::GetApi().Logger_LogMessage( + &ep->logger_, ORT_LOGGING_LEVEL_INFO, + msg.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } + + if (would_be_consumed > budget_bytes) { + { + std::string msg = "CUDA Plugin EP Halting assignment due to capacity threshold at node: " + + node.GetName(); + Ort::Status log_status(Ort::GetApi().Logger_LogMessage( + &ep->logger_, ORT_LOGGING_LEVEL_WARNING, + msg.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } + resource_budget.SignalStopAssignment(); + break; // topological-order halt + } + + consumed_bytes = would_be_consumed; + resource_budget.ReportAcceptedNodeCost(node, cost); + } + + RETURN_IF_ERROR(ep_api.EpGraphSupportInfo_AddSingleNode( + graph_support_info, node)); } return nullptr; diff --git a/onnxruntime/core/session/abi_ep_types.h b/onnxruntime/core/session/abi_ep_types.h index deaadf7c67e6e..c7ae4704494ff 100644 --- a/onnxruntime/core/session/abi_ep_types.h +++ b/onnxruntime/core/session/abi_ep_types.h @@ -16,6 +16,7 @@ namespace onnxruntime { struct EpGraph; struct EpNode; +class IResourceAccountant; } // namespace onnxruntime /// @@ -50,4 +51,12 @@ struct OrtEpGraphSupportInfo { const onnxruntime::EpGraph& ort_graph; std::vector node_groupings; const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup; + + // Optional resource accountant for capacity-aware partitioning. + // Owned by the graph partitioner; lifetime exceeds this struct. + onnxruntime::IResourceAccountant* resource_accountant = nullptr; + + // Per-node costs reported by the plugin via EpGraphSupportInfo_ReportAcceptedNodeCost. + // Keyed by the OrtNode pointer (same lifetime as the EpGraph). + std::vector> accepted_node_costs; }; diff --git a/onnxruntime/core/session/plugin_ep/ep_api.cc b/onnxruntime/core/session/plugin_ep/ep_api.cc index d56f4299402b5..a997cf37a2f76 100644 --- a/onnxruntime/core/session/plugin_ep/ep_api.cc +++ b/onnxruntime/core/session/plugin_ep/ep_api.cc @@ -33,6 +33,7 @@ #include "core/session/plugin_ep/ep_control_flow_kernel_impls.h" #include "core/session/utils.h" #include "core/common/profiler_common.h" +#include "core/framework/resource_accountant.h" #include "core/session/plugin_ep/ep_event_profiling.h" using namespace onnxruntime; @@ -1198,6 +1199,142 @@ ORT_API_STATUS_IMPL(ProfilingEventsContainer_AddEvents, API_IMPL_END } +// Resource accounting for capacity-aware partitioning + +namespace { +// Convert internal ResourceCount (std::variant) to the C-safe tagged union. +OrtResourceCount ToOrtResourceCount(const onnxruntime::ResourceCount& rc) { + return std::visit([](auto&& val) -> OrtResourceCount { + using T = std::decay_t; + if constexpr (std::is_same_v) { + return OrtResourceCount::FromTotalBytes(val); + } + // Future variant members: add else-if branches here and return OrtResourceCount with appropriate kind. + }, + rc); +} + +// Convert the C-safe tagged union back to internal ResourceCount. +onnxruntime::ResourceCount FromOrtResourceCount(const OrtResourceCount& ort_rc) { + switch (ort_rc.kind) { + case OrtResourceCountKind_TotalBytes: + return onnxruntime::ResourceCount{ort_rc.value.total_bytes}; + default: + ORT_THROW("Unknown OrtResourceCountKind: ", static_cast(ort_rc.kind)); + } +} +} // namespace + +ORT_API_STATUS_IMPL(EpGraphSupportInfo_HasResourceBudget, + _In_ const OrtEpGraphSupportInfo* graph_support_info, + _Out_ bool* has_budget) { + API_IMPL_BEGIN + ORT_API_RETURN_IF(graph_support_info == nullptr, ORT_INVALID_ARGUMENT, + "OrtEpGraphSupportInfo instance is NULL"); + ORT_API_RETURN_IF(has_budget == nullptr, ORT_INVALID_ARGUMENT, "has_budget must not be NULL"); + *has_budget = (graph_support_info->resource_accountant != nullptr); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(EpGraphSupportInfo_GetResourceBudget, + _In_ const OrtEpGraphSupportInfo* graph_support_info, + _Out_ OrtResourceCount* budget) { + API_IMPL_BEGIN + ORT_API_RETURN_IF(graph_support_info == nullptr, ORT_INVALID_ARGUMENT, + "OrtEpGraphSupportInfo instance is NULL"); + ORT_API_RETURN_IF(budget == nullptr, ORT_INVALID_ARGUMENT, "budget must not be NULL"); + auto* accountant = graph_support_info->resource_accountant; + ORT_API_RETURN_IF(accountant == nullptr, ORT_INVALID_ARGUMENT, "No resource accountant is active"); + + auto threshold = accountant->GetThreshold(); + if (threshold) { + *budget = ToOrtResourceCount(*threshold); + } else { + *budget = OrtResourceCount::FromTotalBytes(std::numeric_limits::max()); + } + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(EpGraphSupportInfo_GetConsumedResources, + _In_ const OrtEpGraphSupportInfo* graph_support_info, + _Out_ OrtResourceCount* consumed) { + API_IMPL_BEGIN + ORT_API_RETURN_IF(graph_support_info == nullptr, ORT_INVALID_ARGUMENT, + "OrtEpGraphSupportInfo instance is NULL"); + ORT_API_RETURN_IF(consumed == nullptr, ORT_INVALID_ARGUMENT, "consumed must not be NULL"); + auto* accountant = graph_support_info->resource_accountant; + ORT_API_RETURN_IF(accountant == nullptr, ORT_INVALID_ARGUMENT, "No resource accountant is active"); + + *consumed = ToOrtResourceCount(accountant->GetConsumedAmount()); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(EpGraphSupportInfo_ComputeNodeResourceCost, + _In_ OrtEpGraphSupportInfo* graph_support_info, + _In_ const OrtNode* node, _Out_ OrtResourceCount* cost) { + API_IMPL_BEGIN + ORT_API_RETURN_IF(graph_support_info == nullptr, ORT_INVALID_ARGUMENT, + "OrtEpGraphSupportInfo instance is NULL"); + ORT_API_RETURN_IF(node == nullptr, ORT_INVALID_ARGUMENT, "OrtNode is NULL"); + ORT_API_RETURN_IF(cost == nullptr, ORT_INVALID_ARGUMENT, "cost must not be NULL"); + auto* accountant = graph_support_info->resource_accountant; + ORT_API_RETURN_IF(accountant == nullptr, ORT_INVALID_ARGUMENT, "No resource accountant is active"); + + const auto* ep_node = onnxruntime::EpNode::ToInternal(node); + ORT_API_RETURN_IF(ep_node == nullptr, ORT_INVALID_ARGUMENT, "Invalid OrtNode variant"); + const onnxruntime::Node& internal_node = ep_node->GetInternalNode(); + + *cost = ToOrtResourceCount(accountant->ComputeResourceCount(internal_node)); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(EpGraphSupportInfo_ReportAcceptedNodeCost, + _In_ OrtEpGraphSupportInfo* graph_support_info, + _In_ const OrtNode* node, _In_ OrtResourceCount cost) { + API_IMPL_BEGIN + ORT_API_RETURN_IF(graph_support_info == nullptr, ORT_INVALID_ARGUMENT, + "OrtEpGraphSupportInfo instance is NULL"); + ORT_API_RETURN_IF(node == nullptr, ORT_INVALID_ARGUMENT, "OrtNode is NULL"); + auto* accountant = graph_support_info->resource_accountant; + ORT_API_RETURN_IF(accountant == nullptr, ORT_INVALID_ARGUMENT, "No resource accountant is active"); + + graph_support_info->accepted_node_costs.emplace_back(node, cost); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(EpGraphSupportInfo_IsStopIssued, + _In_ const OrtEpGraphSupportInfo* graph_support_info, + _Out_ bool* is_stopped) { + API_IMPL_BEGIN + ORT_API_RETURN_IF(graph_support_info == nullptr, ORT_INVALID_ARGUMENT, + "OrtEpGraphSupportInfo instance is NULL"); + ORT_API_RETURN_IF(is_stopped == nullptr, ORT_INVALID_ARGUMENT, "is_stopped must not be NULL"); + auto* accountant = graph_support_info->resource_accountant; + ORT_API_RETURN_IF(accountant == nullptr, ORT_INVALID_ARGUMENT, "No resource accountant is active"); + + *is_stopped = accountant->IsStopIssued(); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(EpGraphSupportInfo_SignalStopAssignment, + _In_ OrtEpGraphSupportInfo* graph_support_info) { + API_IMPL_BEGIN + ORT_API_RETURN_IF(graph_support_info == nullptr, ORT_INVALID_ARGUMENT, + "OrtEpGraphSupportInfo instance is NULL"); + auto* accountant = graph_support_info->resource_accountant; + ORT_API_RETURN_IF(accountant == nullptr, ORT_INVALID_ARGUMENT, "No resource accountant is active"); + + accountant->SetStopAssignment(); + return nullptr; + API_IMPL_END +} + static constexpr OrtEpApi ort_ep_api = { // NOTE: ABI compatibility depends on the order within this struct so all additions must be at the end, // and no functions can be removed (the implementation needs to change to return an error). @@ -1287,6 +1424,15 @@ static constexpr OrtEpApi ort_ep_api = { &OrtExecutionProviderApi::ProfilingEvent_GetArgValue, &OrtExecutionProviderApi::ProfilingEventsContainer_AddEvents, // End of Version 25 - DO NOT MODIFY ABOVE + + &OrtExecutionProviderApi::EpGraphSupportInfo_HasResourceBudget, + &OrtExecutionProviderApi::EpGraphSupportInfo_GetResourceBudget, + &OrtExecutionProviderApi::EpGraphSupportInfo_GetConsumedResources, + &OrtExecutionProviderApi::EpGraphSupportInfo_ComputeNodeResourceCost, + &OrtExecutionProviderApi::EpGraphSupportInfo_ReportAcceptedNodeCost, + &OrtExecutionProviderApi::EpGraphSupportInfo_IsStopIssued, + &OrtExecutionProviderApi::EpGraphSupportInfo_SignalStopAssignment, + // End of Version 26 - DO NOT MODIFY ABOVE }; // checks that we don't violate the rule that the functions must remain in the slots they were originally assigned @@ -1298,6 +1444,8 @@ static_assert(offsetof(OrtEpApi, GetEnvConfigEntries) / sizeof(void*) == 49, "Size of version 24 API cannot change"); static_assert(offsetof(OrtEpApi, ProfilingEventsContainer_AddEvents) / sizeof(void*) == 72, "Size of version 25 API cannot change"); +static_assert(offsetof(OrtEpApi, EpGraphSupportInfo_SignalStopAssignment) / sizeof(void*) == 79, + "Size of version 26 API cannot change"); } // namespace OrtExecutionProviderApi diff --git a/onnxruntime/core/session/plugin_ep/ep_api.h b/onnxruntime/core/session/plugin_ep/ep_api.h index 4608318f388ee..9e2c3affa333f 100644 --- a/onnxruntime/core/session/plugin_ep/ep_api.h +++ b/onnxruntime/core/session/plugin_ep/ep_api.h @@ -178,4 +178,19 @@ ORT_API_STATUS_IMPL(ProfilingEvent_GetDurationUs, _In_ const OrtProfilingEvent* _Out_ int64_t* out); ORT_API_STATUS_IMPL(ProfilingEvent_GetArgValue, _In_ const OrtProfilingEvent* event, _In_ const char* key, _Outptr_result_maybenull_ const char** out); + +// Resource accounting for capacity-aware partitioning +ORT_API_STATUS_IMPL(EpGraphSupportInfo_HasResourceBudget, _In_ const OrtEpGraphSupportInfo* graph_support_info, + _Out_ bool* has_budget); +ORT_API_STATUS_IMPL(EpGraphSupportInfo_GetResourceBudget, _In_ const OrtEpGraphSupportInfo* graph_support_info, + _Out_ OrtResourceCount* budget); +ORT_API_STATUS_IMPL(EpGraphSupportInfo_GetConsumedResources, _In_ const OrtEpGraphSupportInfo* graph_support_info, + _Out_ OrtResourceCount* consumed); +ORT_API_STATUS_IMPL(EpGraphSupportInfo_ComputeNodeResourceCost, _In_ OrtEpGraphSupportInfo* graph_support_info, + _In_ const OrtNode* node, _Out_ OrtResourceCount* cost); +ORT_API_STATUS_IMPL(EpGraphSupportInfo_ReportAcceptedNodeCost, _In_ OrtEpGraphSupportInfo* graph_support_info, + _In_ const OrtNode* node, _In_ OrtResourceCount cost); +ORT_API_STATUS_IMPL(EpGraphSupportInfo_IsStopIssued, _In_ const OrtEpGraphSupportInfo* graph_support_info, + _Out_ bool* is_stopped); +ORT_API_STATUS_IMPL(EpGraphSupportInfo_SignalStopAssignment, _In_ OrtEpGraphSupportInfo* graph_support_info); } // namespace OrtExecutionProviderApi diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc index 8a082a5392d6c..acdd2849f5285 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc @@ -13,6 +13,8 @@ #include "core/framework/error_code_helper.h" #include "core/framework/plugin_data_transfer.h" #include "core/framework/plugin_ep_stream.h" +#include "core/framework/resource_accountant.h" +#include "core/common/inlined_containers.h" #include "core/graph/ep_api_types.h" #include "core/graph/model_editor_api_types.h" #include "core/session/abi_devices.h" @@ -225,7 +227,6 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie const GraphOptimizerRegistry& graph_optimizer_registry, IResourceAccountant* resource_accountant) const { ORT_UNUSED_PARAMETER(graph_optimizer_registry); // TODO: Add support - ORT_UNUSED_PARAMETER(resource_accountant); // TODO: Add support? Not used by prioritized EPs const logging::Logger& logger = GetEpLoggerOrDefault(); @@ -236,6 +237,7 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie } OrtEpGraphSupportInfo api_graph_support_info(*ep_graph, kernel_lookup); + api_graph_support_info.resource_accountant = resource_accountant; Status status = ToStatusAndRelease(ort_ep_->GetCapability(ort_ep_.get(), ep_graph->ToExternal(), &api_graph_support_info)); if (!status.IsOK()) { @@ -249,6 +251,18 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie return {}; } + // Build a mapping from OrtNode* to accepted cost for resource accounting. + // The plugin reports accepted nodes and their costs via EpGraphSupportInfo_ReportAcceptedNodeCost. + // Costs are OrtResourceCount tagged unions that are converted back to internal ResourceCount + // (std::variant) when attaching to IndexedSubGraph. + InlinedHashMap node_cost_map; + if (resource_accountant != nullptr && !api_graph_support_info.accepted_node_costs.empty()) { + node_cost_map.reserve(api_graph_support_info.accepted_node_costs.size()); + for (const auto& [ort_node, cost] : api_graph_support_info.accepted_node_costs) { + node_cost_map[ort_node] = cost; + } + } + // Create ComputeCapability instances from OrtEpGraphSupportInfo::NodeGrouping instances. for (const OrtEpGraphSupportInfo::NodeGrouping& node_grouping : api_graph_support_info.node_groupings) { // Skip this node grouping if any node has already been assigned to another EP. @@ -273,7 +287,29 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie auto indexed_sub_graph = std::make_unique(); - indexed_sub_graph->nodes.push_back(node_grouping.nodes[0]->GetInternalNode().Index()); + const NodeIndex node_index = node_grouping.nodes[0]->GetInternalNode().Index(); + indexed_sub_graph->nodes.push_back(node_index); + + // Attach resource accounting if the plugin reported a cost for this node. + if (resource_accountant != nullptr) { + const OrtNode* ort_node_key = static_cast(node_grouping.nodes[0]); + auto cost_it = node_cost_map.find(ort_node_key); + if (cost_it != node_cost_map.end()) { + indexed_sub_graph->SetAccountant(resource_accountant); + // Convert OrtResourceCount tagged union back to internal ResourceCount (std::variant). + const OrtResourceCount& ort_cost = cost_it->second; + switch (ort_cost.kind) { + case OrtResourceCountKind_TotalBytes: + indexed_sub_graph->AppendNodeCost(ResourceCount{ort_cost.value.total_bytes}); + break; + default: + LOGS(logger, WARNING) << "Unknown OrtResourceCountKind: " + << static_cast(ort_cost.kind) << "; skipping cost."; + break; + } + } + } + result.push_back(std::make_unique(std::move(indexed_sub_graph))); } else if (node_grouping.kind == OrtEpGraphSupportInfo::NodeGroupingKind::kFusedNode) { if (node_grouping.nodes.empty()) { diff --git a/onnxruntime/test/framework/resource_accountant_test.cc b/onnxruntime/test/framework/resource_accountant_test.cc index a102fe4e7770b..b07b42d10ce1c 100644 --- a/onnxruntime/test/framework/resource_accountant_test.cc +++ b/onnxruntime/test/framework/resource_accountant_test.cc @@ -2,103 +2,47 @@ // Licensed under the MIT License. #include "core/framework/resource_accountant.h" +#include "core/framework/config_options.h" #include "core/graph/indexed_sub_graph.h" #include "core/graph/constants.h" #include "core/graph/model.h" +#include "core/session/onnxruntime_session_options_config_keys.h" #include "gtest/gtest.h" #include "test/util/include/asserts.h" #include "test/util/include/test_environment.h" +#include +#include + namespace onnxruntime { namespace test { -// Test accountant mimicking SizeBasedStatsAccountant ad-hoc path: -// Uses pending/committed weight sets so that: -// - Within a GetCapability pass, shared weights are deduped -// - Across passes, only committed weights persist and pending are discarded -class TestDedupAccountant : public IResourceAccountant { - public: - TestDedupAccountant() = default; - - ResourceCount GetConsumedAmount() const override { - return consumed_; - } - - void AddConsumedAmount(const ResourceCount& amount) noexcept override { - if (std::holds_alternative(amount)) { - consumed_ += std::get(amount); - } - } +namespace { - void RemoveConsumedAmount(const ResourceCount& amount) noexcept override { - if (std::holds_alternative(amount)) { - consumed_ -= std::get(amount); - } - } - - ResourceCount ComputeResourceCount(const Node& node) override { - const auto* graph = node.GetContainingGraph(); - if (graph == nullptr) { - return static_cast(0); - } - - size_t total = 0; - for (const auto* input_def : node.InputDefs()) { - if (!input_def->Exists()) { - continue; - } - const auto& name = input_def->Name(); - constexpr bool check_outer_scope = true; - const auto* init = graph->GetInitializer(name, check_outer_scope); - if (init != nullptr) { - if (committed_weights_.count(name) > 0) { - continue; - } - if (pending_weights_.count(name) > 0) { - continue; - } - auto it = weight_sizes_.find(name); - if (it != weight_sizes_.end()) { - total += it->second; - } - pending_weights_.insert(name); - pending_weights_by_node_[node.Index()].insert(name); - } - } - return total; - } - - void ResetPendingWeights() override { - pending_weights_.clear(); - pending_weights_by_node_.clear(); - } - - void CommitWeightsForNode(NodeIndex node_index) override { - auto it = pending_weights_by_node_.find(node_index); - if (it != pending_weights_by_node_.end()) { - for (const auto& name : it->second) { - pending_weights_.erase(name); - } - committed_weights_.insert(it->second.begin(), it->second.end()); - pending_weights_by_node_.erase(it); - } - } - - void RegisterWeight(const std::string& name, size_t size) { - weight_sizes_[name] = size; - } +// Helper to extract size_t from ResourceCount variant. +size_t GetSizeT(const ResourceCount& rc) { + return std::get(rc); +} - size_t GetConsumedSizeT() const { return consumed_; } +// Helper to create a real SizeBasedStatsAccountant in ad-hoc mode (no stats file) via factory. +IResourceAccountant* CreateAdHocAccountant( + size_t limit_kb, + const std::filesystem::path& model_path, + std::optional& acc_map) { + ConfigOptions config; + std::string setting = std::to_string(limit_kb) + ","; + ORT_THROW_IF_ERROR(config.AddConfigEntry( + kOrtSessionOptionsResourceCudaPartitioningSettings, setting.c_str())); + ORT_THROW_IF_ERROR(CreateAccountants(config, model_path, acc_map)); + ORT_ENFORCE(acc_map.has_value()); + auto it = acc_map->find(kCudaExecutionProvider); + ORT_ENFORCE(it != acc_map->end()); + return it->second.get(); +} - private: - size_t consumed_ = 0; - InlinedHashSet committed_weights_; - InlinedHashSet pending_weights_; - InlinedHashMap> pending_weights_by_node_; - InlinedHashMap weight_sizes_; -}; +} // namespace // Two Add nodes that share a single initializer weight_W. struct SharedWeightGraph { @@ -147,31 +91,36 @@ struct SharedWeightGraph { } }; -// Regression: AccountForAllNodes sums pre-stored per-node costs +// Ad-hoc path expected costs for SharedWeightGraph: +// weight_W = 250 floats = 1000 bytes, each output = 250 floats = 1000 bytes +// node_A: (1000 init + 1000 out) * 1.5 = 3000 +// node_B: (0 deduped + 1000 out) * 1.5 = 1500 + +// AccountForAllNodes sums pre-stored per-node costs // that already have correct within-pass weight deduplication. TEST(ResourceAccountantTest, AccountForAllNodes_CorrectlyUsesPreStoredCosts) { auto h = SharedWeightGraph::Create(); - TestDedupAccountant accountant; - accountant.RegisterWeight("weight_W", 1000); + std::optional acc_map; + auto* accountant = CreateAdHocAccountant(/*limit_kb=*/100, PathString(), acc_map); IndexedSubGraph sub_graph; sub_graph.nodes.push_back(h.node_a->Index()); sub_graph.nodes.push_back(h.node_b->Index()); - sub_graph.SetAccountant(&accountant); + sub_graph.SetAccountant(accountant); - auto cost_a = accountant.ComputeResourceCount(*h.node_a); + auto cost_a = accountant->ComputeResourceCount(*h.node_a); sub_graph.AppendNodeCost(cost_a); - EXPECT_EQ(std::get(cost_a), size_t{1000}); + EXPECT_EQ(GetSizeT(cost_a), size_t{3000}); - auto cost_b = accountant.ComputeResourceCount(*h.node_b); + auto cost_b = accountant->ComputeResourceCount(*h.node_b); sub_graph.AppendNodeCost(cost_b); - EXPECT_EQ(std::get(cost_b), size_t{0}); + EXPECT_EQ(GetSizeT(cost_b), size_t{1500}); ASSERT_TRUE(sub_graph.IsAccountingEnabled()); sub_graph.AccountForAllNodes(); - EXPECT_EQ(accountant.GetConsumedSizeT(), size_t{1000}) - << "AccountForAllNodes should sum pre-stored costs (1000 + 0)"; + EXPECT_EQ(GetSizeT(accountant->GetConsumedAmount()), size_t{4500}) + << "AccountForAllNodes should sum pre-stored costs (3000 + 1500)"; } // Verifies that ResetPendingWeights + re-probe produces correct results. @@ -179,30 +128,33 @@ TEST(ResourceAccountantTest, AccountForAllNodes_CorrectlyUsesPreStoredCosts) { // re-probing should see the full weight cost again since nothing was committed. TEST(ResourceAccountantTest, ComputeAndAccountForNode_CorrectAfterReset) { auto h = SharedWeightGraph::Create(); - TestDedupAccountant accountant; - accountant.RegisterWeight("weight_W", 1000); + std::optional acc_map; + auto* accountant = CreateAdHocAccountant(/*limit_kb=*/100, PathString(), acc_map); // Probing pass populates pending weights - auto cost_a = accountant.ComputeResourceCount(*h.node_a); - EXPECT_EQ(std::get(cost_a), size_t{1000}); - auto cost_b = accountant.ComputeResourceCount(*h.node_b); - EXPECT_EQ(std::get(cost_b), size_t{0}); + auto cost_a = accountant->ComputeResourceCount(*h.node_a); + EXPECT_EQ(GetSizeT(cost_a), size_t{3000}); + auto cost_b = accountant->ComputeResourceCount(*h.node_b); + EXPECT_EQ(GetSizeT(cost_b), size_t{1500}); // Discard the pass (simulating capabilities.clear() before second GetCapability) - accountant.ResetPendingWeights(); + accountant->ResetPendingWeights(); // Re-probe: weight_W was never committed, so it should be counted again IndexedSubGraph sub_graph; sub_graph.nodes.push_back(h.node_a->Index()); - sub_graph.SetAccountant(&accountant); - auto recomputed_cost = accountant.ComputeResourceCount(*h.node_a); + sub_graph.SetAccountant(accountant); + auto recomputed_cost = accountant->ComputeResourceCount(*h.node_a); sub_graph.AccountForNode(h.node_a->Index(), recomputed_cost); - EXPECT_EQ(accountant.GetConsumedSizeT(), size_t{1000}) + EXPECT_EQ(GetSizeT(accountant->GetConsumedAmount()), size_t{3000}) << "After ResetPendingWeights, re-probe should see full weight cost"; } // Each node has a unique initializer. AccountForAllNodes sums both. +// weight_1 = 100 floats = 400 bytes, weight_2 = 100 floats = 400 bytes, outputs = 400 bytes each +// node1: (400 init + 400 out) * 1.5 = 1200 +// node2: (400 init + 400 out) * 1.5 = 1200 TEST(ResourceAccountantTest, AccountForAllNodes_NoSharedWeights) { std::unordered_map dom; dom[kOnnxDomain] = 12; @@ -239,23 +191,22 @@ TEST(ResourceAccountantTest, AccountForAllNodes_NoSharedWeights) { auto& node2 = graph.AddNode("n2", "Add", "", {out1, w2}, {out2}); ASSERT_STATUS_OK(graph.Resolve()); - TestDedupAccountant accountant; - accountant.RegisterWeight("weight_1", 400); - accountant.RegisterWeight("weight_2", 600); + std::optional acc_map; + auto* accountant = CreateAdHocAccountant(/*limit_kb=*/100, PathString(), acc_map); IndexedSubGraph sub_graph; sub_graph.nodes.push_back(node1.Index()); sub_graph.nodes.push_back(node2.Index()); - sub_graph.SetAccountant(&accountant); + sub_graph.SetAccountant(accountant); - sub_graph.AppendNodeCost(accountant.ComputeResourceCount(node1)); - sub_graph.AppendNodeCost(accountant.ComputeResourceCount(node2)); + sub_graph.AppendNodeCost(accountant->ComputeResourceCount(node1)); + sub_graph.AppendNodeCost(accountant->ComputeResourceCount(node2)); ASSERT_TRUE(sub_graph.IsAccountingEnabled()); sub_graph.AccountForAllNodes(); - EXPECT_EQ(accountant.GetConsumedSizeT(), size_t{1000}) - << "No shared weights: should sum all costs (400 + 600)"; + EXPECT_EQ(GetSizeT(accountant->GetConsumedAmount()), size_t{2400}) + << "No shared weights: should sum all costs (1200 + 1200)"; } // AccountForNode per-node and AccountForAllNodes bulk produce same result. @@ -263,64 +214,187 @@ TEST(ResourceAccountantTest, AccountForNode_MatchesAccountForAllNodes) { auto h = SharedWeightGraph::Create(); // Per-node path - TestDedupAccountant acc1; - acc1.RegisterWeight("weight_W", 1000); + std::optional acc_map1; + auto* acc1 = CreateAdHocAccountant(/*limit_kb=*/100, PathString(), acc_map1); IndexedSubGraph sub1; sub1.nodes.push_back(h.node_a->Index()); sub1.nodes.push_back(h.node_b->Index()); - sub1.SetAccountant(&acc1); - sub1.AppendNodeCost(acc1.ComputeResourceCount(*h.node_a)); - sub1.AppendNodeCost(acc1.ComputeResourceCount(*h.node_b)); + sub1.SetAccountant(acc1); + sub1.AppendNodeCost(acc1->ComputeResourceCount(*h.node_a)); + sub1.AppendNodeCost(acc1->ComputeResourceCount(*h.node_b)); sub1.AccountForNode(0); sub1.AccountForNode(1); - size_t per_node = acc1.GetConsumedSizeT(); + size_t per_node = GetSizeT(acc1->GetConsumedAmount()); // Bulk path - TestDedupAccountant acc2; - acc2.RegisterWeight("weight_W", 1000); + std::optional acc_map2; + auto* acc2 = CreateAdHocAccountant(/*limit_kb=*/100, PathString(), acc_map2); IndexedSubGraph sub2; sub2.nodes.push_back(h.node_a->Index()); sub2.nodes.push_back(h.node_b->Index()); - sub2.SetAccountant(&acc2); - sub2.AppendNodeCost(acc2.ComputeResourceCount(*h.node_a)); - sub2.AppendNodeCost(acc2.ComputeResourceCount(*h.node_b)); + sub2.SetAccountant(acc2); + sub2.AppendNodeCost(acc2->ComputeResourceCount(*h.node_a)); + sub2.AppendNodeCost(acc2->ComputeResourceCount(*h.node_b)); sub2.AccountForAllNodes(); - size_t bulk = acc2.GetConsumedSizeT(); + size_t bulk = GetSizeT(acc2->GetConsumedAmount()); EXPECT_EQ(per_node, bulk) << "Per-node and bulk should produce identical results"; - EXPECT_EQ(per_node, size_t{1000}); + EXPECT_EQ(per_node, size_t{4500}); } // Cross-subgraph dedup: EP1 commits node_A, EP2 probes node_B and // correctly sees weight_W as already accounted. +// node_A cost: 3000, node_B cost after commit: (0 + 1000) * 1.5 = 1500 TEST(ResourceAccountantTest, CrossSubGraph_DedupWorks) { auto h = SharedWeightGraph::Create(); - TestDedupAccountant accountant; - accountant.RegisterWeight("weight_W", 1000); + std::optional acc_map; + auto* accountant = CreateAdHocAccountant(/*limit_kb=*/100, PathString(), acc_map); // EP1 probes and commits node_A IndexedSubGraph sub1; sub1.nodes.push_back(h.node_a->Index()); - sub1.SetAccountant(&accountant); - sub1.AppendNodeCost(accountant.ComputeResourceCount(*h.node_a)); + sub1.SetAccountant(accountant); + sub1.AppendNodeCost(accountant->ComputeResourceCount(*h.node_a)); sub1.AccountForNode(0); - EXPECT_EQ(accountant.GetConsumedSizeT(), size_t{1000}); + accountant->CommitWeightsForNode(h.node_a->Index()); + EXPECT_EQ(GetSizeT(accountant->GetConsumedAmount()), size_t{3000}); - // EP2 probes node_B: weight_W already committed - auto cost_b = accountant.ComputeResourceCount(*h.node_b); - EXPECT_EQ(std::get(cost_b), size_t{0}) - << "weight_W was committed by EP1, should be deduped for EP2"; + // Reset pending to simulate new GetCapability pass + accountant->ResetPendingWeights(); - // EP2 commits node_B with cost 0 + // EP2 probes node_B: weight_W already committed, only output counted + auto cost_b = accountant->ComputeResourceCount(*h.node_b); + EXPECT_EQ(GetSizeT(cost_b), size_t{1500}) + << "weight_W was committed by EP1, only output (1000 * 1.5) counted"; + + // EP2 commits node_B IndexedSubGraph sub2; sub2.nodes.push_back(h.node_b->Index()); - sub2.SetAccountant(&accountant); + sub2.SetAccountant(accountant); sub2.AppendNodeCost(cost_b); sub2.AccountForNode(0); - EXPECT_EQ(accountant.GetConsumedSizeT(), size_t{1000}) - << "Total should still be 1000 - weight_W counted once across both"; + EXPECT_EQ(GetSizeT(accountant->GetConsumedAmount()), size_t{4500}) + << "Total should be 3000 + 1500 - weight_W initializer counted once"; +} + +// --------------------------------------------------------------------------- +// Stats-based path and factory tests +// --------------------------------------------------------------------------- + +// Stats-based path: cost is sum of all NodeAllocationStats fields. +TEST(RealAccountantTest, StatsPath_ComputesCostFromStatsFile) { + auto h = SharedWeightGraph::Create(); + + // Write a stats file with known costs + std::error_code ec; + auto stats_dir = std::filesystem::temp_directory_path(ec); + ASSERT_FALSE(ec) << ec.message(); + auto stats_path = stats_dir / "test_resource_accountant_stats.csv"; + + // Get the unique node names the accountant will look up + std::string name_a = IResourceAccountant::MakeUniqueNodeName(*h.node_a); + std::string name_b = IResourceAccountant::MakeUniqueNodeName(*h.node_b); + + { + std::ofstream ofs(stats_path); + ASSERT_TRUE(ofs.is_open()); + ofs << "#name,input_sizes,initializers_sizes,total_dynamic_sizes,total_temp_allocations\n"; + // input_sizes=100, initializers=200, dynamic=300, temp=400 -> total=1000 + ofs << name_a << ",100,200,300,400\n"; + // input_sizes=50, initializers=0, dynamic=150, temp=0 -> total=200 + ofs << name_b << ",50,0,150,0\n"; + } + + // Factory expects stats file relative to model_path dir + ConfigOptions config; + std::string setting = "500," + stats_path.filename().string(); + ASSERT_STATUS_OK(config.AddConfigEntry( + kOrtSessionOptionsResourceCudaPartitioningSettings, setting.c_str())); + + std::optional acc_map; + ASSERT_STATUS_OK(CreateAccountants(config, stats_dir / "dummy_model.onnx", acc_map)); + ASSERT_TRUE(acc_map.has_value()); + auto* accountant = acc_map->at(kCudaExecutionProvider).get(); + + auto cost_a = accountant->ComputeResourceCount(*h.node_a); + EXPECT_EQ(std::get(cost_a), size_t{1000}); + + auto cost_b = accountant->ComputeResourceCount(*h.node_b); + EXPECT_EQ(std::get(cost_b), size_t{200}); + + // Threshold should be 500 KB = 512000 bytes + auto threshold = accountant->GetThreshold(); + ASSERT_TRUE(threshold.has_value()); + EXPECT_EQ(std::get(*threshold), size_t{500 * 1024}); + + std::error_code remove_ec; + std::filesystem::remove(stats_path, remove_ec); +} + +// Stats-based path returns 0 for unknown nodes. +TEST(RealAccountantTest, StatsPath_UnknownNodeReturnsZero) { + auto h = SharedWeightGraph::Create(); + + std::error_code ec; + auto stats_dir = std::filesystem::temp_directory_path(ec); + ASSERT_FALSE(ec) << ec.message(); + auto stats_path = stats_dir / "test_resource_accountant_empty_stats.csv"; + + { + std::ofstream ofs(stats_path); + ASSERT_TRUE(ofs.is_open()); + ofs << "#name,input_sizes,initializers_sizes,total_dynamic_sizes,total_temp_allocations\n"; + // No entries for our nodes + } + + ConfigOptions config; + std::string setting = "1000," + stats_path.filename().string(); + ASSERT_STATUS_OK(config.AddConfigEntry( + kOrtSessionOptionsResourceCudaPartitioningSettings, setting.c_str())); + + std::optional acc_map; + ASSERT_STATUS_OK(CreateAccountants(config, stats_dir / "dummy_model.onnx", acc_map)); + auto* accountant = acc_map->at(kCudaExecutionProvider).get(); + + auto cost = accountant->ComputeResourceCount(*h.node_a); + EXPECT_EQ(std::get(cost), size_t{0}); + + std::error_code remove_ec; + std::filesystem::remove(stats_path, remove_ec); +} + +// Factory with no limit and no stats file creates accountant with no threshold. +TEST(RealAccountantTest, Factory_NoLimitNoStats) { + ConfigOptions config; + ASSERT_STATUS_OK(config.AddConfigEntry( + kOrtSessionOptionsResourceCudaPartitioningSettings, ",")); + + std::optional acc_map; + ASSERT_STATUS_OK(CreateAccountants(config, PathString(), acc_map)); + ASSERT_TRUE(acc_map.has_value()); + auto* accountant = acc_map->at(kCudaExecutionProvider).get(); + EXPECT_FALSE(accountant->GetThreshold().has_value()); +} + +// Factory returns empty optional when no config is set. +TEST(RealAccountantTest, Factory_NoConfigReturnsEmpty) { + ConfigOptions config; + std::optional acc_map; + ASSERT_STATUS_OK(CreateAccountants(config, PathString(), acc_map)); + EXPECT_FALSE(acc_map.has_value()); +} + +// Factory rejects malformed config (missing comma). +TEST(RealAccountantTest, Factory_MalformedConfigReturnsError) { + ConfigOptions config; + ASSERT_STATUS_OK(config.AddConfigEntry( + kOrtSessionOptionsResourceCudaPartitioningSettings, "1000")); // missing comma + + std::optional acc_map; + auto status = CreateAccountants(config, PathString(), acc_map); + EXPECT_FALSE(status.IsOK()); } } // namespace test diff --git a/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc b/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc new file mode 100644 index 0000000000000..3eb7c50f9c8b2 --- /dev/null +++ b/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc @@ -0,0 +1,373 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Integration tests for resource-constrained partitioning through the CUDA plugin EP. +// +// Two test levels: +// 1. OrtResourceCount struct tests — validate the C-safe tagged union. +// 2. Partitioning verification tests — use InferenceSessionWrapper to inspect +// per-node EP assignments after partitioning through the plugin EP. + +#if defined(ORT_UNIT_TEST_HAS_CUDA_PLUGIN_EP) + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "core/graph/model.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/inference_session.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "core/session/ort_env.h" +#include "core/session/utils.h" +#include "test/util/include/asserts.h" +#include "test/util/include/file_util.h" +#include "test/util/include/inference_session_wrapper.h" + +extern std::unique_ptr ort_env; + +namespace onnxruntime { +namespace test { +namespace { + +constexpr const char* kResourcePartitioningRegistrationName = "CudaPluginResourceTest"; + +// Resolve the CUDA plugin EP shared library path. +std::filesystem::path GetCudaPluginLibraryPath() { + return GetSharedLibraryFileName(ORT_TSTR("onnxruntime_providers_cuda_plugin")); +} + +// RAII handle that registers/unregisters the CUDA plugin EP library. +class ScopedCudaPluginRegistration { + public: + ScopedCudaPluginRegistration(Ort::Env& env, const char* registration_name) + : env_(env), name_(registration_name) { + auto lib_path = GetCudaPluginLibraryPath(); + if (!std::filesystem::exists(lib_path)) { + available_ = false; + return; + } + env_.RegisterExecutionProviderLibrary(name_.c_str(), lib_path.c_str()); + available_ = true; + } + + ~ScopedCudaPluginRegistration() { + if (available_) { + try { + env_.UnregisterExecutionProviderLibrary(name_.c_str()); + } catch (...) { + } + } + } + + bool IsAvailable() const { return available_; } + + ScopedCudaPluginRegistration(const ScopedCudaPluginRegistration&) = delete; + ScopedCudaPluginRegistration& operator=(const ScopedCudaPluginRegistration&) = delete; + + private: + Ort::Env& env_; + std::string name_; + bool available_ = false; +}; + +// Find the CUDA plugin EP device after registration. +Ort::ConstEpDevice FindCudaPluginDevice(Ort::Env& env) { + auto ep_devices = env.GetEpDevices(); + for (const auto& device : ep_devices) { + if (strcmp(device.EpName(), "CudaPluginExecutionProvider") == 0) { + return device; + } + } + return Ort::ConstEpDevice{nullptr}; +} + +// Get the internal OrtEnv* from the C++ Ort::Env wrapper. +// Ort::Env inherits Base which has operator OrtEnv*(). +OrtEnv& GetOrtEnv() { + OrtEnv* p = static_cast(*ort_env); + ORT_ENFORCE(p != nullptr); + return *p; +} + +} // namespace + +// --------------------------------------------------------------------------- +// OrtResourceCount struct tests +// --------------------------------------------------------------------------- + +TEST(OrtResourceCountTest, None_HasKindNone) { + OrtResourceCount rc = OrtResourceCount::None(); + EXPECT_EQ(rc.kind, OrtResourceCountKind_None); +} + +TEST(OrtResourceCountTest, FromTotalBytes_RoundTrips) { + constexpr size_t kTestValue = 42 * 1024 * 1024; // 42 MB + OrtResourceCount rc = OrtResourceCount::FromTotalBytes(kTestValue); + EXPECT_EQ(rc.kind, OrtResourceCountKind_TotalBytes); + EXPECT_EQ(rc.AsTotalBytes(), kTestValue); +} + +TEST(OrtResourceCountTest, FromTotalBytes_MaxValue) { + OrtResourceCount rc = OrtResourceCount::FromTotalBytes(std::numeric_limits::max()); + EXPECT_EQ(rc.kind, OrtResourceCountKind_TotalBytes); + EXPECT_EQ(rc.AsTotalBytes(), std::numeric_limits::max()); +} + +TEST(OrtResourceCountTest, FromTotalBytes_Zero) { + OrtResourceCount rc = OrtResourceCount::FromTotalBytes(0); + EXPECT_EQ(rc.kind, OrtResourceCountKind_TotalBytes); + EXPECT_EQ(rc.AsTotalBytes(), size_t{0}); +} + +TEST(OrtResourceCountTest, CopySemantics) { + OrtResourceCount original = OrtResourceCount::FromTotalBytes(12345); + OrtResourceCount copy = original; + EXPECT_EQ(copy.kind, OrtResourceCountKind_TotalBytes); + EXPECT_EQ(copy.AsTotalBytes(), size_t{12345}); + copy.value.total_bytes = 99999; + EXPECT_EQ(original.AsTotalBytes(), size_t{12345}); +} + +TEST(OrtResourceCountTest, ReservedFieldIsZero) { + OrtResourceCount rc = OrtResourceCount::FromTotalBytes(100); + EXPECT_EQ(rc.reserved_, uint32_t{0}); +} + +// --------------------------------------------------------------------------- +// Lower-level partitioning tests that verify per-node EP assignments +// --------------------------------------------------------------------------- + +class CudaPluginPartitioningTest : public ::testing::Test { + protected: + void SetUp() override { + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "No CUDA device available."; + } + + registration_ = std::make_unique( + *ort_env, kResourcePartitioningRegistrationName); + if (!registration_->IsAvailable()) { + GTEST_SKIP() << "CUDA plugin EP library not found."; + } + + cuda_device_ = FindCudaPluginDevice(*ort_env); + if (!cuda_device_) { + GTEST_SKIP() << "No CUDA plugin EP device found after registration."; + } + } + + void TearDown() override { + registration_.reset(); + cudaDeviceSynchronize(); + } + + // Load a model through the CUDA plugin EP with the given resource budget, + // then call the verifier to inspect graph node assignments. + // + // Uses InferenceSessionWrapper + OrtSessionOptions + InitializeSession + // so that the plugin factory creates the EP and partitioning runs normally, + // but we can access the graph via wrapper.GetGraph(). + void LoadAndVerifyPartitioning(const ORTCHAR_T* model_path, + size_t budget_kb, + const std::function& verifier) { + OrtSessionOptions ort_options; + + // Create the plugin EP factory from the registered device. + const OrtEpDevice* device_ptr = static_cast(cuda_device_); + auto ep_devices_span = gsl::make_span(&device_ptr, 1); + + std::unique_ptr factory; + ASSERT_STATUS_OK(CreateIExecutionProviderFactoryForEpDevices( + GetOrtEnv().GetEnvironment(), ep_devices_span, factory)); + + ort_options.provider_factories.push_back(std::move(factory)); + + // Set resource partitioning budget if requested. + if (budget_kb > 0) { + std::string config_value = std::to_string(budget_kb) + ","; + ASSERT_STATUS_OK(ort_options.value.config_options.AddConfigEntry( + "session.resource_cuda_partitioning_settings", config_value.c_str())); + } + + // Create the session wrapper — gives us access to the graph after partitioning. + InferenceSessionWrapper session(ort_options.value, GetOrtEnv().GetEnvironment()); + ASSERT_STATUS_OK(session.Load(model_path)); + + // InitializeSession iterates provider_factories, creates the plugin EP, + // registers it with the session, and calls session.Initialize() which + // runs graph partitioning (invoking plugin GetCapability). + OrtStatus* status = InitializeSession(&ort_options, session); + ASSERT_EQ(status, nullptr) << "InitializeSession failed"; + + verifier(session.GetGraph()); + } + + std::unique_ptr registration_; + Ort::ConstEpDevice cuda_device_{nullptr}; +}; + +// With no resource budget, all CUDA-supported nodes should be assigned to the plugin EP. +TEST_F(CudaPluginPartitioningTest, NoBudget_AllNodesCudaPlugin) { + constexpr const ORTCHAR_T* model_path = ORT_TSTR("testdata/mul_1.onnx"); + + LoadAndVerifyPartitioning(model_path, /*budget_kb=*/0, [](const Graph& graph) { + for (const auto& node : graph.Nodes()) { + // With no budget constraint, all nodes that the CUDA plugin supports + // should be assigned to it. The plugin EP type name may vary, so just + // verify it's NOT assigned to CPU. + EXPECT_NE(node.GetExecutionProviderType(), kCpuExecutionProvider) + << "Node " << node.Name() << " (" << node.OpType() + << ") unexpectedly assigned to CPU with no budget constraint"; + } + }); +} + +// With a very large budget, all nodes should still be on the plugin EP (same as no budget). +TEST_F(CudaPluginPartitioningTest, LargeBudget_AllNodesCudaPlugin) { + constexpr const ORTCHAR_T* model_path = ORT_TSTR("testdata/mul_1.onnx"); + + // 1 TB — effectively unlimited + LoadAndVerifyPartitioning(model_path, /*budget_kb=*/1024 * 1024 * 1024, [](const Graph& graph) { + for (const auto& node : graph.Nodes()) { + EXPECT_NE(node.GetExecutionProviderType(), kCpuExecutionProvider) + << "Node " << node.Name() << " (" << node.OpType() + << ") unexpectedly assigned to CPU with large budget"; + } + }); +} + +// With a tiny budget (1 byte), nodes should be offloaded to CPU because +// the resource accountant will run out of budget. +TEST_F(CudaPluginPartitioningTest, TinyBudget_NodesOffloadedToCpu) { + // Use a model with multiple nodes so we can see some go to CPU. + constexpr const ORTCHAR_T* model_path = ORT_TSTR("testdata/transformers/tiny_gpt2_beamsearch.onnx"); + + // 1 byte budget — ad-hoc accountant will compute non-zero cost for any + // node with initializers or known output shapes, so nodes must be offloaded. + LoadAndVerifyPartitioning(model_path, /*budget_kb=*/1, [](const Graph& graph) { + bool has_cpu_node = false; + for (const auto& node : graph.Nodes()) { + if (node.GetExecutionProviderType() == kCpuExecutionProvider) { + has_cpu_node = true; + break; + } + } + EXPECT_TRUE(has_cpu_node) + << "With a 1 KB budget, at least some nodes should be offloaded to CPU"; + }); +} + +// --------------------------------------------------------------------------- +// E2E tests (existing high-level session tests, kept for coverage) +// --------------------------------------------------------------------------- + +class CudaResourcePartitioningTest : public ::testing::Test { + protected: + void SetUp() override { + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "No CUDA device available."; + } + + registration_ = std::make_unique( + *ort_env, kResourcePartitioningRegistrationName); + if (!registration_->IsAvailable()) { + GTEST_SKIP() << "CUDA plugin EP library not found."; + } + + cuda_device_ = FindCudaPluginDevice(*ort_env); + if (!cuda_device_) { + GTEST_SKIP() << "No CUDA plugin EP device found after registration."; + } + } + + void TearDown() override { + registration_.reset(); + cudaDeviceSynchronize(); + } + + Ort::Session CreateSessionWithBudget(const ORTCHAR_T* model_path, + size_t budget_kb) { + Ort::SessionOptions so; + so.AppendExecutionProvider_V2(*ort_env, {cuda_device_}, {}); + + if (budget_kb > 0) { + std::string config_value = std::to_string(budget_kb) + ","; + so.AddConfigEntry("session.resource_cuda_partitioning_settings", + config_value.c_str()); + } + + return Ort::Session(*ort_env, model_path, so); + } + + std::unique_ptr registration_; + Ort::ConstEpDevice cuda_device_{nullptr}; +}; + +TEST_F(CudaResourcePartitioningTest, NoBudget_SessionCreatesSuccessfully) { + auto model_path = ORT_TSTR("testdata/mul_1.onnx"); + ASSERT_NO_THROW(CreateSessionWithBudget(model_path, 0)); +} + +TEST_F(CudaResourcePartitioningTest, BudgetConstrained_ProducesValidOutput) { + auto model_path = ORT_TSTR("testdata/mul_1.onnx"); + Ort::Session session = CreateSessionWithBudget(model_path, 100); + + auto input_name = session.GetInputNameAllocated(0, Ort::AllocatorWithDefaultOptions()); + auto output_name = session.GetOutputNameAllocated(0, Ort::AllocatorWithDefaultOptions()); + + auto type_info = session.GetInputTypeInfo(0); + auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); + auto shape = tensor_info.GetShape(); + for (auto& dim : shape) { + if (dim < 0) dim = 1; + } + + size_t num_elements = 1; + for (auto dim : shape) { + num_elements *= static_cast(dim); + } + + std::vector input_data(num_elements, 2.0f); + auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); + Ort::Value input_tensor = Ort::Value::CreateTensor( + memory_info, input_data.data(), input_data.size(), + shape.data(), shape.size()); + + const char* input_names[] = {input_name.get()}; + const char* output_names[] = {output_name.get()}; + + auto outputs = session.Run(Ort::RunOptions{nullptr}, + input_names, &input_tensor, 1, + output_names, 1); + + ASSERT_EQ(outputs.size(), size_t{1}); + ASSERT_TRUE(outputs[0].IsTensor()); + + auto* output_data = outputs[0].GetTensorData(); + auto output_shape = outputs[0].GetTensorTypeAndShapeInfo().GetShape(); + size_t output_count = 1; + for (auto dim : output_shape) { + output_count *= static_cast(dim); + } + for (size_t i = 0; i < output_count; ++i) { + EXPECT_FALSE(std::isnan(output_data[i])) << "NaN at index " << i; + EXPECT_FALSE(std::isinf(output_data[i])) << "Inf at index " << i; + } +} + +} // namespace test +} // namespace onnxruntime + +#endif // defined(ORT_UNIT_TEST_HAS_CUDA_PLUGIN_EP) From 6bb392a13743a3c030faca749b2db48d358aa4be Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Thu, 9 Apr 2026 14:34:57 -0700 Subject: [PATCH 02/14] Address review comments --- onnxruntime/core/providers/cuda/plugin/cuda_ep.cc | 4 +++- .../cuda/plugin/cuda_resource_partitioning_test.cc | 9 +++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc index 87d9de95bed49..8a40332199931 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc @@ -15,6 +15,8 @@ #include #include +#include "core/common/safeint.h" + namespace onnxruntime { namespace cuda_plugin { @@ -179,7 +181,7 @@ OrtStatus* ORT_API_CALL CudaEp::GetCapabilityImpl( if (has_budget && !previously_assigned) { OrtResourceCount cost = resource_budget.ComputeNodeCost(node); size_t cost_bytes = cost.AsTotalBytes(); - size_t would_be_consumed = consumed_bytes + cost_bytes; + size_t would_be_consumed = SafeInt(consumed_bytes) + cost_bytes; { // Log per-node cost information (mirrors in-tree CUDA EP logging) diff --git a/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc b/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc index 3eb7c50f9c8b2..bb5547f7d7c3d 100644 --- a/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc +++ b/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc @@ -24,6 +24,7 @@ #include "core/graph/model.h" #include "core/session/abi_session_options_impl.h" #include "core/session/inference_session.h" +#include "core/framework/error_code_helper.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/session/ort_env.h" #include "core/session/utils.h" @@ -207,7 +208,7 @@ class CudaPluginPartitioningTest : public ::testing::Test { // registers it with the session, and calls session.Initialize() which // runs graph partitioning (invoking plugin GetCapability). OrtStatus* status = InitializeSession(&ort_options, session); - ASSERT_EQ(status, nullptr) << "InitializeSession failed"; + ASSERT_STATUS_OK(ToStatusAndRelease(status)); verifier(session.GetGraph()); } @@ -236,8 +237,8 @@ TEST_F(CudaPluginPartitioningTest, NoBudget_AllNodesCudaPlugin) { TEST_F(CudaPluginPartitioningTest, LargeBudget_AllNodesCudaPlugin) { constexpr const ORTCHAR_T* model_path = ORT_TSTR("testdata/mul_1.onnx"); - // 1 TB — effectively unlimited - LoadAndVerifyPartitioning(model_path, /*budget_kb=*/1024 * 1024 * 1024, [](const Graph& graph) { + // 1 GB — effectively unlimited and safe across 32-bit/64-bit builds + LoadAndVerifyPartitioning(model_path, /*budget_kb=*/1024 * 1024, [](const Graph& graph) { for (const auto& node : graph.Nodes()) { EXPECT_NE(node.GetExecutionProviderType(), kCpuExecutionProvider) << "Node " << node.Name() << " (" << node.OpType() @@ -252,7 +253,7 @@ TEST_F(CudaPluginPartitioningTest, TinyBudget_NodesOffloadedToCpu) { // Use a model with multiple nodes so we can see some go to CPU. constexpr const ORTCHAR_T* model_path = ORT_TSTR("testdata/transformers/tiny_gpt2_beamsearch.onnx"); - // 1 byte budget — ad-hoc accountant will compute non-zero cost for any + // 1 KB budget — ad-hoc accountant will compute non-zero cost for any // node with initializers or known output shapes, so nodes must be offloaded. LoadAndVerifyPartitioning(model_path, /*budget_kb=*/1, [](const Graph& graph) { bool has_cpu_node = false; From 7f387a0e75529c42b828f3d7a57dae5686169e72 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Thu, 9 Apr 2026 16:16:18 -0700 Subject: [PATCH 03/14] Address review and build issues --- onnxruntime/core/session/plugin_ep/ep_api.cc | 10 +--------- .../cuda/plugin/cuda_resource_partitioning_test.cc | 5 ++--- 2 files changed, 3 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/session/plugin_ep/ep_api.cc b/onnxruntime/core/session/plugin_ep/ep_api.cc index a997cf37a2f76..689b651554dd6 100644 --- a/onnxruntime/core/session/plugin_ep/ep_api.cc +++ b/onnxruntime/core/session/plugin_ep/ep_api.cc @@ -5,6 +5,7 @@ #include #include +#include #include #include #include @@ -1214,15 +1215,6 @@ OrtResourceCount ToOrtResourceCount(const onnxruntime::ResourceCount& rc) { rc); } -// Convert the C-safe tagged union back to internal ResourceCount. -onnxruntime::ResourceCount FromOrtResourceCount(const OrtResourceCount& ort_rc) { - switch (ort_rc.kind) { - case OrtResourceCountKind_TotalBytes: - return onnxruntime::ResourceCount{ort_rc.value.total_bytes}; - default: - ORT_THROW("Unknown OrtResourceCountKind: ", static_cast(ort_rc.kind)); - } -} } // namespace ORT_API_STATUS_IMPL(EpGraphSupportInfo_HasResourceBudget, diff --git a/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc b/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc index bb5547f7d7c3d..718ad027cd10a 100644 --- a/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc +++ b/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc @@ -11,6 +11,7 @@ #if defined(ORT_UNIT_TEST_HAS_CUDA_PLUGIN_EP) #include +#include #include #include #include @@ -93,9 +94,7 @@ Ort::ConstEpDevice FindCudaPluginDevice(Ort::Env& env) { // Get the internal OrtEnv* from the C++ Ort::Env wrapper. // Ort::Env inherits Base which has operator OrtEnv*(). OrtEnv& GetOrtEnv() { - OrtEnv* p = static_cast(*ort_env); - ORT_ENFORCE(p != nullptr); - return *p; + return *static_cast(*ort_env); } } // namespace From fc399f05b36cdb78a65d6415f617d8e16cca725a Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Thu, 9 Apr 2026 16:41:33 -0700 Subject: [PATCH 04/14] Address review issues --- include/onnxruntime/core/session/onnxruntime_cxx_api.h | 2 +- onnxruntime/core/session/plugin_ep/ep_api.cc | 3 +++ .../providers/cuda/plugin/cuda_resource_partitioning_test.cc | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 3764dc85f683a..ea013efc84c76 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -3742,7 +3742,7 @@ Ort::KeyValuePairs GetEnvConfigEntries(); /// OrtResourceCount consumed = budget.GetConsumedResources(); /// for (const OrtNode* node : candidates) { /// OrtResourceCount cost = budget.ComputeNodeCost(Ort::ConstNode{node}); -/// if (consumed.AsTotalBytes() + cost.AsTotalBytes() > remaining.AsTotalBytes()) { +/// if (cost.AsTotalBytes() > remaining.AsTotalBytes() - consumed.AsTotalBytes()) { /// budget.SignalStopAssignment(); /// break; /// } diff --git a/onnxruntime/core/session/plugin_ep/ep_api.cc b/onnxruntime/core/session/plugin_ep/ep_api.cc index 689b651554dd6..e4bc95d844c23 100644 --- a/onnxruntime/core/session/plugin_ep/ep_api.cc +++ b/onnxruntime/core/session/plugin_ep/ep_api.cc @@ -1293,6 +1293,9 @@ ORT_API_STATUS_IMPL(EpGraphSupportInfo_ReportAcceptedNodeCost, ORT_API_RETURN_IF(node == nullptr, ORT_INVALID_ARGUMENT, "OrtNode is NULL"); auto* accountant = graph_support_info->resource_accountant; ORT_API_RETURN_IF(accountant == nullptr, ORT_INVALID_ARGUMENT, "No resource accountant is active"); + ORT_API_RETURN_IF(cost.reserved_ != 0, ORT_INVALID_ARGUMENT, "OrtResourceCount reserved_ field must be zero"); + ORT_API_RETURN_IF(cost.kind != OrtResourceCountKind_None && cost.kind != OrtResourceCountKind_TotalBytes, + ORT_INVALID_ARGUMENT, "Unsupported OrtResourceCountKind value"); graph_support_info->accepted_node_costs.emplace_back(node, cost); return nullptr; diff --git a/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc b/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc index 718ad027cd10a..e539ed132083f 100644 --- a/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc +++ b/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc @@ -246,7 +246,7 @@ TEST_F(CudaPluginPartitioningTest, LargeBudget_AllNodesCudaPlugin) { }); } -// With a tiny budget (1 byte), nodes should be offloaded to CPU because +// With a tiny budget (1 KB), nodes should be offloaded to CPU because // the resource accountant will run out of budget. TEST_F(CudaPluginPartitioningTest, TinyBudget_NodesOffloadedToCpu) { // Use a model with multiple nodes so we can see some go to CPU. From a96cbd2f3a68b49a519b311784a4bd22f120fc32 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Thu, 9 Apr 2026 18:04:04 -0700 Subject: [PATCH 05/14] Address review comments --- .../ep_plugin_provider_interfaces.cc | 2 ++ .../framework/resource_accountant_test.cc | 19 ++++++++++++++++--- .../plugin/cuda_resource_partitioning_test.cc | 5 +++-- 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc index acdd2849f5285..6c2247168cc95 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc @@ -302,6 +302,8 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie case OrtResourceCountKind_TotalBytes: indexed_sub_graph->AppendNodeCost(ResourceCount{ort_cost.value.total_bytes}); break; + case OrtResourceCountKind_None: + [[fallthrough]]; default: LOGS(logger, WARNING) << "Unknown OrtResourceCountKind: " << static_cast(ort_cost.kind) << "; skipping cost."; diff --git a/onnxruntime/test/framework/resource_accountant_test.cc b/onnxruntime/test/framework/resource_accountant_test.cc index b07b42d10ce1c..07ba198a9c323 100644 --- a/onnxruntime/test/framework/resource_accountant_test.cc +++ b/onnxruntime/test/framework/resource_accountant_test.cc @@ -15,6 +15,15 @@ #include #include +#include + +#ifdef _WIN32 +#include +#define ORT_TEST_PID _getpid() +#else +#include +#define ORT_TEST_PID getpid() +#endif namespace onnxruntime { namespace test { @@ -287,11 +296,13 @@ TEST(ResourceAccountantTest, CrossSubGraph_DedupWorks) { TEST(RealAccountantTest, StatsPath_ComputesCostFromStatsFile) { auto h = SharedWeightGraph::Create(); - // Write a stats file with known costs + // Write a stats file with known costs (unique per PID to avoid parallel collisions) std::error_code ec; auto stats_dir = std::filesystem::temp_directory_path(ec); ASSERT_FALSE(ec) << ec.message(); - auto stats_path = stats_dir / "test_resource_accountant_stats.csv"; + std::ostringstream fname; + fname << "test_resource_accountant_stats_" << ORT_TEST_PID << ".csv"; + auto stats_path = stats_dir / fname.str(); // Get the unique node names the accountant will look up std::string name_a = IResourceAccountant::MakeUniqueNodeName(*h.node_a); @@ -340,7 +351,9 @@ TEST(RealAccountantTest, StatsPath_UnknownNodeReturnsZero) { std::error_code ec; auto stats_dir = std::filesystem::temp_directory_path(ec); ASSERT_FALSE(ec) << ec.message(); - auto stats_path = stats_dir / "test_resource_accountant_empty_stats.csv"; + std::ostringstream fname; + fname << "test_resource_accountant_empty_stats_" << ORT_TEST_PID << ".csv"; + auto stats_path = stats_dir / fname.str(); { std::ofstream ofs(stats_path); diff --git a/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc b/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc index e539ed132083f..5eede700a69c6 100644 --- a/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc +++ b/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc @@ -27,6 +27,7 @@ #include "core/session/inference_session.h" #include "core/framework/error_code_helper.h" #include "core/session/onnxruntime_cxx_api.h" +#include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/ort_env.h" #include "core/session/utils.h" #include "test/util/include/asserts.h" @@ -196,7 +197,7 @@ class CudaPluginPartitioningTest : public ::testing::Test { if (budget_kb > 0) { std::string config_value = std::to_string(budget_kb) + ","; ASSERT_STATUS_OK(ort_options.value.config_options.AddConfigEntry( - "session.resource_cuda_partitioning_settings", config_value.c_str())); + kOrtSessionOptionsResourceCudaPartitioningSettings, config_value.c_str())); } // Create the session wrapper — gives us access to the graph after partitioning. @@ -304,7 +305,7 @@ class CudaResourcePartitioningTest : public ::testing::Test { if (budget_kb > 0) { std::string config_value = std::to_string(budget_kb) + ","; - so.AddConfigEntry("session.resource_cuda_partitioning_settings", + so.AddConfigEntry(kOrtSessionOptionsResourceCudaPartitioningSettings, config_value.c_str()); } From 9c8a10442ec8efec93c045efe819e935fe4bbaa6 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Fri, 10 Apr 2026 11:02:59 -0700 Subject: [PATCH 06/14] Build error --- .../providers/cuda/plugin/cuda_resource_partitioning_test.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc b/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc index 5eede700a69c6..4fb3add7dc806 100644 --- a/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc +++ b/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc @@ -301,7 +301,8 @@ class CudaResourcePartitioningTest : public ::testing::Test { Ort::Session CreateSessionWithBudget(const ORTCHAR_T* model_path, size_t budget_kb) { Ort::SessionOptions so; - so.AppendExecutionProvider_V2(*ort_env, {cuda_device_}, {}); + so.AppendExecutionProvider_V2(*ort_env, {cuda_device_}, + std::unordered_map{}); if (budget_kb > 0) { std::string config_value = std::to_string(budget_kb) + ","; From 440259d9f30679af944ce6cc95e074c4073ca04d Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Fri, 10 Apr 2026 12:14:33 -0700 Subject: [PATCH 07/14] Address Cuda Plugin EP matching and correct device id --- include/onnxruntime/core/graph/constants.h | 1 + .../core/framework/layering_annotations.cc | 15 +++++++++++---- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/include/onnxruntime/core/graph/constants.h b/include/onnxruntime/core/graph/constants.h index 4ed16ffd27264..2f2462dfa92a9 100644 --- a/include/onnxruntime/core/graph/constants.h +++ b/include/onnxruntime/core/graph/constants.h @@ -31,6 +31,7 @@ constexpr size_t kMaxExecutionProviderNameLen = 30; constexpr const char* kCpuExecutionProvider = "CPUExecutionProvider"; constexpr const char* kCudaExecutionProvider = "CUDAExecutionProvider"; +constexpr const char* kCudaPluginExecutionProvider = "CudaPluginExecutionProvider"; constexpr const char* kCudaNHWCExecutionProvider = "CUDANHWCExecutionProvider"; constexpr const char* kDnnlExecutionProvider = "DnnlExecutionProvider"; constexpr const char* kOpenVINOExecutionProvider = "OpenVINOExecutionProvider"; diff --git a/onnxruntime/core/framework/layering_annotations.cc b/onnxruntime/core/framework/layering_annotations.cc index 91df102abef17..207f39ea64939 100644 --- a/onnxruntime/core/framework/layering_annotations.cc +++ b/onnxruntime/core/framework/layering_annotations.cc @@ -183,7 +183,8 @@ bool MatchEpDevice(const EpDeviceView& ep, if (target_specifier.empty()) { if (ep.device_type == OrtDevice::GPU) return true; // Heuristic fallback for common GPU EPs if hardware info is missing - return ep.ep_name == kCudaExecutionProvider || ep.ep_name == kDmlExecutionProvider; + return ep.ep_name == kCudaExecutionProvider || ep.ep_name == kCudaPluginExecutionProvider || + ep.ep_name == kDmlExecutionProvider; } // "gpu:" or "gpu:" if (ep.device_type == OrtDevice::GPU) { @@ -203,7 +204,7 @@ bool MatchEpDevice(const EpDeviceView& ep, ep.vendor_id == OrtDevice::VendorIds::INTEL) return true; // Heuristic: gpu:nvidia -> CUDA if (CaseInsensitiveCompare(target_specifier, "nvidia") && - ep.ep_name == kCudaExecutionProvider) return true; + (ep.ep_name == kCudaExecutionProvider || ep.ep_name == kCudaPluginExecutionProvider)) return true; } return false; } @@ -225,7 +226,7 @@ bool MatchEpDevice(const EpDeviceView& ep, } // "cuda" if (CaseInsensitiveCompare(target_type_str, "cuda")) { - return ep.ep_name == kCudaExecutionProvider; + return ep.ep_name == kCudaExecutionProvider || ep.ep_name == kCudaPluginExecutionProvider; } // "dml" if (CaseInsensitiveCompare(target_type_str, "dml")) { @@ -284,7 +285,13 @@ std::optional EpLayeringMatcher::Match(gsl::spanvendor_id : 0u, - has_hw ? static_cast(ep_device.device->device_id) : OrtDevice::DeviceId{}, + // Prefer the device ordinal from device_memory_info (set by the EP factory to + // a runtime device ordinal such as a CUDA ordinal) over the OrtHardwareDevice::device_id + // which is a hardware-type identifier and not guaranteed to be a stable runtime ordinal. + ep_device.device_memory_info + ? ep_device.device_memory_info->device.Id() + : (has_hw ? static_cast(ep_device.device->device_id) + : OrtDevice::DeviceId{}), has_hw ? std::string_view(ep_device.device->vendor) : std::string_view{}}; if (MatchEpDevice(view, target_type_str, target_specifier, rule.device)) { From 87b48ff6befa2c3934e0e2f8b5e135784b8e0e2b Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Fri, 10 Apr 2026 14:14:49 -0700 Subject: [PATCH 08/14] Address review comments --- .../core/session/onnxruntime_ep_c_api.h | 14 ++++----- onnxruntime/core/session/plugin_ep/ep_api.cc | 6 ++-- .../ep_plugin_provider_interfaces.cc | 29 ++++++++++++++++--- .../plugin/cuda_resource_partitioning_test.cc | 12 ++++---- 4 files changed, 42 insertions(+), 19 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index f430e6496066d..a29c669ece3f8 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -956,7 +956,7 @@ struct OrtScanKernelHelper { */ typedef enum OrtResourceCountKind { OrtResourceCountKind_None = 0, ///< Unset / zero-cost sentinel. - OrtResourceCountKind_TotalBytes = 1, ///< Single size_t: total estimated bytes. + OrtResourceCountKind_TotalBytes = 1, ///< Single uint64_t: total estimated bytes. } OrtResourceCountKind; /** @@ -972,12 +972,12 @@ typedef enum OrtResourceCountKind { * \since Version 1.26. */ typedef struct OrtResourceCount { - OrtResourceCountKind kind; + uint32_t kind; ///< OrtResourceCountKind discriminator. Use uint32_t for ABI stability. uint32_t reserved_; ///< Alignment padding + future flags. Must be zero. union { - size_t total_bytes; ///< Active when kind == OrtResourceCountKind_TotalBytes. - uint8_t _storage[48]; ///< ABI reserve: all types must fit within 48 bytes. + uint64_t total_bytes; ///< Active when kind == OrtResourceCountKind_TotalBytes. + uint64_t _storage[6]; ///< ABI reserve (48 bytes): all types must fit within this. } value; #ifdef __cplusplus @@ -988,7 +988,7 @@ typedef struct OrtResourceCount { } /// Create a TotalBytes resource count. - static OrtResourceCount FromTotalBytes(size_t bytes) { + static OrtResourceCount FromTotalBytes(uint64_t bytes) { OrtResourceCount rc{}; rc.kind = OrtResourceCountKind_TotalBytes; rc.value.total_bytes = bytes; @@ -996,7 +996,7 @@ typedef struct OrtResourceCount { } /// Extract total_bytes. Caller must check kind == OrtResourceCountKind_TotalBytes first. - size_t AsTotalBytes() const { + uint64_t AsTotalBytes() const { return value.total_bytes; } #endif @@ -2087,7 +2087,7 @@ struct OrtEpApi { * Only valid if EpGraphSupportInfo_HasResourceBudget returns true. * If the accountant has no explicit threshold (e.g. auto-detection mode), * the returned OrtResourceCount will have kind == OrtResourceCountKind_TotalBytes with - * value.total_bytes set to SIZE_MAX. + * value.total_bytes set to UINT64_MAX. * * \param[in] graph_support_info The OrtEpGraphSupportInfo instance from OrtEp::GetCapability(). * \param[out] budget Output parameter set to the total resource budget. diff --git a/onnxruntime/core/session/plugin_ep/ep_api.cc b/onnxruntime/core/session/plugin_ep/ep_api.cc index e4bc95d844c23..03cc776ded7f0 100644 --- a/onnxruntime/core/session/plugin_ep/ep_api.cc +++ b/onnxruntime/core/session/plugin_ep/ep_api.cc @@ -1209,8 +1209,10 @@ OrtResourceCount ToOrtResourceCount(const onnxruntime::ResourceCount& rc) { using T = std::decay_t; if constexpr (std::is_same_v) { return OrtResourceCount::FromTotalBytes(val); + } else { + // If ResourceCount gains new variant members, add conversion branches above. + static_assert(sizeof(T) == 0, "Unhandled ResourceCount variant member in ToOrtResourceCount"); } - // Future variant members: add else-if branches here and return OrtResourceCount with appropriate kind. }, rc); } @@ -1243,7 +1245,7 @@ ORT_API_STATUS_IMPL(EpGraphSupportInfo_GetResourceBudget, if (threshold) { *budget = ToOrtResourceCount(*threshold); } else { - *budget = OrtResourceCount::FromTotalBytes(std::numeric_limits::max()); + *budget = OrtResourceCount::FromTotalBytes(std::numeric_limits::max()); } return nullptr; API_IMPL_END diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc index 6909c6c42ecb4..3d92e6cf38f86 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc @@ -310,11 +310,12 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie // Convert OrtResourceCount tagged union back to internal ResourceCount (std::variant). const OrtResourceCount& ort_cost = cost_it->second; switch (ort_cost.kind) { + case OrtResourceCountKind_None: + indexed_sub_graph->AppendNodeCost(ResourceCount{size_t{0}}); + break; case OrtResourceCountKind_TotalBytes: indexed_sub_graph->AppendNodeCost(ResourceCount{ort_cost.value.total_bytes}); break; - case OrtResourceCountKind_None: - [[fallthrough]]; default: LOGS(logger, WARNING) << "Unknown OrtResourceCountKind: " << static_cast(ort_cost.kind) << "; skipping cost."; @@ -364,8 +365,8 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie // be true because we've already checked that the EP did not try to claim nodes already assigned to another EP. // TODO(adrianlizarraga): This check can be removed when we stop using utils::CreateSupportedPartitions() above. std::vector& capability_node_indices = capabilities[0]->sub_graph->nodes; - std::unordered_set capability_node_indices_set(capability_node_indices.begin(), - capability_node_indices.end()); + InlinedHashSet capability_node_indices_set(capability_node_indices.begin(), + capability_node_indices.end()); if (node_set.size() != capability_node_indices_set.size()) { LOGS(logger, ERROR) << "OrtEp::GetCapability() for " << Type() @@ -373,6 +374,26 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie return {}; } + // Attach resource accounting for fused capabilities. + // Compute per-component-node costs from the accountant so that: + // - nodes_costs.size() == nodes.size() (required by IsAccountingEnabled()) + // - AccountForAllNodes() later calls CommitWeightsForNode() for each component, + // which finalizes the pending/committed weight dedup state in the accountant. + if (resource_accountant != nullptr) { + auto* fused_sub_graph = capabilities[0]->sub_graph.get(); + fused_sub_graph->SetAccountant(resource_accountant); + + for (NodeIndex idx : fused_sub_graph->nodes) { + const Node* node = graph_viewer.GetNode(idx); + // Append a cost for every node to keep nodes_costs aligned with nodes. + // If the node can't be found (shouldn't happen), append zero cost so + // CommitWeightsForNode() is still called for the correct node index. + fused_sub_graph->AppendNodeCost( + node != nullptr ? resource_accountant->ComputeResourceCount(*node) + : ResourceCount{size_t{0}}); + } + } + result.push_back(std::move(capabilities[0])); } else { LOGS(logger, ERROR) << "PluginExecutionProvider::GetCapability() has invalid NodeGroupingKind: " diff --git a/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc b/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc index 4fb3add7dc806..514742376c3ff 100644 --- a/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc +++ b/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc @@ -110,31 +110,31 @@ TEST(OrtResourceCountTest, None_HasKindNone) { } TEST(OrtResourceCountTest, FromTotalBytes_RoundTrips) { - constexpr size_t kTestValue = 42 * 1024 * 1024; // 42 MB + constexpr uint64_t kTestValue = 42 * 1024 * 1024; // 42 MB OrtResourceCount rc = OrtResourceCount::FromTotalBytes(kTestValue); EXPECT_EQ(rc.kind, OrtResourceCountKind_TotalBytes); EXPECT_EQ(rc.AsTotalBytes(), kTestValue); } TEST(OrtResourceCountTest, FromTotalBytes_MaxValue) { - OrtResourceCount rc = OrtResourceCount::FromTotalBytes(std::numeric_limits::max()); + OrtResourceCount rc = OrtResourceCount::FromTotalBytes(std::numeric_limits::max()); EXPECT_EQ(rc.kind, OrtResourceCountKind_TotalBytes); - EXPECT_EQ(rc.AsTotalBytes(), std::numeric_limits::max()); + EXPECT_EQ(rc.AsTotalBytes(), std::numeric_limits::max()); } TEST(OrtResourceCountTest, FromTotalBytes_Zero) { OrtResourceCount rc = OrtResourceCount::FromTotalBytes(0); EXPECT_EQ(rc.kind, OrtResourceCountKind_TotalBytes); - EXPECT_EQ(rc.AsTotalBytes(), size_t{0}); + EXPECT_EQ(rc.AsTotalBytes(), uint64_t{0}); } TEST(OrtResourceCountTest, CopySemantics) { OrtResourceCount original = OrtResourceCount::FromTotalBytes(12345); OrtResourceCount copy = original; EXPECT_EQ(copy.kind, OrtResourceCountKind_TotalBytes); - EXPECT_EQ(copy.AsTotalBytes(), size_t{12345}); + EXPECT_EQ(copy.AsTotalBytes(), uint64_t{12345}); copy.value.total_bytes = 99999; - EXPECT_EQ(original.AsTotalBytes(), size_t{12345}); + EXPECT_EQ(original.AsTotalBytes(), uint64_t{12345}); } TEST(OrtResourceCountTest, ReservedFieldIsZero) { From c256c8ca465fc04feae1357088893db2345d7359 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Fri, 10 Apr 2026 16:11:48 -0700 Subject: [PATCH 09/14] Move accounting to Host --- .../core/framework/resource_accountant.h | 16 +- .../core/session/onnxruntime_cxx_api.h | 52 ------ .../core/session/onnxruntime_cxx_inline.h | 38 ---- .../core/session/onnxruntime_ep_c_api.h | 172 ------------------ .../core/framework/graph_partitioner.cc | 4 +- .../core/framework/resource_accountant.cc | 2 +- .../core/providers/cuda/plugin/cuda_ep.cc | 63 +------ .../cuda/plugin/cuda_kernel_adapter.h | 1 - onnxruntime/core/session/abi_ep_types.h | 4 - onnxruntime/core/session/plugin_ep/ep_api.cc | 145 --------------- onnxruntime/core/session/plugin_ep/ep_api.h | 14 -- .../ep_plugin_provider_interfaces.cc | 134 ++++++++------ .../framework/resource_accountant_test.cc | 28 ++- .../plugin/cuda_resource_partitioning_test.cc | 47 +---- 14 files changed, 127 insertions(+), 593 deletions(-) diff --git a/include/onnxruntime/core/framework/resource_accountant.h b/include/onnxruntime/core/framework/resource_accountant.h index 7bb5a993d140b..0e89082b0ec6a 100644 --- a/include/onnxruntime/core/framework/resource_accountant.h +++ b/include/onnxruntime/core/framework/resource_accountant.h @@ -61,9 +61,14 @@ class IResourceAccountant { bool IsStopIssued() const noexcept { return stop_assignment_; } - // Called before each GetCapability pass to discard pending weight tracking - // from a previous (discarded) pass. Default no-op for stats-based accountants. - virtual void ResetPendingWeights() {} + // Called before each GetCapability pass to reset per-pass state: + // clears the stop flag (which only applies to the pass that set it) + // and discards pending weight tracking from a previous (discarded) pass. + // Subclasses override ResetPendingWeightsImpl for EP-specific cleanup. + void ResetForNewPass() { + stop_assignment_ = false; + ResetPendingWeightsImpl(); + } // Called when a node's cost is committed (AccountForNode/AccountForAllNodes). // Moves the node's pending weights into the committed set so they persist @@ -72,6 +77,11 @@ class IResourceAccountant { static std::string MakeUniqueNodeName(const Node& node); + protected: + // Override to discard EP-specific pending weight tracking. + // Default no-op for stats-based accountants. + virtual void ResetPendingWeightsImpl() {} + private: bool stop_assignment_ = false; std::optional threshold_; diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 15743fe62242e..7f0b16e2fdee7 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -3730,57 +3730,5 @@ using UnownedSharedPrePackedWeightCache = ///< Wraps OrtEpApi::GetEnvConfigEntries() Ort::KeyValuePairs GetEnvConfigEntries(); -/// \brief Non-owning C++ wrapper for resource budget queries on OrtEpGraphSupportInfo. -/// -/// Constructed from the OrtEpGraphSupportInfo* passed to OrtEp::GetCapability. -/// Provides convenient methods for resource-constrained node selection. -/// All costs and budgets use OrtResourceCount, the ABI-stable tagged union. -/// -/// Example use in a plugin EP's GetCapability implementation: -/// \code -/// OrtStatus* GetCapabilityImpl(OrtEp*, const OrtGraph* graph, -/// OrtEpGraphSupportInfo* info) noexcept { -/// Ort::ResourceBudget budget(info); -/// if (budget.HasBudget()) { -/// OrtResourceCount remaining = budget.GetBudget(); -/// OrtResourceCount consumed = budget.GetConsumedResources(); -/// for (const OrtNode* node : candidates) { -/// OrtResourceCount cost = budget.ComputeNodeCost(Ort::ConstNode{node}); -/// if (cost.AsTotalBytes() > remaining.AsTotalBytes() - consumed.AsTotalBytes()) { -/// budget.SignalStopAssignment(); -/// break; -/// } -/// budget.ReportAcceptedNodeCost(Ort::ConstNode{node}, cost); -/// } -/// } -/// } -/// \endcode -struct ResourceBudget { - explicit ResourceBudget(OrtEpGraphSupportInfo* info) : info_(info) {} - - /// Returns true if a resource budget is configured for this EP. - bool HasBudget() const; - - /// Returns the total resource budget. Only valid if HasBudget() is true. - OrtResourceCount GetBudget() const; - - /// Returns the amount of resources already consumed. - OrtResourceCount GetConsumedResources() const; - - /// Computes the estimated resource cost of the given node. - OrtResourceCount ComputeNodeCost(ConstNode node) const; - - /// Reports that the plugin accepted a node at the given cost. - void ReportAcceptedNodeCost(ConstNode node, OrtResourceCount cost); - - /// Returns true if stop has been signaled (by this or another EP). - bool IsStopIssued() const; - - /// Signals that this EP wants to stop receiving nodes. - void SignalStopAssignment(); - - private: - OrtEpGraphSupportInfo* info_; -}; } // namespace Ort #include "onnxruntime_cxx_inline.h" diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 162f513fcae14..cb145c2b6c10a 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -4175,42 +4175,4 @@ inline OpSchema GetOpSchema(const char* name, int max_inclusive_version, const c return OpSchema{schema}; } -// ResourceBudget implementation -inline bool ResourceBudget::HasBudget() const { - bool has_budget = false; - ThrowOnError(GetEpApi().EpGraphSupportInfo_HasResourceBudget(info_, &has_budget)); - return has_budget; -} - -inline OrtResourceCount ResourceBudget::GetBudget() const { - OrtResourceCount budget = OrtResourceCount::None(); - ThrowOnError(GetEpApi().EpGraphSupportInfo_GetResourceBudget(info_, &budget)); - return budget; -} - -inline OrtResourceCount ResourceBudget::GetConsumedResources() const { - OrtResourceCount consumed = OrtResourceCount::None(); - ThrowOnError(GetEpApi().EpGraphSupportInfo_GetConsumedResources(info_, &consumed)); - return consumed; -} - -inline OrtResourceCount ResourceBudget::ComputeNodeCost(ConstNode node) const { - OrtResourceCount cost = OrtResourceCount::None(); - ThrowOnError(GetEpApi().EpGraphSupportInfo_ComputeNodeResourceCost(info_, node, &cost)); - return cost; -} - -inline void ResourceBudget::ReportAcceptedNodeCost(ConstNode node, OrtResourceCount cost) { - ThrowOnError(GetEpApi().EpGraphSupportInfo_ReportAcceptedNodeCost(info_, node, cost)); -} - -inline bool ResourceBudget::IsStopIssued() const { - bool stop = false; - ThrowOnError(GetEpApi().EpGraphSupportInfo_IsStopIssued(info_, &stop)); - return stop; -} - -inline void ResourceBudget::SignalStopAssignment() { - ThrowOnError(GetEpApi().EpGraphSupportInfo_SignalStopAssignment(info_)); -} } // namespace Ort diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index a29c669ece3f8..8ff15e5c35ed5 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -946,62 +946,6 @@ struct OrtScanKernelHelper { _In_ const OrtValue* input, _In_opt_ OrtSyncStream* stream, _Inout_ OrtValue* output); }; -/** - * \brief Discriminator for the resource count type stored in an OrtResourceCount. - * - * New resource accounting types can be added by appending new enum values. - * The OrtResourceCount union storage is large enough to hold all current and future types. - * - * \since Version 1.26. - */ -typedef enum OrtResourceCountKind { - OrtResourceCountKind_None = 0, ///< Unset / zero-cost sentinel. - OrtResourceCountKind_TotalBytes = 1, ///< Single uint64_t: total estimated bytes. -} OrtResourceCountKind; - -/** - * \brief ABI-stable tagged union representing a resource cost or budget. - * - * This struct is a C-safe variant that can be passed by value across the plugin DLL boundary. - * The `kind` field selects which union member is active. The `_storage` member reserves space - * for future resource types without changing the struct layout. - * - * Adding new resource types requires only: (a) a new OrtResourceCountKind enum value, - * (b) a new union member. No new C API functions are needed. - * - * \since Version 1.26. - */ -typedef struct OrtResourceCount { - uint32_t kind; ///< OrtResourceCountKind discriminator. Use uint32_t for ABI stability. - uint32_t reserved_; ///< Alignment padding + future flags. Must be zero. - - union { - uint64_t total_bytes; ///< Active when kind == OrtResourceCountKind_TotalBytes. - uint64_t _storage[6]; ///< ABI reserve (48 bytes): all types must fit within this. - } value; - -#ifdef __cplusplus - /// Create a zero-cost (None) resource count. - static OrtResourceCount None() { - OrtResourceCount rc{}; - return rc; - } - - /// Create a TotalBytes resource count. - static OrtResourceCount FromTotalBytes(uint64_t bytes) { - OrtResourceCount rc{}; - rc.kind = OrtResourceCountKind_TotalBytes; - rc.value.total_bytes = bytes; - return rc; - } - - /// Extract total_bytes. Caller must check kind == OrtResourceCountKind_TotalBytes first. - uint64_t AsTotalBytes() const { - return value.total_bytes; - } -#endif -} OrtResourceCount; - /** * \brief The OrtEpApi struct provides functions that are relevant to the implementation of an execution provider. * @@ -2066,122 +2010,6 @@ struct OrtEpApi { ORT_API2_STATUS(ProfilingEventsContainer_AddEvents, _In_ OrtProfilingEventsContainer* events_container, _In_reads_(num_events) const OrtProfilingEvent* const* events, _In_ size_t num_events); - - /** \brief Query whether resource accounting is active for this GetCapability call. - * - * Returns true if a resource accountant is attached to the given OrtEpGraphSupportInfo instance, - * meaning the session was configured with resource-constrained partitioning settings. - * - * \param[in] graph_support_info The OrtEpGraphSupportInfo instance from OrtEp::GetCapability(). - * \param[out] has_budget Output parameter set to true if a resource budget is active. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.26. - */ - ORT_API2_STATUS(EpGraphSupportInfo_HasResourceBudget, _In_ const OrtEpGraphSupportInfo* graph_support_info, - _Out_ bool* has_budget); - - /** \brief Get the total resource budget. - * - * Only valid if EpGraphSupportInfo_HasResourceBudget returns true. - * If the accountant has no explicit threshold (e.g. auto-detection mode), - * the returned OrtResourceCount will have kind == OrtResourceCountKind_TotalBytes with - * value.total_bytes set to UINT64_MAX. - * - * \param[in] graph_support_info The OrtEpGraphSupportInfo instance from OrtEp::GetCapability(). - * \param[out] budget Output parameter set to the total resource budget. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.26. - */ - ORT_API2_STATUS(EpGraphSupportInfo_GetResourceBudget, _In_ const OrtEpGraphSupportInfo* graph_support_info, - _Out_ OrtResourceCount* budget); - - /** \brief Get the amount of resources already consumed from prior partitioning passes or previously assigned nodes. - * - * Only valid if EpGraphSupportInfo_HasResourceBudget returns true. - * - * \param[in] graph_support_info The OrtEpGraphSupportInfo instance from OrtEp::GetCapability(). - * \param[out] consumed Output parameter set to the consumed resource amount. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.26. - */ - ORT_API2_STATUS(EpGraphSupportInfo_GetConsumedResources, _In_ const OrtEpGraphSupportInfo* graph_support_info, - _Out_ OrtResourceCount* consumed); - - /** \brief Compute the estimated resource cost for a node. - * - * Uses pre-recorded memory statistics if available, otherwise estimates from initializer sizes - * and static output shapes with a safety multiplier. - * - * Only valid if EpGraphSupportInfo_HasResourceBudget returns true. - * - * \param[in] graph_support_info The OrtEpGraphSupportInfo instance from OrtEp::GetCapability(). - * \param[in] node The OrtNode for which to compute the resource cost. - * \param[out] cost Output parameter set to the estimated resource cost. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.26. - */ - ORT_API2_STATUS(EpGraphSupportInfo_ComputeNodeResourceCost, _In_ OrtEpGraphSupportInfo* graph_support_info, - _In_ const OrtNode* node, _Out_ OrtResourceCount* cost); - - /** \brief Report that a node was accepted and its cost should be tracked. - * - * The cost is stored internally so the host can attach it to the IndexedSubGraph after - * GetCapability returns. This does NOT commit the cost to the accountant's consumed amount — - * that happens later during node assignment by the graph partitioner. - * - * Only valid if EpGraphSupportInfo_HasResourceBudget returns true. - * - * \param[in] graph_support_info The OrtEpGraphSupportInfo instance from OrtEp::GetCapability(). - * \param[in] node The OrtNode whose cost is being reported. - * \param[in] cost The cost (as returned by EpGraphSupportInfo_ComputeNodeResourceCost). - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.26. - */ - ORT_API2_STATUS(EpGraphSupportInfo_ReportAcceptedNodeCost, _In_ OrtEpGraphSupportInfo* graph_support_info, - _In_ const OrtNode* node, _In_ OrtResourceCount cost); - - /** \brief Query whether a previous GetCapability pass already signaled stop. - * - * Returns true if EpGraphSupportInfo_SignalStopAssignment was called in a prior pass - * (or by another mechanism). The plugin can use this to early-exit from GetCapability - * without re-evaluating nodes. - * - * Only valid if EpGraphSupportInfo_HasResourceBudget returns true. - * - * \param[in] graph_support_info The OrtEpGraphSupportInfo instance from OrtEp::GetCapability(). - * \param[out] is_stopped Output parameter set to true if stop was previously signaled. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.26. - */ - ORT_API2_STATUS(EpGraphSupportInfo_IsStopIssued, _In_ const OrtEpGraphSupportInfo* graph_support_info, - _Out_ bool* is_stopped); - - /** \brief Signal that the EP wants to stop accepting further nodes due to budget exhaustion. - * - * After this call, the accountant's stop flag is set. Subsequent GetCapability calls for this EP - * will see EpGraphSupportInfo_IsStopIssued() returning true and can return early. - * - * Only valid if EpGraphSupportInfo_HasResourceBudget returns true. - * - * \param[in] graph_support_info The OrtEpGraphSupportInfo instance from OrtEp::GetCapability(). - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.26. - */ - ORT_API2_STATUS(EpGraphSupportInfo_SignalStopAssignment, _In_ OrtEpGraphSupportInfo* graph_support_info); }; /** diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index cc65142318d02..a3def4a81bb50 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -277,7 +277,7 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l ORT_RETURN_IF_ERROR(create_graph_viewer(sub_graph_holder, graph_viewer)); if (params.resource_accountant) { - params.resource_accountant->ResetPendingWeights(); + params.resource_accountant->ResetForNewPass(); } capabilities = get_capabilities(current_ep, *graph_viewer, kernel_lookup, params.resource_accountant, graph_optimizer_registry); @@ -348,7 +348,7 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l ORT_RETURN_IF_ERROR(create_graph_viewer(sub_graph_holder, graph_viewer)); if (params.resource_accountant) { - params.resource_accountant->ResetPendingWeights(); + params.resource_accountant->ResetForNewPass(); } capabilities = get_capabilities(current_ep, *graph_viewer, kernel_lookup, params.resource_accountant, graph_optimizer_registry); diff --git a/onnxruntime/core/framework/resource_accountant.cc b/onnxruntime/core/framework/resource_accountant.cc index 68610ebb4be17..bd450bb4112ef 100644 --- a/onnxruntime/core/framework/resource_accountant.cc +++ b/onnxruntime/core/framework/resource_accountant.cc @@ -120,7 +120,7 @@ class SizeBasedStatsAccountant : public IResourceAccountant { } } - void ResetPendingWeights() override { + void ResetPendingWeightsImpl() override { pending_weights_.clear(); pending_weights_by_node_.clear(); } diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc index c3187e9bbda2c..54b8c9cb4a216 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc @@ -18,7 +18,7 @@ #include #include -#include "core/common/safeint.h" +#include "core/graph/constants.h" namespace onnxruntime { namespace cuda_plugin { @@ -183,18 +183,6 @@ OrtStatus* ORT_API_CALL CudaEp::GetCapabilityImpl( auto* ep = static_cast(this_ptr); const OrtEpApi& ep_api = ep->factory_.GetEpApi(); - // Early exit if a previous GetCapability pass already signaled stop. - // This mirrors the in-tree CUDA EP's check at the top of GetCapability(). - Ort::ResourceBudget resource_budget(graph_support_info); - bool has_budget = resource_budget.HasBudget(); - if (has_budget && resource_budget.IsStopIssued()) { - Ort::Status log_status(Ort::GetApi().Logger_LogMessage( - &ep->logger_, ORT_LOGGING_LEVEL_WARNING, - "CUDA Plugin EP returning due to Stop Set", - ORT_FILE, __LINE__, __FUNCTION__)); - return nullptr; - } - Ort::ConstGraph graph{ort_graph}; std::vector all_nodes = graph.GetNodes(); @@ -241,59 +229,16 @@ OrtStatus* ORT_API_CALL CudaEp::GetCapabilityImpl( gsl::span(tentative_nodes.data(), tentative_nodes.size()), cpu_preferred_nodes)); - // Phase 3: Add final supported nodes (tentative minus CPU-preferred), - // respecting the optional resource budget. - // resource_budget and has_budget were computed at the top of this function. - size_t budget_bytes = std::numeric_limits::max(); - size_t consumed_bytes = 0; - if (has_budget) { - budget_bytes = resource_budget.GetBudget().AsTotalBytes(); - consumed_bytes = resource_budget.GetConsumedResources().AsTotalBytes(); - } + // Phase 3: Add final supported nodes (tentative minus CPU-preferred). + // Resource budget enforcement is handled by the host after GetCapability returns. for (const OrtNode* ort_node : candidate_nodes) { if (cpu_preferred_nodes.count(ort_node) != 0) { continue; } - // Previously assigned nodes (ep_name matched) are already accounted for. - Ort::ConstNode node{ort_node}; - bool previously_assigned = !node.GetEpName().empty(); - - if (has_budget && !previously_assigned) { - OrtResourceCount cost = resource_budget.ComputeNodeCost(node); - size_t cost_bytes = cost.AsTotalBytes(); - size_t would_be_consumed = SafeInt(consumed_bytes) + cost_bytes; - - { - // Log per-node cost information (mirrors in-tree CUDA EP logging) - std::string msg = "CUDA Plugin EP Node: " + node.GetName() + - " Memory usage: " + std::to_string(cost_bytes) + - " would be consumed: " + std::to_string(would_be_consumed) + - " threshold: " + std::to_string(budget_bytes); - Ort::Status log_status(Ort::GetApi().Logger_LogMessage( - &ep->logger_, ORT_LOGGING_LEVEL_INFO, - msg.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); - } - - if (would_be_consumed > budget_bytes) { - { - std::string msg = "CUDA Plugin EP Halting assignment due to capacity threshold at node: " + - node.GetName(); - Ort::Status log_status(Ort::GetApi().Logger_LogMessage( - &ep->logger_, ORT_LOGGING_LEVEL_WARNING, - msg.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); - } - resource_budget.SignalStopAssignment(); - break; // topological-order halt - } - - consumed_bytes = would_be_consumed; - resource_budget.ReportAcceptedNodeCost(node, cost); - } - RETURN_IF_ERROR(ep_api.EpGraphSupportInfo_AddSingleNode( - graph_support_info, node)); + graph_support_info, ort_node)); } return nullptr; diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h index 735e3eb660c62..b1da0aa816a03 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h @@ -99,7 +99,6 @@ class OrtStreamAdapter { #include "core/providers/common.h" namespace onnxruntime { -inline constexpr const char* kCudaPluginExecutionProvider = "CudaPluginExecutionProvider"; // Forward declaration of GetEnvironmentVar for plugin builds on Windows. // Defined in provider_api_shims.cc; mirrors the provider_api.h declaration diff --git a/onnxruntime/core/session/abi_ep_types.h b/onnxruntime/core/session/abi_ep_types.h index c7ae4704494ff..8e0691b985dc1 100644 --- a/onnxruntime/core/session/abi_ep_types.h +++ b/onnxruntime/core/session/abi_ep_types.h @@ -55,8 +55,4 @@ struct OrtEpGraphSupportInfo { // Optional resource accountant for capacity-aware partitioning. // Owned by the graph partitioner; lifetime exceeds this struct. onnxruntime::IResourceAccountant* resource_accountant = nullptr; - - // Per-node costs reported by the plugin via EpGraphSupportInfo_ReportAcceptedNodeCost. - // Keyed by the OrtNode pointer (same lifetime as the EpGraph). - std::vector> accepted_node_costs; }; diff --git a/onnxruntime/core/session/plugin_ep/ep_api.cc b/onnxruntime/core/session/plugin_ep/ep_api.cc index 03cc776ded7f0..d56f4299402b5 100644 --- a/onnxruntime/core/session/plugin_ep/ep_api.cc +++ b/onnxruntime/core/session/plugin_ep/ep_api.cc @@ -5,7 +5,6 @@ #include #include -#include #include #include #include @@ -34,7 +33,6 @@ #include "core/session/plugin_ep/ep_control_flow_kernel_impls.h" #include "core/session/utils.h" #include "core/common/profiler_common.h" -#include "core/framework/resource_accountant.h" #include "core/session/plugin_ep/ep_event_profiling.h" using namespace onnxruntime; @@ -1200,138 +1198,6 @@ ORT_API_STATUS_IMPL(ProfilingEventsContainer_AddEvents, API_IMPL_END } -// Resource accounting for capacity-aware partitioning - -namespace { -// Convert internal ResourceCount (std::variant) to the C-safe tagged union. -OrtResourceCount ToOrtResourceCount(const onnxruntime::ResourceCount& rc) { - return std::visit([](auto&& val) -> OrtResourceCount { - using T = std::decay_t; - if constexpr (std::is_same_v) { - return OrtResourceCount::FromTotalBytes(val); - } else { - // If ResourceCount gains new variant members, add conversion branches above. - static_assert(sizeof(T) == 0, "Unhandled ResourceCount variant member in ToOrtResourceCount"); - } - }, - rc); -} - -} // namespace - -ORT_API_STATUS_IMPL(EpGraphSupportInfo_HasResourceBudget, - _In_ const OrtEpGraphSupportInfo* graph_support_info, - _Out_ bool* has_budget) { - API_IMPL_BEGIN - ORT_API_RETURN_IF(graph_support_info == nullptr, ORT_INVALID_ARGUMENT, - "OrtEpGraphSupportInfo instance is NULL"); - ORT_API_RETURN_IF(has_budget == nullptr, ORT_INVALID_ARGUMENT, "has_budget must not be NULL"); - *has_budget = (graph_support_info->resource_accountant != nullptr); - return nullptr; - API_IMPL_END -} - -ORT_API_STATUS_IMPL(EpGraphSupportInfo_GetResourceBudget, - _In_ const OrtEpGraphSupportInfo* graph_support_info, - _Out_ OrtResourceCount* budget) { - API_IMPL_BEGIN - ORT_API_RETURN_IF(graph_support_info == nullptr, ORT_INVALID_ARGUMENT, - "OrtEpGraphSupportInfo instance is NULL"); - ORT_API_RETURN_IF(budget == nullptr, ORT_INVALID_ARGUMENT, "budget must not be NULL"); - auto* accountant = graph_support_info->resource_accountant; - ORT_API_RETURN_IF(accountant == nullptr, ORT_INVALID_ARGUMENT, "No resource accountant is active"); - - auto threshold = accountant->GetThreshold(); - if (threshold) { - *budget = ToOrtResourceCount(*threshold); - } else { - *budget = OrtResourceCount::FromTotalBytes(std::numeric_limits::max()); - } - return nullptr; - API_IMPL_END -} - -ORT_API_STATUS_IMPL(EpGraphSupportInfo_GetConsumedResources, - _In_ const OrtEpGraphSupportInfo* graph_support_info, - _Out_ OrtResourceCount* consumed) { - API_IMPL_BEGIN - ORT_API_RETURN_IF(graph_support_info == nullptr, ORT_INVALID_ARGUMENT, - "OrtEpGraphSupportInfo instance is NULL"); - ORT_API_RETURN_IF(consumed == nullptr, ORT_INVALID_ARGUMENT, "consumed must not be NULL"); - auto* accountant = graph_support_info->resource_accountant; - ORT_API_RETURN_IF(accountant == nullptr, ORT_INVALID_ARGUMENT, "No resource accountant is active"); - - *consumed = ToOrtResourceCount(accountant->GetConsumedAmount()); - return nullptr; - API_IMPL_END -} - -ORT_API_STATUS_IMPL(EpGraphSupportInfo_ComputeNodeResourceCost, - _In_ OrtEpGraphSupportInfo* graph_support_info, - _In_ const OrtNode* node, _Out_ OrtResourceCount* cost) { - API_IMPL_BEGIN - ORT_API_RETURN_IF(graph_support_info == nullptr, ORT_INVALID_ARGUMENT, - "OrtEpGraphSupportInfo instance is NULL"); - ORT_API_RETURN_IF(node == nullptr, ORT_INVALID_ARGUMENT, "OrtNode is NULL"); - ORT_API_RETURN_IF(cost == nullptr, ORT_INVALID_ARGUMENT, "cost must not be NULL"); - auto* accountant = graph_support_info->resource_accountant; - ORT_API_RETURN_IF(accountant == nullptr, ORT_INVALID_ARGUMENT, "No resource accountant is active"); - - const auto* ep_node = onnxruntime::EpNode::ToInternal(node); - ORT_API_RETURN_IF(ep_node == nullptr, ORT_INVALID_ARGUMENT, "Invalid OrtNode variant"); - const onnxruntime::Node& internal_node = ep_node->GetInternalNode(); - - *cost = ToOrtResourceCount(accountant->ComputeResourceCount(internal_node)); - return nullptr; - API_IMPL_END -} - -ORT_API_STATUS_IMPL(EpGraphSupportInfo_ReportAcceptedNodeCost, - _In_ OrtEpGraphSupportInfo* graph_support_info, - _In_ const OrtNode* node, _In_ OrtResourceCount cost) { - API_IMPL_BEGIN - ORT_API_RETURN_IF(graph_support_info == nullptr, ORT_INVALID_ARGUMENT, - "OrtEpGraphSupportInfo instance is NULL"); - ORT_API_RETURN_IF(node == nullptr, ORT_INVALID_ARGUMENT, "OrtNode is NULL"); - auto* accountant = graph_support_info->resource_accountant; - ORT_API_RETURN_IF(accountant == nullptr, ORT_INVALID_ARGUMENT, "No resource accountant is active"); - ORT_API_RETURN_IF(cost.reserved_ != 0, ORT_INVALID_ARGUMENT, "OrtResourceCount reserved_ field must be zero"); - ORT_API_RETURN_IF(cost.kind != OrtResourceCountKind_None && cost.kind != OrtResourceCountKind_TotalBytes, - ORT_INVALID_ARGUMENT, "Unsupported OrtResourceCountKind value"); - - graph_support_info->accepted_node_costs.emplace_back(node, cost); - return nullptr; - API_IMPL_END -} - -ORT_API_STATUS_IMPL(EpGraphSupportInfo_IsStopIssued, - _In_ const OrtEpGraphSupportInfo* graph_support_info, - _Out_ bool* is_stopped) { - API_IMPL_BEGIN - ORT_API_RETURN_IF(graph_support_info == nullptr, ORT_INVALID_ARGUMENT, - "OrtEpGraphSupportInfo instance is NULL"); - ORT_API_RETURN_IF(is_stopped == nullptr, ORT_INVALID_ARGUMENT, "is_stopped must not be NULL"); - auto* accountant = graph_support_info->resource_accountant; - ORT_API_RETURN_IF(accountant == nullptr, ORT_INVALID_ARGUMENT, "No resource accountant is active"); - - *is_stopped = accountant->IsStopIssued(); - return nullptr; - API_IMPL_END -} - -ORT_API_STATUS_IMPL(EpGraphSupportInfo_SignalStopAssignment, - _In_ OrtEpGraphSupportInfo* graph_support_info) { - API_IMPL_BEGIN - ORT_API_RETURN_IF(graph_support_info == nullptr, ORT_INVALID_ARGUMENT, - "OrtEpGraphSupportInfo instance is NULL"); - auto* accountant = graph_support_info->resource_accountant; - ORT_API_RETURN_IF(accountant == nullptr, ORT_INVALID_ARGUMENT, "No resource accountant is active"); - - accountant->SetStopAssignment(); - return nullptr; - API_IMPL_END -} - static constexpr OrtEpApi ort_ep_api = { // NOTE: ABI compatibility depends on the order within this struct so all additions must be at the end, // and no functions can be removed (the implementation needs to change to return an error). @@ -1421,15 +1287,6 @@ static constexpr OrtEpApi ort_ep_api = { &OrtExecutionProviderApi::ProfilingEvent_GetArgValue, &OrtExecutionProviderApi::ProfilingEventsContainer_AddEvents, // End of Version 25 - DO NOT MODIFY ABOVE - - &OrtExecutionProviderApi::EpGraphSupportInfo_HasResourceBudget, - &OrtExecutionProviderApi::EpGraphSupportInfo_GetResourceBudget, - &OrtExecutionProviderApi::EpGraphSupportInfo_GetConsumedResources, - &OrtExecutionProviderApi::EpGraphSupportInfo_ComputeNodeResourceCost, - &OrtExecutionProviderApi::EpGraphSupportInfo_ReportAcceptedNodeCost, - &OrtExecutionProviderApi::EpGraphSupportInfo_IsStopIssued, - &OrtExecutionProviderApi::EpGraphSupportInfo_SignalStopAssignment, - // End of Version 26 - DO NOT MODIFY ABOVE }; // checks that we don't violate the rule that the functions must remain in the slots they were originally assigned @@ -1441,8 +1298,6 @@ static_assert(offsetof(OrtEpApi, GetEnvConfigEntries) / sizeof(void*) == 49, "Size of version 24 API cannot change"); static_assert(offsetof(OrtEpApi, ProfilingEventsContainer_AddEvents) / sizeof(void*) == 72, "Size of version 25 API cannot change"); -static_assert(offsetof(OrtEpApi, EpGraphSupportInfo_SignalStopAssignment) / sizeof(void*) == 79, - "Size of version 26 API cannot change"); } // namespace OrtExecutionProviderApi diff --git a/onnxruntime/core/session/plugin_ep/ep_api.h b/onnxruntime/core/session/plugin_ep/ep_api.h index 9e2c3affa333f..e32e267a75ba5 100644 --- a/onnxruntime/core/session/plugin_ep/ep_api.h +++ b/onnxruntime/core/session/plugin_ep/ep_api.h @@ -179,18 +179,4 @@ ORT_API_STATUS_IMPL(ProfilingEvent_GetDurationUs, _In_ const OrtProfilingEvent* ORT_API_STATUS_IMPL(ProfilingEvent_GetArgValue, _In_ const OrtProfilingEvent* event, _In_ const char* key, _Outptr_result_maybenull_ const char** out); -// Resource accounting for capacity-aware partitioning -ORT_API_STATUS_IMPL(EpGraphSupportInfo_HasResourceBudget, _In_ const OrtEpGraphSupportInfo* graph_support_info, - _Out_ bool* has_budget); -ORT_API_STATUS_IMPL(EpGraphSupportInfo_GetResourceBudget, _In_ const OrtEpGraphSupportInfo* graph_support_info, - _Out_ OrtResourceCount* budget); -ORT_API_STATUS_IMPL(EpGraphSupportInfo_GetConsumedResources, _In_ const OrtEpGraphSupportInfo* graph_support_info, - _Out_ OrtResourceCount* consumed); -ORT_API_STATUS_IMPL(EpGraphSupportInfo_ComputeNodeResourceCost, _In_ OrtEpGraphSupportInfo* graph_support_info, - _In_ const OrtNode* node, _Out_ OrtResourceCount* cost); -ORT_API_STATUS_IMPL(EpGraphSupportInfo_ReportAcceptedNodeCost, _In_ OrtEpGraphSupportInfo* graph_support_info, - _In_ const OrtNode* node, _In_ OrtResourceCount cost); -ORT_API_STATUS_IMPL(EpGraphSupportInfo_IsStopIssued, _In_ const OrtEpGraphSupportInfo* graph_support_info, - _Out_ bool* is_stopped); -ORT_API_STATUS_IMPL(EpGraphSupportInfo_SignalStopAssignment, _In_ OrtEpGraphSupportInfo* graph_support_info); } // namespace OrtExecutionProviderApi diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc index 3d92e6cf38f86..2789fac436c7b 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc @@ -15,6 +15,7 @@ #include "core/framework/plugin_ep_stream.h" #include "core/framework/resource_accountant.h" #include "core/common/inlined_containers.h" +#include "core/common/safeint.h" #include "core/graph/ep_api_types.h" #include "core/graph/model_editor_api_types.h" #include "core/session/abi_devices.h" @@ -241,6 +242,14 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie const logging::Logger& logger = GetEpLoggerOrDefault(); + // Early exit if a previous GetCapability pass already signaled stop (e.g., budget exhausted). + // The framework calls GetCapability multiple times (e.g., after layout transformation), + // and each EP is responsible for checking the stop flag. + if (resource_accountant != nullptr && resource_accountant->IsStopIssued()) { + LOGS(logger, WARNING) << Type() << " returning due to stop already set"; + return {}; + } + std::unique_ptr ep_graph = nullptr; if (Status status = EpGraph::Create(graph_viewer, ep_graph, true); !status.IsOK()) { LOGS(logger, ERROR) << "Failed to create OrtGraph for " << Type() << ": " << status.ToString(); @@ -262,16 +271,16 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie return {}; } - // Build a mapping from OrtNode* to accepted cost for resource accounting. - // The plugin reports accepted nodes and their costs via EpGraphSupportInfo_ReportAcceptedNodeCost. - // Costs are OrtResourceCount tagged unions that are converted back to internal ResourceCount - // (std::variant) when attaching to IndexedSubGraph. - InlinedHashMap node_cost_map; - if (resource_accountant != nullptr && !api_graph_support_info.accepted_node_costs.empty()) { - node_cost_map.reserve(api_graph_support_info.accepted_node_costs.size()); - for (const auto& [ort_node, cost] : api_graph_support_info.accepted_node_costs) { - node_cost_map[ort_node] = cost; - } + // Host-side resource budget enforcement. + // The host computes costs and enforces the budget uniformly for all node grouping kinds. + // Plugin EPs only propose supported nodes; the host decides which to accept. + const bool has_budget = resource_accountant != nullptr && resource_accountant->GetThreshold().has_value(); + size_t budget_bytes = std::numeric_limits::max(); + size_t consumed_bytes = 0; + + if (has_budget) { + budget_bytes = std::get(*resource_accountant->GetThreshold()); + consumed_bytes = std::get(resource_accountant->GetConsumedAmount()); } // Create ComputeCapability instances from OrtEpGraphSupportInfo::NodeGrouping instances. @@ -289,46 +298,50 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie if (node_grouping.kind == OrtEpGraphSupportInfo::NodeGroupingKind::kSingleAssignedNode) { if (node_grouping.nodes.size() != 1) { - // The EpGraphSupportInfo_AddSingleNode() C API should already return an error if the EP tries to provide - // an invalid node. However, we check here too just in case this changes. LOGS(logger, ERROR) << "OrtEp::GetCapability() for " << Type() << " did not specify exactly one valid node " << "when calling EpGraphSupportInfo_AddSingleNode()."; return {}; } - auto indexed_sub_graph = std::make_unique(); + const Node& internal_node = node_grouping.nodes[0]->GetInternalNode(); + const NodeIndex node_index = internal_node.Index(); - const NodeIndex node_index = node_grouping.nodes[0]->GetInternalNode().Index(); - indexed_sub_graph->nodes.push_back(node_index); - - // Attach resource accounting if the plugin reported a cost for this node. + // Host-side budget enforcement for single nodes. if (resource_accountant != nullptr) { - const OrtNode* ort_node_key = static_cast(node_grouping.nodes[0]); - auto cost_it = node_cost_map.find(ort_node_key); - if (cost_it != node_cost_map.end()) { - indexed_sub_graph->SetAccountant(resource_accountant); - // Convert OrtResourceCount tagged union back to internal ResourceCount (std::variant). - const OrtResourceCount& ort_cost = cost_it->second; - switch (ort_cost.kind) { - case OrtResourceCountKind_None: - indexed_sub_graph->AppendNodeCost(ResourceCount{size_t{0}}); - break; - case OrtResourceCountKind_TotalBytes: - indexed_sub_graph->AppendNodeCost(ResourceCount{ort_cost.value.total_bytes}); - break; - default: - LOGS(logger, WARNING) << "Unknown OrtResourceCountKind: " - << static_cast(ort_cost.kind) << "; skipping cost."; - break; + ResourceCount cost = resource_accountant->ComputeResourceCount(internal_node); + size_t cost_bytes = std::get(cost); + + if (has_budget) { + size_t would_be_consumed = SafeInt(consumed_bytes) + cost_bytes; + + LOGS(logger, INFO) << Type() << " node: " << internal_node.Name() + << " (" << internal_node.OpType() << ")" + << " cost: " << cost_bytes + << " would_be_consumed: " << would_be_consumed + << " budget: " << budget_bytes; + + if (would_be_consumed > budget_bytes) { + LOGS(logger, WARNING) << Type() << " halting assignment due to budget at node: " + << internal_node.Name(); + resource_accountant->SetStopAssignment(); + break; // stop processing further groupings } + + consumed_bytes = would_be_consumed; } - } - result.push_back(std::make_unique(std::move(indexed_sub_graph))); + auto indexed_sub_graph = std::make_unique(); + indexed_sub_graph->nodes.push_back(node_index); + indexed_sub_graph->SetAccountant(resource_accountant); + indexed_sub_graph->AppendNodeCost(cost); + result.push_back(std::make_unique(std::move(indexed_sub_graph))); + } else { + auto indexed_sub_graph = std::make_unique(); + indexed_sub_graph->nodes.push_back(node_index); + result.push_back(std::make_unique(std::move(indexed_sub_graph))); + } } else if (node_grouping.kind == OrtEpGraphSupportInfo::NodeGroupingKind::kFusedNode) { if (node_grouping.nodes.empty()) { - // The EpGraphSupportInfo_AddNodesToFuse() C API should already return an error if the EP tries to provide - // an empty array of nodes from OrtEp::GetCapability(). However, we check here too just in case this changes. LOGS(logger, ERROR) << "OrtEp::GetCapability() for " << Type() << " set an empty array of nodes " << "when specifying supported nodes."; return {}; @@ -361,8 +374,7 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie return {}; } - // Log an error if the nodes in node_set do not match the nodes in capabilities[0]. We expect this to always - // be true because we've already checked that the EP did not try to claim nodes already assigned to another EP. + // Log an error if the nodes in node_set do not match the nodes in capabilities[0]. // TODO(adrianlizarraga): This check can be removed when we stop using utils::CreateSupportedPartitions() above. std::vector& capability_node_indices = capabilities[0]->sub_graph->nodes; InlinedHashSet capability_node_indices_set(capability_node_indices.begin(), @@ -374,23 +386,43 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie return {}; } - // Attach resource accounting for fused capabilities. - // Compute per-component-node costs from the accountant so that: - // - nodes_costs.size() == nodes.size() (required by IsAccountingEnabled()) - // - AccountForAllNodes() later calls CommitWeightsForNode() for each component, - // which finalizes the pending/committed weight dedup state in the accountant. + // Host-side budget enforcement for fused capabilities. + // Compute per-component-node costs from the accountant and check total against budget. if (resource_accountant != nullptr) { auto* fused_sub_graph = capabilities[0]->sub_graph.get(); fused_sub_graph->SetAccountant(resource_accountant); + size_t group_cost_bytes = 0; + InlinedVector node_costs; + node_costs.reserve(fused_sub_graph->nodes.size()); + for (NodeIndex idx : fused_sub_graph->nodes) { const Node* node = graph_viewer.GetNode(idx); - // Append a cost for every node to keep nodes_costs aligned with nodes. - // If the node can't be found (shouldn't happen), append zero cost so - // CommitWeightsForNode() is still called for the correct node index. - fused_sub_graph->AppendNodeCost( - node != nullptr ? resource_accountant->ComputeResourceCount(*node) - : ResourceCount{size_t{0}}); + ResourceCount cost = node != nullptr + ? resource_accountant->ComputeResourceCount(*node) + : ResourceCount{size_t{0}}; + group_cost_bytes = SafeInt(group_cost_bytes) + std::get(cost); + node_costs.push_back(cost); + } + + if (has_budget) { + size_t would_be_consumed = SafeInt(consumed_bytes) + group_cost_bytes; + + LOGS(logger, INFO) << Type() << " fused group cost: " << group_cost_bytes + << " would_be_consumed: " << would_be_consumed + << " budget: " << budget_bytes; + + if (would_be_consumed > budget_bytes) { + LOGS(logger, WARNING) << Type() << " halting assignment: fused group exceeds budget."; + resource_accountant->SetStopAssignment(); + break; // stop processing further groupings + } + + consumed_bytes = would_be_consumed; + } + + for (const auto& cost : node_costs) { + fused_sub_graph->AppendNodeCost(cost); } } diff --git a/onnxruntime/test/framework/resource_accountant_test.cc b/onnxruntime/test/framework/resource_accountant_test.cc index 07ba198a9c323..694a624c39a4d 100644 --- a/onnxruntime/test/framework/resource_accountant_test.cc +++ b/onnxruntime/test/framework/resource_accountant_test.cc @@ -132,8 +132,8 @@ TEST(ResourceAccountantTest, AccountForAllNodes_CorrectlyUsesPreStoredCosts) { << "AccountForAllNodes should sum pre-stored costs (3000 + 1500)"; } -// Verifies that ResetPendingWeights + re-probe produces correct results. -// After probing (which only writes to pending), resetting pending and +// Verifies that ResetForNewPass + re-probe produces correct results. +// After probing (which only writes to pending), resetting for a new pass and // re-probing should see the full weight cost again since nothing was committed. TEST(ResourceAccountantTest, ComputeAndAccountForNode_CorrectAfterReset) { auto h = SharedWeightGraph::Create(); @@ -147,7 +147,7 @@ TEST(ResourceAccountantTest, ComputeAndAccountForNode_CorrectAfterReset) { EXPECT_EQ(GetSizeT(cost_b), size_t{1500}); // Discard the pass (simulating capabilities.clear() before second GetCapability) - accountant->ResetPendingWeights(); + accountant->ResetForNewPass(); // Re-probe: weight_W was never committed, so it should be counted again IndexedSubGraph sub_graph; @@ -157,7 +157,23 @@ TEST(ResourceAccountantTest, ComputeAndAccountForNode_CorrectAfterReset) { sub_graph.AccountForNode(h.node_a->Index(), recomputed_cost); EXPECT_EQ(GetSizeT(accountant->GetConsumedAmount()), size_t{3000}) - << "After ResetPendingWeights, re-probe should see full weight cost"; + << "After ResetForNewPass, re-probe should see full weight cost"; +} + +// ResetForNewPass clears the stop flag so a second GetCapability pass +// (e.g., after layout transformation) can run from scratch. +TEST(ResourceAccountantTest, ResetForNewPass_ClearsStopFlag) { + std::optional acc_map; + auto* accountant = CreateAdHocAccountant(/*limit_kb=*/1, PathString(), acc_map); + + // Simulate first pass exhausting budget and setting stop. + accountant->SetStopAssignment(); + EXPECT_TRUE(accountant->IsStopIssued()); + + // Framework discards first-pass results and resets for second pass. + accountant->ResetForNewPass(); + EXPECT_FALSE(accountant->IsStopIssued()) + << "ResetForNewPass should clear the stop flag for the next pass"; } // Each node has a unique initializer. AccountForAllNodes sums both. @@ -269,8 +285,8 @@ TEST(ResourceAccountantTest, CrossSubGraph_DedupWorks) { accountant->CommitWeightsForNode(h.node_a->Index()); EXPECT_EQ(GetSizeT(accountant->GetConsumedAmount()), size_t{3000}); - // Reset pending to simulate new GetCapability pass - accountant->ResetPendingWeights(); + // Reset for new pass to simulate new GetCapability pass + accountant->ResetForNewPass(); // EP2 probes node_B: weight_W already committed, only output counted auto cost_b = accountant->ComputeResourceCount(*h.node_b); diff --git a/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc b/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc index 514742376c3ff..fc73ce57d97f7 100644 --- a/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc +++ b/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc @@ -4,9 +4,9 @@ // Integration tests for resource-constrained partitioning through the CUDA plugin EP. // // Two test levels: -// 1. OrtResourceCount struct tests — validate the C-safe tagged union. -// 2. Partitioning verification tests — use InferenceSessionWrapper to inspect +// 1. Partitioning verification tests — use InferenceSessionWrapper to inspect // per-node EP assignments after partitioning through the plugin EP. +// 2. E2E session tests — validate output correctness under budget constraints. #if defined(ORT_UNIT_TEST_HAS_CUDA_PLUGIN_EP) @@ -15,7 +15,6 @@ #include #include #include -#include #include #include @@ -100,48 +99,6 @@ OrtEnv& GetOrtEnv() { } // namespace -// --------------------------------------------------------------------------- -// OrtResourceCount struct tests -// --------------------------------------------------------------------------- - -TEST(OrtResourceCountTest, None_HasKindNone) { - OrtResourceCount rc = OrtResourceCount::None(); - EXPECT_EQ(rc.kind, OrtResourceCountKind_None); -} - -TEST(OrtResourceCountTest, FromTotalBytes_RoundTrips) { - constexpr uint64_t kTestValue = 42 * 1024 * 1024; // 42 MB - OrtResourceCount rc = OrtResourceCount::FromTotalBytes(kTestValue); - EXPECT_EQ(rc.kind, OrtResourceCountKind_TotalBytes); - EXPECT_EQ(rc.AsTotalBytes(), kTestValue); -} - -TEST(OrtResourceCountTest, FromTotalBytes_MaxValue) { - OrtResourceCount rc = OrtResourceCount::FromTotalBytes(std::numeric_limits::max()); - EXPECT_EQ(rc.kind, OrtResourceCountKind_TotalBytes); - EXPECT_EQ(rc.AsTotalBytes(), std::numeric_limits::max()); -} - -TEST(OrtResourceCountTest, FromTotalBytes_Zero) { - OrtResourceCount rc = OrtResourceCount::FromTotalBytes(0); - EXPECT_EQ(rc.kind, OrtResourceCountKind_TotalBytes); - EXPECT_EQ(rc.AsTotalBytes(), uint64_t{0}); -} - -TEST(OrtResourceCountTest, CopySemantics) { - OrtResourceCount original = OrtResourceCount::FromTotalBytes(12345); - OrtResourceCount copy = original; - EXPECT_EQ(copy.kind, OrtResourceCountKind_TotalBytes); - EXPECT_EQ(copy.AsTotalBytes(), uint64_t{12345}); - copy.value.total_bytes = 99999; - EXPECT_EQ(original.AsTotalBytes(), uint64_t{12345}); -} - -TEST(OrtResourceCountTest, ReservedFieldIsZero) { - OrtResourceCount rc = OrtResourceCount::FromTotalBytes(100); - EXPECT_EQ(rc.reserved_, uint32_t{0}); -} - // --------------------------------------------------------------------------- // Lower-level partitioning tests that verify per-node EP assignments // --------------------------------------------------------------------------- From 5f33bf7890157dec3b6a447d8cb3fe78d50241db Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Fri, 10 Apr 2026 17:41:23 -0700 Subject: [PATCH 10/14] Address review comments --- .../ep_plugin_provider_interfaces.cc | 16 ++-- .../framework/resource_accountant_test.cc | 79 +++++++++++-------- .../plugin/cuda_resource_partitioning_test.cc | 1 + 3 files changed, 57 insertions(+), 39 deletions(-) diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc index 2789fac436c7b..10bc6e96a1c1a 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc @@ -314,11 +314,11 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie if (has_budget) { size_t would_be_consumed = SafeInt(consumed_bytes) + cost_bytes; - LOGS(logger, INFO) << Type() << " node: " << internal_node.Name() - << " (" << internal_node.OpType() << ")" - << " cost: " << cost_bytes - << " would_be_consumed: " << would_be_consumed - << " budget: " << budget_bytes; + LOGS(logger, VERBOSE) << Type() << " node: " << internal_node.Name() + << " (" << internal_node.OpType() << ")" + << " cost: " << cost_bytes + << " would_be_consumed: " << would_be_consumed + << " budget: " << budget_bytes; if (would_be_consumed > budget_bytes) { LOGS(logger, WARNING) << Type() << " halting assignment due to budget at node: " @@ -408,9 +408,9 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie if (has_budget) { size_t would_be_consumed = SafeInt(consumed_bytes) + group_cost_bytes; - LOGS(logger, INFO) << Type() << " fused group cost: " << group_cost_bytes - << " would_be_consumed: " << would_be_consumed - << " budget: " << budget_bytes; + LOGS(logger, VERBOSE) << Type() << " fused group cost: " << group_cost_bytes + << " would_be_consumed: " << would_be_consumed + << " budget: " << budget_bytes; if (would_be_consumed > budget_bytes) { LOGS(logger, WARNING) << Type() << " halting assignment: fused group exceeds budget."; diff --git a/onnxruntime/test/framework/resource_accountant_test.cc b/onnxruntime/test/framework/resource_accountant_test.cc index 694a624c39a4d..fb9032b1c42ff 100644 --- a/onnxruntime/test/framework/resource_accountant_test.cc +++ b/onnxruntime/test/framework/resource_accountant_test.cc @@ -36,19 +36,21 @@ size_t GetSizeT(const ResourceCount& rc) { } // Helper to create a real SizeBasedStatsAccountant in ad-hoc mode (no stats file) via factory. -IResourceAccountant* CreateAdHocAccountant( +// Must be called from within a TEST body (or via ASSERT_NO_FATAL_FAILURE) because it uses ASSERT_*. +void CreateAdHocAccountant( size_t limit_kb, const std::filesystem::path& model_path, - std::optional& acc_map) { + std::optional& acc_map, + IResourceAccountant*& out) { ConfigOptions config; std::string setting = std::to_string(limit_kb) + ","; - ORT_THROW_IF_ERROR(config.AddConfigEntry( + ASSERT_STATUS_OK(config.AddConfigEntry( kOrtSessionOptionsResourceCudaPartitioningSettings, setting.c_str())); - ORT_THROW_IF_ERROR(CreateAccountants(config, model_path, acc_map)); - ORT_ENFORCE(acc_map.has_value()); + ASSERT_STATUS_OK(CreateAccountants(config, model_path, acc_map)); + ASSERT_TRUE(acc_map.has_value()); auto it = acc_map->find(kCudaExecutionProvider); - ORT_ENFORCE(it != acc_map->end()); - return it->second.get(); + ASSERT_NE(it, acc_map->end()); + out = it->second.get(); } } // namespace @@ -60,8 +62,7 @@ struct SharedWeightGraph { Node* node_a = nullptr; Node* node_b = nullptr; - static SharedWeightGraph Create() { - SharedWeightGraph h; + static void Create(SharedWeightGraph& h) { std::unordered_map dom; dom[kOnnxDomain] = 12; h.model = std::make_unique( @@ -94,9 +95,7 @@ struct SharedWeightGraph { h.node_a = &h.graph->AddNode("node_A", "Add", "A", {ia, wa}, {oa}); h.node_b = &h.graph->AddNode("node_B", "Add", "B", {ib, wa}, {ob}); - auto status = h.graph->Resolve(); - ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); - return h; + ASSERT_STATUS_OK(h.graph->Resolve()); } }; @@ -108,9 +107,11 @@ struct SharedWeightGraph { // AccountForAllNodes sums pre-stored per-node costs // that already have correct within-pass weight deduplication. TEST(ResourceAccountantTest, AccountForAllNodes_CorrectlyUsesPreStoredCosts) { - auto h = SharedWeightGraph::Create(); + SharedWeightGraph h; + ASSERT_NO_FATAL_FAILURE(SharedWeightGraph::Create(h)); std::optional acc_map; - auto* accountant = CreateAdHocAccountant(/*limit_kb=*/100, PathString(), acc_map); + IResourceAccountant* accountant = nullptr; + ASSERT_NO_FATAL_FAILURE(CreateAdHocAccountant(/*limit_kb=*/100, PathString(), acc_map, accountant)); IndexedSubGraph sub_graph; sub_graph.nodes.push_back(h.node_a->Index()); @@ -136,9 +137,11 @@ TEST(ResourceAccountantTest, AccountForAllNodes_CorrectlyUsesPreStoredCosts) { // After probing (which only writes to pending), resetting for a new pass and // re-probing should see the full weight cost again since nothing was committed. TEST(ResourceAccountantTest, ComputeAndAccountForNode_CorrectAfterReset) { - auto h = SharedWeightGraph::Create(); + SharedWeightGraph h; + ASSERT_NO_FATAL_FAILURE(SharedWeightGraph::Create(h)); std::optional acc_map; - auto* accountant = CreateAdHocAccountant(/*limit_kb=*/100, PathString(), acc_map); + IResourceAccountant* accountant = nullptr; + ASSERT_NO_FATAL_FAILURE(CreateAdHocAccountant(/*limit_kb=*/100, PathString(), acc_map, accountant)); // Probing pass populates pending weights auto cost_a = accountant->ComputeResourceCount(*h.node_a); @@ -164,7 +167,8 @@ TEST(ResourceAccountantTest, ComputeAndAccountForNode_CorrectAfterReset) { // (e.g., after layout transformation) can run from scratch. TEST(ResourceAccountantTest, ResetForNewPass_ClearsStopFlag) { std::optional acc_map; - auto* accountant = CreateAdHocAccountant(/*limit_kb=*/1, PathString(), acc_map); + IResourceAccountant* accountant = nullptr; + ASSERT_NO_FATAL_FAILURE(CreateAdHocAccountant(/*limit_kb=*/1, PathString(), acc_map, accountant)); // Simulate first pass exhausting budget and setting stop. accountant->SetStopAssignment(); @@ -217,7 +221,8 @@ TEST(ResourceAccountantTest, AccountForAllNodes_NoSharedWeights) { ASSERT_STATUS_OK(graph.Resolve()); std::optional acc_map; - auto* accountant = CreateAdHocAccountant(/*limit_kb=*/100, PathString(), acc_map); + IResourceAccountant* accountant = nullptr; + ASSERT_NO_FATAL_FAILURE(CreateAdHocAccountant(/*limit_kb=*/100, PathString(), acc_map, accountant)); IndexedSubGraph sub_graph; sub_graph.nodes.push_back(node1.Index()); @@ -236,11 +241,13 @@ TEST(ResourceAccountantTest, AccountForAllNodes_NoSharedWeights) { // AccountForNode per-node and AccountForAllNodes bulk produce same result. TEST(ResourceAccountantTest, AccountForNode_MatchesAccountForAllNodes) { - auto h = SharedWeightGraph::Create(); + SharedWeightGraph h; + ASSERT_NO_FATAL_FAILURE(SharedWeightGraph::Create(h)); // Per-node path std::optional acc_map1; - auto* acc1 = CreateAdHocAccountant(/*limit_kb=*/100, PathString(), acc_map1); + IResourceAccountant* acc1 = nullptr; + ASSERT_NO_FATAL_FAILURE(CreateAdHocAccountant(/*limit_kb=*/100, PathString(), acc_map1, acc1)); IndexedSubGraph sub1; sub1.nodes.push_back(h.node_a->Index()); sub1.nodes.push_back(h.node_b->Index()); @@ -253,7 +260,8 @@ TEST(ResourceAccountantTest, AccountForNode_MatchesAccountForAllNodes) { // Bulk path std::optional acc_map2; - auto* acc2 = CreateAdHocAccountant(/*limit_kb=*/100, PathString(), acc_map2); + IResourceAccountant* acc2 = nullptr; + ASSERT_NO_FATAL_FAILURE(CreateAdHocAccountant(/*limit_kb=*/100, PathString(), acc_map2, acc2)); IndexedSubGraph sub2; sub2.nodes.push_back(h.node_a->Index()); sub2.nodes.push_back(h.node_b->Index()); @@ -272,9 +280,11 @@ TEST(ResourceAccountantTest, AccountForNode_MatchesAccountForAllNodes) { // correctly sees weight_W as already accounted. // node_A cost: 3000, node_B cost after commit: (0 + 1000) * 1.5 = 1500 TEST(ResourceAccountantTest, CrossSubGraph_DedupWorks) { - auto h = SharedWeightGraph::Create(); + SharedWeightGraph h; + ASSERT_NO_FATAL_FAILURE(SharedWeightGraph::Create(h)); std::optional acc_map; - auto* accountant = CreateAdHocAccountant(/*limit_kb=*/100, PathString(), acc_map); + IResourceAccountant* accountant = nullptr; + ASSERT_NO_FATAL_FAILURE(CreateAdHocAccountant(/*limit_kb=*/100, PathString(), acc_map, accountant)); // EP1 probes and commits node_A IndexedSubGraph sub1; @@ -308,9 +318,19 @@ TEST(ResourceAccountantTest, CrossSubGraph_DedupWorks) { // Stats-based path and factory tests // --------------------------------------------------------------------------- +// RAII helper to remove a temp file on scope exit, even if a test assertion fails. +struct ScopedFileRemover { + std::filesystem::path path; + ~ScopedFileRemover() { + std::error_code ec; + std::filesystem::remove(path, ec); + } +}; + // Stats-based path: cost is sum of all NodeAllocationStats fields. TEST(RealAccountantTest, StatsPath_ComputesCostFromStatsFile) { - auto h = SharedWeightGraph::Create(); + SharedWeightGraph h; + ASSERT_NO_FATAL_FAILURE(SharedWeightGraph::Create(h)); // Write a stats file with known costs (unique per PID to avoid parallel collisions) std::error_code ec; @@ -319,6 +339,7 @@ TEST(RealAccountantTest, StatsPath_ComputesCostFromStatsFile) { std::ostringstream fname; fname << "test_resource_accountant_stats_" << ORT_TEST_PID << ".csv"; auto stats_path = stats_dir / fname.str(); + ScopedFileRemover stats_cleanup{stats_path}; // Get the unique node names the accountant will look up std::string name_a = IResourceAccountant::MakeUniqueNodeName(*h.node_a); @@ -355,14 +376,12 @@ TEST(RealAccountantTest, StatsPath_ComputesCostFromStatsFile) { auto threshold = accountant->GetThreshold(); ASSERT_TRUE(threshold.has_value()); EXPECT_EQ(std::get(*threshold), size_t{500 * 1024}); - - std::error_code remove_ec; - std::filesystem::remove(stats_path, remove_ec); } // Stats-based path returns 0 for unknown nodes. TEST(RealAccountantTest, StatsPath_UnknownNodeReturnsZero) { - auto h = SharedWeightGraph::Create(); + SharedWeightGraph h; + ASSERT_NO_FATAL_FAILURE(SharedWeightGraph::Create(h)); std::error_code ec; auto stats_dir = std::filesystem::temp_directory_path(ec); @@ -370,6 +389,7 @@ TEST(RealAccountantTest, StatsPath_UnknownNodeReturnsZero) { std::ostringstream fname; fname << "test_resource_accountant_empty_stats_" << ORT_TEST_PID << ".csv"; auto stats_path = stats_dir / fname.str(); + ScopedFileRemover stats_cleanup{stats_path}; { std::ofstream ofs(stats_path); @@ -389,9 +409,6 @@ TEST(RealAccountantTest, StatsPath_UnknownNodeReturnsZero) { auto cost = accountant->ComputeResourceCount(*h.node_a); EXPECT_EQ(std::get(cost), size_t{0}); - - std::error_code remove_ec; - std::filesystem::remove(stats_path, remove_ec); } // Factory with no limit and no stats file creates accountant with no threshold. diff --git a/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc b/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc index fc73ce57d97f7..da64966a0f664 100644 --- a/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc +++ b/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc @@ -16,6 +16,7 @@ #include #include #include +#include #include #include From a8967a5ca1fde95ef9bcf9ab5a24c358a6f6db2b Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Mon, 13 Apr 2026 10:15:57 -0700 Subject: [PATCH 11/14] Address double accounting bug for already assigned nodes --- .../plugin_ep/ep_plugin_provider_interfaces.cc | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc index 10bc6e96a1c1a..75bd5bcd4e3f2 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc @@ -306,8 +306,15 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie const Node& internal_node = node_grouping.nodes[0]->GetInternalNode(); const NodeIndex node_index = internal_node.Index(); + // Node already assigned from a previous pass (e.g., before layout transformation + // or after function inlining). Its cost was already committed — skip re-computation to avoid + // double-charging the output-size component. + // FindFirstNodeAssignedToOtherEP already filtered out nodes assigned to a different EP, + // so a non-empty EP type here means it was assigned to this EP. + const bool previously_assigned = !internal_node.GetExecutionProviderType().empty(); + // Host-side budget enforcement for single nodes. - if (resource_accountant != nullptr) { + if (resource_accountant != nullptr && !previously_assigned) { ResourceCount cost = resource_accountant->ComputeResourceCount(internal_node); size_t cost_bytes = std::get(cost); @@ -388,6 +395,8 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie // Host-side budget enforcement for fused capabilities. // Compute per-component-node costs from the accountant and check total against budget. + // Skip cost computation for nodes already assigned to this EP from a previous pass + // to avoid double-charging the output-size component. if (resource_accountant != nullptr) { auto* fused_sub_graph = capabilities[0]->sub_graph.get(); fused_sub_graph->SetAccountant(resource_accountant); @@ -398,7 +407,9 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie for (NodeIndex idx : fused_sub_graph->nodes) { const Node* node = graph_viewer.GetNode(idx); - ResourceCount cost = node != nullptr + const bool node_already_assigned = + node != nullptr && !node->GetExecutionProviderType().empty(); + ResourceCount cost = (node != nullptr && !node_already_assigned) ? resource_accountant->ComputeResourceCount(*node) : ResourceCount{size_t{0}}; group_cost_bytes = SafeInt(group_cost_bytes) + std::get(cost); From 1509c6cf4771d344ba507411ece01a5c64b83383 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Mon, 13 Apr 2026 10:21:08 -0700 Subject: [PATCH 12/14] Address other review comments --- .../core/framework/graph_partitioner.cc | 10 +++- .../ep_plugin_provider_interfaces.cc | 48 ++++++++----------- .../plugin/cuda_resource_partitioning_test.cc | 7 ++- 3 files changed, 36 insertions(+), 29 deletions(-) diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index a3def4a81bb50..2bebab9862e7c 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -1189,7 +1189,15 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, for (const auto& ep : execution_providers) { IResourceAccountant* resource_accountant = nullptr; if (acc_map.has_value()) { - auto hit = acc_map->find(ep->Type()); + // Plugin EPs have a different Type() than the in-tree EP they replace + // (e.g., kCudaPluginExecutionProvider vs kCudaExecutionProvider), but the + // accountant is registered under the in-tree EP name. Translate the key + // so plugin EPs find the correct accountant. + const auto& ep_type = ep->Type(); + const auto accountant_key = ep_type == kCudaPluginExecutionProvider + ? std::string{kCudaExecutionProvider} + : ep_type; + auto hit = acc_map->find(accountant_key); if (hit != acc_map->end()) { resource_accountant = hit->second.get(); } diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc index 75bd5bcd4e3f2..b02e1e4d3fa54 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc @@ -313,40 +313,34 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie // so a non-empty EP type here means it was assigned to this EP. const bool previously_assigned = !internal_node.GetExecutionProviderType().empty(); + auto indexed_sub_graph = std::make_unique(); + indexed_sub_graph->nodes.push_back(node_index); + // Host-side budget enforcement for single nodes. if (resource_accountant != nullptr && !previously_assigned) { ResourceCount cost = resource_accountant->ComputeResourceCount(internal_node); size_t cost_bytes = std::get(cost); - - if (has_budget) { - size_t would_be_consumed = SafeInt(consumed_bytes) + cost_bytes; - - LOGS(logger, VERBOSE) << Type() << " node: " << internal_node.Name() - << " (" << internal_node.OpType() << ")" - << " cost: " << cost_bytes - << " would_be_consumed: " << would_be_consumed - << " budget: " << budget_bytes; - - if (would_be_consumed > budget_bytes) { - LOGS(logger, WARNING) << Type() << " halting assignment due to budget at node: " - << internal_node.Name(); - resource_accountant->SetStopAssignment(); - break; // stop processing further groupings - } - - consumed_bytes = would_be_consumed; + size_t would_be_consumed = SafeInt(consumed_bytes) + cost_bytes; + + LOGS(logger, VERBOSE) << Type() << " node: " << internal_node.Name() + << " (" << internal_node.OpType() << ")" + << " cost: " << cost_bytes + << " would_be_consumed: " << would_be_consumed + << " budget: " << budget_bytes; + + if (has_budget && would_be_consumed > budget_bytes) { + LOGS(logger, WARNING) << Type() << " halting assignment due to budget at node: " + << internal_node.Name(); + resource_accountant->SetStopAssignment(); + break; // stop processing further groupings } - auto indexed_sub_graph = std::make_unique(); - indexed_sub_graph->nodes.push_back(node_index); + consumed_bytes = would_be_consumed; indexed_sub_graph->SetAccountant(resource_accountant); indexed_sub_graph->AppendNodeCost(cost); - result.push_back(std::make_unique(std::move(indexed_sub_graph))); - } else { - auto indexed_sub_graph = std::make_unique(); - indexed_sub_graph->nodes.push_back(node_index); - result.push_back(std::make_unique(std::move(indexed_sub_graph))); } + + result.push_back(std::make_unique(std::move(indexed_sub_graph))); } else if (node_grouping.kind == OrtEpGraphSupportInfo::NodeGroupingKind::kFusedNode) { if (node_grouping.nodes.empty()) { LOGS(logger, ERROR) << "OrtEp::GetCapability() for " << Type() << " set an empty array of nodes " @@ -428,10 +422,10 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie resource_accountant->SetStopAssignment(); break; // stop processing further groupings } - - consumed_bytes = would_be_consumed; } + consumed_bytes = SafeInt(consumed_bytes) + group_cost_bytes; + for (const auto& cost : node_costs) { fused_sub_graph->AppendNodeCost(cost); } diff --git a/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc b/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc index da64966a0f664..b903bc40e6c84 100644 --- a/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc +++ b/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc @@ -22,6 +22,7 @@ #include #include +#include "core/graph/constants.h" #include "core/graph/model.h" #include "core/session/abi_session_options_impl.h" #include "core/session/inference_session.h" @@ -215,14 +216,18 @@ TEST_F(CudaPluginPartitioningTest, TinyBudget_NodesOffloadedToCpu) { // node with initializers or known output shapes, so nodes must be offloaded. LoadAndVerifyPartitioning(model_path, /*budget_kb=*/1, [](const Graph& graph) { bool has_cpu_node = false; + bool has_plugin_node = false; for (const auto& node : graph.Nodes()) { if (node.GetExecutionProviderType() == kCpuExecutionProvider) { has_cpu_node = true; - break; + } else if (node.GetExecutionProviderType() == kCudaPluginExecutionProvider) { + has_plugin_node = true; } } EXPECT_TRUE(has_cpu_node) << "With a 1 KB budget, at least some nodes should be offloaded to CPU"; + EXPECT_TRUE(has_plugin_node) + << "Budget enforcement should be partial, not all-or-nothing CPU fallback"; }); } From a5b1b09df371d30675319a7e4a73eef35f1fd633 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Mon, 13 Apr 2026 13:36:46 -0700 Subject: [PATCH 13/14] Address review comments --- .../core/framework/resource_accountant.h | 12 ++ .../core/framework/resource_accountant.cc | 32 ++++ .../ep_plugin_provider_interfaces.cc | 45 +++--- .../plugin/cuda_resource_partitioning_test.cc | 148 +++++++++++++++--- 4 files changed, 194 insertions(+), 43 deletions(-) diff --git a/include/onnxruntime/core/framework/resource_accountant.h b/include/onnxruntime/core/framework/resource_accountant.h index 0e89082b0ec6a..d28c5e45c99a2 100644 --- a/include/onnxruntime/core/framework/resource_accountant.h +++ b/include/onnxruntime/core/framework/resource_accountant.h @@ -26,6 +26,18 @@ struct Node; // for different EPs using ResourceCount = std::variant; +// Type-erased arithmetic for ResourceCount values. +// Implementations use std::visit so the compiler enforces exhaustive handling +// of all variant members — adding a new type to ResourceCount will produce +// build errors at each call site that must be addressed. +// +// NOTE: These functions are NOT available through the provider bridge (shared library EPs). +// Budget enforcement for bridge-based EPs (e.g., in-tree CUDA EP) will be moved to the +// graph partitioner in a follow-up PR. +ResourceCount AddResourceCounts(const ResourceCount& a, const ResourceCount& b); +bool ResourceCountExceeds(const ResourceCount& a, const ResourceCount& b); +std::string FormatResourceCount(const ResourceCount& rc); + /// /// This class is used for graph partitioning by EPs /// It stores the cumulative amount of the resource such as diff --git a/onnxruntime/core/framework/resource_accountant.cc b/onnxruntime/core/framework/resource_accountant.cc index bd450bb4112ef..019bbdd8611be 100644 --- a/onnxruntime/core/framework/resource_accountant.cc +++ b/onnxruntime/core/framework/resource_accountant.cc @@ -317,4 +317,36 @@ std::string IResourceAccountant::MakeUniqueNodeName(const Node& node) { return result; } +ResourceCount AddResourceCounts(const ResourceCount& a, const ResourceCount& b) { + return std::visit( + [](auto lhs, auto rhs) -> ResourceCount { + static_assert(std::is_same_v, + "AddResourceCounts requires both operands to hold the same type. " + "Handle the new ResourceCount variant member."); + if constexpr (std::is_integral_v) { + return static_cast(SafeInt(lhs) + rhs); + } else { + return lhs + rhs; + } + }, + a, b); +} + +bool ResourceCountExceeds(const ResourceCount& a, const ResourceCount& b) { + return std::visit( + [](auto lhs, auto rhs) -> bool { + static_assert(std::is_same_v, + "ResourceCountExceeds requires both operands to hold the same type. " + "Handle the new ResourceCount variant member."); + return lhs > rhs; + }, + a, b); +} + +std::string FormatResourceCount(const ResourceCount& rc) { + return std::visit( + [](auto val) -> std::string { return std::to_string(val); }, + rc); +} + } // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc index b02e1e4d3fa54..f50f87cb4100d 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc @@ -15,7 +15,6 @@ #include "core/framework/plugin_ep_stream.h" #include "core/framework/resource_accountant.h" #include "core/common/inlined_containers.h" -#include "core/common/safeint.h" #include "core/graph/ep_api_types.h" #include "core/graph/model_editor_api_types.h" #include "core/session/abi_devices.h" @@ -275,13 +274,10 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie // The host computes costs and enforces the budget uniformly for all node grouping kinds. // Plugin EPs only propose supported nodes; the host decides which to accept. const bool has_budget = resource_accountant != nullptr && resource_accountant->GetThreshold().has_value(); - size_t budget_bytes = std::numeric_limits::max(); - size_t consumed_bytes = 0; - - if (has_budget) { - budget_bytes = std::get(*resource_accountant->GetThreshold()); - consumed_bytes = std::get(resource_accountant->GetConsumedAmount()); - } + ResourceCount consumed = resource_accountant != nullptr + ? resource_accountant->GetConsumedAmount() + : ResourceCount{}; + ResourceCount budget = has_budget ? *resource_accountant->GetThreshold() : ResourceCount{}; // Create ComputeCapability instances from OrtEpGraphSupportInfo::NodeGrouping instances. for (const OrtEpGraphSupportInfo::NodeGrouping& node_grouping : api_graph_support_info.node_groupings) { @@ -319,23 +315,22 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie // Host-side budget enforcement for single nodes. if (resource_accountant != nullptr && !previously_assigned) { ResourceCount cost = resource_accountant->ComputeResourceCount(internal_node); - size_t cost_bytes = std::get(cost); - size_t would_be_consumed = SafeInt(consumed_bytes) + cost_bytes; + ResourceCount would_be_consumed = AddResourceCounts(consumed, cost); LOGS(logger, VERBOSE) << Type() << " node: " << internal_node.Name() << " (" << internal_node.OpType() << ")" - << " cost: " << cost_bytes - << " would_be_consumed: " << would_be_consumed - << " budget: " << budget_bytes; + << " cost: " << FormatResourceCount(cost) + << " would_be_consumed: " << FormatResourceCount(would_be_consumed) + << " budget: " << FormatResourceCount(budget); - if (has_budget && would_be_consumed > budget_bytes) { + if (has_budget && ResourceCountExceeds(would_be_consumed, budget)) { LOGS(logger, WARNING) << Type() << " halting assignment due to budget at node: " << internal_node.Name(); resource_accountant->SetStopAssignment(); break; // stop processing further groupings } - consumed_bytes = would_be_consumed; + consumed = would_be_consumed; indexed_sub_graph->SetAccountant(resource_accountant); indexed_sub_graph->AppendNodeCost(cost); } @@ -395,7 +390,7 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie auto* fused_sub_graph = capabilities[0]->sub_graph.get(); fused_sub_graph->SetAccountant(resource_accountant); - size_t group_cost_bytes = 0; + ResourceCount group_cost{}; InlinedVector node_costs; node_costs.reserve(fused_sub_graph->nodes.size()); @@ -405,26 +400,26 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie node != nullptr && !node->GetExecutionProviderType().empty(); ResourceCount cost = (node != nullptr && !node_already_assigned) ? resource_accountant->ComputeResourceCount(*node) - : ResourceCount{size_t{0}}; - group_cost_bytes = SafeInt(group_cost_bytes) + std::get(cost); + : ResourceCount{}; + group_cost = AddResourceCounts(group_cost, cost); node_costs.push_back(cost); } - if (has_budget) { - size_t would_be_consumed = SafeInt(consumed_bytes) + group_cost_bytes; + ResourceCount would_be_consumed = AddResourceCounts(consumed, group_cost); - LOGS(logger, VERBOSE) << Type() << " fused group cost: " << group_cost_bytes - << " would_be_consumed: " << would_be_consumed - << " budget: " << budget_bytes; + if (has_budget) { + LOGS(logger, VERBOSE) << Type() << " fused group cost: " << FormatResourceCount(group_cost) + << " would_be_consumed: " << FormatResourceCount(would_be_consumed) + << " budget: " << FormatResourceCount(budget); - if (would_be_consumed > budget_bytes) { + if (ResourceCountExceeds(would_be_consumed, budget)) { LOGS(logger, WARNING) << Type() << " halting assignment: fused group exceeds budget."; resource_accountant->SetStopAssignment(); break; // stop processing further groupings } } - consumed_bytes = SafeInt(consumed_bytes) + group_cost_bytes; + consumed = would_be_consumed; for (const auto& cost : node_costs) { fused_sub_graph->AppendNodeCost(cost); diff --git a/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc b/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc index b903bc40e6c84..dce0c09787c55 100644 --- a/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc +++ b/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc @@ -93,6 +93,75 @@ Ort::ConstEpDevice FindCudaPluginDevice(Ort::Env& env) { return Ort::ConstEpDevice{nullptr}; } +// Build a serialized ONNX model with a chain of Add nodes. +// Each node adds its own initializer (of `weight_elements` floats) to the +// previous node's output, producing a linear graph: +// input -> Add(w0) -> Add(w1) -> ... -> Add(wN-1) -> output +// The initializer size directly controls what the ad-hoc resource accountant +// computes per node, giving us precise budget targeting. +std::string BuildAddChainModel(size_t num_nodes, int64_t weight_elements) { + ONNX_NAMESPACE::ModelProto model; + model.set_ir_version(ONNX_NAMESPACE::IR_VERSION); + auto* opset = model.add_opset_import(); + opset->set_domain(""); + opset->set_version(13); + + auto* graph = model.mutable_graph(); + graph->set_name("add_chain"); + + // Shared shape for all tensors. + auto set_type_shape = [weight_elements](ONNX_NAMESPACE::TypeProto* tp) { + auto* tensor_type = tp->mutable_tensor_type(); + tensor_type->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + tensor_type->mutable_shape()->add_dim()->set_dim_value(weight_elements); + }; + + // Graph input. + auto* graph_input = graph->add_input(); + graph_input->set_name("input"); + set_type_shape(graph_input->mutable_type()); + + std::string prev_output = "input"; + for (size_t i = 0; i < num_nodes; ++i) { + std::string weight_name = "w_" + std::to_string(i); + std::string output_name = (i + 1 < num_nodes) + ? "t_" + std::to_string(i) + : "output"; + + // Initializer with known byte size = weight_elements * sizeof(float). + auto* init = graph->add_initializer(); + init->set_name(weight_name); + init->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + init->add_dims(weight_elements); + // Use raw_data for compactness — zeros are fine. + init->set_raw_data(std::string(weight_elements * sizeof(float), '\0')); + + // Weight input value_info (needed for valid graph). + auto* w_input = graph->add_input(); + w_input->set_name(weight_name); + set_type_shape(w_input->mutable_type()); + + // Add node. + auto* node = graph->add_node(); + node->set_op_type("Add"); + node->set_name("add_" + std::to_string(i)); + node->add_input(prev_output); + node->add_input(weight_name); + node->add_output(output_name); + + prev_output = output_name; + } + + // Graph output. + auto* graph_output = graph->add_output(); + graph_output->set_name("output"); + set_type_shape(graph_output->mutable_type()); + + std::string serialized; + model.SerializeToString(&serialized); + return serialized; +} + // Get the internal OrtEnv* from the C++ Ort::Env wrapper. // Ort::Env inherits Base which has operator OrtEnv*(). OrtEnv& GetOrtEnv() { @@ -172,6 +241,36 @@ class CudaPluginPartitioningTest : public ::testing::Test { verifier(session.GetGraph()); } + // Overload that loads a model from serialized bytes (e.g., from BuildAddChainModel). + void LoadAndVerifyPartitioning(const std::string& model_bytes, + size_t budget_kb, + const std::function& verifier) { + OrtSessionOptions ort_options; + + const OrtEpDevice* device_ptr = static_cast(cuda_device_); + auto ep_devices_span = gsl::make_span(&device_ptr, 1); + + std::unique_ptr factory; + ASSERT_STATUS_OK(CreateIExecutionProviderFactoryForEpDevices( + GetOrtEnv().GetEnvironment(), ep_devices_span, factory)); + + ort_options.provider_factories.push_back(std::move(factory)); + + if (budget_kb > 0) { + std::string config_value = std::to_string(budget_kb) + ","; + ASSERT_STATUS_OK(ort_options.value.config_options.AddConfigEntry( + kOrtSessionOptionsResourceCudaPartitioningSettings, config_value.c_str())); + } + + InferenceSessionWrapper session(ort_options.value, GetOrtEnv().GetEnvironment()); + ASSERT_STATUS_OK(session.Load(model_bytes.data(), static_cast(model_bytes.size()))); + + OrtStatus* status = InitializeSession(&ort_options, session); + ASSERT_STATUS_OK(ToStatusAndRelease(status)); + + verifier(session.GetGraph()); + } + std::unique_ptr registration_; Ort::ConstEpDevice cuda_device_{nullptr}; }; @@ -206,29 +305,42 @@ TEST_F(CudaPluginPartitioningTest, LargeBudget_AllNodesCudaPlugin) { }); } -// With a tiny budget (1 KB), nodes should be offloaded to CPU because -// the resource accountant will run out of budget. +// With a small budget, the resource accountant should assign fewer nodes +// to the plugin EP than the no-budget baseline. TEST_F(CudaPluginPartitioningTest, TinyBudget_NodesOffloadedToCpu) { - // Use a model with multiple nodes so we can see some go to CPU. - constexpr const ORTCHAR_T* model_path = ORT_TSTR("testdata/transformers/tiny_gpt2_beamsearch.onnx"); - - // 1 KB budget — ad-hoc accountant will compute non-zero cost for any - // node with initializers or known output shapes, so nodes must be offloaded. - LoadAndVerifyPartitioning(model_path, /*budget_kb=*/1, [](const Graph& graph) { - bool has_cpu_node = false; - bool has_plugin_node = false; + // Build a chain of 6 Add nodes, each with a 256-element float initializer + // (1 KB per weight). The ad-hoc accountant adds weight + output sizes with + // a 1.5x multiplier, so each node costs roughly 1.5 * (1 KB + 1 KB) = 3 KB. + // A 10 KB budget should accept ~3 nodes before halting. + const std::string model = BuildAddChainModel(/*num_nodes=*/6, /*weight_elements=*/256); + + // Baseline: count plugin nodes with no budget. + size_t baseline_plugin_count = 0; + LoadAndVerifyPartitioning(model, /*budget_kb=*/0, [&](const Graph& graph) { for (const auto& node : graph.Nodes()) { - if (node.GetExecutionProviderType() == kCpuExecutionProvider) { - has_cpu_node = true; - } else if (node.GetExecutionProviderType() == kCudaPluginExecutionProvider) { - has_plugin_node = true; + if (node.GetExecutionProviderType() == kCudaPluginExecutionProvider) { + ++baseline_plugin_count; } } - EXPECT_TRUE(has_cpu_node) - << "With a 1 KB budget, at least some nodes should be offloaded to CPU"; - EXPECT_TRUE(has_plugin_node) - << "Budget enforcement should be partial, not all-or-nothing CPU fallback"; }); + ASSERT_GT(baseline_plugin_count, size_t{1}) + << "Baseline must have multiple plugin nodes for the test to be meaningful"; + + // Now run with a 10 KB budget — should accept some but not all nodes. + size_t constrained_plugin_count = 0; + LoadAndVerifyPartitioning(model, /*budget_kb=*/10, [&](const Graph& graph) { + for (const auto& node : graph.Nodes()) { + if (node.GetExecutionProviderType() == kCudaPluginExecutionProvider) { + ++constrained_plugin_count; + } + } + }); + + EXPECT_GT(constrained_plugin_count, size_t{0}) + << "Budget should be large enough to accept at least one node"; + EXPECT_LT(constrained_plugin_count, baseline_plugin_count) + << "A 10 KB budget should reduce plugin EP node count from the no-budget baseline (" + << baseline_plugin_count << " nodes)"; } // --------------------------------------------------------------------------- From 25bf8015eea7d7dd99b60c5e818b2dbb3ab566c1 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Mon, 13 Apr 2026 13:57:59 -0700 Subject: [PATCH 14/14] Address review comments --- onnxruntime/test/framework/resource_accountant_test.cc | 5 ++++- .../cuda/plugin/cuda_resource_partitioning_test.cc | 10 +++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/onnxruntime/test/framework/resource_accountant_test.cc b/onnxruntime/test/framework/resource_accountant_test.cc index fb9032b1c42ff..fe5ec0b039200 100644 --- a/onnxruntime/test/framework/resource_accountant_test.cc +++ b/onnxruntime/test/framework/resource_accountant_test.cc @@ -31,8 +31,11 @@ namespace test { namespace { // Helper to extract size_t from ResourceCount variant. +// Uses std::get_if so test failures produce a clear assertion rather than std::bad_variant_access. size_t GetSizeT(const ResourceCount& rc) { - return std::get(rc); + const auto* value = std::get_if(&rc); + EXPECT_NE(value, nullptr) << "ResourceCount does not hold size_t"; + return value != nullptr ? *value : 0; } // Helper to create a real SizeBasedStatsAccountant in ad-hoc mode (no stats file) via factory. diff --git a/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc b/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc index dce0c09787c55..23bef914eaf15 100644 --- a/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc +++ b/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc @@ -86,7 +86,7 @@ class ScopedCudaPluginRegistration { Ort::ConstEpDevice FindCudaPluginDevice(Ort::Env& env) { auto ep_devices = env.GetEpDevices(); for (const auto& device : ep_devices) { - if (strcmp(device.EpName(), "CudaPluginExecutionProvider") == 0) { + if (strcmp(device.EpName(), kCudaPluginExecutionProvider) == 0) { return device; } } @@ -197,7 +197,9 @@ class CudaPluginPartitioningTest : public ::testing::Test { void TearDown() override { registration_.reset(); - cudaDeviceSynchronize(); + if (cuda_device_) { + cudaDeviceSynchronize(); + } } // Load a model through the CUDA plugin EP with the given resource budget, @@ -370,7 +372,9 @@ class CudaResourcePartitioningTest : public ::testing::Test { void TearDown() override { registration_.reset(); - cudaDeviceSynchronize(); + if (cuda_device_) { + cudaDeviceSynchronize(); + } } Ort::Session CreateSessionWithBudget(const ORTCHAR_T* model_path,