-
Notifications
You must be signed in to change notification settings - Fork 3.9k
Fix OOB reads in SoftmaxCrossEntropyLoss via label bounds validation #28004
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,99 @@ | ||
| // Copyright (c) Microsoft Corporation. All rights reserved. | ||
| // Licensed under the MIT License. | ||
|
|
||
| #include "gtest/gtest.h" | ||
| #include "test/providers/provider_test_utils.h" | ||
|
|
||
| namespace onnxruntime { | ||
| namespace test { | ||
|
|
||
| // Regression tests for OOB reads when label values are outside [0, C). | ||
|
|
||
| TEST(CrossEntropyTest, SoftmaxCrossEntropyLoss_LabelTooLarge) { | ||
| OpTester test("SoftmaxCrossEntropyLoss", 12); | ||
| test.AddAttribute("reduction", std::string("mean")); | ||
| test.AddAttribute("ignore_index", static_cast<int64_t>(-1)); | ||
|
|
||
|
vraspar marked this conversation as resolved.
|
||
| std::vector<float> X_data(3 * 5, 1.0f); | ||
| std::vector<int64_t> index_data = {0, 5, 2}; // 5 is out of range [0, 5) | ||
|
|
||
| test.AddInput<float>("X", {3, 5}, X_data); | ||
| test.AddInput<int64_t>("index", {3}, index_data); | ||
| test.AddOutput<float>("output", {}, {0.0f}); | ||
| test.AddOutput<float>("log_prob", {3, 5}, std::vector<float>(15, 0.0f)); | ||
|
|
||
| test.Run(OpTester::ExpectResult::kExpectFailure, "out of range"); | ||
| } | ||
|
vraspar marked this conversation as resolved.
|
||
|
|
||
| TEST(CrossEntropyTest, SoftmaxCrossEntropyLoss_NegativeLabel) { | ||
| OpTester test("SoftmaxCrossEntropyLoss", 12); | ||
| test.AddAttribute("reduction", std::string("mean")); | ||
| test.AddAttribute("ignore_index", static_cast<int64_t>(-100)); | ||
|
|
||
| std::vector<float> X_data(3 * 5, 1.0f); | ||
| std::vector<int64_t> index_data = {0, -1, 2}; // -1 is out of range (and != ignore_index) | ||
|
|
||
| test.AddInput<float>("X", {3, 5}, X_data); | ||
| test.AddInput<int64_t>("index", {3}, index_data); | ||
| test.AddOutput<float>("output", {}, {0.0f}); | ||
| test.AddOutput<float>("log_prob", {3, 5}, std::vector<float>(15, 0.0f)); | ||
|
|
||
| test.Run(OpTester::ExpectResult::kExpectFailure, "out of range"); | ||
| } | ||
|
|
||
| TEST(CrossEntropyTest, SoftmaxCrossEntropyLoss_LabelTooLargeWithWeights) { | ||
| OpTester test("SoftmaxCrossEntropyLoss", 12); | ||
| test.AddAttribute("reduction", std::string("mean")); | ||
| test.AddAttribute("ignore_index", static_cast<int64_t>(-1)); | ||
|
|
||
| std::vector<float> X_data(3 * 5, 1.0f); | ||
| std::vector<int64_t> index_data = {0, 100, 2}; // 100 is out of range | ||
| std::vector<float> weight_data = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; | ||
|
|
||
| test.AddInput<float>("X", {3, 5}, X_data); | ||
| test.AddInput<int64_t>("index", {3}, index_data); | ||
| test.AddInput<float>("weight", {5}, weight_data); | ||
| test.AddOutput<float>("output", {}, {0.0f}); | ||
| test.AddOutput<float>("log_prob", {3, 5}, std::vector<float>(15, 0.0f)); | ||
|
|
||
| test.Run(OpTester::ExpectResult::kExpectFailure, "out of range"); | ||
| } | ||
|
|
||
| TEST(CrossEntropyTest, SoftmaxCrossEntropyLossGrad_LabelTooLarge) { | ||
| OpTester test("SoftmaxCrossEntropyLossGrad", 1, onnxruntime::kMSDomain); | ||
| test.AddAttribute("reduction", std::string("mean")); | ||
| test.AddAttribute("ignore_index", static_cast<int64_t>(-1)); | ||
|
|
||
| std::vector<float> dY_data = {1.0f}; | ||
| std::vector<float> log_prob_data(3 * 5, -1.6094f); | ||
| std::vector<int64_t> index_data = {0, 5, 2}; // 5 is out of range [0, 5) | ||
|
|
||
| test.AddInput<float>("dY", {}, dY_data); | ||
| test.AddInput<float>("log_prob", {3, 5}, log_prob_data); | ||
| test.AddInput<int64_t>("index", {3}, index_data); | ||
| test.AddOutput<float>("dX", {3, 5}, std::vector<float>(15, 0.0f)); | ||
|
|
||
| test.Run(OpTester::ExpectResult::kExpectFailure, "out of range"); | ||
| } | ||
|
|
||
| TEST(CrossEntropyTest, SoftmaxCrossEntropyLossGrad_LabelTooLargeWithWeights) { | ||
| OpTester test("SoftmaxCrossEntropyLossGrad", 1, onnxruntime::kMSDomain); | ||
| test.AddAttribute("reduction", std::string("mean")); | ||
| test.AddAttribute("ignore_index", static_cast<int64_t>(-1)); | ||
|
|
||
| std::vector<float> dY_data = {1.0f}; | ||
| std::vector<float> log_prob_data(3 * 5, -1.6094f); | ||
| std::vector<int64_t> index_data = {0, 5, 2}; // 5 is out of range [0, 5) | ||
| std::vector<float> weight_data = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; | ||
|
|
||
| test.AddInput<float>("dY", {}, dY_data); | ||
| test.AddInput<float>("log_prob", {3, 5}, log_prob_data); | ||
| test.AddInput<int64_t>("index", {3}, index_data); | ||
| test.AddInput<float>("weight", {5}, weight_data); | ||
| test.AddOutput<float>("dX", {3, 5}, std::vector<float>(15, 0.0f)); | ||
|
|
||
| test.Run(OpTester::ExpectResult::kExpectFailure, "out of range"); | ||
| } | ||
|
|
||
| } // namespace test | ||
| } // namespace onnxruntime | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -56,12 +56,17 @@ void GetNDCFromLogitAndLabelShape(const TensorShape& logit_shape, const TensorSh | |
| void VerifyLogitWeightAndLabelShape(const TensorShape& logit_shape, | ||
| const TensorShape& label_shape, | ||
| const TensorShape* weight_shape) { | ||
| ORT_ENFORCE(nullptr == weight_shape || 1 == weight_shape->NumDimensions(), "Weights tensor is not 1-D."); | ||
|
|
||
| const size_t label_dims = label_shape.NumDimensions(); | ||
| ORT_ENFORCE(label_dims >= 1, "label must be at least 1-D."); | ||
| ORT_ENFORCE(logit_shape.NumDimensions() >= 2, "logit must be at least 2-D."); | ||
| ORT_ENFORCE(logit_shape.NumDimensions() == label_dims + 1, | ||
| "logit_shape must be (1 + label_shape)"); | ||
|
|
||
| ORT_ENFORCE(nullptr == weight_shape || 1 == weight_shape->NumDimensions(), "Weights tensor is not 1-D."); | ||
| ORT_ENFORCE(nullptr == weight_shape || (*weight_shape)[0] == logit_shape[1], | ||
| "Weight tensor size (", (weight_shape ? (*weight_shape)[0] : 0), | ||
| ") must equal the number of classes (", logit_shape[1], ")"); | ||
|
vraspar marked this conversation as resolved.
vraspar marked this conversation as resolved.
|
||
|
|
||
| ORT_ENFORCE(label_shape[0] == logit_shape[0], "The shape of logit and label does not match"); | ||
|
|
||
| if (label_dims >= 2) { | ||
|
|
@@ -147,6 +152,16 @@ Status SoftmaxCrossEntropyLoss<T1, T2>::Compute(OpKernelContext* context) const | |
| } | ||
|
|
||
| const T2* label_data = label.template Data<T2>(); | ||
|
|
||
| // Validate label values are within [0, C) to prevent out-of-bounds reads. | ||
| for (int64_t i = 0; i < N_D; i++) { | ||
| if (ignore_index != label_data[i]) { | ||
| ORT_RETURN_IF(label_data[i] < 0 || label_data[i] >= C, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we move that test in one of the loops below where the label is compared a second time to ignore_index?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think the separate validation pass is the better approach here, for a few reasons:
That said, I'm open to folding it in if you feel strongly about it. Just wanted to flag the tradeoffs. |
||
| "SoftmaxCrossEntropyLoss: label value ", label_data[i], | ||
| " at index ", i, " is out of range [0, ", C, ")"); | ||
|
vraspar marked this conversation as resolved.
|
||
| } | ||
| } | ||
|
vraspar marked this conversation as resolved.
|
||
|
|
||
| T1* loss_data = loss->template MutableData<T1>(); | ||
| std::vector<T1> shifted_logit(narrow<size_t>(n_d_c)); | ||
| ORT_ENFORCE(n_d_c <= static_cast<uint64_t>(std::numeric_limits<Eigen::Index>::max())); | ||
|
|
@@ -267,6 +282,16 @@ Status SoftmaxCrossEntropyLossGrad<T1, T2>::Compute(OpKernelContext* context) co | |
| const T1* dY_data = dY.template Data<T1>(); | ||
| const T1* log_prob_data = log_prob.template Data<T1>(); | ||
| const T2* label_data = label.template Data<T2>(); | ||
|
|
||
| // Validate label values are within [0, C) to prevent out-of-bounds reads. | ||
| for (int64_t i = 0; i < N_D; i++) { | ||
| if (ignore_index != label_data[i]) { | ||
| ORT_RETURN_IF(label_data[i] < 0 || label_data[i] >= C, | ||
| "SoftmaxCrossEntropyLossGrad: label value ", label_data[i], | ||
| " at index ", i, " is out of range [0, ", C, ")"); | ||
|
vraspar marked this conversation as resolved.
|
||
| } | ||
| } | ||
|
vraspar marked this conversation as resolved.
|
||
|
|
||
| Tensor* d_logit = context->Output(0, probability_shape); | ||
| T1* d_logit_data = d_logit->template MutableData<T1>(); | ||
| std::memset(d_logit_data, 0, narrow<size_t>(sizeof(T1) * N_D)); | ||
|
|
@@ -299,11 +324,11 @@ Status SoftmaxCrossEntropyLossGrad<T1, T2>::Compute(OpKernelContext* context) co | |
| int64_t row = index / C; | ||
| int64_t col = index % C; | ||
| T2 label_sample = label_data[row]; | ||
| T1 weight_smaple = weight_data[label_sample] * dY_data[row]; | ||
| if (ignore_index == label_sample) { | ||
| d_logit_data[index] = 0; | ||
| } else { | ||
| d_logit_data[index] = (exp(log_prob_data[index]) - (label_sample == col)) * weight_smaple; | ||
| T1 weight_sample = weight_data[label_sample] * dY_data[row]; | ||
| d_logit_data[index] = (exp(log_prob_data[index]) - (label_sample == col)) * weight_sample; | ||
| } | ||
|
vraspar marked this conversation as resolved.
|
||
| } | ||
| }); | ||
|
|
@@ -330,11 +355,11 @@ Status SoftmaxCrossEntropyLossGrad<T1, T2>::Compute(OpKernelContext* context) co | |
| int64_t row = index / C; | ||
| int64_t col = index % C; | ||
| T2 label_sample = label_data[row]; | ||
| T1 weight_smaple = weight_data[label_sample] * dY_scaled; | ||
| if (ignore_index == label_sample) { | ||
| d_logit_data[index] = 0; | ||
| } else { | ||
| d_logit_data[index] = (exp(log_prob_data[index]) - (label_sample == col)) * weight_smaple; | ||
| T1 weight_sample = weight_data[label_sample] * dY_scaled; | ||
| d_logit_data[index] = (exp(log_prob_data[index]) - (label_sample == col)) * weight_sample; | ||
| } | ||
|
vraspar marked this conversation as resolved.
|
||
| } | ||
| }); | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.