@@ -50,29 +50,48 @@ bool ProviderIsCpuBased(const IExecutionProvider& provider) {
5050 return provider.GetDevice ().Type () == OrtDevice::CPU;
5151}
5252
53- // Returns true if no data transfer is needed between the two devices.
54- // HOST_ACCESSIBLE memory is a superset — accessible by both host and device — so it can satisfy
55- // DEFAULT memory requirements on the same physical device without a copy.
56- static bool DevicesAreMemoryCompatible (const OrtDevice& a, const OrtDevice& b) {
57- const bool a_is_cpu_mem = a.UsesCpuMemory ();
58- const bool b_is_cpu_mem = b.UsesCpuMemory ();
59-
60- // Both CPU-accessible: compatible unless both are HOST_ACCESSIBLE on different physical devices.
61- if (a_is_cpu_mem && b_is_cpu_mem) {
62- if (a.Type () == OrtDevice::CPU || b.Type () == OrtDevice::CPU) {
63- return true ;
53+ // Returns true if src memory can satisfy tgt's requirements without a data copy.
54+ //
55+ // HOST_ACCESSIBLE → DEFAULT is valid: the device can access HOST_ACCESSIBLE memory directly.
56+ // DEFAULT → HOST_ACCESSIBLE is NOT valid: HOST_ACCESSIBLE implies CPU consumers, and DEFAULT
57+ // memory is device-only — the CPU cannot read it.
58+ //
59+ // For the mixed case, src alignment must meet tgt's minimum requirement.
60+ // Alignment 0 means "unspecified" and is treated as compatible with any requirement.
61+ bool CanSourceSatisfyTarget (const OrtDevice& src, const OrtDevice& tgt) {
62+ const bool src_is_cpu_mem = src.UsesCpuMemory ();
63+ const bool tgt_is_cpu_mem = tgt.UsesCpuMemory ();
64+
65+ // Identical devices are always compatible.
66+ if (src == tgt) {
67+ return true ;
68+ }
69+
70+ // Alignment 0 means "unspecified" — treat as compatible with any alignment requirement.
71+ const bool is_alignment_satisfied = src.GetAlignment () == 0 || tgt.GetAlignment () == 0 ||
72+ src.GetAlignment () >= tgt.GetAlignment ();
73+
74+ // Both are CPU-accessible (CPU type or HOST_ACCESSIBLE memory).
75+ if (src_is_cpu_mem && tgt_is_cpu_mem) {
76+ // CPU target can read from any CPU or HOST_ACCESSIBLE source, regardless of the source device
77+ if (tgt.Type () == OrtDevice::CPU) {
78+ return is_alignment_satisfied;
6479 }
65- return a.Type () == b.Type () &&
66- a.Vendor () == b.Vendor () &&
67- a.Id () == b.Id ();
80+ // Both are HOST_ACCESSIBLE on some device: require the same physical device.
81+ return src.Type () == tgt.Type () &&
82+ src.Vendor () == tgt.Vendor () &&
83+ src.Id () == tgt.Id () && is_alignment_satisfied;
6884 }
6985
70- // HOST_ACCESSIBLE <-> DEFAULT: compatible only on the same physical device.
71- if ((a_is_cpu_mem != b_is_cpu_mem) &&
72- a.Type () == b.Type () &&
73- a.Vendor () == b.Vendor () &&
74- a.Id () == b.Id ()) {
75- return true ;
86+ // HOST_ACCESSIBLE source can serve a DEFAULT target on the same physical device —
87+ // the device can DMA from HOST_ACCESSIBLE memory directly.
88+ // The reverse (DEFAULT → HOST_ACCESSIBLE) is unsafe: HOST_ACCESSIBLE implies CPU consumers,
89+ // and DEFAULT memory is device-only so the CPU cannot read it.
90+ if (src_is_cpu_mem && !tgt_is_cpu_mem &&
91+ src.Type () == tgt.Type () &&
92+ src.Vendor () == tgt.Vendor () &&
93+ src.Id () == tgt.Id ()) {
94+ return is_alignment_satisfied;
7695 }
7796
7897 return false ;
@@ -146,16 +165,19 @@ const std::string& GetNodeInputProviderType(const SessionState::NodeInfo& info)
146165}
147166
148167// Populate device_fetches for the output-copy path.
149- // Reuses a pre-allocated user buffer when the memory is compatible (same device or HOST_ACCESSIBLE
150- // <-> DEFAULT on the same physical device); otherwise inserts an empty placeholder.
168+ // When the user pre-allocates a fetch buffer, reuse it directly as the EP's output buffer if
169+ // the user's buffer (tgt) can satisfy the EP's output device (src) requirements — i.e.,
170+ // CanSourceSatisfyTarget(tgt, src). This avoids a post-execution copy.
171+ // Otherwise inserts an empty placeholder for the EP to allocate into.
151172static void PopulateDeviceFetches (gsl::span<const MLValueCopyInfo> fetch_copy_info,
152173 const std::vector<OrtValue>& fetches,
153174 std::vector<OrtValue>& device_fetches) {
175+ ORT_ENFORCE (fetch_copy_info.size () >= fetches.size ());
154176 device_fetches.reserve (fetches.size ());
155177 for (size_t i = 0 ; i < fetches.size (); ++i) {
156178 const auto & src = fetch_copy_info[i].source_device ;
157179 const auto & tgt = fetch_copy_info[i].target_device ;
158- if ((src == tgt || DevicesAreMemoryCompatible (src, tgt) ) && fetches[i].IsAllocated ()) {
180+ if (CanSourceSatisfyTarget ( tgt, src ) && fetches[i].IsAllocated ()) {
159181 device_fetches.push_back (fetches[i]);
160182 } else {
161183 device_fetches.push_back ({});
@@ -178,10 +200,9 @@ static Status BatchOrCopyMLValue(const SessionState& session_state,
178200 std::vector<IDataTransfer::SrcDstPair>* copy_tensor_pairs = nullptr )
179201#endif
180202{
181- // No data transfer needed if devices are the same or memory-compatible
182- // (e.g. HOST_ACCESSIBLE <-> DEFAULT on the same physical device).
183- if (copy_info.source_device == copy_info.target_device ||
184- DevicesAreMemoryCompatible (copy_info.source_device , copy_info.target_device )) {
203+ // No data transfer needed if devices are identical, or the source can satisfy the target
204+ // (HOST_ACCESSIBLE source serving a DEFAULT target on the same physical device).
205+ if (CanSourceSatisfyTarget (copy_info.source_device , copy_info.target_device )) {
185206 target_mlvalue = source_mlvalue;
186207 return Status::OK ();
187208 }
@@ -372,8 +393,7 @@ static bool FinalizeCopyInfoForFeeds(gsl::span<const OrtDevice> feed_locations,
372393 for (size_t i = 0 , end = feed_locations.size (); i < end; ++i) {
373394 copy_info[i].source_device = feed_locations[i];
374395
375- if (copy_info[i].source_device != copy_info[i].target_device &&
376- !DevicesAreMemoryCompatible (copy_info[i].source_device , copy_info[i].target_device )) {
396+ if (!CanSourceSatisfyTarget (copy_info[i].source_device , copy_info[i].target_device )) {
377397 copy_needed = true ;
378398 }
379399 }
@@ -394,8 +414,7 @@ static bool FinalizeCopyInfoForFetches(gsl::span<const OrtDevice* const>& fetch_
394414 copy_info[i].target_device = *alloc_info;
395415 }
396416
397- if (copy_info[i].source_device != copy_info[i].target_device &&
398- !DevicesAreMemoryCompatible (copy_info[i].source_device , copy_info[i].target_device )) {
417+ if (!CanSourceSatisfyTarget (copy_info[i].source_device , copy_info[i].target_device )) {
399418 copy_needed = true ;
400419 }
401420 }
@@ -702,9 +721,7 @@ ExecuteGraphImpl(const SessionState& session_state,
702721 feeds_to_use = device_feeds;
703722 }
704723
705- auto num_outputs = fetches.size ();
706724 const auto & fetch_copy_info = feeds_fetches_manager.GetFetchesDeviceCopyInfo ();
707-
708725 if (device_copy_checks.output_copy_needed == DeviceCopyCheck::Copy) {
709726 PopulateDeviceFetches (fetch_copy_info, fetches, device_fetches);
710727 p_fetches = &device_fetches;
@@ -847,7 +864,6 @@ common::Status ExecutePartialGraphImpl(const SessionState& session_state, FeedsF
847864 p_feeds = device_feeds;
848865 }
849866
850- auto num_outputs = fetches.size ();
851867 const auto & fetch_copy_info = feeds_fetches_manager.GetFetchesDeviceCopyInfo ();
852868
853869 if (device_copy_checks.output_copy_needed == DeviceCopyCheck::Copy) {
0 commit comments