From 3ef5882e37657bf4d39acc17d1411c35d831109d Mon Sep 17 00:00:00 2001 From: vraspar Date: Fri, 6 Mar 2026 22:47:02 +0000 Subject: [PATCH 1/6] Validate g_idx values in MatMulNBits to prevent OOB read --- .../cpu/quantization/matmul_nbits_helper.h | 13 +++ .../test/contrib_ops/matmul_4bits_test.cc | 84 +++++++++++++++++++ 2 files changed, 97 insertions(+) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_helper.h b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_helper.h index 25bcb3932795b..34f5664cff597 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_helper.h +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_helper.h @@ -67,6 +67,19 @@ Status CheckInputs(const T* /*activation*/, // Group_index shall be 1D of K, or K padded to multiple of block_size ASSERT_TENSOR_SHAPE_2(group_index, make_shape(k), make_shape(k_blocks * block_size)); + // Validate group_index values are within valid range [0, k_blocks) + if (group_index != nullptr) { + auto g_idx_data = group_index->template Data(); + auto g_idx_size = group_index->Shape().Size(); + for (int64_t i = 0; i < g_idx_size; ++i) { + if (g_idx_data[i] < 0 || g_idx_data[i] >= k_blocks) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "group_index value at index ", i, " is ", g_idx_data[i], + ", which is out of valid range [0, ", k_blocks, ")"); + } + } + } + ASSERT_TENSOR_SHAPE(bias, make_shape(n)); return Status::OK(); diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index fbbdc419118cd..4b530a65306cf 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -880,6 +880,90 @@ TEST(MatMulNBits, Basic_M10_N128_K512) { } #endif +// Test that out-of-range g_idx values are rejected with INVALID_ARGUMENT. +TEST(MatMulNBits, InvalidGIdx_OutOfRange) { + constexpr int64_t M = 2, N = 4, K = 32, block_size = 16; + constexpr int64_t k_blocks = (K + block_size - 1) / block_size; // 2 + constexpr int64_t blob_size = block_size * QBits / 8; // 8 + + OpTester test("MatMulNBits", 1, kMSDomain); + test.AddAttribute("K", K); + test.AddAttribute("N", N); + test.AddAttribute("block_size", block_size); + test.AddAttribute("bits", QBits); + test.AddAttribute("accuracy_level", int64_t{0}); + + // A: [M, K] + std::vector a_data(M * K, 1.0f); + test.AddInput("A", {M, K}, a_data, false); + + // B: [N, k_blocks, blob_size] + std::vector b_data(N * k_blocks * blob_size, 0); + test.AddInput("B", {N, k_blocks, blob_size}, b_data, true); + + // scales: [N, k_blocks] + std::vector scales(N * k_blocks, 1.0f); + test.AddInput("scales", {N, k_blocks}, scales, true); + + // zero_points: optional (skip) + test.AddOptionalInputEdge(); + + // g_idx with out-of-range values (valid range is [0, k_blocks) = [0, 2)) + std::vector g_idx(K); + for (int64_t i = 0; i < K; i++) { + g_idx[i] = 99999; // way out of range + } + test.AddInput("g_idx", {K}, g_idx, true); + + // bias: optional (skip) + test.AddOptionalInputEdge(); + + // Output placeholder (won't actually be compared since we expect failure) + std::vector y_data(M * N, 0.0f); + test.AddOutput("Y", {M, N}, y_data); + + test.Run(OpTester::ExpectResult::kExpectFailure, "group_index value"); +} + +// Test that negative g_idx values are rejected. +TEST(MatMulNBits, InvalidGIdx_Negative) { + constexpr int64_t M = 2, N = 4, K = 32, block_size = 16; + constexpr int64_t k_blocks = (K + block_size - 1) / block_size; + constexpr int64_t blob_size = block_size * QBits / 8; + + OpTester test("MatMulNBits", 1, kMSDomain); + test.AddAttribute("K", K); + test.AddAttribute("N", N); + test.AddAttribute("block_size", block_size); + test.AddAttribute("bits", QBits); + test.AddAttribute("accuracy_level", int64_t{0}); + + std::vector a_data(M * K, 1.0f); + test.AddInput("A", {M, K}, a_data, false); + + std::vector b_data(N * k_blocks * blob_size, 0); + test.AddInput("B", {N, k_blocks, blob_size}, b_data, true); + + std::vector scales(N * k_blocks, 1.0f); + test.AddInput("scales", {N, k_blocks}, scales, true); + + test.AddOptionalInputEdge(); + + // g_idx with negative values + std::vector g_idx(K); + for (int64_t i = 0; i < K; i++) { + g_idx[i] = -1; + } + test.AddInput("g_idx", {K}, g_idx, true); + + test.AddOptionalInputEdge(); + + std::vector y_data(M * N, 0.0f); + test.AddOutput("Y", {M, N}, y_data); + + test.Run(OpTester::ExpectResult::kExpectFailure, "group_index value"); +} + } // namespace test } // namespace onnxruntime From a24153b76849d28b405bf0b6b7fea071b5ee52a9 Mon Sep 17 00:00:00 2001 From: vraspar Date: Wed, 11 Mar 2026 23:00:06 +0000 Subject: [PATCH 2/6] Enhance group_index validation in CheckInputs to ensure CPU device type and update test for out-of-range g_idx values --- .../contrib_ops/cpu/quantization/matmul_nbits_helper.h | 6 +++--- onnxruntime/test/contrib_ops/matmul_4bits_test.cc | 2 ++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_helper.h b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_helper.h index 34f5664cff597..072471d3b83a3 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_helper.h +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_helper.h @@ -68,10 +68,10 @@ Status CheckInputs(const T* /*activation*/, ASSERT_TENSOR_SHAPE_2(group_index, make_shape(k), make_shape(k_blocks * block_size)); // Validate group_index values are within valid range [0, k_blocks) - if (group_index != nullptr) { + if (group_index != nullptr && group_index->Location().device.Type() == OrtDevice::CPU) { auto g_idx_data = group_index->template Data(); - auto g_idx_size = group_index->Shape().Size(); - for (int64_t i = 0; i < g_idx_size; ++i) { + auto g_idx_size = static_cast(group_index->Shape().Size()); + for (size_t i = 0; i < g_idx_size; ++i) { if (g_idx_data[i] < 0 || g_idx_data[i] >= k_blocks) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "group_index value at index ", i, " is ", g_idx_data[i], diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 4b530a65306cf..1aeaf59bc0632 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -880,6 +880,7 @@ TEST(MatMulNBits, Basic_M10_N128_K512) { } #endif +#if !defined(USE_DML) && !defined(USE_WEBGPU) // Test that out-of-range g_idx values are rejected with INVALID_ARGUMENT. TEST(MatMulNBits, InvalidGIdx_OutOfRange) { constexpr int64_t M = 2, N = 4, K = 32, block_size = 16; @@ -963,6 +964,7 @@ TEST(MatMulNBits, InvalidGIdx_Negative) { test.Run(OpTester::ExpectResult::kExpectFailure, "group_index value"); } +#endif // !defined(USE_DML) && !defined(USE_WEBGPU) } // namespace test } // namespace onnxruntime From 9788fb8ea4b5378d5b65e8567d371f6ba3aceae4 Mon Sep 17 00:00:00 2001 From: vraspar Date: Wed, 25 Mar 2026 20:31:03 +0000 Subject: [PATCH 3/6] Add validation for group_index in Dequantize4BitsKernelReOrder and update tests for out-of-range g_idx values --- .../cuda/quantization/dequantize_blockwise_4bits.cu | 1 + onnxruntime/test/contrib_ops/matmul_4bits_test.cc | 8 ++++---- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_4bits.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_4bits.cu index cbcd4ed2f54a0..087ab2a31cafd 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_4bits.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_4bits.cu @@ -112,6 +112,7 @@ __global__ void Dequantize4BitsKernelReOrder( const int32_t* reorder_idx_with_off = reorder_idx + kb_idx * block_size + ((threadIdx.x * element_per_thread) & (block_size - 1)); for (int i = 0; i < element_per_thread; i++) { int32_t rid = reorder_idx_with_off[i]; + CUDA_KERNEL_ASSERT(rid >= 0 && rid < groups_per_K); T scale = *(scale_data + n_idx * scales_shape_x + rid); uint8_t zp = 8; // Default zero point is 1 << (bits - 1) if (zero_points) { diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 1aeaf59bc0632..0c77eff407286 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -880,7 +880,6 @@ TEST(MatMulNBits, Basic_M10_N128_K512) { } #endif -#if !defined(USE_DML) && !defined(USE_WEBGPU) // Test that out-of-range g_idx values are rejected with INVALID_ARGUMENT. TEST(MatMulNBits, InvalidGIdx_OutOfRange) { constexpr int64_t M = 2, N = 4, K = 32, block_size = 16; @@ -923,7 +922,8 @@ TEST(MatMulNBits, InvalidGIdx_OutOfRange) { std::vector y_data(M * N, 0.0f); test.AddOutput("Y", {M, N}, y_data); - test.Run(OpTester::ExpectResult::kExpectFailure, "group_index value"); + test.Run(OpTester::ExpectResult::kExpectFailure, "group_index value", + {kDmlExecutionProvider, kWebGpuExecutionProvider}); } // Test that negative g_idx values are rejected. @@ -962,9 +962,9 @@ TEST(MatMulNBits, InvalidGIdx_Negative) { std::vector y_data(M * N, 0.0f); test.AddOutput("Y", {M, N}, y_data); - test.Run(OpTester::ExpectResult::kExpectFailure, "group_index value"); + test.Run(OpTester::ExpectResult::kExpectFailure, "group_index value", + {kDmlExecutionProvider, kWebGpuExecutionProvider}); } -#endif // !defined(USE_DML) && !defined(USE_WEBGPU) } // namespace test } // namespace onnxruntime From 796470115956b325726f6bcc33510af46e8fab89 Mon Sep 17 00:00:00 2001 From: vraspar Date: Wed, 1 Apr 2026 22:18:34 +0000 Subject: [PATCH 4/6] Exclude CUDA EP and skip InvalidGIdx tests in debug builds --- onnxruntime/test/contrib_ops/matmul_4bits_test.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 0c77eff407286..3fdf3138347a3 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -881,6 +881,8 @@ TEST(MatMulNBits, Basic_M10_N128_K512) { #endif // Test that out-of-range g_idx values are rejected with INVALID_ARGUMENT. +// Skip in debug builds to avoid hitting CUDA_KERNEL_ASSERT which corrupts the CUDA device context. +#ifdef NDEBUG TEST(MatMulNBits, InvalidGIdx_OutOfRange) { constexpr int64_t M = 2, N = 4, K = 32, block_size = 16; constexpr int64_t k_blocks = (K + block_size - 1) / block_size; // 2 @@ -923,7 +925,7 @@ TEST(MatMulNBits, InvalidGIdx_OutOfRange) { test.AddOutput("Y", {M, N}, y_data); test.Run(OpTester::ExpectResult::kExpectFailure, "group_index value", - {kDmlExecutionProvider, kWebGpuExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kDmlExecutionProvider, kWebGpuExecutionProvider}); } // Test that negative g_idx values are rejected. @@ -963,8 +965,9 @@ TEST(MatMulNBits, InvalidGIdx_Negative) { test.AddOutput("Y", {M, N}, y_data); test.Run(OpTester::ExpectResult::kExpectFailure, "group_index value", - {kDmlExecutionProvider, kWebGpuExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kDmlExecutionProvider, kWebGpuExecutionProvider}); } +#endif // NDEBUG } // namespace test } // namespace onnxruntime From f8627a3607eb35f4d79ccd5f097d1d08a62e36c4 Mon Sep 17 00:00:00 2001 From: vraspar Date: Fri, 3 Apr 2026 21:53:18 +0000 Subject: [PATCH 5/6] Exclude OpenVINO EP from InvalidGIdx tests --- onnxruntime/test/contrib_ops/matmul_4bits_test.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 3fdf3138347a3..29b7ed7961a93 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -925,7 +925,8 @@ TEST(MatMulNBits, InvalidGIdx_OutOfRange) { test.AddOutput("Y", {M, N}, y_data); test.Run(OpTester::ExpectResult::kExpectFailure, "group_index value", - {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kDmlExecutionProvider, kWebGpuExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kDmlExecutionProvider, kWebGpuExecutionProvider, + kOpenVINOExecutionProvider}); } // Test that negative g_idx values are rejected. @@ -965,7 +966,8 @@ TEST(MatMulNBits, InvalidGIdx_Negative) { test.AddOutput("Y", {M, N}, y_data); test.Run(OpTester::ExpectResult::kExpectFailure, "group_index value", - {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kDmlExecutionProvider, kWebGpuExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kDmlExecutionProvider, kWebGpuExecutionProvider, + kOpenVINOExecutionProvider}); } #endif // NDEBUG From 456c6aa9d2ce033e4918e3d76eb362b4375cf966 Mon Sep 17 00:00:00 2001 From: vraspar Date: Mon, 6 Apr 2026 21:18:28 +0000 Subject: [PATCH 6/6] Add release-safe clamp in CUDA kernel and remove redundant NDEBUG guard - Add rid clamping after CUDA_KERNEL_ASSERT in Dequantize4BitsKernelReOrder to prevent OOB memory access in release builds where the assert is a no-op - Remove unnecessary #ifdef NDEBUG guard around InvalidGIdx tests since CUDA EP is already excluded via OpTester::Run() parameters Addresses review feedback from tianleiwu. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../cuda/quantization/dequantize_blockwise_4bits.cu | 1 + onnxruntime/test/contrib_ops/matmul_4bits_test.cc | 4 +--- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_4bits.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_4bits.cu index 087ab2a31cafd..5c4501945cecf 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_4bits.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_4bits.cu @@ -113,6 +113,7 @@ __global__ void Dequantize4BitsKernelReOrder( for (int i = 0; i < element_per_thread; i++) { int32_t rid = reorder_idx_with_off[i]; CUDA_KERNEL_ASSERT(rid >= 0 && rid < groups_per_K); + rid = max(0, min(rid, groups_per_K - 1)); // Clamp for release safety T scale = *(scale_data + n_idx * scales_shape_x + rid); uint8_t zp = 8; // Default zero point is 1 << (bits - 1) if (zero_points) { diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 29b7ed7961a93..b463aa3a6c363 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -881,8 +881,7 @@ TEST(MatMulNBits, Basic_M10_N128_K512) { #endif // Test that out-of-range g_idx values are rejected with INVALID_ARGUMENT. -// Skip in debug builds to avoid hitting CUDA_KERNEL_ASSERT which corrupts the CUDA device context. -#ifdef NDEBUG +// CUDA EP is excluded from these tests, so no risk of hitting CUDA_KERNEL_ASSERT. TEST(MatMulNBits, InvalidGIdx_OutOfRange) { constexpr int64_t M = 2, N = 4, K = 32, block_size = 16; constexpr int64_t k_blocks = (K + block_size - 1) / block_size; // 2 @@ -969,7 +968,6 @@ TEST(MatMulNBits, InvalidGIdx_Negative) { {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kDmlExecutionProvider, kWebGpuExecutionProvider, kOpenVINOExecutionProvider}); } -#endif // NDEBUG } // namespace test } // namespace onnxruntime