Skip to content

Commit cf5a86b

Browse files
committed
First pass addressing review comments
1 parent 2578d1f commit cf5a86b

5 files changed

Lines changed: 164 additions & 41 deletions

File tree

onnxruntime/core/framework/allocation_planner.cc

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -913,11 +913,6 @@ class PlannerImpl {
913913
ProcessDef(index, node_output);
914914
OrtDevice output_device = exec_provider->GetOrtDeviceByMemType(p_kernel_def->OutputMemoryType(i));
915915
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
916-
// Downstream nodes of certain providers may require a CPU accessible location override
917-
// to make sure the EP does not incur an unnecessary copy.
918-
// We only do it for CPU based EPs. We are not likely to encounter
919-
// non CPU devices here since they are already taken care of by using MemCpy nodes earlier.
920-
// However, we still ignore them.
921916
if (output_device.UsesCpuMemory()) {
922917
const auto& output_name = node_output->Name();
923918
const auto consumers = graph_viewer_.GetConsumerNodes(output_name);

onnxruntime/core/framework/utils.cc

Lines changed: 50 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -50,29 +50,48 @@ bool ProviderIsCpuBased(const IExecutionProvider& provider) {
5050
return provider.GetDevice().Type() == OrtDevice::CPU;
5151
}
5252

53-
// Returns true if no data transfer is needed between the two devices.
54-
// HOST_ACCESSIBLE memory is a superset — accessible by both host and device — so it can satisfy
55-
// DEFAULT memory requirements on the same physical device without a copy.
56-
static bool DevicesAreMemoryCompatible(const OrtDevice& a, const OrtDevice& b) {
57-
const bool a_is_cpu_mem = a.UsesCpuMemory();
58-
const bool b_is_cpu_mem = b.UsesCpuMemory();
59-
60-
// Both CPU-accessible: compatible unless both are HOST_ACCESSIBLE on different physical devices.
61-
if (a_is_cpu_mem && b_is_cpu_mem) {
62-
if (a.Type() == OrtDevice::CPU || b.Type() == OrtDevice::CPU) {
63-
return true;
53+
// Returns true if src memory can satisfy tgt's requirements without a data copy.
54+
//
55+
// HOST_ACCESSIBLE → DEFAULT is valid: the device can access HOST_ACCESSIBLE memory directly.
56+
// DEFAULT → HOST_ACCESSIBLE is NOT valid: HOST_ACCESSIBLE implies CPU consumers, and DEFAULT
57+
// memory is device-only — the CPU cannot read it.
58+
//
59+
// For the mixed case, src alignment must meet tgt's minimum requirement.
60+
// Alignment 0 means "unspecified" and is treated as compatible with any requirement.
61+
bool CanSourceSatisfyTarget(const OrtDevice& src, const OrtDevice& tgt) {
62+
const bool src_is_cpu_mem = src.UsesCpuMemory();
63+
const bool tgt_is_cpu_mem = tgt.UsesCpuMemory();
64+
65+
// Identical devices are always compatible.
66+
if (src == tgt) {
67+
return true;
68+
}
69+
70+
// Alignment 0 means "unspecified" — treat as compatible with any alignment requirement.
71+
const bool is_alignment_satisfied = src.GetAlignment() == 0 || tgt.GetAlignment() == 0 ||
72+
src.GetAlignment() >= tgt.GetAlignment();
73+
74+
// Both are CPU-accessible (CPU type or HOST_ACCESSIBLE memory).
75+
if (src_is_cpu_mem && tgt_is_cpu_mem) {
76+
// CPU target can read from any CPU or HOST_ACCESSIBLE source, regardless of the source device
77+
if (tgt.Type() == OrtDevice::CPU) {
78+
return is_alignment_satisfied;
6479
}
65-
return a.Type() == b.Type() &&
66-
a.Vendor() == b.Vendor() &&
67-
a.Id() == b.Id();
80+
// Both are HOST_ACCESSIBLE on some device: require the same physical device.
81+
return src.Type() == tgt.Type() &&
82+
src.Vendor() == tgt.Vendor() &&
83+
src.Id() == tgt.Id() && is_alignment_satisfied;
6884
}
6985

70-
// HOST_ACCESSIBLE <-> DEFAULT: compatible only on the same physical device.
71-
if ((a_is_cpu_mem != b_is_cpu_mem) &&
72-
a.Type() == b.Type() &&
73-
a.Vendor() == b.Vendor() &&
74-
a.Id() == b.Id()) {
75-
return true;
86+
// HOST_ACCESSIBLE source can serve a DEFAULT target on the same physical device —
87+
// the device can DMA from HOST_ACCESSIBLE memory directly.
88+
// The reverse (DEFAULT → HOST_ACCESSIBLE) is unsafe: HOST_ACCESSIBLE implies CPU consumers,
89+
// and DEFAULT memory is device-only so the CPU cannot read it.
90+
if (src_is_cpu_mem && !tgt_is_cpu_mem &&
91+
src.Type() == tgt.Type() &&
92+
src.Vendor() == tgt.Vendor() &&
93+
src.Id() == tgt.Id()) {
94+
return is_alignment_satisfied;
7695
}
7796

7897
return false;
@@ -146,16 +165,19 @@ const std::string& GetNodeInputProviderType(const SessionState::NodeInfo& info)
146165
}
147166

148167
// Populate device_fetches for the output-copy path.
149-
// Reuses a pre-allocated user buffer when the memory is compatible (same device or HOST_ACCESSIBLE
150-
// <-> DEFAULT on the same physical device); otherwise inserts an empty placeholder.
168+
// When the user pre-allocates a fetch buffer, reuse it directly as the EP's output buffer if
169+
// the user's buffer (tgt) can satisfy the EP's output device (src) requirements — i.e.,
170+
// CanSourceSatisfyTarget(tgt, src). This avoids a post-execution copy.
171+
// Otherwise inserts an empty placeholder for the EP to allocate into.
151172
static void PopulateDeviceFetches(gsl::span<const MLValueCopyInfo> fetch_copy_info,
152173
const std::vector<OrtValue>& fetches,
153174
std::vector<OrtValue>& device_fetches) {
175+
ORT_ENFORCE(fetch_copy_info.size() >= fetches.size());
154176
device_fetches.reserve(fetches.size());
155177
for (size_t i = 0; i < fetches.size(); ++i) {
156178
const auto& src = fetch_copy_info[i].source_device;
157179
const auto& tgt = fetch_copy_info[i].target_device;
158-
if ((src == tgt || DevicesAreMemoryCompatible(src, tgt)) && fetches[i].IsAllocated()) {
180+
if (CanSourceSatisfyTarget(tgt, src) && fetches[i].IsAllocated()) {
159181
device_fetches.push_back(fetches[i]);
160182
} else {
161183
device_fetches.push_back({});
@@ -178,10 +200,9 @@ static Status BatchOrCopyMLValue(const SessionState& session_state,
178200
std::vector<IDataTransfer::SrcDstPair>* copy_tensor_pairs = nullptr)
179201
#endif
180202
{
181-
// No data transfer needed if devices are the same or memory-compatible
182-
// (e.g. HOST_ACCESSIBLE <-> DEFAULT on the same physical device).
183-
if (copy_info.source_device == copy_info.target_device ||
184-
DevicesAreMemoryCompatible(copy_info.source_device, copy_info.target_device)) {
203+
// No data transfer needed if devices are identical, or the source can satisfy the target
204+
// (HOST_ACCESSIBLE source serving a DEFAULT target on the same physical device).
205+
if (CanSourceSatisfyTarget(copy_info.source_device, copy_info.target_device)) {
185206
target_mlvalue = source_mlvalue;
186207
return Status::OK();
187208
}
@@ -372,8 +393,7 @@ static bool FinalizeCopyInfoForFeeds(gsl::span<const OrtDevice> feed_locations,
372393
for (size_t i = 0, end = feed_locations.size(); i < end; ++i) {
373394
copy_info[i].source_device = feed_locations[i];
374395

375-
if (copy_info[i].source_device != copy_info[i].target_device &&
376-
!DevicesAreMemoryCompatible(copy_info[i].source_device, copy_info[i].target_device)) {
396+
if (!CanSourceSatisfyTarget(copy_info[i].source_device, copy_info[i].target_device)) {
377397
copy_needed = true;
378398
}
379399
}
@@ -394,8 +414,7 @@ static bool FinalizeCopyInfoForFetches(gsl::span<const OrtDevice* const>& fetch_
394414
copy_info[i].target_device = *alloc_info;
395415
}
396416

397-
if (copy_info[i].source_device != copy_info[i].target_device &&
398-
!DevicesAreMemoryCompatible(copy_info[i].source_device, copy_info[i].target_device)) {
417+
if (!CanSourceSatisfyTarget(copy_info[i].source_device, copy_info[i].target_device)) {
399418
copy_needed = true;
400419
}
401420
}
@@ -702,9 +721,7 @@ ExecuteGraphImpl(const SessionState& session_state,
702721
feeds_to_use = device_feeds;
703722
}
704723

705-
auto num_outputs = fetches.size();
706724
const auto& fetch_copy_info = feeds_fetches_manager.GetFetchesDeviceCopyInfo();
707-
708725
if (device_copy_checks.output_copy_needed == DeviceCopyCheck::Copy) {
709726
PopulateDeviceFetches(fetch_copy_info, fetches, device_fetches);
710727
p_fetches = &device_fetches;
@@ -847,7 +864,6 @@ common::Status ExecutePartialGraphImpl(const SessionState& session_state, FeedsF
847864
p_feeds = device_feeds;
848865
}
849866

850-
auto num_outputs = fetches.size();
851867
const auto& fetch_copy_info = feeds_fetches_manager.GetFetchesDeviceCopyInfo();
852868

853869
if (device_copy_checks.output_copy_needed == DeviceCopyCheck::Copy) {

onnxruntime/core/framework/utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ bool ProviderIsCpuBased(const IExecutionProvider& provider);
5757

5858
bool IsMemcpyNode(const Node& node);
5959

60+
// Returns true if src memory can satisfy tgt's requirements without a data copy.
61+
// HOST_ACCESSIBLE -> DEFAULT is valid (device can access HOST_ACCESSIBLE memory directly).
62+
// DEFAULT -> HOST_ACCESSIBLE is NOT valid (CPU cannot read device-only memory).
63+
bool CanSourceSatisfyTarget(const OrtDevice& src, const OrtDevice& tgt);
64+
6065
common::Status CopyOneInputAcrossDevices(const SessionState& session_state, const std::string& input_name,
6166
const OrtValue& orig_mlvalue, OrtValue& new_mlvalue);
6267

onnxruntime/core/session/provider_policy_context.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -421,8 +421,8 @@ void ProviderPolicyContext::FoldSelectedDevices(std::vector<const OrtEpDevice*>
421421
info.ep_factory = devices_selected[0]->ep_factory;
422422

423423
do {
424-
auto iter = std::find_if(devices_selected.begin(), devices_selected.end(), [&ep_name](const OrtEpDevice* d) {
425-
return d->ep_name == ep_name;
424+
auto iter = std::find_if(devices_selected.begin(), devices_selected.end(), [&ep_name, &info](const OrtEpDevice* d) {
425+
return d->ep_name == ep_name && d->ep_factory == info.ep_factory;
426426
});
427427

428428
if (iter != devices_selected.end()) {
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "gtest/gtest.h"
5+
#include "core/framework/utils.h"
6+
7+
namespace onnxruntime {
8+
namespace test {
9+
10+
constexpr OrtDevice::VendorId kTestVendor1 = 0x1234;
11+
constexpr OrtDevice::VendorId kTestVendor2 = 0x5678;
12+
13+
static OrtDevice Cpu() {
14+
return OrtDevice{OrtDevice::CPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0};
15+
}
16+
17+
static OrtDevice HostAccessible(OrtDevice::VendorId vendor, OrtDevice::DeviceId id,
18+
OrtDevice::Alignment align = 0) {
19+
return OrtDevice{OrtDevice::NPU, OrtDevice::MemType::HOST_ACCESSIBLE, vendor, id, align};
20+
}
21+
22+
static OrtDevice Default(OrtDevice::VendorId vendor, OrtDevice::DeviceId id,
23+
OrtDevice::Alignment align = 0) {
24+
return OrtDevice{OrtDevice::NPU, OrtDevice::MemType::DEFAULT, vendor, id, align};
25+
}
26+
27+
TEST(CanSourceSatisfyTargetTest, CpuSourceHostAccessibleTarget) {
28+
EXPECT_FALSE(utils::CanSourceSatisfyTarget(Cpu(), HostAccessible(kTestVendor1, 0)));
29+
}
30+
31+
TEST(CanSourceSatisfyTargetTest, HostAccessibleSourceCpuTarget) {
32+
EXPECT_TRUE(utils::CanSourceSatisfyTarget(HostAccessible(kTestVendor1, 0), Cpu()));
33+
}
34+
35+
// src == tgt early return: identical devices are always compatible
36+
TEST(CanSourceSatisfyTargetTest, BothHostAccessibleSameDevice) {
37+
auto dev = HostAccessible(kTestVendor1, 0, 16);
38+
EXPECT_TRUE(utils::CanSourceSatisfyTarget(dev, dev));
39+
}
40+
41+
// Branch 3: both HOST_ACCESSIBLE, different physical device
42+
TEST(CanSourceSatisfyTargetTest, BothHostAccessibleDifferentId) {
43+
EXPECT_FALSE(utils::CanSourceSatisfyTarget(
44+
HostAccessible(kTestVendor1, 0), HostAccessible(kTestVendor1, 1)));
45+
}
46+
47+
TEST(CanSourceSatisfyTargetTest, BothHostAccessibleDifferentVendor) {
48+
EXPECT_FALSE(utils::CanSourceSatisfyTarget(
49+
HostAccessible(kTestVendor1, 0), HostAccessible(kTestVendor2, 0)));
50+
}
51+
52+
TEST(CanSourceSatisfyTargetTest, BothHostAccessibleDifferentAlignment) {
53+
// Different alignment => OrtDevice::operator== returns false
54+
EXPECT_FALSE(utils::CanSourceSatisfyTarget(
55+
HostAccessible(kTestVendor1, 0, 16), HostAccessible(kTestVendor1, 0, 32)));
56+
}
57+
58+
// Branch 4: HOST_ACCESSIBLE (src) -> DEFAULT (tgt), same physical device
59+
TEST(CanSourceSatisfyTargetTest, HostAccessibleToDefaultSameDevice) {
60+
EXPECT_TRUE(utils::CanSourceSatisfyTarget(
61+
HostAccessible(kTestVendor1, 0), Default(kTestVendor1, 0)));
62+
}
63+
64+
TEST(CanSourceSatisfyTargetTest, HostAccessibleToDefaultAlignmentSatisfied) {
65+
// src alignment >= tgt alignment: compatible
66+
EXPECT_TRUE(utils::CanSourceSatisfyTarget(
67+
HostAccessible(kTestVendor1, 0, 64), Default(kTestVendor1, 0, 32)));
68+
}
69+
70+
TEST(CanSourceSatisfyTargetTest, HostAccessibleToDefaultAlignmentInsufficient) {
71+
// src alignment < tgt alignment: incompatible
72+
EXPECT_FALSE(utils::CanSourceSatisfyTarget(
73+
HostAccessible(kTestVendor1, 0, 16), Default(kTestVendor1, 0, 64)));
74+
}
75+
76+
TEST(CanSourceSatisfyTargetTest, HostAccessibleToDefaultSrcAlignmentZero) {
77+
// 0 = unspecified, treated as wildcard
78+
EXPECT_TRUE(utils::CanSourceSatisfyTarget(
79+
HostAccessible(kTestVendor1, 0, 0), Default(kTestVendor1, 0, 64)));
80+
}
81+
82+
TEST(CanSourceSatisfyTargetTest, HostAccessibleToDefaultTgtAlignmentZero) {
83+
// 0 = unspecified, treated as wildcard
84+
EXPECT_TRUE(utils::CanSourceSatisfyTarget(
85+
HostAccessible(kTestVendor1, 0, 16), Default(kTestVendor1, 0, 0)));
86+
}
87+
88+
TEST(CanSourceSatisfyTargetTest, HostAccessibleToDefaultDifferentDeviceId) {
89+
EXPECT_FALSE(utils::CanSourceSatisfyTarget(
90+
HostAccessible(kTestVendor1, 0), Default(kTestVendor1, 1)));
91+
}
92+
93+
// Branch 5: incompatible cases
94+
95+
TEST(CanSourceSatisfyTargetTest, DefaultToHostAccessibleRejected) {
96+
// Reversed direction: CPU cannot read DEFAULT (device-only) memory
97+
EXPECT_FALSE(utils::CanSourceSatisfyTarget(
98+
Default(kTestVendor1, 0), HostAccessible(kTestVendor1, 0)));
99+
}
100+
101+
TEST(CanSourceSatisfyTargetTest, DefaultToDefaultRejected) {
102+
EXPECT_FALSE(utils::CanSourceSatisfyTarget(
103+
Default(kTestVendor1, 0), Default(kTestVendor2, 0)));
104+
}
105+
106+
} // namespace test
107+
} // namespace onnxruntime

0 commit comments

Comments
 (0)