Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions include/onnxruntime/core/framework/resource_accountant.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@ struct Node;
// for different EPs
using ResourceCount = std::variant<size_t>;

// Type-erased arithmetic for ResourceCount values.
// Implementations use std::visit so the compiler enforces exhaustive handling
// of all variant members — adding a new type to ResourceCount will produce
// build errors at each call site that must be addressed.
//
// NOTE: These functions are NOT available through the provider bridge (shared library EPs).
// Budget enforcement for bridge-based EPs (e.g., in-tree CUDA EP) will be moved to the
// graph partitioner in a follow-up PR.
ResourceCount AddResourceCounts(const ResourceCount& a, const ResourceCount& b);
bool ResourceCountExceeds(const ResourceCount& a, const ResourceCount& b);
std::string FormatResourceCount(const ResourceCount& rc);
Comment thread
yuslepukhin marked this conversation as resolved.

/// <summary>
/// This class is used for graph partitioning by EPs
/// It stores the cumulative amount of the resource such as
Expand Down Expand Up @@ -61,9 +73,14 @@ class IResourceAccountant {

bool IsStopIssued() const noexcept { return stop_assignment_; }

// Called before each GetCapability pass to discard pending weight tracking
// from a previous (discarded) pass. Default no-op for stats-based accountants.
virtual void ResetPendingWeights() {}
// Called before each GetCapability pass to reset per-pass state:
// clears the stop flag (which only applies to the pass that set it)
// and discards pending weight tracking from a previous (discarded) pass.
// Subclasses override ResetPendingWeightsImpl for EP-specific cleanup.
void ResetForNewPass() {
stop_assignment_ = false;
ResetPendingWeightsImpl();
}

// Called when a node's cost is committed (AccountForNode/AccountForAllNodes).
// Moves the node's pending weights into the committed set so they persist
Expand All @@ -72,6 +89,11 @@ class IResourceAccountant {

static std::string MakeUniqueNodeName(const Node& node);

protected:
// Override to discard EP-specific pending weight tracking.
// Default no-op for stats-based accountants.
virtual void ResetPendingWeightsImpl() {}

private:
bool stop_assignment_ = false;
std::optional<ResourceCount> threshold_;
Expand Down
1 change: 1 addition & 0 deletions include/onnxruntime/core/graph/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ constexpr size_t kMaxExecutionProviderNameLen = 30;

constexpr const char* kCpuExecutionProvider = "CPUExecutionProvider";
constexpr const char* kCudaExecutionProvider = "CUDAExecutionProvider";
constexpr const char* kCudaPluginExecutionProvider = "CudaPluginExecutionProvider";
constexpr const char* kCudaNHWCExecutionProvider = "CUDANHWCExecutionProvider";
constexpr const char* kDnnlExecutionProvider = "DnnlExecutionProvider";
constexpr const char* kOpenVINOExecutionProvider = "OpenVINOExecutionProvider";
Expand Down
1 change: 1 addition & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -3729,5 +3729,6 @@ using UnownedSharedPrePackedWeightCache =

///< Wraps OrtEpApi::GetEnvConfigEntries()
Ort::KeyValuePairs GetEnvConfigEntries();

} // namespace Ort
#include "onnxruntime_cxx_inline.h"
1 change: 1 addition & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -4174,4 +4174,5 @@ inline OpSchema GetOpSchema(const char* name, int max_inclusive_version, const c
ThrowOnError(GetEpApi().GetOpSchema(name, max_inclusive_version, domain, &schema));
return OpSchema{schema};
}

} // namespace Ort
14 changes: 11 additions & 3 deletions onnxruntime/core/framework/graph_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l
ORT_RETURN_IF_ERROR(create_graph_viewer(sub_graph_holder, graph_viewer));

if (params.resource_accountant) {
params.resource_accountant->ResetPendingWeights();
params.resource_accountant->ResetForNewPass();
}
capabilities = get_capabilities(current_ep, *graph_viewer, kernel_lookup, params.resource_accountant,
graph_optimizer_registry);
Expand Down Expand Up @@ -348,7 +348,7 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l
ORT_RETURN_IF_ERROR(create_graph_viewer(sub_graph_holder, graph_viewer));

