Skip to content

Commit e39541c

Browse files
vrasparCopilot
andauthored
Add tensor size validation for MatMulBnb4 to prevent OOB read via K/N attribute mismatch (#27995)
### Description Validates that b_quant and absmax tensor sizes are consistent with K/N/block_size attributes before dequantization in the MatMulBnb4 operator. Fixes https://portal.microsofticm.com/imp/v5/incidents/details/31000000559964/summary ### Changes - **Constructor**: Validate K > 0, N > 0, block_size > 0 - **Compute()**: Validate b_quant size >= (K*N+1)/2 and absmax size >= ceil(K*N/block_size) - **Tests**: Two regression tests for undersized b_quant and absmax tensors ### Motivation and Context MSRC case 109215: A crafted model can set K/N attributes larger than actual tensor sizes, causing OOB reads from b_quant_data and absmax_data during dequantization. --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 57161cc commit e39541c

3 files changed

Lines changed: 140 additions & 0 deletions

File tree

onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ class MatMulBnb4 final : public OpKernel {
1919
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("N", &N_));
2020
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("block_size", &block_size_));
2121
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("quant_type", &quant_type_));
22+
ORT_ENFORCE(K_ > 0, "K must be positive, got ", K_);
23+
ORT_ENFORCE(N_ > 0, "N must be positive, got ", N_);
24+
ORT_ENFORCE(block_size_ > 0, "block_size must be positive, got ", block_size_);
2225
ORT_ENFORCE(
2326
quant_type_ == FP4 || quant_type_ == NF4,
2427
"Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
@@ -50,6 +53,32 @@ Status MatMulBnb4::Compute(OpKernelContext* ctx) const {
5053
const uint8_t* b_quant_data = b_quant->Data<uint8_t>();
5154
const float* absmax_data = absmax->Data<float>();
5255

56+
// Overflow-safe computation of expected tensor sizes.
57+
// K_, N_, block_size_ are validated > 0 in the constructor.
58+
if (K_ > std::numeric_limits<int64_t>::max() / N_) {
59+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
60+
"Overflow computing K * N for K=", K_, ", N=", N_, ".");
61+
}
62+
const int64_t numel = K_ * N_;
63+
// Overflow-safe ceiling division: rewrite (a + b - 1) / b as ((a - 1) / b) + 1.
64+
// Safe because numel > 0 (K_ > 0 and N_ > 0 validated in constructor).
65+
const int64_t expected_b_quant_size = ((numel - 1) / 2) + 1;
66+
const int64_t expected_absmax_size = ((numel - 1) / block_size_) + 1;
67+
68+
if (b_quant->Shape().Size() < expected_b_quant_size) {
69+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
70+
"b_quant tensor size (", b_quant->Shape().Size(),
71+
") is too small for K=", K_, " and N=", N_,
72+
". Expected at least ", expected_b_quant_size, " elements.");
73+
}
74+
if (absmax->Shape().Size() < expected_absmax_size) {
75+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
76+
"absmax tensor size (", absmax->Shape().Size(),
77+
") is too small for K=", K_, ", N=", N_,
78+
", block_size=", block_size_,
79+
". Expected at least ", expected_absmax_size, " elements.");
80+
}
81+
5382
AllocatorPtr allocator;
5483
auto status = ctx->GetTempSpaceAllocator(&allocator);
5584
ORT_RETURN_IF_ERROR(status);

onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ class MatMulBnb4 final : public CudaKernel {
2222
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("N", &N_));
2323
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("block_size", &block_size_));
2424
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("quant_type", &quant_type_));
25+
ORT_ENFORCE(K_ > 0, "K must be positive, got ", K_);
26+
ORT_ENFORCE(N_ > 0, "N must be positive, got ", N_);
27+
ORT_ENFORCE(block_size_ > 0, "block_size must be positive, got ", block_size_);
2528
ORT_ENFORCE(
2629
quant_type_ == FP4 || quant_type_ == NF4,
2730
"Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
@@ -51,6 +54,32 @@ Status MatMulBnb4<T>::ComputeInternal(OpKernelContext* ctx) const {
5154
const uint8_t* b_quant_data = b_quant->Data<uint8_t>();
5255
const auto* absmax_data = absmax->Data<T>();
5356

57+
// Overflow-safe computation of expected tensor sizes.
58+
// K_, N_, block_size_ are validated > 0 in the constructor.
59+
if (K_ > std::numeric_limits<int64_t>::max() / N_) {
60+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
61+
"Overflow computing K * N for K=", K_, ", N=", N_, ".");
62+
}
63+
const int64_t numel = K_ * N_;
64+
// Overflow-safe ceiling division: rewrite (a + b - 1) / b as ((a - 1) / b) + 1.
65+
// Safe because numel > 0 (K_ > 0 and N_ > 0 validated in constructor).
66+
const int64_t expected_b_quant_size = ((numel - 1) / 2) + 1;
67+
const int64_t expected_absmax_size = ((numel - 1) / block_size_) + 1;
68+
69+
if (b_quant->Shape().Size() < expected_b_quant_size) {
70+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
71+
"b_quant tensor size (", b_quant->Shape().Size(),
72+
") is too small for K=", K_, " and N=", N_,
73+
". Expected at least ", expected_b_quant_size, " elements.");
74+
}
75+
if (absmax->Shape().Size() < expected_absmax_size) {
76+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
77+
"absmax tensor size (", absmax->Shape().Size(),
78+
") is too small for K=", K_, ", N=", N_,
79+
", block_size=", block_size_,
80+
". Expected at least ", expected_absmax_size, " elements.");
81+
}
82+
5483
typedef typename ToCudaType<T>::MappedType CudaT;
5584

