Skip to content

Commit b6773d9

Browse files
committed
Add bounds validation for LinearClassifier coefficients to prevent OOB read in GEMM
1 parent aaa4944 commit b6773d9

2 files changed

Lines changed: 56 additions & 0 deletions

File tree

onnxruntime/core/providers/cpu/ml/linearclassifier.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ LinearClassifier::LinearClassifier(const OpKernelInfo& info)
3636

3737
using_strings_ = !classlabels_strings_.empty();
3838
class_count_ = static_cast<ptrdiff_t>(intercepts_.size());
39+
40+
ORT_ENFORCE(class_count_ > 0, "LinearClassifier: intercepts must not be empty.");
41+
ORT_ENFORCE(coefficients_.size() % static_cast<size_t>(class_count_) == 0,
42+
"LinearClassifier: coefficients size (", coefficients_.size(),
43+
") must be a multiple of the number of classes (", class_count_, ").");
44+
3945
SetupMlasBackendKernelSelectorFromConfigOptions(mlas_backend_kernel_selector_config_, info.GetConfigOptions());
4046
}
4147

@@ -146,6 +152,14 @@ Status LinearClassifier::Compute(OpKernelContext* ctx) const {
146152
input_shape[0])
147153
: narrow<ptrdiff_t>(input_shape[1]);
148154

155+
// Validate coefficients are large enough to prevent OOB read in GEMM.
156+
const size_t expected_coefficients_size = SafeInt<size_t>(class_count_) * SafeInt<size_t>(num_features);
157+
if (coefficients_.size() < expected_coefficients_size) {
158+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
159+
"LinearClassifier: coefficients length (", coefficients_.size(),
160+
") is less than classes (", class_count_, ") * features (", num_features, ")");
161+
}
162+
149163
Tensor* Y = ctx->Output(0, {num_batches});
150164

151165
int64_t output_classes = class_count_;

onnxruntime/test/providers/cpu/ml/linearclassifer_test.cc

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,5 +166,47 @@ TEST(MLOpTest, LinearClassifierMulticlassInt32Input) {
166166
TEST(MLOpTest, LinearClassifierMulticlassDoubleInput) {
167167
LinearClassifierMulticlass<double>();
168168
}
169+
170+
// Regression test: coefficients size doesn't match class_count * num_features.
171+
TEST(MLOpTest, LinearClassifierInvalidCoefficientsSizeFails) {
172+
OpTester test("LinearClassifier", 1, onnxruntime::kMLDomain);
173+
174+
// 3 intercepts => class_count = 3, input has 2 features => expects 6 coefficients.
175+
std::vector<float> coefficients = {-0.22562418f, 0.34188559f, 0.68346153f};
176+
std::vector<int64_t> classes = {1, 2, 3};
177+
std::vector<float> intercepts = {-3.91601811f, 0.42575697f, 0.13731251f};
178+
179+
test.AddAttribute("coefficients", coefficients);
180+
test.AddAttribute("intercepts", intercepts);
181+
test.AddAttribute("classlabels_ints", classes);
182+
183+
test.AddInput<float>("X", {1, 2}, {1.f, 0.f});
184+
test.AddOutput<int64_t>("Y", {1}, {0LL});
185+
test.AddOutput<float>("Z", {1, 3}, {0.f, 0.f, 0.f});
186+
187+
test.Run(OpTester::ExpectResult::kExpectFailure,
188+
"LinearClassifier: coefficients length (3) is less than classes (3) * features (2)");
189+
}
190+
191+
// Regression test: coefficients not divisible by class_count.
192+
TEST(MLOpTest, LinearClassifierCoefficientsSizeNotDivisibleByClassCountFails) {
193+
OpTester test("LinearClassifier", 1, onnxruntime::kMLDomain);
194+
195+
// 3 intercepts => class_count = 3, but 5 coefficients is not divisible by 3.
196+
std::vector<float> coefficients = {1.f, 2.f, 3.f, 4.f, 5.f};
197+
std::vector<int64_t> classes = {1, 2, 3};
198+
std::vector<float> intercepts = {0.1f, 0.2f, 0.3f};
199+
200+
test.AddAttribute("coefficients", coefficients);
201+
test.AddAttribute("intercepts", intercepts);
202+
test.AddAttribute("classlabels_ints", classes);
203+
204+
test.AddInput<float>("X", {1, 2}, {1.f, 0.f});
205+
test.AddOutput<int64_t>("Y", {1}, {0LL});
206+
test.AddOutput<float>("Z", {1, 3}, {0.f, 0.f, 0.f});
207+
208+
test.Run(OpTester::ExpectResult::kExpectFailure,
209+
"coefficients size (5) must be a multiple of the number of classes (3)");
210+
}
169211
} // namespace test
170212
} // namespace onnxruntime

0 commit comments

Comments
 (0)