Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

Comment thread
vraspar marked this conversation as resolved.
#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"

namespace onnxruntime {
namespace test {

// Regression tests for OOB reads when label values are outside [0, C).

TEST(CrossEntropyTest, SoftmaxCrossEntropyLoss_LabelTooLarge) {
OpTester test("SoftmaxCrossEntropyLoss", 12);
test.AddAttribute("reduction", std::string("mean"));
test.AddAttribute("ignore_index", static_cast<int64_t>(-1));

Comment thread
vraspar marked this conversation as resolved.
std::vector<float> X_data(3 * 5, 1.0f);
std::vector<int64_t> index_data = {0, 5, 2}; // 5 is out of range [0, 5)

test.AddInput<float>("X", {3, 5}, X_data);
test.AddInput<int64_t>("index", {3}, index_data);
test.AddOutput<float>("output", {}, {0.0f});
test.AddOutput<float>("log_prob", {3, 5}, std::vector<float>(15, 0.0f));

test.Run(OpTester::ExpectResult::kExpectFailure, "out of range");
}
Comment thread
vraspar marked this conversation as resolved.

TEST(CrossEntropyTest, SoftmaxCrossEntropyLoss_NegativeLabel) {
OpTester test("SoftmaxCrossEntropyLoss", 12);
test.AddAttribute("reduction", std::string("mean"));
test.AddAttribute("ignore_index", static_cast<int64_t>(-100));

std::vector<float> X_data(3 * 5, 1.0f);
std::vector<int64_t> index_data = {0, -1, 2}; // -1 is out of range (and != ignore_index)

test.AddInput<float>("X", {3, 5}, X_data);
test.AddInput<int64_t>("index", {3}, index_data);
test.AddOutput<float>("output", {}, {0.0f});
test.AddOutput<float>("log_prob", {3, 5}, std::vector<float>(15, 0.0f));

test.Run(OpTester::ExpectResult::kExpectFailure, "out of range");
}

TEST(CrossEntropyTest, SoftmaxCrossEntropyLoss_LabelTooLargeWithWeights) {
OpTester test("SoftmaxCrossEntropyLoss", 12);
test.AddAttribute("reduction", std::string("mean"));
test.AddAttribute("ignore_index", static_cast<int64_t>(-1));

std::vector<float> X_data(3 * 5, 1.0f);
std::vector<int64_t> index_data = {0, 100, 2}; // 100 is out of range
std::vector<float> weight_data = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f};

test.AddInput<float>("X", {3, 5}, X_data);
test.AddInput<int64_t>("index", {3}, index_data);
test.AddInput<float>("weight", {5}, weight_data);
test.AddOutput<float>("output", {}, {0.0f});
test.AddOutput<float>("log_prob", {3, 5}, std::vector<float>(15, 0.0f));

test.Run(OpTester::ExpectResult::kExpectFailure, "out of range");
}

TEST(CrossEntropyTest, SoftmaxCrossEntropyLossGrad_LabelTooLarge) {
OpTester test("SoftmaxCrossEntropyLossGrad", 1, onnxruntime::kMSDomain);
test.AddAttribute("reduction", std::string("mean"));
test.AddAttribute("ignore_index", static_cast<int64_t>(-1));

std::vector<float> dY_data = {1.0f};
std::vector<float> log_prob_data(3 * 5, -1.6094f);
std::vector<int64_t> index_data = {0, 5, 2}; // 5 is out of range [0, 5)

test.AddInput<float>("dY", {}, dY_data);
test.AddInput<float>("log_prob", {3, 5}, log_prob_data);
test.AddInput<int64_t>("index", {3}, index_data);
test.AddOutput<float>("dX", {3, 5}, std::vector<float>(15, 0.0f));

test.Run(OpTester::ExpectResult::kExpectFailure, "out of range");
}

TEST(CrossEntropyTest, SoftmaxCrossEntropyLossGrad_LabelTooLargeWithWeights) {
OpTester test("SoftmaxCrossEntropyLossGrad", 1, onnxruntime::kMSDomain);
test.AddAttribute("reduction", std::string("mean"));
test.AddAttribute("ignore_index", static_cast<int64_t>(-1));

std::vector<float> dY_data = {1.0f};
std::vector<float> log_prob_data(3 * 5, -1.6094f);
std::vector<int64_t> index_data = {0, 5, 2}; // 5 is out of range [0, 5)
std::vector<float> weight_data = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f};

test.AddInput<float>("dY", {}, dY_data);
test.AddInput<float>("log_prob", {3, 5}, log_prob_data);
test.AddInput<int64_t>("index", {3}, index_data);
test.AddInput<float>("weight", {5}, weight_data);
test.AddOutput<float>("dX", {3, 5}, std::vector<float>(15, 0.0f));

test.Run(OpTester::ExpectResult::kExpectFailure, "out of range");
}

} // namespace test
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,17 @@ void GetNDCFromLogitAndLabelShape(const TensorShape& logit_shape, const TensorSh
void VerifyLogitWeightAndLabelShape(const TensorShape& logit_shape,
const TensorShape& label_shape,
const TensorShape* weight_shape) {
ORT_ENFORCE(nullptr == weight_shape || 1 == weight_shape->NumDimensions(), "Weights tensor is not 1-D.");

const size_t label_dims = label_shape.NumDimensions();
ORT_ENFORCE(label_dims >= 1, "label must be at least 1-D.");
ORT_ENFORCE(logit_shape.NumDimensions() >= 2, "logit must be at least 2-D.");
ORT_ENFORCE(logit_shape.NumDimensions() == label_dims + 1,
"logit_shape must be (1 + label_shape)");

ORT_ENFORCE(nullptr == weight_shape || 1 == weight_shape->NumDimensions(), "Weights tensor is not 1-D.");
ORT_ENFORCE(nullptr == weight_shape || (*weight_shape)[0] == logit_shape[1],
"Weight tensor size (", (weight_shape ? (*weight_shape)[0] : 0),
") must equal the number of classes (", logit_shape[1], ")");
Comment thread
vraspar marked this conversation as resolved.
Comment thread
vraspar marked this conversation as resolved.

ORT_ENFORCE(label_shape[0] == logit_shape[0], "The shape of logit and label does not match");

if (label_dims >= 2) {
Expand Down Expand Up @@ -147,6 +152,16 @@ Status SoftmaxCrossEntropyLoss<T1, T2>::Compute(OpKernelContext* context) const
}

const T2* label_data = label.template Data<T2>();

// Validate label values are within [0, C) to prevent out-of-bounds reads.
for (int64_t i = 0; i < N_D; i++) {
if (ignore_index != label_data[i]) {
ORT_RETURN_IF(label_data[i] < 0 || label_data[i] >= C,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we move that test in one of the loops below where the label is compared a second time to ignore_index?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we move that test in one of the loops below where the label is compared a second time to ignore_index?

I think the separate validation pass is the better approach here, for a few reasons:

  1. Fail-fast: validates all labels before any computation begins, so we never produce partial/corrupt output on bad input.
  2. Parallel safety: the Grad path uses TryParallelFor, and propagating errors out of parallel lambdas is tricky. A separate upfront check avoids that complexity.
  3. Negligible cost: the label scan is O(N_D) vs O(N_D × C) for the main computation, so the extra pass should not be measurable in practice.

That said, I'm open to folding it in if you feel strongly about it. Just wanted to flag the tradeoffs.

"SoftmaxCrossEntropyLoss: label value ", label_data[i],
" at index ", i, " is out of range [0, ", C, ")");
Comment thread
vraspar marked this conversation as resolved.
}
}
Comment thread
vraspar marked this conversation as resolved.

T1* loss_data = loss->template MutableData<T1>();
std::vector<T1> shifted_logit(narrow<size_t>(n_d_c));
ORT_ENFORCE(n_d_c <= static_cast<uint64_t>(std::numeric_limits<Eigen::Index>::max()));
Expand Down Expand Up @@ -267,6 +282,16 @@ Status SoftmaxCrossEntropyLossGrad<T1, T2>::Compute(OpKernelContext* context) co
const T1* dY_data = dY.template Data<T1>();
const T1* log_prob_data = log_prob.template Data<T1>();
const T2* label_data = label.template Data<T2>();

// Validate label values are within [0, C) to prevent out-of-bounds reads.
for (int64_t i = 0; i < N_D; i++) {
if (ignore_index != label_data[i]) {
ORT_RETURN_IF(label_data[i] < 0 || label_data[i] >= C,
"SoftmaxCrossEntropyLossGrad: label value ", label_data[i],
" at index ", i, " is out of range [0, ", C, ")");
Comment thread
vraspar marked this conversation as resolved.
}
}
Comment thread
vraspar marked this conversation as resolved.

Tensor* d_logit = context->Output(0, probability_shape);
T1* d_logit_data = d_logit->template MutableData<T1>();
std::memset(d_logit_data, 0, narrow<size_t>(sizeof(T1) * N_D));
Expand Down Expand Up @@ -299,11 +324,11 @@ Status SoftmaxCrossEntropyLossGrad<T1, T2>::Compute(OpKernelContext* context) co
int64_t row = index / C;
int64_t col = index % C;
T2 label_sample = label_data[row];
T1 weight_smaple = weight_data[label_sample] * dY_data[row];
if (ignore_index == label_sample) {
d_logit_data[index] = 0;
} else {
d_logit_data[index] = (exp(log_prob_data[index]) - (label_sample == col)) * weight_smaple;
T1 weight_sample = weight_data[label_sample] * dY_data[row];
d_logit_data[index] = (exp(log_prob_data[index]) - (label_sample == col)) * weight_sample;
}
Comment thread
vraspar marked this conversation as resolved.
}
});
Expand All @@ -330,11 +355,11 @@ Status SoftmaxCrossEntropyLossGrad<T1, T2>::Compute(OpKernelContext* context) co
int64_t row = index / C;
int64_t col = index % C;
T2 label_sample = label_data[row];
T1 weight_smaple = weight_data[label_sample] * dY_scaled;
if (ignore_index == label_sample) {
d_logit_data[index] = 0;
} else {
d_logit_data[index] = (exp(log_prob_data[index]) - (label_sample == col)) * weight_smaple;
T1 weight_sample = weight_data[label_sample] * dY_scaled;
d_logit_data[index] = (exp(log_prob_data[index]) - (label_sample == col)) * weight_sample;
}
Comment thread
vraspar marked this conversation as resolved.
}
});
Expand Down
Loading