5685
// TODO: find a better way to create the quant_map without using a buffer

onnxruntime/test/contrib_ops/matmul_bnb4_test.cc

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,88 @@ void RunTest(int64_t quant_type, int64_t M, int64_t N, int64_t K, int64_t block_
115115
}
116116
}
117117

118+
TEST(MatMulBnb4, RejectsUndersizedBQuantTensor) {
119+
// K=32, N=2 → numel=64, expected b_quant size = (64+1)/2 = 32
120+
// Provide only 4 bytes (valid for K=4, N=2) but claim K=32, N=2
121+
OpTester test("MatMulBnb4", 1, kMSDomain);
122+
test.AddAttribute<int64_t>("K", 32LL);
123+
test.AddAttribute<int64_t>("N", 2LL);
124+
test.AddAttribute<int64_t>("block_size", 32LL);
125+
test.AddAttribute<int64_t>("quant_type", 1LL); // NF4
126+
127+
test.AddInput<float>("A", {1, 32}, std::vector<float>(32, 0.0f));
128+
test.AddInput<uint8_t>("B", {4}, std::vector<uint8_t>(4, 0)); // too small
129+
test.AddInput<float>("absmax", {2}, std::vector<float>(2, 1.0f));
130+
test.AddOutput<float>("Y", {1, 2}, std::vector<float>(2, 0.0f));
131+
132+
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
133+
execution_providers.push_back(DefaultCpuExecutionProvider());
134+
test.Run(OpTester::ExpectResult::kExpectFailure, "b_quant tensor size", {}, nullptr, &execution_providers);
135+
}
136+
137+
TEST(MatMulBnb4, RejectsUndersizedAbsmaxTensor) {
138+
// K=32, N=2, block_size=32 → numel=64, expected absmax size = (64+32-1)/32 = 2
139+
// Provide only 1 absmax element
140+
int64_t K = 32, N = 2, block_size = 32;
141+
int64_t numel = K * N;
142+
int64_t quantized_numel = (numel + 1) / 2;
143+
144+
OpTester test("MatMulBnb4", 1, kMSDomain);
145+
test.AddAttribute<int64_t>("K", K);
146+
test.AddAttribute<int64_t>("N", N);
147+
test.AddAttribute<int64_t>("block_size", block_size);
148+
test.AddAttribute<int64_t>("quant_type", 1LL); // NF4
149+
150+
test.AddInput<float>("A", {1, K}, std::vector<float>(K, 0.0f));
151+
test.AddInput<uint8_t>("B", {quantized_numel}, std::vector<uint8_t>(quantized_numel, 0));
152+
test.AddInput<float>("absmax", {1}, std::vector<float>(1, 1.0f)); // too small
153+
test.AddOutput<float>("Y", {1, N}, std::vector<float>(N, 0.0f));
154+
155+
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
156+
execution_providers.push_back(DefaultCpuExecutionProvider());
157+
test.Run(OpTester::ExpectResult::kExpectFailure, "absmax tensor size", {}, nullptr, &execution_providers);
158+
}
159+
160+
#if defined(USE_CUDA)
161+
TEST(MatMulBnb4, RejectsUndersizedBQuantTensorCuda) {
162+
OpTester test("MatMulBnb4", 1, kMSDomain);
163+
test.AddAttribute<int64_t>("K", 32LL);
164+
test.AddAttribute<int64_t>("N", 2LL);
165+
test.AddAttribute<int64_t>("block_size", 32LL);
166+
test.AddAttribute<int64_t>("quant_type", 1LL); // NF4
167+
168+
test.AddInput<float>("A", {1, 32}, std::vector<float>(32, 0.0f));
169+
test.AddInput<uint8_t>("B", {4}, std::vector<uint8_t>(4, 0)); // too small
170+
test.AddInput<float>("absmax", {2}, std::vector<float>(2, 1.0f));
171+
test.AddOutput<float>("Y", {1, 2}, std::vector<float>(2, 0.0f));
172+
173+
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
174+
execution_providers.push_back(DefaultCudaExecutionProvider());
175+
test.Run(OpTester::ExpectResult::kExpectFailure, "b_quant tensor size", {}, nullptr, &execution_providers);
176+
}
177+
178+
TEST(MatMulBnb4, RejectsUndersizedAbsmaxTensorCuda) {
179+
int64_t K = 32, N = 2, block_size = 32;
180+
int64_t numel = K * N;
181+
int64_t quantized_numel = (numel + 1) / 2;
182+
183+
OpTester test("MatMulBnb4", 1, kMSDomain);
184+
test.AddAttribute<int64_t>("K", K);
185+
test.AddAttribute<int64_t>("N", N);
186+
test.AddAttribute<int64_t>("block_size", block_size);
187+
test.AddAttribute<int64_t>("quant_type", 1LL); // NF4
188+
189+
test.AddInput<float>("A", {1, K}, std::vector<float>(K, 0.0f));
190+
test.AddInput<uint8_t>("B", {quantized_numel}, std::vector<uint8_t>(quantized_numel, 0));
191+
test.AddInput<float>("absmax", {1}, std::vector<float>(1, 1.0f)); // too small
192+
test.AddOutput<float>("Y", {1, N}, std::vector<float>(N, 0.0f));
193+
194+
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
195+
execution_providers.push_back(DefaultCudaExecutionProvider());
196+
test.Run(OpTester::ExpectResult::kExpectFailure, "absmax tensor size", {}, nullptr, &execution_providers);
197+
}
198+
#endif
199+
118200
TEST(MatMulBnb4, DISABLED_Float32) {
119201
for (auto qt : {0, 1}) {
120202
for (auto M : {1, 2, 100}) {

0 commit comments

Comments
 (0)