Skip to content

Commit 127704c

Browse files
vrasparCopilot
andauthored
Validate g_idx values in MatMulNBits to prevent OOB read (#27582)
### Description In `Dequantize4BitsKernelReOrder` (CPU and CUDA EP), values from the `g_idx` tensor are used directly as array indices into the `scales` and `zero_points` buffers without bounds checking. This PR adds value-range validation and tests for the `g_idx` input tensor in the `MatMulNBits` operator. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 9d7e6d5 commit 127704c

3 files changed

Lines changed: 104 additions & 0 deletions

File tree

onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_helper.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,19 @@ Status CheckInputs(const T* /*activation*/,
6767
// Group_index shall be 1D of K, or K padded to multiple of block_size
6868
ASSERT_TENSOR_SHAPE_2(group_index, make_shape(k), make_shape(k_blocks * block_size));
6969

70+
// Validate group_index values are within valid range [0, k_blocks)
71+
if (group_index != nullptr && group_index->Location().device.Type() == OrtDevice::CPU) {
72+
auto g_idx_data = group_index->template Data<int32_t>();
73+
auto g_idx_size = static_cast<size_t>(group_index->Shape().Size());
74+
for (size_t i = 0; i < g_idx_size; ++i) {
75+
if (g_idx_data[i] < 0 || g_idx_data[i] >= k_blocks) {
76+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
77+
"group_index value at index ", i, " is ", g_idx_data[i],
78+
", which is out of valid range [0, ", k_blocks, ")");
79+
}
80+
}
81+
}
82+
7083
ASSERT_TENSOR_SHAPE(bias, make_shape(n));
7184

7285
return Status::OK();

onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_4bits.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@ __global__ void Dequantize4BitsKernelReOrder(
111111
const int32_t* reorder_idx_with_off = reorder_idx + kb_idx * block_size + ((threadIdx.x * element_per_thread) & (block_size - 1));
112112
for (int i = 0; i < element_per_thread; i++) {
113113
int32_t rid = reorder_idx_with_off[i];
114+
CUDA_KERNEL_ASSERT(rid >= 0 && rid < groups_per_K);
115+
rid = max(0, min(rid, groups_per_K - 1)); // Clamp for release safety
114116
T scale = *(scale_data + n_idx * scales_shape_x + rid);
115117
uint8_t zp = 8; // Default zero point is 1 << (bits - 1)
116118
if (zero_points) {

onnxruntime/test/contrib_ops/matmul_4bits_test.cc

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -880,6 +880,95 @@ TEST(MatMulNBits, Basic_M10_N128_K512) {
880880
}
881881
#endif
882882

883+
// Test that out-of-range g_idx values are rejected with INVALID_ARGUMENT.
884+
// CUDA EP is excluded from these tests, so no risk of hitting CUDA_KERNEL_ASSERT.
885+
TEST(MatMulNBits, InvalidGIdx_OutOfRange) {
886+
constexpr int64_t M = 2, N = 4, K = 32, block_size = 16;
887+
constexpr int64_t k_blocks = (K + block_size - 1) / block_size; // 2
888+
constexpr int64_t blob_size = block_size * QBits / 8; // 8
889+
890+
OpTester test("MatMulNBits", 1, kMSDomain);
891+
test.AddAttribute<int64_t>("K", K);
892+
test.AddAttribute<int64_t>("N", N);
893+
test.AddAttribute<int64_t>("block_size", block_size);
894+
test.AddAttribute<int64_t>("bits", QBits);
895+
test.AddAttribute<int64_t>("accuracy_level", int64_t{0});
896+
897+
// A: [M, K]
898+
std::vector<float> a_data(M * K, 1.0f);
899+
test.AddInput<float>("A", {M, K}, a_data, false);
900+
901+
// B: [N, k_blocks, blob_size]
902+
std::vector<uint8_t> b_data(N * k_blocks * blob_size, 0);
903+
test.AddInput<uint8_t>("B", {N, k_blocks, blob_size}, b_data, true);
904+
905+
// scales: [N, k_blocks]
906+
std::vector<float> scales(N * k_blocks, 1.0f);
907+
test.AddInput<float>("scales", {N, k_blocks}, scales, true);
908+
909+
// zero_points: optional (skip)
910+
test.AddOptionalInputEdge<uint8_t>();
911+
912+
// g_idx with out-of-range values (valid range is [0, k_blocks) = [0, 2))
913+
std::vector<int32_t> g_idx(K);
914+
for (int64_t i = 0; i < K; i++) {
915+
g_idx[i] = 99999; // way out of range
916+
}
917+
test.AddInput<int32_t>("g_idx", {K}, g_idx, true);
918+
919+
// bias: optional (skip)
920+
test.AddOptionalInputEdge<float>();
921+
922+
// Output placeholder (won't actually be compared since we expect failure)
923+
std::vector<float> y_data(M * N, 0.0f);
924+
test.AddOutput<float>("Y", {M, N}, y_data);
925+
926+
test.Run(OpTester::ExpectResult::kExpectFailure, "group_index value",
927+
{kCudaExecutionProvider, kCudaNHWCExecutionProvider, kDmlExecutionProvider, kWebGpuExecutionProvider,
928+
kOpenVINOExecutionProvider});
929+
}
930+
931+
// Test that negative g_idx values are rejected.
932+
TEST(MatMulNBits, InvalidGIdx_Negative) {
933+
constexpr int64_t M = 2, N = 4, K = 32, block_size = 16;
934+
constexpr int64_t k_blocks = (K + block_size - 1) / block_size;
935+
constexpr int64_t blob_size = block_size * QBits / 8;
936+
937+
OpTester test("MatMulNBits", 1, kMSDomain);
938+
test.AddAttribute<int64_t>("K", K);
939+
test.AddAttribute<int64_t>("N", N);
940+
test.AddAttribute<int64_t>("block_size", block_size);
941+
test.AddAttribute<int64_t>("bits", QBits);
942+
test.AddAttribute<int64_t>("accuracy_level", int64_t{0});
943+
944+
std::vector<float> a_data(M * K, 1.0f);
945+
test.AddInput<float>("A", {M, K}, a_data, false);
946+
947+
std::vector<uint8_t> b_data(N * k_blocks * blob_size, 0);
948+
test.AddInput<uint8_t>("B", {N, k_blocks, blob_size}, b_data, true);
949+
950+
std::vector<float> scales(N * k_blocks, 1.0f);
951+
test.AddInput<float>("scales", {N, k_blocks}, scales, true);
952+
953+
test.AddOptionalInputEdge<uint8_t>();
954+
955+
// g_idx with negative values
956+
std::vector<int32_t> g_idx(K);
957+
for (int64_t i = 0; i < K; i++) {
958+
g_idx[i] = -1;
959+
}
960+
test.AddInput<int32_t>("g_idx", {K}, g_idx, true);
961+
962+
test.AddOptionalInputEdge<float>();
963+
964+
std::vector<float> y_data(M * N, 0.0f);
965+
test.AddOutput<float>("Y", {M, N}, y_data);
966+
967+
test.Run(OpTester::ExpectResult::kExpectFailure, "group_index value",
968+
{kCudaExecutionProvider, kCudaNHWCExecutionProvider, kDmlExecutionProvider, kWebGpuExecutionProvider,
969+
kOpenVINOExecutionProvider});
970+
}
971+
883972
} // namespace test
884973
} // namespace onnxruntime
885974

0 commit comments

Comments
 (0)