Skip to content

Commit f7113bd

Browse files
authored
[CUDA EP Plugin] ResourceAcountant integration (#28028)
This pull request introduces several enhancements and refactorings to the resource accounting and execution provider (EP) infrastructure, with a focus on better support for plugin-based CUDA execution providers. The most significant changes include the addition of type-erased arithmetic for resource accounting, improved handling of resource budgets for plugin EPs, and more robust device matching logic. These updates increase maintainability, enforce stricter type safety, and ensure correct resource tracking across both in-tree and plugin-based EPs. **Resource accounting improvements:** * Added type-erased arithmetic functions (`AddResourceCounts`, `ResourceCountExceeds`, `FormatResourceCount`) for `ResourceCount` to enforce exhaustive handling of variant types and improve type safety. [[1]](diffhunk://#diff-7b1c9ef14536f9a66ed370cb729b6609d12c5907b460d8f145a7ad5a401e0fb6R29-R40) [[2]](diffhunk://#diff-03c846683a6d76ded189d6ef24dc545da89ca418d0bce5cf1243d33cf1e2ac06R320-R351) * Refactored the `IResourceAccountant` interface: replaced `ResetPendingWeights` with `ResetForNewPass`, which resets both the stop flag and pending weights, and introduced a protected `ResetPendingWeightsImpl` for subclass-specific cleanup. [[1]](diffhunk://#diff-7b1c9ef14536f9a66ed370cb729b6609d12c5907b460d8f145a7ad5a401e0fb6L64-R83) [[2]](diffhunk://#diff-7b1c9ef14536f9a66ed370cb729b6609d12c5907b460d8f145a7ad5a401e0fb6R92-R96) [[3]](diffhunk://#diff-03c846683a6d76ded189d6ef24dc545da89ca418d0bce5cf1243d33cf1e2ac06L123-R123) [[4]](diffhunk://#diff-e2d3910ae7593ee7ba4fd74e53f738fa973ae2fc32c069f1088ba458b91f8d4bL280-R280) [[5]](diffhunk://#diff-e2d3910ae7593ee7ba4fd74e53f738fa973ae2fc32c069f1088ba458b91f8d4bL351-R351) **Plugin CUDA EP and resource budget enforcement:** * Added `kCudaPluginExecutionProvider` constant and updated logic to ensure plugin EPs correctly map to their in-tree accountant counterparts and are included in device matching and partitioning. [[1]](diffhunk://#diff-442c270eea3703252c48e97a7573960e14bf27a45a4443348840ed565330bf70R34) [[2]](diffhunk://#diff-b20f416b9fe3b85423eea6707c38753351a3f1b8ef7a319858b27794507e0686L102) [[3]](diffhunk://#diff-a8f614056d63b5b3325eea1d855afc96550c977c16d8fdba641012a79194b7b5L186-R187) [[4]](diffhunk://#diff-a8f614056d63b5b3325eea1d855afc96550c977c16d8fdba641012a79194b7b5L206-R207) [[5]](diffhunk://#diff-a8f614056d63b5b3325eea1d855afc96550c977c16d8fdba641012a79194b7b5L228-R229) [[6]](diffhunk://#diff-e2d3910ae7593ee7ba4fd74e53f738fa973ae2fc32c069f1088ba458b91f8d4bL1192-R1200) * Updated plugin EP infrastructure to pass and utilize resource accountant pointers, enabling host-side resource budget enforcement for plugin EPs and ensuring correct node assignment. [[1]](diffhunk://#diff-fb00c9a234d8cc889927a22de94acfcfd893b56505e8ed613961b1bf13c0e435R19) [[2]](diffhunk://#diff-fb00c9a234d8cc889927a22de94acfcfd893b56505e8ed613961b1bf13c0e435R54-R57) [[3]](diffhunk://#diff-6dac10650c4e1c5a55b95378173b33e95b300bf7c2350d8476088693b98652a5R16-R17) [[4]](diffhunk://#diff-6dac10650c4e1c5a55b95378173b33e95b300bf7c2350d8476088693b98652a5L239-R259) [[5]](diffhunk://#diff-6dac10650c4e1c5a55b95378173b33e95b300bf7c2350d8476088693b98652a5R273-R281) [[6]](diffhunk://#diff-0890d267a71ca02f4173c2ab226e6c5707fcbbf6bbb5f602fa5d92aa82f42a80R14-R22) [[7]](diffhunk://#diff-0890d267a71ca02f4173c2ab226e6c5707fcbbf6bbb5f602fa5d92aa82f42a80R233-R241) **Device matching and partitioning:** * Improved device matching heuristics to consider both in-tree and plugin CUDA EPs, and updated logic to prefer runtime device ordinals for more reliable device selection. Other minor changes include code style cleanups and additional includes for completeness.
1 parent ffbc5e8 commit f7113bd

14 files changed

Lines changed: 925 additions & 174 deletions

File tree

include/onnxruntime/core/framework/resource_accountant.h

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,18 @@ struct Node;
2626
// for different EPs
2727
using ResourceCount = std::variant<size_t>;
2828

29+
// Type-erased arithmetic for ResourceCount values.
30+
// Implementations use std::visit so the compiler enforces exhaustive handling
31+
// of all variant members — adding a new type to ResourceCount will produce
32+
// build errors at each call site that must be addressed.
33+
//
34+
// NOTE: These functions are NOT available through the provider bridge (shared library EPs).
35+
// Budget enforcement for bridge-based EPs (e.g., in-tree CUDA EP) will be moved to the
36+
// graph partitioner in a follow-up PR.
37+
ResourceCount AddResourceCounts(const ResourceCount& a, const ResourceCount& b);
38+
bool ResourceCountExceeds(const ResourceCount& a, const ResourceCount& b);
39+
std::string FormatResourceCount(const ResourceCount& rc);
40+
2941
/// <summary>
3042
/// This class is used for graph partitioning by EPs
3143
/// It stores the cumulative amount of the resource such as
@@ -61,9 +73,14 @@ class IResourceAccountant {
6173

6274
bool IsStopIssued() const noexcept { return stop_assignment_; }
6375

64-
// Called before each GetCapability pass to discard pending weight tracking
65-
// from a previous (discarded) pass. Default no-op for stats-based accountants.
66-
virtual void ResetPendingWeights() {}
76+
// Called before each GetCapability pass to reset per-pass state:
77+
// clears the stop flag (which only applies to the pass that set it)
78+
// and discards pending weight tracking from a previous (discarded) pass.
79+
// Subclasses override ResetPendingWeightsImpl for EP-specific cleanup.
80+
void ResetForNewPass() {
81+
stop_assignment_ = false;
82+
ResetPendingWeightsImpl();
83+
}
6784

6885
// Called when a node's cost is committed (AccountForNode/AccountForAllNodes).
6986
// Moves the node's pending weights into the committed set so they persist
@@ -72,6 +89,11 @@ class IResourceAccountant {
7289

7390
static std::string MakeUniqueNodeName(const Node& node);
7491

92+
protected:
93+
// Override to discard EP-specific pending weight tracking.
94+
// Default no-op for stats-based accountants.
95+
virtual void ResetPendingWeightsImpl() {}
96+
7597
private:
7698
bool stop_assignment_ = false;
7799
std::optional<ResourceCount> threshold_;

include/onnxruntime/core/graph/constants.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ constexpr size_t kMaxExecutionProviderNameLen = 30;
3131

3232
constexpr const char* kCpuExecutionProvider = "CPUExecutionProvider";
3333
constexpr const char* kCudaExecutionProvider = "CUDAExecutionProvider";
34+
constexpr const char* kCudaPluginExecutionProvider = "CudaPluginExecutionProvider";
3435
constexpr const char* kCudaNHWCExecutionProvider = "CUDANHWCExecutionProvider";
3536
constexpr const char* kDnnlExecutionProvider = "DnnlExecutionProvider";
3637
constexpr const char* kOpenVINOExecutionProvider = "OpenVINOExecutionProvider";

include/onnxruntime/core/session/onnxruntime_cxx_api.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3729,5 +3729,6 @@ using UnownedSharedPrePackedWeightCache =
37293729

37303730
///< Wraps OrtEpApi::GetEnvConfigEntries()
37313731
Ort::KeyValuePairs GetEnvConfigEntries();
3732+
37323733
} // namespace Ort
37333734
#include "onnxruntime_cxx_inline.h"

include/onnxruntime/core/session/onnxruntime_cxx_inline.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4174,4 +4174,5 @@ inline OpSchema GetOpSchema(const char* name, int max_inclusive_version, const c
41744174
ThrowOnError(GetEpApi().GetOpSchema(name, max_inclusive_version, domain, &schema));
41754175
return OpSchema{schema};
41764176
}
4177+
41774178
} // namespace Ort

onnxruntime/core/framework/graph_partitioner.cc

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l
277277
ORT_RETURN_IF_ERROR(create_graph_viewer(sub_graph_holder, graph_viewer));
278278

279279
if (params.resource_accountant) {
280-
params.resource_accountant->ResetPendingWeights();
280+
params.resource_accountant->ResetForNewPass();
281281
}
282282
capabilities = get_capabilities(current_ep, *graph_viewer, kernel_lookup, params.resource_accountant,
283283
graph_optimizer_registry);
@@ -348,7 +348,7 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params, const l
348348
ORT_RETURN_IF_ERROR(create_graph_viewer(sub_graph_holder, graph_viewer));
349349

350350
if (params.resource_accountant) {
351-
params.resource_accountant->ResetPendingWeights();
351+
params.resource_accountant->ResetForNewPass();
352352
}
353353
capabilities = get_capabilities(current_ep, *graph_viewer, kernel_lookup, params.resource_accountant,
354354
graph_optimizer_registry);
@@ -1189,7 +1189,15 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params,
11891189
for (const auto& ep : execution_providers) {
11901190
IResourceAccountant* resource_accountant = nullptr;
11911191
if (acc_map.has_value()) {
1192-
auto hit = acc_map->find(ep->Type());
1192+
// Plugin EPs have a different Type() than the in-tree EP they replace
1193+
// (e.g., kCudaPluginExecutionProvider vs kCudaExecutionProvider), but the
1194+
// accountant is registered under the in-tree EP name. Translate the key
1195+
// so plugin EPs find the correct accountant.
1196+
const auto& ep_type = ep->Type();
1197+
const auto accountant_key = ep_type == kCudaPluginExecutionProvider
1198+
? std::string{kCudaExecutionProvider}
1199+
: ep_type;
1200+
auto hit = acc_map->find(accountant_key);
11931201
if (hit != acc_map->end()) {
11941202
resource_accountant = hit->second.get();
11951203
}

onnxruntime/core/framework/layering_annotations.cc

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,8 @@ bool MatchEpDevice(const EpDeviceView& ep,
183183
if (target_specifier.empty()) {
184184
if (ep.device_type == OrtDevice::GPU) return true;
185185
// Heuristic fallback for common GPU EPs if hardware info is missing
186-
return ep.ep_name == kCudaExecutionProvider || ep.ep_name == kDmlExecutionProvider;
186+
return ep.ep_name == kCudaExecutionProvider || ep.ep_name == kCudaPluginExecutionProvider ||
187+
ep.ep_name == kDmlExecutionProvider;
187188
}
188189
// "gpu:<vendor>" or "gpu:<index>"
189190
if (ep.device_type == OrtDevice::GPU) {
@@ -203,7 +204,7 @@ bool MatchEpDevice(const EpDeviceView& ep,
203204
ep.vendor_id == OrtDevice::VendorIds::INTEL) return true;
204205
// Heuristic: gpu:nvidia -> CUDA
205206
if (CaseInsensitiveCompare(target_specifier, "nvidia") &&
206-
ep.ep_name == kCudaExecutionProvider) return true;
207+
(ep.ep_name == kCudaExecutionProvider || ep.ep_name == kCudaPluginExecutionProvider)) return true;
207208
}
208209
return false;
209210
}
@@ -225,7 +226,7 @@ bool MatchEpDevice(const EpDeviceView& ep,
225226
}
226227
// "cuda"
227228
if (CaseInsensitiveCompare(target_type_str, "cuda")) {
228-
return ep.ep_name == kCudaExecutionProvider;
229+
return ep.ep_name == kCudaExecutionProvider || ep.ep_name == kCudaPluginExecutionProvider;
229230
}
230231
// "dml"
231232
if (CaseInsensitiveCompare(target_type_str, "dml")) {
@@ -284,7 +285,13 @@ std::optional<std::string> EpLayeringMatcher::Match(gsl::span<const OrtEpDevice*
284285
ep_device.ep_name,
285286
device_type,
286287
has_hw ? ep_device.device->vendor_id : 0u,
287-
has_hw ? static_cast<OrtDevice::DeviceId>(ep_device.device->device_id) : OrtDevice::DeviceId{},
288+
// Prefer the device ordinal from device_memory_info (set by the EP factory to
289+
// a runtime device ordinal such as a CUDA ordinal) over the OrtHardwareDevice::device_id
290+
// which is a hardware-type identifier and not guaranteed to be a stable runtime ordinal.
291+
ep_device.device_memory_info
292+
? ep_device.device_memory_info->device.Id()
293+
: (has_hw ? static_cast<OrtDevice::DeviceId>(ep_device.device->device_id)
294+
: OrtDevice::DeviceId{}),
288295
has_hw ? std::string_view(ep_device.device->vendor) : std::string_view{}};
289296

290297
if (MatchEpDevice(view, target_type_str, target_specifier, rule.device)) {

onnxruntime/core/framework/resource_accountant.cc

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ class SizeBasedStatsAccountant : public IResourceAccountant {
120120
}
121121
}
122122

123-
void ResetPendingWeights() override {
123+
void ResetPendingWeightsImpl() override {
124124
pending_weights_.clear();
125125
pending_weights_by_node_.clear();
126126
}
@@ -317,4 +317,36 @@ std::string IResourceAccountant::MakeUniqueNodeName(const Node& node) {
317317
return result;
318318
}
319319

320+
ResourceCount AddResourceCounts(const ResourceCount& a, const ResourceCount& b) {
321+
return std::visit(
322+
[](auto lhs, auto rhs) -> ResourceCount {
323+
static_assert(std::is_same_v<decltype(lhs), decltype(rhs)>,
324+
"AddResourceCounts requires both operands to hold the same type. "
325+
"Handle the new ResourceCount variant member.");
326+
if constexpr (std::is_integral_v<decltype(lhs)>) {
327+
return static_cast<decltype(lhs)>(SafeInt<decltype(lhs)>(lhs) + rhs);
328+
} else {
329+
return lhs + rhs;
330+
}
331+
},
332+
a, b);
333+
}
334+
335+
bool ResourceCountExceeds(const ResourceCount& a, const ResourceCount& b) {
336+
return std::visit(
337+
[](auto lhs, auto rhs) -> bool {
338+
static_assert(std::is_same_v<decltype(lhs), decltype(rhs)>,
339+
"ResourceCountExceeds requires both operands to hold the same type. "
340+
"Handle the new ResourceCount variant member.");
341+
return lhs > rhs;
342+
},
343+
a, b);
344+
}
345+
346+
std::string FormatResourceCount(const ResourceCount& rc) {
347+
return std::visit(
348+
[](auto val) -> std::string { return std::to_string(val); },
349+
rc);
350+
}
351+
320352
} // namespace onnxruntime

onnxruntime/core/providers/cuda/plugin/cuda_ep.cc

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,15 @@
1111
#include "ep/get_capability_utils.h"
1212

1313
#include <cstring>
14+
#include <limits>
1415
#include <stdexcept>
1516
#include <string>
1617
#include <string_view>
1718
#include <unordered_map>
1819
#include <unordered_set>
1920

21+
#include "core/graph/constants.h"
22+
2023
namespace onnxruntime {
2124
namespace cuda_plugin {
2225

@@ -227,12 +230,15 @@ OrtStatus* ORT_API_CALL CudaEp::GetCapabilityImpl(
227230
cpu_preferred_nodes));
228231

229232
// Phase 3: Add final supported nodes (tentative minus CPU-preferred).
233+
// Resource budget enforcement is handled by the host after GetCapability returns.
234+
230235
for (const OrtNode* ort_node : candidate_nodes) {
231-
if (cpu_preferred_nodes.count(ort_node) == 0) {
232-
Ort::ConstNode node{ort_node};
233-
RETURN_IF_ERROR(ep_api.EpGraphSupportInfo_AddSingleNode(
234-
graph_support_info, node));
236+
if (cpu_preferred_nodes.count(ort_node) != 0) {
237+
continue;
235238
}
239+
240+
RETURN_IF_ERROR(ep_api.EpGraphSupportInfo_AddSingleNode(
241+
graph_support_info, ort_node));
236242
}
237243

238244
return nullptr;

onnxruntime/core/providers/cuda/plugin/cuda_kernel_adapter.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ class OrtStreamAdapter {
9999
#include "core/providers/common.h"
100100

101101
namespace onnxruntime {
102-
inline constexpr const char* kCudaPluginExecutionProvider = "CudaPluginExecutionProvider";
103102

104103
// Forward declaration of GetEnvironmentVar for plugin builds on Windows.
105104
// Defined in provider_api_shims.cc; mirrors the provider_api.h declaration

onnxruntime/core/session/abi_ep_types.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
namespace onnxruntime {
1717
struct EpGraph;
1818
struct EpNode;
19+
class IResourceAccountant;
1920
} // namespace onnxruntime
2021

2122
/// <summary>
@@ -50,4 +51,8 @@ struct OrtEpGraphSupportInfo {
5051
const onnxruntime::EpGraph& ort_graph;
5152
std::vector<NodeGrouping> node_groupings;
5253
const onnxruntime::IExecutionProvider::IKernelLookup& kernel_lookup;
54+
55+
// Optional resource accountant for capacity-aware partitioning.
56+
// Owned by the graph partitioner; lifetime exceeds this struct.
57+
onnxruntime::IResourceAccountant* resource_accountant = nullptr;
5358
};

0 commit comments

Comments
 (0)