Skip to content

Commit f2c15e3

Browse files
committed
Address review: rank guards, overflow-safe ceil-div, weighted grad test
1 parent 5fb8b6e commit f2c15e3

4 files changed

Lines changed: 29 additions & 8 deletions

File tree

onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ Status MatMulBnb4::Compute(OpKernelContext* ctx) const {
5858
"K * N would overflow int64: K=", K_, ", N=", N_);
5959
}
6060
const int64_t numel = K_ * N_;
61-
const int64_t expected_b_quant_size = (numel + 1) / 2;
62-
const int64_t expected_absmax_size = (numel + block_size_ - 1) / block_size_;
61+
const int64_t expected_b_quant_size = ((numel - 1) / 2) + 1;
62+
const int64_t expected_absmax_size = ((numel - 1) / block_size_) + 1;
6363

6464
if (b_quant->Shape().Size() < expected_b_quant_size) {
6565
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,

onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ Status MatMulBnb4<T>::ComputeInternal(OpKernelContext* ctx) const {
5959
"K * N would overflow int64: K=", K_, ", N=", N_);
6060
}
6161
const int64_t numel = K_ * N_;
62-
const int64_t expected_b_quant_size = (numel + 1) / 2;
63-
const int64_t expected_absmax_size = (numel + block_size_ - 1) / block_size_;
62+
const int64_t expected_b_quant_size = ((numel - 1) / 2) + 1;
63+
const int64_t expected_absmax_size = ((numel - 1) / block_size_) + 1;
6464

6565
if (b_quant->Shape().Size() < expected_b_quant_size) {
6666
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,

orttraining/orttraining/test/training_ops/cpu/loss/cross_entropy_test.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,5 +76,24 @@ TEST(CrossEntropyTest, SoftmaxCrossEntropyLossGrad_LabelTooLarge) {
7676
test.Run(OpTester::ExpectResult::kExpectFailure, "out of range");
7777
}
7878

79+
TEST(CrossEntropyTest, SoftmaxCrossEntropyLossGrad_LabelTooLargeWithWeights) {
80+
OpTester test("SoftmaxCrossEntropyLossGrad", 1, onnxruntime::kMSDomain);
81+
test.AddAttribute("reduction", std::string("mean"));
82+
test.AddAttribute("ignore_index", static_cast<int64_t>(-1));
83+
84+
std::vector<float> dY_data = {1.0f};
85+
std::vector<float> log_prob_data(3 * 5, -1.6094f);
86+
std::vector<int64_t> index_data = {0, 5, 2}; // 5 is out of range [0, 5)
87+
std::vector<float> weight_data = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
88+
89+
test.AddInput<float>("dY", {}, dY_data);
90+
test.AddInput<float>("log_prob", {3, 5}, log_prob_data);
91+
test.AddInput<int64_t>("index", {3}, index_data);
92+
test.AddInput<float>("weight", {5}, weight_data);
93+
test.AddOutput<float>("dX", {3, 5}, std::vector<float>(15, 0.0f));
94+
95+
test.Run(OpTester::ExpectResult::kExpectFailure, "out of range");
96+
}
97+
7998
} // namespace test
8099
} // namespace onnxruntime

orttraining/orttraining/training_ops/cpu/loss/softmax_cross_entropy_loss.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,17 @@ void GetNDCFromLogitAndLabelShape(const TensorShape& logit_shape, const TensorSh
5656
void VerifyLogitWeightAndLabelShape(const TensorShape& logit_shape,
5757
const TensorShape& label_shape,
5858
const TensorShape* weight_shape) {
59+
const size_t label_dims = label_shape.NumDimensions();
60+
ORT_ENFORCE(label_dims >= 1, "label must be at least 1-D.");
61+
ORT_ENFORCE(logit_shape.NumDimensions() >= 2, "logit must be at least 2-D.");
62+
ORT_ENFORCE(logit_shape.NumDimensions() == label_dims + 1,
63+
"logit_shape must be (1 + label_shape)");
64+
5965
ORT_ENFORCE(nullptr == weight_shape || 1 == weight_shape->NumDimensions(), "Weights tensor is not 1-D.");
6066
ORT_ENFORCE(nullptr == weight_shape || (*weight_shape)[0] == logit_shape[1],
6167
"Weight tensor size (", (weight_shape ? (*weight_shape)[0] : 0),
6268
") must equal the number of classes (", logit_shape[1], ")");
6369

64-
const size_t label_dims = label_shape.NumDimensions();
65-
ORT_ENFORCE(logit_shape.NumDimensions() == label_dims + 1,
66-
"logit_shape must be (1 + label_shape)");
67-
6870
ORT_ENFORCE(label_shape[0] == logit_shape[0], "The shape of logit and label does not match");
6971

7072
if (label_dims >= 2) {

0 commit comments

Comments
 (0)