Skip to content

Fix issue w/ DML batched readback size#28680

Draft
adrastogi wants to merge 1 commit into
mainfrom
adrastogi/dml-overflow
Draft

Fix issue w/ DML batched readback size#28680
adrastogi wants to merge 1 commit into
mainfrom
adrastogi/dml-overflow

Conversation

@adrastogi
Copy link
Copy Markdown
Contributor

Description

Fixes DML batched GPU readback sizing to avoid uint32_t aggregate wraparound, which could undersize the D3D12 readback. Adds checked size_t accumulation, safer readback validation, and regression coverage for overflow cases.

Motivation and Context

DML batched readback previously accumulated output tensor byte sizes in uint32_t, so multiple individually valid outputs could wrap the aggregate size used to allocate the D3D12 readback heap. This change tightens the sizing path and keeps the readback allocation consistent with the per-tensor copy sizes.

@adrastogi adrastogi requested a review from Copilot May 26, 2026 21:17
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot encountered an error: Your billing is not configured or you have Copilot licenses from multiple standalone organizations or enterprises. To use premium requests, select a billing entity via the GitHub site, under Settings > Copilot > Features.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot's findings

  • Files reviewed: 4/4 changed files
  • Comments generated: 5

// Map the readback heap and copy it into the destination
void* readbackHeapData = nullptr;
ORT_THROW_IF_FAILED(m_readbackHeap->Map(0, nullptr, &readbackHeapData));
auto unmapReadbackHeap = gsl::finally([this]() { m_readbackHeap->Unmap(0, nullptr); });
Comment on lines 45 to +51

static Status HResultToStatus(HRESULT hr, const char* operation, const char* details)
{
const StatusCode status_code = hr == E_INVALIDARG ? INVALID_ARGUMENT : FAIL;
return Status(ONNXRUNTIME, status_code,
onnxruntime::MakeString(operation, " failed with HRESULT ", hr, ": ", details));
}
Comment on lines +13 to +18
#include <d3d12.h>
#include <gsl/gsl>
#include <wrl/client.h>

using Microsoft::WRL::ComPtr;

Comment on lines +41 to +45
TEST(DmlReadbackHeapTest, ComputeTotalReadbackSizeRejectsSizeTOverflow) {
const std::array<size_t, 2> sizes = {std::numeric_limits<size_t>::max(), 1};

EXPECT_ANY_THROW((void)Dml::detail::ComputeTotalReadbackSize(gsl::make_span(sizes.data(), sizes.size())));
}
Comment on lines +43 to +51

EXPECT_ANY_THROW((void)Dml::detail::ComputeTotalReadbackSize(gsl::make_span(sizes.data(), sizes.size())));
}

TEST(DmlReadbackHeapTest, ComputeTotalReadbackSizeRejectsMidBatchOverflow) {
const size_t half_max = std::numeric_limits<size_t>::max() / 2;
const std::array<size_t, 2> sizes = {half_max + 1, half_max + 1};

EXPECT_ANY_THROW((void)Dml::detail::ComputeTotalReadbackSize(gsl::make_span(sizes.data(), sizes.size())));
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants