diff --git a/include/onnxruntime/core/framework/resource_accountant.h b/include/onnxruntime/core/framework/resource_accountant.h index 7bb5a993d140b..d28c5e45c99a2 100644 --- a/include/onnxruntime/core/framework/resource_accountant.h +++ b/include/onnxruntime/core/framework/resource_accountant.h @@ -26,6 +26,18 @@ struct Node; // for different EPs using ResourceCount = std::variant; +// Type-erased arithmetic for ResourceCount values. +// Implementations use std::visit so the compiler enforces exhaustive handling +// of all variant members — adding a new type to ResourceCount will produce +// build errors at each call site that must be addressed. +// +// NOTE: These functions are NOT available through the provider bridge (shared library EPs). +// Budget enforcement for bridge-based EPs (e.g., in-tree CUDA EP) will be moved to the +// graph partitioner in a follow-up PR. +ResourceCount AddResourceCounts(const ResourceCount& a, const ResourceCount& b); +bool ResourceCountExceeds(const ResourceCount& a, const ResourceCount& b); +std::string FormatResourceCount(const ResourceCount& rc); + /// /// This class is used for graph partitioning by EPs /// It stores the cumulative amount of the resource such as @@ -61,9 +73,14 @@ class IResourceAccountant { bool IsStopIssued() const noexcept { return stop_assignment_; } - // Called before each GetCapability pass to discard pending weight tracking - // from a previous (discarded) pass. Default no-op for stats-based accountants. - virtual void ResetPendingWeights() {} + // Called before each GetCapability pass to reset per-pass state: + // clears the stop flag (which only applies to the pass that set it) + // and discards pending weight tracking from a previous (discarded) pass. + // Subclasses override ResetPendingWeightsImpl for EP-specific cleanup. + void ResetForNewPass() { + stop_assignment_ = false; + ResetPendingWeightsImpl(); + } // Called when a node's cost is committed (AccountForNode/AccountForAllNodes). // Moves the node's pending weights into the committed set so they persist @@ -72,6 +89,11 @@ class IResourceAccountant { static std::string MakeUniqueNodeName(const Node& node); + protected: + // Override to discard EP-specific pending weight tracking. + // Default no-op for stats-based accountants. + virtual void ResetPendingWeightsImpl() {} + private: bool stop_assignment_ = false; std::optional threshold_; diff --git a/include/onnxruntime/core/graph/constants.h b/include/onnxruntime/core/graph/constants.h index 4ed16ffd27264..2f2462dfa92a9 100644 --- a/include/onnxruntime/core/graph/constants.h +++ b/include/onnxruntime/core/graph/constants.h @@ -31,6 +31,7 @@ constexpr size_t kMaxExecutionProviderNameLen = 30; constexpr const char* kCpuExecutionProvider = "CPUExecutionProvider"; constexpr const char* kCudaExecutionProvider = "CUDAExecutionProvider"; +constexpr const char* kCudaPluginExecutionProvider = "CudaPluginExecutionProvider"; constexpr const char* kCudaNHWCExecutionProvider = "CUDANHWCExecutionProvider"; constexpr const char* kDnnlExecutionProvider = "DnnlExecutionProvider"; constexpr const char* kOpenVINOExecutionProvider = "OpenVINOExecutionProvider"; diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 493f399e83a09..7f0b16e2fdee7 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -3729,5 +3729,6 @@ using UnownedSharedPrePackedWeightCache = ///< Wraps OrtEpApi::GetEnvConfigEntries() Ort::KeyValuePairs GetEnvConfigEntries(); + } // namespace Ort #include "onnxruntime_cxx_inline.h" diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 93f7273a6b4b5..cb145c2b6c10a 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -4174,4 +4174,5 @@ inline OpSchema GetOpSchema(const char* name, int max_inclusive_version, const c ThrowOnError(GetEpApi().GetOpSchema(name, max_inclusive_version, domain, &schema)); return OpSchema{schema}; } + } // namespace Ort diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index cc65142318d02..2bebab9862e7c 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -277,7 +277,7 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l ORT_RETURN_IF_ERROR(create_graph_viewer(sub_graph_holder, graph_viewer)); if (params.resource_accountant) { - params.resource_accountant->ResetPendingWeights(); + params.resource_accountant->ResetForNewPass(); } capabilities = get_capabilities(current_ep, *graph_viewer, kernel_lookup, params.resource_accountant, graph_optimizer_registry); @@ -348,7 +348,7 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l ORT_RETURN_IF_ERROR(create_graph_viewer(sub_graph_holder, graph_viewer)); if (params.resource_accountant) { - params.resource_accountant->ResetPendingWeights(); + params.resource_accountant->ResetForNewPass(); } capabilities = get_capabilities(current_ep, *graph_viewer, kernel_lookup, params.resource_accountant, graph_optimizer_registry); @@ -1189,7 +1189,15 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, for (const auto& ep : execution_providers) { IResourceAccountant* resource_accountant = nullptr; if (acc_map.has_value()) { - auto hit = acc_map->find(ep->Type()); + // Plugin EPs have a different Type() than the in-tree EP they replace + // (e.g., kCudaPluginExecutionProvider vs kCudaExecutionProvider), but the + // accountant is registered under the in-tree EP name. Translate the key + // so plugin EPs find the correct accountant. + const auto& ep_type = ep->Type(); + const auto accountant_key = ep_type == kCudaPluginExecutionProvider + ? std::string{kCudaExecutionProvider} + : ep_type; + auto hit = acc_map->find(accountant_key); if (hit != acc_map->end()) { resource_accountant = hit->second.get(); } diff --git a/onnxruntime/core/framework/layering_annotations.cc b/onnxruntime/core/framework/layering_annotations.cc index 91df102abef17..207f39ea64939 100644 --- a/onnxruntime/core/framework/layering_annotations.cc +++ b/onnxruntime/core/framework/layering_annotations.cc @@ -183,7 +183,8 @@ bool MatchEpDevice(const EpDeviceView& ep, if (target_specifier.empty()) { if (ep.device_type == OrtDevice::GPU) return true; // Heuristic fallback for common GPU EPs if hardware info is missing - return ep.ep_name == kCudaExecutionProvider || ep.ep_name == kDmlExecutionProvider; + return ep.ep_name == kCudaExecutionProvider || ep.ep_name == kCudaPluginExecutionProvider || + ep.ep_name == kDmlExecutionProvider; } // "gpu:" or "gpu:" if (ep.device_type == OrtDevice::GPU) { @@ -203,7 +204,7 @@ bool MatchEpDevice(const EpDeviceView& ep, ep.vendor_id == OrtDevice::VendorIds::INTEL) return true; // Heuristic: gpu:nvidia -> CUDA if (CaseInsensitiveCompare(target_specifier, "nvidia") && - ep.ep_name == kCudaExecutionProvider) return true; + (ep.ep_name == kCudaExecutionProvider || ep.ep_name == kCudaPluginExecutionProvider)) return true; } return false; } @@ -225,7 +226,7 @@ bool MatchEpDevice(const EpDeviceView& ep, } // "cuda" if (CaseInsensitiveCompare(target_type_str, "cuda")) { - return ep.ep_name == kCudaExecutionProvider; + return ep.ep_name == kCudaExecutionProvider || ep.ep_name == kCudaPluginExecutionProvider; } // "dml" if (CaseInsensitiveCompare(target_type_str, "dml")) { @@ -284,7 +285,13 @@ std::optional EpLayeringMatcher::Match(gsl::spanvendor_id : 0u, - has_hw ? static_cast(ep_device.device->device_id) : OrtDevice::DeviceId{}, + // Prefer the device ordinal from device_memory_info (set by the EP factory to + // a runtime device ordinal such as a CUDA ordinal) over the OrtHardwareDevice::device_id + // which is a hardware-type identifier and not guaranteed to be a stable runtime ordinal. + ep_device.device_memory_info + ? ep_device.device_memory_info->device.Id() + : (has_hw ? static_cast(ep_device.device->device_id) + : OrtDevice::DeviceId{}), has_hw ? std::string_view(ep_device.device->vendor) : std::string_view{}}; if (MatchEpDevice(view, target_type_str, target_specifier, rule.device)) { diff --git a/onnxruntime/core/framework/resource_accountant.cc b/onnxruntime/core/framework/resource_accountant.cc index 68610ebb4be17..019bbdd8611be 100644 --- a/onnxruntime/core/framework/resource_accountant.cc +++ b/onnxruntime/core/framework/resource_accountant.cc @@ -120,7 +120,7 @@ class SizeBasedStatsAccountant : public IResourceAccountant { } } - void ResetPendingWeights() override { + void ResetPendingWeightsImpl() override { pending_weights_.clear(); pending_weights_by_node_.clear(); } @@ -317,4 +317,36 @@ std::string IResourceAccountant::MakeUniqueNodeName(const Node& node) { return result; } +ResourceCount AddResourceCounts(const ResourceCount& a, const ResourceCount& b) { + return std::visit( + [](auto lhs, auto rhs) -> ResourceCount { + static_assert(std::is_same_v, + "AddResourceCounts requires both operands to hold the same type. " + "Handle the new ResourceCount variant member."); + if constexpr (std::is_integral_v) { + return static_cast(SafeInt(lhs) + rhs); + } else { + return lhs + rhs; + } + }, + a, b); +} + +bool ResourceCountExceeds(const ResourceCount& a, const ResourceCount& b) { + return std::visit( + [](auto lhs, auto rhs) -> bool { + static_assert(std::is_same_v, + "ResourceCountExceeds requires both operands to hold the same type. " + "Handle the new ResourceCount variant member."); + return lhs > rhs; + }, + a, b); +} + +std::string FormatResourceCount(const ResourceCount& rc) { + return std::visit( + [](auto val) -> std::string { return std::to_string(val); }, + rc); +} + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc index f589249a883c3..54b8c9cb4a216 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc +++ b/onnxruntime/core/providers/cuda/plugin/cuda_ep.cc @@ -11,12 +11,15 @@ #include "ep/get_capability_utils.h" #include +#include #include #include #include #include #include +#include "core/graph/constants.h" + namespace onnxruntime { namespace cuda_plugin { @@ -227,12 +230,15 @@ OrtStatus* ORT_API_CALL CudaEp::GetCapabilityImpl( cpu_preferred_nodes)); // Phase 3: Add final supported nodes (tentative minus CPU-preferred). + // Resource budget enforcement is handled by the host after GetCapability returns. + for (const OrtNode* ort_node : candidate_nodes) { - if (cpu_preferred_nodes.count(ort_node) == 0) { - Ort::ConstNode node{ort_node}; - RETURN_IF_ERROR(ep_api.EpGraphSupportInfo_AddSingleNode( - graph_support_info, node)); + if (cpu_preferred_nodes.count(ort_node) != 0) { + continue; } + + RETURN_IF_ERROR(ep_api.EpGraphSupportInfo_AddSingleNode( + graph_support_info, ort_node)); } return nullptr; diff --git a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h index 735e3eb660c62..b1da0aa816a03 100644 --- a/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h +++ b/onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h @@ -99,7 +99,6 @@ class OrtStreamAdapter { #include "core/providers/common.h" namespace onnxruntime { -inline constexpr const char* kCudaPluginExecutionProvider = "CudaPluginExecutionProvider"; // Forward declaration of GetEnvironmentVar for plugin builds on Windows. // Defined in provider_api_shims.cc; mirrors the provider_api.h declaration diff --git a/onnxruntime/core/session/abi_ep_types.h b/onnxruntime/core/session/abi_ep_types.h index deaadf7c67e6e..8e0691b985dc1 100644 --- a/onnxruntime/core/session/abi_ep_types.h +++ b/onnxruntime/core/session/abi_ep_types.h @@ -16,6 +16,7 @@ namespace onnxruntime { struct EpGraph; struct EpNode; +class IResourceAccountant; } // namespace onnxruntime /// @@ -50,4 +51,8 @@ struct OrtEpGraphSupportInfo { const onnxruntime::EpGraph& ort_graph; std::vector node_groupings; const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup; + + // Optional resource accountant for capacity-aware partitioning. + // Owned by the graph partitioner; lifetime exceeds this struct. + onnxruntime::IResourceAccountant* resource_accountant = nullptr; }; diff --git a/onnxruntime/core/session/plugin_ep/ep_api.h b/onnxruntime/core/session/plugin_ep/ep_api.h index 4608318f388ee..e32e267a75ba5 100644 --- a/onnxruntime/core/session/plugin_ep/ep_api.h +++ b/onnxruntime/core/session/plugin_ep/ep_api.h @@ -178,4 +178,5 @@ ORT_API_STATUS_IMPL(ProfilingEvent_GetDurationUs, _In_ const OrtProfilingEvent* _Out_ int64_t* out); ORT_API_STATUS_IMPL(ProfilingEvent_GetArgValue, _In_ const OrtProfilingEvent* event, _In_ const char* key, _Outptr_result_maybenull_ const char** out); + } // namespace OrtExecutionProviderApi diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc index d32967f8b37e3..f50f87cb4100d 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc @@ -13,6 +13,8 @@ #include "core/framework/error_code_helper.h" #include "core/framework/plugin_data_transfer.h" #include "core/framework/plugin_ep_stream.h" +#include "core/framework/resource_accountant.h" +#include "core/common/inlined_containers.h" #include "core/graph/ep_api_types.h" #include "core/graph/model_editor_api_types.h" #include "core/session/abi_devices.h" @@ -236,10 +238,17 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie const GraphOptimizerRegistry& graph_optimizer_registry, IResourceAccountant* resource_accountant) const { ORT_UNUSED_PARAMETER(graph_optimizer_registry); // TODO: Add support - ORT_UNUSED_PARAMETER(resource_accountant); // TODO: Add support? Not used by prioritized EPs const logging::Logger& logger = GetEpLoggerOrDefault(); + // Early exit if a previous GetCapability pass already signaled stop (e.g., budget exhausted). + // The framework calls GetCapability multiple times (e.g., after layout transformation), + // and each EP is responsible for checking the stop flag. + if (resource_accountant != nullptr && resource_accountant->IsStopIssued()) { + LOGS(logger, WARNING) << Type() << " returning due to stop already set"; + return {}; + } + std::unique_ptr ep_graph = nullptr; if (Status status = EpGraph::Create(graph_viewer, ep_graph, true); !status.IsOK()) { LOGS(logger, ERROR) << "Failed to create OrtGraph for " << Type() << ": " << status.ToString(); @@ -247,6 +256,7 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie } OrtEpGraphSupportInfo api_graph_support_info(*ep_graph, kernel_lookup); + api_graph_support_info.resource_accountant = resource_accountant; Status status = ToStatusAndRelease(ort_ep_->GetCapability(ort_ep_.get(), ep_graph->ToExternal(), &api_graph_support_info)); if (!status.IsOK()) { @@ -260,6 +270,15 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie return {}; } + // Host-side resource budget enforcement. + // The host computes costs and enforces the budget uniformly for all node grouping kinds. + // Plugin EPs only propose supported nodes; the host decides which to accept. + const bool has_budget = resource_accountant != nullptr && resource_accountant->GetThreshold().has_value(); + ResourceCount consumed = resource_accountant != nullptr + ? resource_accountant->GetConsumedAmount() + : ResourceCount{}; + ResourceCount budget = has_budget ? *resource_accountant->GetThreshold() : ResourceCount{}; + // Create ComputeCapability instances from OrtEpGraphSupportInfo::NodeGrouping instances. for (const OrtEpGraphSupportInfo::NodeGrouping& node_grouping : api_graph_support_info.node_groupings) { // Skip this node grouping if any node has already been assigned to another EP. @@ -275,21 +294,50 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie if (node_grouping.kind == OrtEpGraphSupportInfo::NodeGroupingKind::kSingleAssignedNode) { if (node_grouping.nodes.size() != 1) { - // The EpGraphSupportInfo_AddSingleNode() C API should already return an error if the EP tries to provide - // an invalid node. However, we check here too just in case this changes. LOGS(logger, ERROR) << "OrtEp::GetCapability() for " << Type() << " did not specify exactly one valid node " << "when calling EpGraphSupportInfo_AddSingleNode()."; return {}; } + const Node& internal_node = node_grouping.nodes[0]->GetInternalNode(); + const NodeIndex node_index = internal_node.Index(); + + // Node already assigned from a previous pass (e.g., before layout transformation + // or after function inlining). Its cost was already committed — skip re-computation to avoid + // double-charging the output-size component. + // FindFirstNodeAssignedToOtherEP already filtered out nodes assigned to a different EP, + // so a non-empty EP type here means it was assigned to this EP. + const bool previously_assigned = !internal_node.GetExecutionProviderType().empty(); + auto indexed_sub_graph = std::make_unique(); + indexed_sub_graph->nodes.push_back(node_index); + + // Host-side budget enforcement for single nodes. + if (resource_accountant != nullptr && !previously_assigned) { + ResourceCount cost = resource_accountant->ComputeResourceCount(internal_node); + ResourceCount would_be_consumed = AddResourceCounts(consumed, cost); + + LOGS(logger, VERBOSE) << Type() << " node: " << internal_node.Name() + << " (" << internal_node.OpType() << ")" + << " cost: " << FormatResourceCount(cost) + << " would_be_consumed: " << FormatResourceCount(would_be_consumed) + << " budget: " << FormatResourceCount(budget); + + if (has_budget && ResourceCountExceeds(would_be_consumed, budget)) { + LOGS(logger, WARNING) << Type() << " halting assignment due to budget at node: " + << internal_node.Name(); + resource_accountant->SetStopAssignment(); + break; // stop processing further groupings + } + + consumed = would_be_consumed; + indexed_sub_graph->SetAccountant(resource_accountant); + indexed_sub_graph->AppendNodeCost(cost); + } - indexed_sub_graph->nodes.push_back(node_grouping.nodes[0]->GetInternalNode().Index()); result.push_back(std::make_unique(std::move(indexed_sub_graph))); } else if (node_grouping.kind == OrtEpGraphSupportInfo::NodeGroupingKind::kFusedNode) { if (node_grouping.nodes.empty()) { - // The EpGraphSupportInfo_AddNodesToFuse() C API should already return an error if the EP tries to provide - // an empty array of nodes from OrtEp::GetCapability(). However, we check here too just in case this changes. LOGS(logger, ERROR) << "OrtEp::GetCapability() for " << Type() << " set an empty array of nodes " << "when specifying supported nodes."; return {}; @@ -322,12 +370,11 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie return {}; } - // Log an error if the nodes in node_set do not match the nodes in capabilities[0]. We expect this to always - // be true because we've already checked that the EP did not try to claim nodes already assigned to another EP. + // Log an error if the nodes in node_set do not match the nodes in capabilities[0]. // TODO(adrianlizarraga): This check can be removed when we stop using utils::CreateSupportedPartitions() above. std::vector& capability_node_indices = capabilities[0]->sub_graph->nodes; - std::unordered_set capability_node_indices_set(capability_node_indices.begin(), - capability_node_indices.end()); + InlinedHashSet capability_node_indices_set(capability_node_indices.begin(), + capability_node_indices.end()); if (node_set.size() != capability_node_indices_set.size()) { LOGS(logger, ERROR) << "OrtEp::GetCapability() for " << Type() @@ -335,6 +382,50 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie return {}; } + // Host-side budget enforcement for fused capabilities. + // Compute per-component-node costs from the accountant and check total against budget. + // Skip cost computation for nodes already assigned to this EP from a previous pass + // to avoid double-charging the output-size component. + if (resource_accountant != nullptr) { + auto* fused_sub_graph = capabilities[0]->sub_graph.get(); + fused_sub_graph->SetAccountant(resource_accountant); + + ResourceCount group_cost{}; + InlinedVector node_costs; + node_costs.reserve(fused_sub_graph->nodes.size()); + + for (NodeIndex idx : fused_sub_graph->nodes) { + const Node* node = graph_viewer.GetNode(idx); + const bool node_already_assigned = + node != nullptr && !node->GetExecutionProviderType().empty(); + ResourceCount cost = (node != nullptr && !node_already_assigned) + ? resource_accountant->ComputeResourceCount(*node) + : ResourceCount{}; + group_cost = AddResourceCounts(group_cost, cost); + node_costs.push_back(cost); + } + + ResourceCount would_be_consumed = AddResourceCounts(consumed, group_cost); + + if (has_budget) { + LOGS(logger, VERBOSE) << Type() << " fused group cost: " << FormatResourceCount(group_cost) + << " would_be_consumed: " << FormatResourceCount(would_be_consumed) + << " budget: " << FormatResourceCount(budget); + + if (ResourceCountExceeds(would_be_consumed, budget)) { + LOGS(logger, WARNING) << Type() << " halting assignment: fused group exceeds budget."; + resource_accountant->SetStopAssignment(); + break; // stop processing further groupings + } + } + + consumed = would_be_consumed; + + for (const auto& cost : node_costs) { + fused_sub_graph->AppendNodeCost(cost); + } + } + result.push_back(std::move(capabilities[0])); } else { LOGS(logger, ERROR) << "PluginExecutionProvider::GetCapability() has invalid NodeGroupingKind: " diff --git a/onnxruntime/test/framework/resource_accountant_test.cc b/onnxruntime/test/framework/resource_accountant_test.cc index a102fe4e7770b..fe5ec0b039200 100644 --- a/onnxruntime/test/framework/resource_accountant_test.cc +++ b/onnxruntime/test/framework/resource_accountant_test.cc @@ -2,103 +2,61 @@ // Licensed under the MIT License. #include "core/framework/resource_accountant.h" +#include "core/framework/config_options.h" #include "core/graph/indexed_sub_graph.h" #include "core/graph/constants.h" #include "core/graph/model.h" +#include "core/session/onnxruntime_session_options_config_keys.h" #include "gtest/gtest.h" #include "test/util/include/asserts.h" #include "test/util/include/test_environment.h" -namespace onnxruntime { -namespace test { - -// Test accountant mimicking SizeBasedStatsAccountant ad-hoc path: -// Uses pending/committed weight sets so that: -// - Within a GetCapability pass, shared weights are deduped -// - Across passes, only committed weights persist and pending are discarded -class TestDedupAccountant : public IResourceAccountant { - public: - TestDedupAccountant() = default; - - ResourceCount GetConsumedAmount() const override { - return consumed_; - } - - void AddConsumedAmount(const ResourceCount& amount) noexcept override { - if (std::holds_alternative(amount)) { - consumed_ += std::get(amount); - } - } - - void RemoveConsumedAmount(const ResourceCount& amount) noexcept override { - if (std::holds_alternative(amount)) { - consumed_ -= std::get(amount); - } - } +#include +#include +#include - ResourceCount ComputeResourceCount(const Node& node) override { - const auto* graph = node.GetContainingGraph(); - if (graph == nullptr) { - return static_cast(0); - } - - size_t total = 0; - for (const auto* input_def : node.InputDefs()) { - if (!input_def->Exists()) { - continue; - } - const auto& name = input_def->Name(); - constexpr bool check_outer_scope = true; - const auto* init = graph->GetInitializer(name, check_outer_scope); - if (init != nullptr) { - if (committed_weights_.count(name) > 0) { - continue; - } - if (pending_weights_.count(name) > 0) { - continue; - } - auto it = weight_sizes_.find(name); - if (it != weight_sizes_.end()) { - total += it->second; - } - pending_weights_.insert(name); - pending_weights_by_node_[node.Index()].insert(name); - } - } - return total; - } +#ifdef _WIN32 +#include +#define ORT_TEST_PID _getpid() +#else +#include +#define ORT_TEST_PID getpid() +#endif - void ResetPendingWeights() override { - pending_weights_.clear(); - pending_weights_by_node_.clear(); - } +namespace onnxruntime { +namespace test { - void CommitWeightsForNode(NodeIndex node_index) override { - auto it = pending_weights_by_node_.find(node_index); - if (it != pending_weights_by_node_.end()) { - for (const auto& name : it->second) { - pending_weights_.erase(name); - } - committed_weights_.insert(it->second.begin(), it->second.end()); - pending_weights_by_node_.erase(it); - } - } +namespace { - void RegisterWeight(const std::string& name, size_t size) { - weight_sizes_[name] = size; - } +// Helper to extract size_t from ResourceCount variant. +// Uses std::get_if so test failures produce a clear assertion rather than std::bad_variant_access. +size_t GetSizeT(const ResourceCount& rc) { + const auto* value = std::get_if(&rc); + EXPECT_NE(value, nullptr) << "ResourceCount does not hold size_t"; + return value != nullptr ? *value : 0; +} - size_t GetConsumedSizeT() const { return consumed_; } +// Helper to create a real SizeBasedStatsAccountant in ad-hoc mode (no stats file) via factory. +// Must be called from within a TEST body (or via ASSERT_NO_FATAL_FAILURE) because it uses ASSERT_*. +void CreateAdHocAccountant( + size_t limit_kb, + const std::filesystem::path& model_path, + std::optional& acc_map, + IResourceAccountant*& out) { + ConfigOptions config; + std::string setting = std::to_string(limit_kb) + ","; + ASSERT_STATUS_OK(config.AddConfigEntry( + kOrtSessionOptionsResourceCudaPartitioningSettings, setting.c_str())); + ASSERT_STATUS_OK(CreateAccountants(config, model_path, acc_map)); + ASSERT_TRUE(acc_map.has_value()); + auto it = acc_map->find(kCudaExecutionProvider); + ASSERT_NE(it, acc_map->end()); + out = it->second.get(); +} - private: - size_t consumed_ = 0; - InlinedHashSet committed_weights_; - InlinedHashSet pending_weights_; - InlinedHashMap> pending_weights_by_node_; - InlinedHashMap weight_sizes_; -}; +} // namespace // Two Add nodes that share a single initializer weight_W. struct SharedWeightGraph { @@ -107,8 +65,7 @@ struct SharedWeightGraph { Node* node_a = nullptr; Node* node_b = nullptr; - static SharedWeightGraph Create() { - SharedWeightGraph h; + static void Create(SharedWeightGraph& h) { std::unordered_map dom; dom[kOnnxDomain] = 12; h.model = std::make_unique( @@ -141,68 +98,95 @@ struct SharedWeightGraph { h.node_a = &h.graph->AddNode("node_A", "Add", "A", {ia, wa}, {oa}); h.node_b = &h.graph->AddNode("node_B", "Add", "B", {ib, wa}, {ob}); - auto status = h.graph->Resolve(); - ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); - return h; + ASSERT_STATUS_OK(h.graph->Resolve()); } }; -// Regression: AccountForAllNodes sums pre-stored per-node costs +// Ad-hoc path expected costs for SharedWeightGraph: +// weight_W = 250 floats = 1000 bytes, each output = 250 floats = 1000 bytes +// node_A: (1000 init + 1000 out) * 1.5 = 3000 +// node_B: (0 deduped + 1000 out) * 1.5 = 1500 + +// AccountForAllNodes sums pre-stored per-node costs // that already have correct within-pass weight deduplication. TEST(ResourceAccountantTest, AccountForAllNodes_CorrectlyUsesPreStoredCosts) { - auto h = SharedWeightGraph::Create(); - TestDedupAccountant accountant; - accountant.RegisterWeight("weight_W", 1000); + SharedWeightGraph h; + ASSERT_NO_FATAL_FAILURE(SharedWeightGraph::Create(h)); + std::optional acc_map; + IResourceAccountant* accountant = nullptr; + ASSERT_NO_FATAL_FAILURE(CreateAdHocAccountant(/*limit_kb=*/100, PathString(), acc_map, accountant)); IndexedSubGraph sub_graph; sub_graph.nodes.push_back(h.node_a->Index()); sub_graph.nodes.push_back(h.node_b->Index()); - sub_graph.SetAccountant(&accountant); + sub_graph.SetAccountant(accountant); - auto cost_a = accountant.ComputeResourceCount(*h.node_a); + auto cost_a = accountant->ComputeResourceCount(*h.node_a); sub_graph.AppendNodeCost(cost_a); - EXPECT_EQ(std::get(cost_a), size_t{1000}); + EXPECT_EQ(GetSizeT(cost_a), size_t{3000}); - auto cost_b = accountant.ComputeResourceCount(*h.node_b); + auto cost_b = accountant->ComputeResourceCount(*h.node_b); sub_graph.AppendNodeCost(cost_b); - EXPECT_EQ(std::get(cost_b), size_t{0}); + EXPECT_EQ(GetSizeT(cost_b), size_t{1500}); ASSERT_TRUE(sub_graph.IsAccountingEnabled()); sub_graph.AccountForAllNodes(); - EXPECT_EQ(accountant.GetConsumedSizeT(), size_t{1000}) - << "AccountForAllNodes should sum pre-stored costs (1000 + 0)"; + EXPECT_EQ(GetSizeT(accountant->GetConsumedAmount()), size_t{4500}) + << "AccountForAllNodes should sum pre-stored costs (3000 + 1500)"; } -// Verifies that ResetPendingWeights + re-probe produces correct results. -// After probing (which only writes to pending), resetting pending and +// Verifies that ResetForNewPass + re-probe produces correct results. +// After probing (which only writes to pending), resetting for a new pass and // re-probing should see the full weight cost again since nothing was committed. TEST(ResourceAccountantTest, ComputeAndAccountForNode_CorrectAfterReset) { - auto h = SharedWeightGraph::Create(); - TestDedupAccountant accountant; - accountant.RegisterWeight("weight_W", 1000); + SharedWeightGraph h; + ASSERT_NO_FATAL_FAILURE(SharedWeightGraph::Create(h)); + std::optional acc_map; + IResourceAccountant* accountant = nullptr; + ASSERT_NO_FATAL_FAILURE(CreateAdHocAccountant(/*limit_kb=*/100, PathString(), acc_map, accountant)); // Probing pass populates pending weights - auto cost_a = accountant.ComputeResourceCount(*h.node_a); - EXPECT_EQ(std::get(cost_a), size_t{1000}); - auto cost_b = accountant.ComputeResourceCount(*h.node_b); - EXPECT_EQ(std::get(cost_b), size_t{0}); + auto cost_a = accountant->ComputeResourceCount(*h.node_a); + EXPECT_EQ(GetSizeT(cost_a), size_t{3000}); + auto cost_b = accountant->ComputeResourceCount(*h.node_b); + EXPECT_EQ(GetSizeT(cost_b), size_t{1500}); // Discard the pass (simulating capabilities.clear() before second GetCapability) - accountant.ResetPendingWeights(); + accountant->ResetForNewPass(); // Re-probe: weight_W was never committed, so it should be counted again IndexedSubGraph sub_graph; sub_graph.nodes.push_back(h.node_a->Index()); - sub_graph.SetAccountant(&accountant); - auto recomputed_cost = accountant.ComputeResourceCount(*h.node_a); + sub_graph.SetAccountant(accountant); + auto recomputed_cost = accountant->ComputeResourceCount(*h.node_a); sub_graph.AccountForNode(h.node_a->Index(), recomputed_cost); - EXPECT_EQ(accountant.GetConsumedSizeT(), size_t{1000}) - << "After ResetPendingWeights, re-probe should see full weight cost"; + EXPECT_EQ(GetSizeT(accountant->GetConsumedAmount()), size_t{3000}) + << "After ResetForNewPass, re-probe should see full weight cost"; +} + +// ResetForNewPass clears the stop flag so a second GetCapability pass +// (e.g., after layout transformation) can run from scratch. +TEST(ResourceAccountantTest, ResetForNewPass_ClearsStopFlag) { + std::optional acc_map; + IResourceAccountant* accountant = nullptr; + ASSERT_NO_FATAL_FAILURE(CreateAdHocAccountant(/*limit_kb=*/1, PathString(), acc_map, accountant)); + + // Simulate first pass exhausting budget and setting stop. + accountant->SetStopAssignment(); + EXPECT_TRUE(accountant->IsStopIssued()); + + // Framework discards first-pass results and resets for second pass. + accountant->ResetForNewPass(); + EXPECT_FALSE(accountant->IsStopIssued()) + << "ResetForNewPass should clear the stop flag for the next pass"; } // Each node has a unique initializer. AccountForAllNodes sums both. +// weight_1 = 100 floats = 400 bytes, weight_2 = 100 floats = 400 bytes, outputs = 400 bytes each +// node1: (400 init + 400 out) * 1.5 = 1200 +// node2: (400 init + 400 out) * 1.5 = 1200 TEST(ResourceAccountantTest, AccountForAllNodes_NoSharedWeights) { std::unordered_map dom; dom[kOnnxDomain] = 12; @@ -239,88 +223,227 @@ TEST(ResourceAccountantTest, AccountForAllNodes_NoSharedWeights) { auto& node2 = graph.AddNode("n2", "Add", "", {out1, w2}, {out2}); ASSERT_STATUS_OK(graph.Resolve()); - TestDedupAccountant accountant; - accountant.RegisterWeight("weight_1", 400); - accountant.RegisterWeight("weight_2", 600); + std::optional acc_map; + IResourceAccountant* accountant = nullptr; + ASSERT_NO_FATAL_FAILURE(CreateAdHocAccountant(/*limit_kb=*/100, PathString(), acc_map, accountant)); IndexedSubGraph sub_graph; sub_graph.nodes.push_back(node1.Index()); sub_graph.nodes.push_back(node2.Index()); - sub_graph.SetAccountant(&accountant); + sub_graph.SetAccountant(accountant); - sub_graph.AppendNodeCost(accountant.ComputeResourceCount(node1)); - sub_graph.AppendNodeCost(accountant.ComputeResourceCount(node2)); + sub_graph.AppendNodeCost(accountant->ComputeResourceCount(node1)); + sub_graph.AppendNodeCost(accountant->ComputeResourceCount(node2)); ASSERT_TRUE(sub_graph.IsAccountingEnabled()); sub_graph.AccountForAllNodes(); - EXPECT_EQ(accountant.GetConsumedSizeT(), size_t{1000}) - << "No shared weights: should sum all costs (400 + 600)"; + EXPECT_EQ(GetSizeT(accountant->GetConsumedAmount()), size_t{2400}) + << "No shared weights: should sum all costs (1200 + 1200)"; } // AccountForNode per-node and AccountForAllNodes bulk produce same result. TEST(ResourceAccountantTest, AccountForNode_MatchesAccountForAllNodes) { - auto h = SharedWeightGraph::Create(); + SharedWeightGraph h; + ASSERT_NO_FATAL_FAILURE(SharedWeightGraph::Create(h)); // Per-node path - TestDedupAccountant acc1; - acc1.RegisterWeight("weight_W", 1000); + std::optional acc_map1; + IResourceAccountant* acc1 = nullptr; + ASSERT_NO_FATAL_FAILURE(CreateAdHocAccountant(/*limit_kb=*/100, PathString(), acc_map1, acc1)); IndexedSubGraph sub1; sub1.nodes.push_back(h.node_a->Index()); sub1.nodes.push_back(h.node_b->Index()); - sub1.SetAccountant(&acc1); - sub1.AppendNodeCost(acc1.ComputeResourceCount(*h.node_a)); - sub1.AppendNodeCost(acc1.ComputeResourceCount(*h.node_b)); + sub1.SetAccountant(acc1); + sub1.AppendNodeCost(acc1->ComputeResourceCount(*h.node_a)); + sub1.AppendNodeCost(acc1->ComputeResourceCount(*h.node_b)); sub1.AccountForNode(0); sub1.AccountForNode(1); - size_t per_node = acc1.GetConsumedSizeT(); + size_t per_node = GetSizeT(acc1->GetConsumedAmount()); // Bulk path - TestDedupAccountant acc2; - acc2.RegisterWeight("weight_W", 1000); + std::optional acc_map2; + IResourceAccountant* acc2 = nullptr; + ASSERT_NO_FATAL_FAILURE(CreateAdHocAccountant(/*limit_kb=*/100, PathString(), acc_map2, acc2)); IndexedSubGraph sub2; sub2.nodes.push_back(h.node_a->Index()); sub2.nodes.push_back(h.node_b->Index()); - sub2.SetAccountant(&acc2); - sub2.AppendNodeCost(acc2.ComputeResourceCount(*h.node_a)); - sub2.AppendNodeCost(acc2.ComputeResourceCount(*h.node_b)); + sub2.SetAccountant(acc2); + sub2.AppendNodeCost(acc2->ComputeResourceCount(*h.node_a)); + sub2.AppendNodeCost(acc2->ComputeResourceCount(*h.node_b)); sub2.AccountForAllNodes(); - size_t bulk = acc2.GetConsumedSizeT(); + size_t bulk = GetSizeT(acc2->GetConsumedAmount()); EXPECT_EQ(per_node, bulk) << "Per-node and bulk should produce identical results"; - EXPECT_EQ(per_node, size_t{1000}); + EXPECT_EQ(per_node, size_t{4500}); } // Cross-subgraph dedup: EP1 commits node_A, EP2 probes node_B and // correctly sees weight_W as already accounted. +// node_A cost: 3000, node_B cost after commit: (0 + 1000) * 1.5 = 1500 TEST(ResourceAccountantTest, CrossSubGraph_DedupWorks) { - auto h = SharedWeightGraph::Create(); - TestDedupAccountant accountant; - accountant.RegisterWeight("weight_W", 1000); + SharedWeightGraph h; + ASSERT_NO_FATAL_FAILURE(SharedWeightGraph::Create(h)); + std::optional acc_map; + IResourceAccountant* accountant = nullptr; + ASSERT_NO_FATAL_FAILURE(CreateAdHocAccountant(/*limit_kb=*/100, PathString(), acc_map, accountant)); // EP1 probes and commits node_A IndexedSubGraph sub1; sub1.nodes.push_back(h.node_a->Index()); - sub1.SetAccountant(&accountant); - sub1.AppendNodeCost(accountant.ComputeResourceCount(*h.node_a)); + sub1.SetAccountant(accountant); + sub1.AppendNodeCost(accountant->ComputeResourceCount(*h.node_a)); sub1.AccountForNode(0); - EXPECT_EQ(accountant.GetConsumedSizeT(), size_t{1000}); + accountant->CommitWeightsForNode(h.node_a->Index()); + EXPECT_EQ(GetSizeT(accountant->GetConsumedAmount()), size_t{3000}); + + // Reset for new pass to simulate new GetCapability pass + accountant->ResetForNewPass(); - // EP2 probes node_B: weight_W already committed - auto cost_b = accountant.ComputeResourceCount(*h.node_b); - EXPECT_EQ(std::get(cost_b), size_t{0}) - << "weight_W was committed by EP1, should be deduped for EP2"; + // EP2 probes node_B: weight_W already committed, only output counted + auto cost_b = accountant->ComputeResourceCount(*h.node_b); + EXPECT_EQ(GetSizeT(cost_b), size_t{1500}) + << "weight_W was committed by EP1, only output (1000 * 1.5) counted"; - // EP2 commits node_B with cost 0 + // EP2 commits node_B IndexedSubGraph sub2; sub2.nodes.push_back(h.node_b->Index()); - sub2.SetAccountant(&accountant); + sub2.SetAccountant(accountant); sub2.AppendNodeCost(cost_b); sub2.AccountForNode(0); - EXPECT_EQ(accountant.GetConsumedSizeT(), size_t{1000}) - << "Total should still be 1000 - weight_W counted once across both"; + EXPECT_EQ(GetSizeT(accountant->GetConsumedAmount()), size_t{4500}) + << "Total should be 3000 + 1500 - weight_W initializer counted once"; +} + +// --------------------------------------------------------------------------- +// Stats-based path and factory tests +// --------------------------------------------------------------------------- + +// RAII helper to remove a temp file on scope exit, even if a test assertion fails. +struct ScopedFileRemover { + std::filesystem::path path; + ~ScopedFileRemover() { + std::error_code ec; + std::filesystem::remove(path, ec); + } +}; + +// Stats-based path: cost is sum of all NodeAllocationStats fields. +TEST(RealAccountantTest, StatsPath_ComputesCostFromStatsFile) { + SharedWeightGraph h; + ASSERT_NO_FATAL_FAILURE(SharedWeightGraph::Create(h)); + + // Write a stats file with known costs (unique per PID to avoid parallel collisions) + std::error_code ec; + auto stats_dir = std::filesystem::temp_directory_path(ec); + ASSERT_FALSE(ec) << ec.message(); + std::ostringstream fname; + fname << "test_resource_accountant_stats_" << ORT_TEST_PID << ".csv"; + auto stats_path = stats_dir / fname.str(); + ScopedFileRemover stats_cleanup{stats_path}; + + // Get the unique node names the accountant will look up + std::string name_a = IResourceAccountant::MakeUniqueNodeName(*h.node_a); + std::string name_b = IResourceAccountant::MakeUniqueNodeName(*h.node_b); + + { + std::ofstream ofs(stats_path); + ASSERT_TRUE(ofs.is_open()); + ofs << "#name,input_sizes,initializers_sizes,total_dynamic_sizes,total_temp_allocations\n"; + // input_sizes=100, initializers=200, dynamic=300, temp=400 -> total=1000 + ofs << name_a << ",100,200,300,400\n"; + // input_sizes=50, initializers=0, dynamic=150, temp=0 -> total=200 + ofs << name_b << ",50,0,150,0\n"; + } + + // Factory expects stats file relative to model_path dir + ConfigOptions config; + std::string setting = "500," + stats_path.filename().string(); + ASSERT_STATUS_OK(config.AddConfigEntry( + kOrtSessionOptionsResourceCudaPartitioningSettings, setting.c_str())); + + std::optional acc_map; + ASSERT_STATUS_OK(CreateAccountants(config, stats_dir / "dummy_model.onnx", acc_map)); + ASSERT_TRUE(acc_map.has_value()); + auto* accountant = acc_map->at(kCudaExecutionProvider).get(); + + auto cost_a = accountant->ComputeResourceCount(*h.node_a); + EXPECT_EQ(std::get(cost_a), size_t{1000}); + + auto cost_b = accountant->ComputeResourceCount(*h.node_b); + EXPECT_EQ(std::get(cost_b), size_t{200}); + + // Threshold should be 500 KB = 512000 bytes + auto threshold = accountant->GetThreshold(); + ASSERT_TRUE(threshold.has_value()); + EXPECT_EQ(std::get(*threshold), size_t{500 * 1024}); +} + +// Stats-based path returns 0 for unknown nodes. +TEST(RealAccountantTest, StatsPath_UnknownNodeReturnsZero) { + SharedWeightGraph h; + ASSERT_NO_FATAL_FAILURE(SharedWeightGraph::Create(h)); + + std::error_code ec; + auto stats_dir = std::filesystem::temp_directory_path(ec); + ASSERT_FALSE(ec) << ec.message(); + std::ostringstream fname; + fname << "test_resource_accountant_empty_stats_" << ORT_TEST_PID << ".csv"; + auto stats_path = stats_dir / fname.str(); + ScopedFileRemover stats_cleanup{stats_path}; + + { + std::ofstream ofs(stats_path); + ASSERT_TRUE(ofs.is_open()); + ofs << "#name,input_sizes,initializers_sizes,total_dynamic_sizes,total_temp_allocations\n"; + // No entries for our nodes + } + + ConfigOptions config; + std::string setting = "1000," + stats_path.filename().string(); + ASSERT_STATUS_OK(config.AddConfigEntry( + kOrtSessionOptionsResourceCudaPartitioningSettings, setting.c_str())); + + std::optional acc_map; + ASSERT_STATUS_OK(CreateAccountants(config, stats_dir / "dummy_model.onnx", acc_map)); + auto* accountant = acc_map->at(kCudaExecutionProvider).get(); + + auto cost = accountant->ComputeResourceCount(*h.node_a); + EXPECT_EQ(std::get(cost), size_t{0}); +} + +// Factory with no limit and no stats file creates accountant with no threshold. +TEST(RealAccountantTest, Factory_NoLimitNoStats) { + ConfigOptions config; + ASSERT_STATUS_OK(config.AddConfigEntry( + kOrtSessionOptionsResourceCudaPartitioningSettings, ",")); + + std::optional acc_map; + ASSERT_STATUS_OK(CreateAccountants(config, PathString(), acc_map)); + ASSERT_TRUE(acc_map.has_value()); + auto* accountant = acc_map->at(kCudaExecutionProvider).get(); + EXPECT_FALSE(accountant->GetThreshold().has_value()); +} + +// Factory returns empty optional when no config is set. +TEST(RealAccountantTest, Factory_NoConfigReturnsEmpty) { + ConfigOptions config; + std::optional acc_map; + ASSERT_STATUS_OK(CreateAccountants(config, PathString(), acc_map)); + EXPECT_FALSE(acc_map.has_value()); +} + +// Factory rejects malformed config (missing comma). +TEST(RealAccountantTest, Factory_MalformedConfigReturnsError) { + ConfigOptions config; + ASSERT_STATUS_OK(config.AddConfigEntry( + kOrtSessionOptionsResourceCudaPartitioningSettings, "1000")); // missing comma + + std::optional acc_map; + auto status = CreateAccountants(config, PathString(), acc_map); + EXPECT_FALSE(status.IsOK()); } } // namespace test diff --git a/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc b/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc new file mode 100644 index 0000000000000..23bef914eaf15 --- /dev/null +++ b/onnxruntime/test/providers/cuda/plugin/cuda_resource_partitioning_test.cc @@ -0,0 +1,454 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Integration tests for resource-constrained partitioning through the CUDA plugin EP. +// +// Two test levels: +// 1. Partitioning verification tests — use InferenceSessionWrapper to inspect +// per-node EP assignments after partitioning through the plugin EP. +// 2. E2E session tests — validate output correctness under budget constraints. + +#if defined(ORT_UNIT_TEST_HAS_CUDA_PLUGIN_EP) + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "core/graph/constants.h" +#include "core/graph/model.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/inference_session.h" +#include "core/framework/error_code_helper.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/session/ort_env.h" +#include "core/session/utils.h" +#include "test/util/include/asserts.h" +#include "test/util/include/file_util.h" +#include "test/util/include/inference_session_wrapper.h" + +extern std::unique_ptr ort_env; + +namespace onnxruntime { +namespace test { +namespace { + +constexpr const char* kResourcePartitioningRegistrationName = "CudaPluginResourceTest"; + +// Resolve the CUDA plugin EP shared library path. +std::filesystem::path GetCudaPluginLibraryPath() { + return GetSharedLibraryFileName(ORT_TSTR("onnxruntime_providers_cuda_plugin")); +} + +// RAII handle that registers/unregisters the CUDA plugin EP library. +class ScopedCudaPluginRegistration { + public: + ScopedCudaPluginRegistration(Ort::Env& env, const char* registration_name) + : env_(env), name_(registration_name) { + auto lib_path = GetCudaPluginLibraryPath(); + if (!std::filesystem::exists(lib_path)) { + available_ = false; + return; + } + env_.RegisterExecutionProviderLibrary(name_.c_str(), lib_path.c_str()); + available_ = true; + } + + ~ScopedCudaPluginRegistration() { + if (available_) { + try { + env_.UnregisterExecutionProviderLibrary(name_.c_str()); + } catch (...) { + } + } + } + + bool IsAvailable() const { return available_; } + + ScopedCudaPluginRegistration(const ScopedCudaPluginRegistration&) = delete; + ScopedCudaPluginRegistration& operator=(const ScopedCudaPluginRegistration&) = delete; + + private: + Ort::Env& env_; + std::string name_; + bool available_ = false; +}; + +// Find the CUDA plugin EP device after registration. +Ort::ConstEpDevice FindCudaPluginDevice(Ort::Env& env) { + auto ep_devices = env.GetEpDevices(); + for (const auto& device : ep_devices) { + if (strcmp(device.EpName(), kCudaPluginExecutionProvider) == 0) { + return device; + } + } + return Ort::ConstEpDevice{nullptr}; +} + +// Build a serialized ONNX model with a chain of Add nodes. +// Each node adds its own initializer (of `weight_elements` floats) to the +// previous node's output, producing a linear graph: +// input -> Add(w0) -> Add(w1) -> ... -> Add(wN-1) -> output +// The initializer size directly controls what the ad-hoc resource accountant +// computes per node, giving us precise budget targeting. +std::string BuildAddChainModel(size_t num_nodes, int64_t weight_elements) { + ONNX_NAMESPACE::ModelProto model; + model.set_ir_version(ONNX_NAMESPACE::IR_VERSION); + auto* opset = model.add_opset_import(); + opset->set_domain(""); + opset->set_version(13); + + auto* graph = model.mutable_graph(); + graph->set_name("add_chain"); + + // Shared shape for all tensors. + auto set_type_shape = [weight_elements](ONNX_NAMESPACE::TypeProto* tp) { + auto* tensor_type = tp->mutable_tensor_type(); + tensor_type->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + tensor_type->mutable_shape()->add_dim()->set_dim_value(weight_elements); + }; + + // Graph input. + auto* graph_input = graph->add_input(); + graph_input->set_name("input"); + set_type_shape(graph_input->mutable_type()); + + std::string prev_output = "input"; + for (size_t i = 0; i < num_nodes; ++i) { + std::string weight_name = "w_" + std::to_string(i); + std::string output_name = (i + 1 < num_nodes) + ? "t_" + std::to_string(i) + : "output"; + + // Initializer with known byte size = weight_elements * sizeof(float). + auto* init = graph->add_initializer(); + init->set_name(weight_name); + init->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + init->add_dims(weight_elements); + // Use raw_data for compactness — zeros are fine. + init->set_raw_data(std::string(weight_elements * sizeof(float), '\0')); + + // Weight input value_info (needed for valid graph). + auto* w_input = graph->add_input(); + w_input->set_name(weight_name); + set_type_shape(w_input->mutable_type()); + + // Add node. + auto* node = graph->add_node(); + node->set_op_type("Add"); + node->set_name("add_" + std::to_string(i)); + node->add_input(prev_output); + node->add_input(weight_name); + node->add_output(output_name); + + prev_output = output_name; + } + + // Graph output. + auto* graph_output = graph->add_output(); + graph_output->set_name("output"); + set_type_shape(graph_output->mutable_type()); + + std::string serialized; + model.SerializeToString(&serialized); + return serialized; +} + +// Get the internal OrtEnv* from the C++ Ort::Env wrapper. +// Ort::Env inherits Base which has operator OrtEnv*(). +OrtEnv& GetOrtEnv() { + return *static_cast(*ort_env); +} + +} // namespace + +// --------------------------------------------------------------------------- +// Lower-level partitioning tests that verify per-node EP assignments +// --------------------------------------------------------------------------- + +class CudaPluginPartitioningTest : public ::testing::Test { + protected: + void SetUp() override { + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "No CUDA device available."; + } + + registration_ = std::make_unique( + *ort_env, kResourcePartitioningRegistrationName); + if (!registration_->IsAvailable()) { + GTEST_SKIP() << "CUDA plugin EP library not found."; + } + + cuda_device_ = FindCudaPluginDevice(*ort_env); + if (!cuda_device_) { + GTEST_SKIP() << "No CUDA plugin EP device found after registration."; + } + } + + void TearDown() override { + registration_.reset(); + if (cuda_device_) { + cudaDeviceSynchronize(); + } + } + + // Load a model through the CUDA plugin EP with the given resource budget, + // then call the verifier to inspect graph node assignments. + // + // Uses InferenceSessionWrapper + OrtSessionOptions + InitializeSession + // so that the plugin factory creates the EP and partitioning runs normally, + // but we can access the graph via wrapper.GetGraph(). + void LoadAndVerifyPartitioning(const ORTCHAR_T* model_path, + size_t budget_kb, + const std::function& verifier) { + OrtSessionOptions ort_options; + + // Create the plugin EP factory from the registered device. + const OrtEpDevice* device_ptr = static_cast(cuda_device_); + auto ep_devices_span = gsl::make_span(&device_ptr, 1); + + std::unique_ptr factory; + ASSERT_STATUS_OK(CreateIExecutionProviderFactoryForEpDevices( + GetOrtEnv().GetEnvironment(), ep_devices_span, factory)); + + ort_options.provider_factories.push_back(std::move(factory)); + + // Set resource partitioning budget if requested. + if (budget_kb > 0) { + std::string config_value = std::to_string(budget_kb) + ","; + ASSERT_STATUS_OK(ort_options.value.config_options.AddConfigEntry( + kOrtSessionOptionsResourceCudaPartitioningSettings, config_value.c_str())); + } + + // Create the session wrapper — gives us access to the graph after partitioning. + InferenceSessionWrapper session(ort_options.value, GetOrtEnv().GetEnvironment()); + ASSERT_STATUS_OK(session.Load(model_path)); + + // InitializeSession iterates provider_factories, creates the plugin EP, + // registers it with the session, and calls session.Initialize() which + // runs graph partitioning (invoking plugin GetCapability). + OrtStatus* status = InitializeSession(&ort_options, session); + ASSERT_STATUS_OK(ToStatusAndRelease(status)); + + verifier(session.GetGraph()); + } + + // Overload that loads a model from serialized bytes (e.g., from BuildAddChainModel). + void LoadAndVerifyPartitioning(const std::string& model_bytes, + size_t budget_kb, + const std::function& verifier) { + OrtSessionOptions ort_options; + + const OrtEpDevice* device_ptr = static_cast(cuda_device_); + auto ep_devices_span = gsl::make_span(&device_ptr, 1); + + std::unique_ptr factory; + ASSERT_STATUS_OK(CreateIExecutionProviderFactoryForEpDevices( + GetOrtEnv().GetEnvironment(), ep_devices_span, factory)); + + ort_options.provider_factories.push_back(std::move(factory)); + + if (budget_kb > 0) { + std::string config_value = std::to_string(budget_kb) + ","; + ASSERT_STATUS_OK(ort_options.value.config_options.AddConfigEntry( + kOrtSessionOptionsResourceCudaPartitioningSettings, config_value.c_str())); + } + + InferenceSessionWrapper session(ort_options.value, GetOrtEnv().GetEnvironment()); + ASSERT_STATUS_OK(session.Load(model_bytes.data(), static_cast(model_bytes.size()))); + + OrtStatus* status = InitializeSession(&ort_options, session); + ASSERT_STATUS_OK(ToStatusAndRelease(status)); + + verifier(session.GetGraph()); + } + + std::unique_ptr registration_; + Ort::ConstEpDevice cuda_device_{nullptr}; +}; + +// With no resource budget, all CUDA-supported nodes should be assigned to the plugin EP. +TEST_F(CudaPluginPartitioningTest, NoBudget_AllNodesCudaPlugin) { + constexpr const ORTCHAR_T* model_path = ORT_TSTR("testdata/mul_1.onnx"); + + LoadAndVerifyPartitioning(model_path, /*budget_kb=*/0, [](const Graph& graph) { + for (const auto& node : graph.Nodes()) { + // With no budget constraint, all nodes that the CUDA plugin supports + // should be assigned to it. The plugin EP type name may vary, so just + // verify it's NOT assigned to CPU. + EXPECT_NE(node.GetExecutionProviderType(), kCpuExecutionProvider) + << "Node " << node.Name() << " (" << node.OpType() + << ") unexpectedly assigned to CPU with no budget constraint"; + } + }); +} + +// With a very large budget, all nodes should still be on the plugin EP (same as no budget). +TEST_F(CudaPluginPartitioningTest, LargeBudget_AllNodesCudaPlugin) { + constexpr const ORTCHAR_T* model_path = ORT_TSTR("testdata/mul_1.onnx"); + + // 1 GB — effectively unlimited and safe across 32-bit/64-bit builds + LoadAndVerifyPartitioning(model_path, /*budget_kb=*/1024 * 1024, [](const Graph& graph) { + for (const auto& node : graph.Nodes()) { + EXPECT_NE(node.GetExecutionProviderType(), kCpuExecutionProvider) + << "Node " << node.Name() << " (" << node.OpType() + << ") unexpectedly assigned to CPU with large budget"; + } + }); +} + +// With a small budget, the resource accountant should assign fewer nodes +// to the plugin EP than the no-budget baseline. +TEST_F(CudaPluginPartitioningTest, TinyBudget_NodesOffloadedToCpu) { + // Build a chain of 6 Add nodes, each with a 256-element float initializer + // (1 KB per weight). The ad-hoc accountant adds weight + output sizes with + // a 1.5x multiplier, so each node costs roughly 1.5 * (1 KB + 1 KB) = 3 KB. + // A 10 KB budget should accept ~3 nodes before halting. + const std::string model = BuildAddChainModel(/*num_nodes=*/6, /*weight_elements=*/256); + + // Baseline: count plugin nodes with no budget. + size_t baseline_plugin_count = 0; + LoadAndVerifyPartitioning(model, /*budget_kb=*/0, [&](const Graph& graph) { + for (const auto& node : graph.Nodes()) { + if (node.GetExecutionProviderType() == kCudaPluginExecutionProvider) { + ++baseline_plugin_count; + } + } + }); + ASSERT_GT(baseline_plugin_count, size_t{1}) + << "Baseline must have multiple plugin nodes for the test to be meaningful"; + + // Now run with a 10 KB budget — should accept some but not all nodes. + size_t constrained_plugin_count = 0; + LoadAndVerifyPartitioning(model, /*budget_kb=*/10, [&](const Graph& graph) { + for (const auto& node : graph.Nodes()) { + if (node.GetExecutionProviderType() == kCudaPluginExecutionProvider) { + ++constrained_plugin_count; + } + } + }); + + EXPECT_GT(constrained_plugin_count, size_t{0}) + << "Budget should be large enough to accept at least one node"; + EXPECT_LT(constrained_plugin_count, baseline_plugin_count) + << "A 10 KB budget should reduce plugin EP node count from the no-budget baseline (" + << baseline_plugin_count << " nodes)"; +} + +// --------------------------------------------------------------------------- +// E2E tests (existing high-level session tests, kept for coverage) +// --------------------------------------------------------------------------- + +class CudaResourcePartitioningTest : public ::testing::Test { + protected: + void SetUp() override { + int device_count = 0; + cudaError_t err = cudaGetDeviceCount(&device_count); + if (err != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "No CUDA device available."; + } + + registration_ = std::make_unique( + *ort_env, kResourcePartitioningRegistrationName); + if (!registration_->IsAvailable()) { + GTEST_SKIP() << "CUDA plugin EP library not found."; + } + + cuda_device_ = FindCudaPluginDevice(*ort_env); + if (!cuda_device_) { + GTEST_SKIP() << "No CUDA plugin EP device found after registration."; + } + } + + void TearDown() override { + registration_.reset(); + if (cuda_device_) { + cudaDeviceSynchronize(); + } + } + + Ort::Session CreateSessionWithBudget(const ORTCHAR_T* model_path, + size_t budget_kb) { + Ort::SessionOptions so; + so.AppendExecutionProvider_V2(*ort_env, {cuda_device_}, + std::unordered_map{}); + + if (budget_kb > 0) { + std::string config_value = std::to_string(budget_kb) + ","; + so.AddConfigEntry(kOrtSessionOptionsResourceCudaPartitioningSettings, + config_value.c_str()); + } + + return Ort::Session(*ort_env, model_path, so); + } + + std::unique_ptr registration_; + Ort::ConstEpDevice cuda_device_{nullptr}; +}; + +TEST_F(CudaResourcePartitioningTest, NoBudget_SessionCreatesSuccessfully) { + auto model_path = ORT_TSTR("testdata/mul_1.onnx"); + ASSERT_NO_THROW(CreateSessionWithBudget(model_path, 0)); +} + +TEST_F(CudaResourcePartitioningTest, BudgetConstrained_ProducesValidOutput) { + auto model_path = ORT_TSTR("testdata/mul_1.onnx"); + Ort::Session session = CreateSessionWithBudget(model_path, 100); + + auto input_name = session.GetInputNameAllocated(0, Ort::AllocatorWithDefaultOptions()); + auto output_name = session.GetOutputNameAllocated(0, Ort::AllocatorWithDefaultOptions()); + + auto type_info = session.GetInputTypeInfo(0); + auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); + auto shape = tensor_info.GetShape(); + for (auto& dim : shape) { + if (dim < 0) dim = 1; + } + + size_t num_elements = 1; + for (auto dim : shape) { + num_elements *= static_cast(dim); + } + + std::vector input_data(num_elements, 2.0f); + auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); + Ort::Value input_tensor = Ort::Value::CreateTensor( + memory_info, input_data.data(), input_data.size(), + shape.data(), shape.size()); + + const char* input_names[] = {input_name.get()}; + const char* output_names[] = {output_name.get()}; + + auto outputs = session.Run(Ort::RunOptions{nullptr}, + input_names, &input_tensor, 1, + output_names, 1); + + ASSERT_EQ(outputs.size(), size_t{1}); + ASSERT_TRUE(outputs[0].IsTensor()); + + auto* output_data = outputs[0].GetTensorData(); + auto output_shape = outputs[0].GetTensorTypeAndShapeInfo().GetShape(); + size_t output_count = 1; + for (auto dim : output_shape) { + output_count *= static_cast(dim); + } + for (size_t i = 0; i < output_count; ++i) { + EXPECT_FALSE(std::isnan(output_data[i])) << "NaN at index " << i; + EXPECT_FALSE(std::isinf(output_data[i])) << "Inf at index " << i; + } +} + +} // namespace test +} // namespace onnxruntime + +#endif // defined(ORT_UNIT_TEST_HAS_CUDA_PLUGIN_EP)