if (params.resource_accountant) {
params.resource_accountant->ResetPendingWeights();
params.resource_accountant->ResetForNewPass();
}
capabilities = get_capabilities(current_ep, *graph_viewer, kernel_lookup, params.resource_accountant,
graph_optimizer_registry);
Expand Down Expand Up @@ -1189,7 +1189,15 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params,
for (const auto& ep : execution_providers) {
IResourceAccountant* resource_accountant = nullptr;
if (acc_map.has_value()) {
auto hit = acc_map->find(ep->Type());
// Plugin EPs have a different Type() than the in-tree EP they replace
// (e.g., kCudaPluginExecutionProvider vs kCudaExecutionProvider), but the
// accountant is registered under the in-tree EP name. Translate the key
// so plugin EPs find the correct accountant.
const auto& ep_type = ep->Type();
const auto accountant_key = ep_type == kCudaPluginExecutionProvider
? std::string{kCudaExecutionProvider}
: ep_type;
auto hit = acc_map->find(accountant_key);
if (hit != acc_map->end()) {
resource_accountant = hit->second.get();
}
Expand Down
15 changes: 11 additions & 4 deletions onnxruntime/core/framework/layering_annotations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ bool MatchEpDevice(const EpDeviceView& ep,
if (target_specifier.empty()) {
if (ep.device_type == OrtDevice::GPU) return true;
// Heuristic fallback for common GPU EPs if hardware info is missing
return ep.ep_name == kCudaExecutionProvider || ep.ep_name == kDmlExecutionProvider;
return ep.ep_name == kCudaExecutionProvider || ep.ep_name == kCudaPluginExecutionProvider ||
ep.ep_name == kDmlExecutionProvider;
}
// "gpu:<vendor>" or "gpu:<index>"
if (ep.device_type == OrtDevice::GPU) {
Expand All @@ -203,7 +204,7 @@ bool MatchEpDevice(const EpDeviceView& ep,
ep.vendor_id == OrtDevice::VendorIds::INTEL) return true;
// Heuristic: gpu:nvidia -> CUDA
if (CaseInsensitiveCompare(target_specifier, "nvidia") &&
ep.ep_name == kCudaExecutionProvider) return true;
(ep.ep_name == kCudaExecutionProvider || ep.ep_name == kCudaPluginExecutionProvider)) return true;
}
return false;
}
Expand All @@ -225,7 +226,7 @@ bool MatchEpDevice(const EpDeviceView& ep,
}
// "cuda"
if (CaseInsensitiveCompare(target_type_str, "cuda")) {
return ep.ep_name == kCudaExecutionProvider;
return ep.ep_name == kCudaExecutionProvider || ep.ep_name == kCudaPluginExecutionProvider;
}
// "dml"
if (CaseInsensitiveCompare(target_type_str, "dml")) {
Expand Down Expand Up @@ -284,7 +285,13 @@ std::optional<std::string> EpLayeringMatcher::Match(gsl::span<const OrtEpDevice*
ep_device.ep_name,
device_type,
has_hw ? ep_device.device->vendor_id : 0u,
has_hw ? static_cast<OrtDevice::DeviceId>(ep_device.device->device_id) : OrtDevice::DeviceId{},
// Prefer the device ordinal from device_memory_info (set by the EP factory to
// a runtime device ordinal such as a CUDA ordinal) over the OrtHardwareDevice::device_id
// which is a hardware-type identifier and not guaranteed to be a stable runtime ordinal.
ep_device.device_memory_info
? ep_device.device_memory_info->device.Id()
: (has_hw ? static_cast<OrtDevice::DeviceId>(ep_device.device->device_id)
: OrtDevice::DeviceId{}),
has_hw ? std::string_view(ep_device.device->vendor) : std::string_view{}};

if (MatchEpDevice(view, target_type_str, target_specifier, rule.device)) {
Expand Down
34 changes: 33 additions & 1 deletion onnxruntime/core/framework/resource_accountant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@
}
}

void ResetPendingWeights() override {
void ResetPendingWeightsImpl() override {
pending_weights_.clear();
pending_weights_by_node_.clear();
}
Expand Down Expand Up @@ -317,4 +317,36 @@
return result;
}

ResourceCount AddResourceCounts(const ResourceCount& a, const ResourceCount& b) {
return std::visit(
[](auto lhs, auto rhs) -> ResourceCount {
static_assert(std::is_same_v<decltype(lhs), decltype(rhs)>,
"AddResourceCounts requires both operands to hold the same type. "
"Handle the new ResourceCount variant member.");
if constexpr (std::is_integral_v<decltype(lhs)>) {
return static_cast<decltype(lhs)>(SafeInt<decltype(lhs)>(lhs) + rhs);
} else {
return lhs + rhs;
}
},
a, b);
}

bool ResourceCountExceeds(const ResourceCount& a, const ResourceCount& b) {
return std::visit(
[](auto lhs, auto rhs) -> bool {
static_assert(std::is_same_v<decltype(lhs), decltype(rhs)>,
"ResourceCountExceeds requires both operands to hold the same type. "
"Handle the new ResourceCount variant member.");
return lhs > rhs;
},
a, b);
}

std::string FormatResourceCount(const ResourceCount& rc) {
return std::visit(
[](auto val) -> std::string { return std::to_string(val); },

Check warning on line 348 in onnxruntime/core/framework/resource_accountant.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/framework/resource_accountant.cc:348: Add #include <string> for string [build/include_what_you_use] [4]
rc);
}

} // namespace onnxruntime
14 changes: 10 additions & 4 deletions onnxruntime/core/providers/cuda/plugin/cuda_ep.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@
#include "ep/get_capability_utils.h"

