diff --git a/orttraining/orttraining/test/training_ops/cpu/loss/cross_entropy_test.cc b/orttraining/orttraining/test/training_ops/cpu/loss/cross_entropy_test.cc new file mode 100644 index 0000000000000..f8220dd1977c5 --- /dev/null +++ b/orttraining/orttraining/test/training_ops/cpu/loss/cross_entropy_test.cc @@ -0,0 +1,99 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +// Regression tests for OOB reads when label values are outside [0, C). + +TEST(CrossEntropyTest, SoftmaxCrossEntropyLoss_LabelTooLarge) { + OpTester test("SoftmaxCrossEntropyLoss", 12); + test.AddAttribute("reduction", std::string("mean")); + test.AddAttribute("ignore_index", static_cast(-1)); + + std::vector X_data(3 * 5, 1.0f); + std::vector index_data = {0, 5, 2}; // 5 is out of range [0, 5) + + test.AddInput("X", {3, 5}, X_data); + test.AddInput("index", {3}, index_data); + test.AddOutput("output", {}, {0.0f}); + test.AddOutput("log_prob", {3, 5}, std::vector(15, 0.0f)); + + test.Run(OpTester::ExpectResult::kExpectFailure, "out of range"); +} + +TEST(CrossEntropyTest, SoftmaxCrossEntropyLoss_NegativeLabel) { + OpTester test("SoftmaxCrossEntropyLoss", 12); + test.AddAttribute("reduction", std::string("mean")); + test.AddAttribute("ignore_index", static_cast(-100)); + + std::vector X_data(3 * 5, 1.0f); + std::vector index_data = {0, -1, 2}; // -1 is out of range (and != ignore_index) + + test.AddInput("X", {3, 5}, X_data); + test.AddInput("index", {3}, index_data); + test.AddOutput("output", {}, {0.0f}); + test.AddOutput("log_prob", {3, 5}, std::vector(15, 0.0f)); + + test.Run(OpTester::ExpectResult::kExpectFailure, "out of range"); +} + +TEST(CrossEntropyTest, SoftmaxCrossEntropyLoss_LabelTooLargeWithWeights) { + OpTester test("SoftmaxCrossEntropyLoss", 12); + test.AddAttribute("reduction", std::string("mean")); + test.AddAttribute("ignore_index", static_cast(-1)); + + std::vector X_data(3 * 5, 1.0f); + std::vector index_data = {0, 100, 2}; // 100 is out of range + std::vector weight_data = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + + test.AddInput("X", {3, 5}, X_data); + test.AddInput("index", {3}, index_data); + test.AddInput("weight", {5}, weight_data); + test.AddOutput("output", {}, {0.0f}); + test.AddOutput("log_prob", {3, 5}, std::vector(15, 0.0f)); + + test.Run(OpTester::ExpectResult::kExpectFailure, "out of range"); +} + +TEST(CrossEntropyTest, SoftmaxCrossEntropyLossGrad_LabelTooLarge) { + OpTester test("SoftmaxCrossEntropyLossGrad", 1, onnxruntime::kMSDomain); + test.AddAttribute("reduction", std::string("mean")); + test.AddAttribute("ignore_index", static_cast(-1)); + + std::vector dY_data = {1.0f}; + std::vector log_prob_data(3 * 5, -1.6094f); + std::vector index_data = {0, 5, 2}; // 5 is out of range [0, 5) + + test.AddInput("dY", {}, dY_data); + test.AddInput("log_prob", {3, 5}, log_prob_data); + test.AddInput("index", {3}, index_data); + test.AddOutput("dX", {3, 5}, std::vector(15, 0.0f)); + + test.Run(OpTester::ExpectResult::kExpectFailure, "out of range"); +} + +TEST(CrossEntropyTest, SoftmaxCrossEntropyLossGrad_LabelTooLargeWithWeights) { + OpTester test("SoftmaxCrossEntropyLossGrad", 1, onnxruntime::kMSDomain); + test.AddAttribute("reduction", std::string("mean")); + test.AddAttribute("ignore_index", static_cast(-1)); + + std::vector dY_data = {1.0f}; + std::vector log_prob_data(3 * 5, -1.6094f); + std::vector index_data = {0, 5, 2}; // 5 is out of range [0, 5) + std::vector weight_data = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; + + test.AddInput("dY", {}, dY_data); + test.AddInput("log_prob", {3, 5}, log_prob_data); + test.AddInput("index", {3}, index_data); + test.AddInput("weight", {5}, weight_data); + test.AddOutput("dX", {3, 5}, std::vector(15, 0.0f)); + + test.Run(OpTester::ExpectResult::kExpectFailure, "out of range"); +} + +} // namespace test +} // namespace onnxruntime diff --git a/orttraining/orttraining/training_ops/cpu/loss/softmax_cross_entropy_loss.cc b/orttraining/orttraining/training_ops/cpu/loss/softmax_cross_entropy_loss.cc index c74bf06a77d6e..7cec8bd0b6946 100644 --- a/orttraining/orttraining/training_ops/cpu/loss/softmax_cross_entropy_loss.cc +++ b/orttraining/orttraining/training_ops/cpu/loss/softmax_cross_entropy_loss.cc @@ -56,12 +56,17 @@ void GetNDCFromLogitAndLabelShape(const TensorShape& logit_shape, const TensorSh void VerifyLogitWeightAndLabelShape(const TensorShape& logit_shape, const TensorShape& label_shape, const TensorShape* weight_shape) { - ORT_ENFORCE(nullptr == weight_shape || 1 == weight_shape->NumDimensions(), "Weights tensor is not 1-D."); - const size_t label_dims = label_shape.NumDimensions(); + ORT_ENFORCE(label_dims >= 1, "label must be at least 1-D."); + ORT_ENFORCE(logit_shape.NumDimensions() >= 2, "logit must be at least 2-D."); ORT_ENFORCE(logit_shape.NumDimensions() == label_dims + 1, "logit_shape must be (1 + label_shape)"); + ORT_ENFORCE(nullptr == weight_shape || 1 == weight_shape->NumDimensions(), "Weights tensor is not 1-D."); + ORT_ENFORCE(nullptr == weight_shape || (*weight_shape)[0] == logit_shape[1], + "Weight tensor size (", (weight_shape ? (*weight_shape)[0] : 0), + ") must equal the number of classes (", logit_shape[1], ")"); + ORT_ENFORCE(label_shape[0] == logit_shape[0], "The shape of logit and label does not match"); if (label_dims >= 2) { @@ -147,6 +152,16 @@ Status SoftmaxCrossEntropyLoss::Compute(OpKernelContext* context) const } const T2* label_data = label.template Data(); + + // Validate label values are within [0, C) to prevent out-of-bounds reads. + for (int64_t i = 0; i < N_D; i++) { + if (ignore_index != label_data[i]) { + ORT_RETURN_IF(label_data[i] < 0 || label_data[i] >= C, + "SoftmaxCrossEntropyLoss: label value ", label_data[i], + " at index ", i, " is out of range [0, ", C, ")"); + } + } + T1* loss_data = loss->template MutableData(); std::vector shifted_logit(narrow(n_d_c)); ORT_ENFORCE(n_d_c <= static_cast(std::numeric_limits::max())); @@ -267,6 +282,16 @@ Status SoftmaxCrossEntropyLossGrad::Compute(OpKernelContext* context) co const T1* dY_data = dY.template Data(); const T1* log_prob_data = log_prob.template Data(); const T2* label_data = label.template Data(); + + // Validate label values are within [0, C) to prevent out-of-bounds reads. + for (int64_t i = 0; i < N_D; i++) { + if (ignore_index != label_data[i]) { + ORT_RETURN_IF(label_data[i] < 0 || label_data[i] >= C, + "SoftmaxCrossEntropyLossGrad: label value ", label_data[i], + " at index ", i, " is out of range [0, ", C, ")"); + } + } + Tensor* d_logit = context->Output(0, probability_shape); T1* d_logit_data = d_logit->template MutableData(); std::memset(d_logit_data, 0, narrow(sizeof(T1) * N_D)); @@ -299,11 +324,11 @@ Status SoftmaxCrossEntropyLossGrad::Compute(OpKernelContext* context) co int64_t row = index / C; int64_t col = index % C; T2 label_sample = label_data[row]; - T1 weight_smaple = weight_data[label_sample] * dY_data[row]; if (ignore_index == label_sample) { d_logit_data[index] = 0; } else { - d_logit_data[index] = (exp(log_prob_data[index]) - (label_sample == col)) * weight_smaple; + T1 weight_sample = weight_data[label_sample] * dY_data[row]; + d_logit_data[index] = (exp(log_prob_data[index]) - (label_sample == col)) * weight_sample; } } }); @@ -330,11 +355,11 @@ Status SoftmaxCrossEntropyLossGrad::Compute(OpKernelContext* context) co int64_t row = index / C; int64_t col = index % C; T2 label_sample = label_data[row]; - T1 weight_smaple = weight_data[label_sample] * dY_scaled; if (ignore_index == label_sample) { d_logit_data[index] = 0; } else { - d_logit_data[index] = (exp(log_prob_data[index]) - (label_sample == col)) * weight_smaple; + T1 weight_sample = weight_data[label_sample] * dY_scaled; + d_logit_data[index] = (exp(log_prob_data[index]) - (label_sample == col)) * weight_sample; } } });