#include <cstring>
#include <limits>

Check warning on line 14 in onnxruntime/core/providers/cuda/plugin/cuda_ep.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Found C++ system header after other header. Should be: cuda_ep.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/core/providers/cuda/plugin/cuda_ep.cc:14: Found C++ system header after other header. Should be: cuda_ep.h, c system, c++ system, other. [build/include_order] [4]
#include <stdexcept>
#include <string>
#include <string_view>
#include <unordered_map>
#include <unordered_set>

#include "core/graph/constants.h"

namespace onnxruntime {
namespace cuda_plugin {

Expand Down Expand Up @@ -227,12 +230,15 @@
cpu_preferred_nodes));

// Phase 3: Add final supported nodes (tentative minus CPU-preferred).
// Resource budget enforcement is handled by the host after GetCapability returns.

for (const OrtNode* ort_node : candidate_nodes) {
if (cpu_preferred_nodes.count(ort_node) == 0) {
Ort::ConstNode node{ort_node};
RETURN_IF_ERROR(ep_api.EpGraphSupportInfo_AddSingleNode(
graph_support_info, node));
if (cpu_preferred_nodes.count(ort_node) != 0) {
continue;
}

RETURN_IF_ERROR(ep_api.EpGraphSupportInfo_AddSingleNode(
graph_support_info, ort_node));
}

return nullptr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ class OrtStreamAdapter {
#include "core/providers/common.h"

namespace onnxruntime {
inline constexpr const char* kCudaPluginExecutionProvider = "CudaPluginExecutionProvider";

// Forward declaration of GetEnvironmentVar for plugin builds on Windows.
// Defined in provider_api_shims.cc; mirrors the provider_api.h declaration
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/session/abi_ep_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
namespace onnxruntime {
struct EpGraph;
struct EpNode;
class IResourceAccountant;
} // namespace onnxruntime

/// <summary>
Expand Down Expand Up @@ -50,4 +51,8 @@ struct OrtEpGraphSupportInfo {
const onnxruntime::EpGraph& ort_graph;
std::vector<NodeGrouping> node_groupings;
const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup;

// Optional resource accountant for capacity-aware partitioning.
// Owned by the graph partitioner; lifetime exceeds this struct.
onnxruntime::IResourceAccountant* resource_accountant = nullptr;
};
1 change: 1 addition & 0 deletions onnxruntime/core/session/plugin_ep/ep_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,4 +178,5 @@ ORT_API_STATUS_IMPL(ProfilingEvent_GetDurationUs, _In_ const OrtProfilingEvent*
_Out_ int64_t* out);
ORT_API_STATUS_IMPL(ProfilingEvent_GetArgValue, _In_ const OrtProfilingEvent* event, _In_ const char* key,
_Outptr_result_maybenull_ const char** out);

} // namespace OrtExecutionProviderApi
Loading
Loading