From 449be5c2e23dd2f7440a4d81c2ca0efb7cdd5abd Mon Sep 17 00:00:00 2001 From: Jonathan Clohessy Date: Thu, 7 May 2026 10:56:42 +0100 Subject: [PATCH 1/6] Add fix for fastmath corner case with failing tests via sbgemm_neon_kernel Signed-off-by: Jonathan Clohessy --- .../mlas/lib/kleidiai/convolve_kleidiai.cpp | 9 +-- onnxruntime/core/providers/cpu/math/matmul.cc | 11 +++- onnxruntime/core/providers/cpu/math/matmul.h | 2 + .../cpu/math/matmul_fastmath_test.cc | 59 +++++++++++++++++++ 4 files changed, 73 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp index 8dcaff50ef5d0..bc093be2abf61 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp +++ b/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp @@ -44,14 +44,12 @@ struct LhsCacheKey { size_t padding, sh, sw; size_t kh, kw; size_t dilationh, dilationw; - size_t data_hash; bool operator==(const LhsCacheKey& other) const { return ci == other.ci && ih == other.ih && iw == other.iw && padding == other.padding && sh == other.sh && sw == other.sw && kh == other.kh && kw == other.kw && - dilationh == other.dilationh && dilationw == other.dilationw && - data_hash == other.data_hash; + dilationh == other.dilationh && dilationw == other.dilationw; } }; @@ -88,7 +86,7 @@ namespace std { template<> struct hash { size_t operator()(const LhsCacheKey& k) const { - return k.data_hash ^ + return (std::hash()(k.ci) << 1) ^ (std::hash()(k.ih) << 2) ^ (std::hash()(k.iw) << 3) ^ @@ -506,8 +504,7 @@ static std::unique_ptr LhsPackImageDataSme(const size_t ci, const s ci, ih, iw, padding, sh, sw, kh, kw, - 1, 1, - HashWeights(in) + 1, 1 }; auto& lhs_ptrs_cache = lhs_ptrs_cache_by_pad[cur_pad_ptr]; diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc index ca91f46db93da..8d087a2f1482c 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc @@ -253,7 +253,11 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, /*out*/ Alloc dim2 = static_cast(b_shape[1]); } - if (use_fastmath_mode_ && (trans_a_attr_ == 0) && (trans_b_attr_ == 0) && ((dim1 * dim2) >= kFastMathModeKernelsizeThreshold)) { + if (use_fastmath_mode_ && + (trans_a_attr_ == 0) && + (trans_b_attr_ == 0) && + ((dim1 * dim2) >= kFastMathModeKernelsizeThreshold) && + ((dim1 % kFastMathModeKAlignment) == 0)) { is_packed = GemmPackBBfloat16(alloc, tensor, trans_a_attr_ != 0, trans_b_attr_ != 0, packed_b_, packed_b_size, b_shape_, &mlas_backend_kernel_selector_config_); } else #endif @@ -323,7 +327,10 @@ Status MatMul::Compute(OpKernelContext* ctx) const { const size_t lda = helper.Lda(trans_a); const size_t ldb = helper.Ldb(trans_b); #if defined(__aarch64__) && defined(__linux__) - if (use_fastmath_mode_ && !trans_a && !trans_b && ((N * K) >= kFastMathModeKernelsizeThreshold)) { + if (use_fastmath_mode_ && + !trans_a && !trans_b && + ((N * K) >= kFastMathModeKernelsizeThreshold) && + ((K % kFastMathModeKAlignment) == 0)) { std::vector data(max_len); for (size_t i = 0; i < max_len; i++) { data[i].BIsfp32 = !(bool(packed_b_)); diff --git a/onnxruntime/core/providers/cpu/math/matmul.h b/onnxruntime/core/providers/cpu/math/matmul.h index d1c0df19f924e..c3e93846f670f 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.h +++ b/onnxruntime/core/providers/cpu/math/matmul.h @@ -102,6 +102,8 @@ class MatMul final : public OpKernel { bool use_fastmath_mode_; // sbgemm kernel is implemented as 8x8 blocks with weights pre-packed to 4 blocks of 4x2 // so a minimum of 32 elements is defined to outweigh the additional prepacking overhead + // The NEON SBGemm kernel consumes A in 4-float groups. Keep K tails on SGEMM + const size_t kFastMathModeKAlignment = 4; const size_t kFastMathModeKernelsizeThreshold = 32; #endif }; diff --git a/onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc b/onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc index 8e6ff1f387bf1..1ff67e1948f8b 100644 --- a/onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc +++ b/onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc @@ -5,9 +5,11 @@ #include "core/session/onnxruntime_session_options_config_keys.h" #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" +#include "core/session/inference_session.h" #include "test/common/dnnl_op_test_utils.h" #include "test/common/cuda_op_test_utils.h" #include "test/common/tensor_op_test_utils.h" +#include "test/util/include/test_environment.h" #include "default_providers.h" #if defined(__aarch64__) && defined(__linux__) @@ -168,6 +170,63 @@ void RunMatMulTest(int32_t opset_version) { RunMatMulTest(opset_version, false, false, false); } +TEST(MathOpTest, MatMulFloatTypeFastMathKTailFallsBackToSgemm) { + constexpr int64_t M = 1; + constexpr int64_t N = 8; + constexpr int64_t K = 13; + + OpTester test("MatMul", 7); + + std::vector input0_vals(K); + std::iota(input0_vals.begin(), input0_vals.end(), 1.0f); + test.AddInput("A", {M, K}, input0_vals); + + std::vector input1_vals(K * N, 1.0f); + test.AddInput("B", {K, N}, input1_vals, true); + + std::vector expected_vals(N, 91.0f); + test.AddOutput("Y", {M, N}, expected_vals); + + Model& model = test.BuildModel(); + ASSERT_STATUS_OK(model.MainGraph().Resolve()); + + std::string serialized_model; + ASSERT_TRUE(model.ToProto().SerializeToString(&serialized_model)); + + SessionOptions so; + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "1")); + + InferenceSession session_object{so, GetEnvironment()}; + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultCpuExecutionProvider())); + + std::stringstream model_stream(serialized_model); + ASSERT_STATUS_OK(session_object.Load(model_stream)); + ASSERT_STATUS_OK(session_object.Initialize()); + + std::vector input0_backing(input0_vals.begin(), input0_vals.end()); + input0_backing.resize(K + 3, std::numeric_limits::quiet_NaN()); + + OrtValue input0; + Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape({M, K}), input0_backing.data(), + OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator), input0); + + NameMLValMap feeds; + feeds.insert(std::make_pair(std::string("A"), input0)); + + std::vector fetches; + ASSERT_STATUS_OK(session_object.Run(RunOptions{}, feeds, AsSpan({std::string("Y")}), &fetches)); + + const auto& output_tensor = fetches[0].Get(); + ASSERT_EQ(output_tensor.Shape(), TensorShape({M, N})); + + const auto* output_data = output_tensor.Data(); + for (int64_t i = 0; i < N; ++i) { + ASSERT_TRUE(std::isfinite(output_data[i])) << "Output " << i << " should not include padded A tail values."; + ASSERT_EQ(output_data[i], expected_vals[i]) << "Output " << i; + } +} + TEST(MathOpTest, MatMulFloatType_FastMath) { // TODO: Unskip when fixed #41968513 if (DefaultDmlExecutionProvider().get() != nullptr) { From e1b1f233a78ea629b9dce60c764f82805985c4ff Mon Sep 17 00:00:00 2001 From: Jonathan Clohessy Date: Thu, 7 May 2026 11:32:42 +0100 Subject: [PATCH 2/6] Revert convolve change and update matmul condition to deal with kdim correctly Signed-off-by: Jonathan Clohessy --- onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp | 9 ++++++--- onnxruntime/core/providers/cpu/math/matmul.cc | 7 +++++-- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp index bc093be2abf61..8dcaff50ef5d0 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp +++ b/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp @@ -44,12 +44,14 @@ struct LhsCacheKey { size_t padding, sh, sw; size_t kh, kw; size_t dilationh, dilationw; + size_t data_hash; bool operator==(const LhsCacheKey& other) const { return ci == other.ci && ih == other.ih && iw == other.iw && padding == other.padding && sh == other.sh && sw == other.sw && kh == other.kh && kw == other.kw && - dilationh == other.dilationh && dilationw == other.dilationw; + dilationh == other.dilationh && dilationw == other.dilationw && + data_hash == other.data_hash; } }; @@ -86,7 +88,7 @@ namespace std { template<> struct hash { size_t operator()(const LhsCacheKey& k) const { - return + return k.data_hash ^ (std::hash()(k.ci) << 1) ^ (std::hash()(k.ih) << 2) ^ (std::hash()(k.iw) << 3) ^ @@ -504,7 +506,8 @@ static std::unique_ptr LhsPackImageDataSme(const size_t ci, const s ci, ih, iw, padding, sh, sw, kh, kw, - 1, 1 + 1, 1, + HashWeights(in) }; auto& lhs_ptrs_cache = lhs_ptrs_cache_by_pad[cur_pad_ptr]; diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc index 8d087a2f1482c..4402a4b1526d0 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc @@ -253,11 +253,14 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, /*out*/ Alloc dim2 = static_cast(b_shape[1]); } + const size_t k_dim = b_shape.NumDimensions() >= 2 + ? static_cast(b_shape[b_shape.NumDimensions() - 2]) + : dim1; if (use_fastmath_mode_ && (trans_a_attr_ == 0) && (trans_b_attr_ == 0) && - ((dim1 * dim2) >= kFastMathModeKernelsizeThreshold) && - ((dim1 % kFastMathModeKAlignment) == 0)) { + ((k_dim * dim2) >= kFastMathModeKernelsizeThreshold) && + ((k_dim % kFastMathModeKAlignment) == 0)) { is_packed = GemmPackBBfloat16(alloc, tensor, trans_a_attr_ != 0, trans_b_attr_ != 0, packed_b_, packed_b_size, b_shape_, &mlas_backend_kernel_selector_config_); } else #endif From ae96d11bb36e29b0446dcffb50317fcc3a031a36 Mon Sep 17 00:00:00 2001 From: Jonathan Clohessy Date: Sat, 23 May 2026 08:40:34 +0100 Subject: [PATCH 3/6] Update test with guards for tensor validity Signed-off-by: Jonathan Clohessy --- onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc b/onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc index 1ff67e1948f8b..dc5b42a517ec8 100644 --- a/onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc +++ b/onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc @@ -216,6 +216,8 @@ TEST(MathOpTest, MatMulFloatTypeFastMathKTailFallsBackToSgemm) { std::vector fetches; ASSERT_STATUS_OK(session_object.Run(RunOptions{}, feeds, AsSpan({std::string("Y")}), &fetches)); + ASSERT_EQ(fetches.size(), 1u); + ASSERT_TRUE(fetches[0].IsTensor()); const auto& output_tensor = fetches[0].Get(); ASSERT_EQ(output_tensor.Shape(), TensorShape({M, N})); From c766837a474e60f4d0118da0866b3b9ac55ab239 Mon Sep 17 00:00:00 2001 From: Jonathan Clohessy Date: Thu, 28 May 2026 15:12:19 +0100 Subject: [PATCH 4/6] Address review comments Signed-off-by: Jonathan Clohessy --- onnxruntime/core/mlas/lib/sbgemm.h | 9 +- .../core/mlas/lib/sbgemm_kernel_neon.cpp | 47 +++++++++- onnxruntime/core/providers/cpu/math/matmul.cc | 33 +++---- onnxruntime/core/providers/cpu/math/matmul.h | 14 ++- .../test/mlas/unittest/test_sbgemm.cpp | 1 + onnxruntime/test/mlas/unittest/test_sbgemm.h | 4 + .../cpu/math/matmul_fastmath_test.cc | 88 +++++++++++-------- 7 files changed, 130 insertions(+), 66 deletions(-) diff --git a/onnxruntime/core/mlas/lib/sbgemm.h b/onnxruntime/core/mlas/lib/sbgemm.h index 35520c7a4226a..99e3912910b16 100644 --- a/onnxruntime/core/mlas/lib/sbgemm.h +++ b/onnxruntime/core/mlas/lib/sbgemm.h @@ -134,7 +134,8 @@ MlasSBGemmPackedOperation(size_t M, size_t RangeStartN, size_t RangeCountN, size bool ZeroMode = (k == 0) && InitialZeroMode; CountK = std::min(K - k, PackedStrideK); - const bfloat16_t* pb = (const bfloat16_t*)PackedB + AlignedN * k + CountK * SliceStartN; + const size_t AlignedCountK = (CountK + KernelType::PackedK - 1) & ~(KernelType::PackedK - 1); + const bfloat16_t* pb = (const bfloat16_t*)PackedB + AlignedN * k + AlignedCountK * SliceStartN; float* c = C + n; const float* pbias = ((nullptr == Bias) ? nullptr : Bias + RangeStartN + n); MlasSBGemmKernel(M, CountN, CountK, A + k, lda, pb, c, ldc, ZeroMode ? pbias : nullptr, ZeroMode); @@ -319,7 +320,7 @@ MlasSBGemmPackBSize( BIsfp32) { size_t bytes_required; bytes_required = GetMlasPlatform().MlasSBGemmPackBSizeOverride(TransA, TransB, N, K); - if (bytes_required != 0){ // If ArmKleidiAI::MlasSBGemmPackBSize ran to completion + if (bytes_required != 0) { // If ArmKleidiAI::MlasSBGemmPackBSize ran to completion return bytes_required; } } @@ -365,7 +366,7 @@ MlasSBGemmConvertPackB( TransA == CBLAS_TRANSPOSE::CblasNoTrans && TransB == CBLAS_TRANSPOSE::CblasNoTrans && BIsfp32 && - GetMlasPlatform().MlasSBGemmPackBOverride(TransA, TransB, N, K, B, ldb, PackedB)){ + GetMlasPlatform().MlasSBGemmPackBOverride(TransA, TransB, N, K, B, ldb, PackedB)) { return; } #endif @@ -400,7 +401,7 @@ MlasSBGemmBatch( TransB == CBLAS_TRANSPOSE::CblasNoTrans && Data->AIsfp32 && (Data->BIsPacked || Data->BIsfp32) && - GetMlasPlatform().MlasSBGemmBatchOverride(TransA, TransB, M, N, K, Data, BatchN, ThreadPool)){ + GetMlasPlatform().MlasSBGemmBatchOverride(TransA, TransB, M, N, K, Data, BatchN, ThreadPool)) { return; } #endif diff --git a/onnxruntime/core/mlas/lib/sbgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sbgemm_kernel_neon.cpp index a6a73996c548b..00abcb31e284f 100644 --- a/onnxruntime/core/mlas/lib/sbgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sbgemm_kernel_neon.cpp @@ -17,6 +17,10 @@ Module Name: #if defined(__aarch64__) && defined(__linux__) +#include +#include +#include + #include "arm_neon.h" #include "mlasi.h" #include "sbgemm.h" @@ -29,6 +33,25 @@ struct MLAS_SBGEMM_KERNEL_NEON { static constexpr MLAS_SBGEMM_STRIDES Strides{128, 128, 256}; // M:N:K }; +MLAS_FORCEINLINE size_t +MlasSBGemmKernelRows(size_t CountM) +{ + return CountM >= 8 ? 8 : CountM >= 4 ? 4 + : CountM >= 2 ? 2 + : 1; +} + +MLAS_FORCEINLINE void +MlasSBGemmCopyPackA(float* PackedA, const float* A, size_t lda, size_t Rows, size_t CountK, size_t AlignedCountK) +{ + for (size_t m = 0; m < Rows; ++m) { + float* dst = PackedA + m * AlignedCountK; + const float* src = A + m * lda; + std::memcpy(dst, src, CountK * sizeof(float)); + std::fill_n(dst + CountK, AlignedCountK - CountK, 0.0f); + } +} + bool MLASCALL MlasBf16AccelerationSupported() { @@ -338,12 +361,32 @@ template <> MLAS_FORCEINLINE void MlasSBGemmKernel(size_t CountM, size_t CountN, size_t CountK, const float* A, size_t lda, const bfloat16_t* B, float* C, size_t ldc, const float* Bias, const bool ZeroMode) { + constexpr size_t PackedK = MLAS_SBGEMM_KERNEL_NEON::PackedK; + const size_t AlignedCountK = (CountK + PackedK - 1) & ~(PackedK - 1); + const bool PackATail = AlignedCountK != CountK; + std::vector PackedA; + if (PackATail) { + PackedA.resize(MLAS_SBGEMM_KERNEL_NEON::KernelMaxM * AlignedCountK); + } + while (CountM > 0) { + const float* KernelA = A; + size_t KernelLda = lda; + size_t KernelCountK = CountK; + + if (PackATail) { + const size_t RowsToPack = MlasSBGemmKernelRows(CountM); + MlasSBGemmCopyPackA(PackedA.data(), A, lda, RowsToPack, CountK, AlignedCountK); + KernelA = PackedA.data(); + KernelLda = AlignedCountK; + KernelCountK = AlignedCountK; + } + size_t RowsHandled; if (ZeroMode) { - RowsHandled = MlasSbgemmKernelZero(A, B, C, CountK, CountM, CountN, lda, ldc, Bias); + RowsHandled = MlasSbgemmKernelZero(KernelA, B, C, KernelCountK, CountM, CountN, KernelLda, ldc, Bias); } else { - RowsHandled = MlasSbgemmKernelAdd(A, B, C, CountK, CountM, CountN, lda, ldc, Bias); + RowsHandled = MlasSbgemmKernelAdd(KernelA, B, C, KernelCountK, CountM, CountN, KernelLda, ldc, Bias); } C += ldc * RowsHandled; A += lda * RowsHandled; diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc index 4402a4b1526d0..6fa3e0d9a4827 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc @@ -244,23 +244,9 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, /*out*/ Alloc if (input_idx == 1) { size_t packed_b_size; #if defined(__aarch64__) && defined(__linux__) - size_t dim1 = 0; - size_t dim2 = 0; TensorShape b_shape = tensor.Shape(); - if (b_shape.NumDimensions() == 2) { - dim1 = static_cast(b_shape[0]); - dim2 = static_cast(b_shape[1]); - } - - const size_t k_dim = b_shape.NumDimensions() >= 2 - ? static_cast(b_shape[b_shape.NumDimensions() - 2]) - : dim1; - if (use_fastmath_mode_ && - (trans_a_attr_ == 0) && - (trans_b_attr_ == 0) && - ((k_dim * dim2) >= kFastMathModeKernelsizeThreshold) && - ((k_dim % kFastMathModeKAlignment) == 0)) { + if (CanPackBForFastMathModeSBGemm(b_shape)) { is_packed = GemmPackBBfloat16(alloc, tensor, trans_a_attr_ != 0, trans_b_attr_ != 0, packed_b_, packed_b_size, b_shape_, &mlas_backend_kernel_selector_config_); } else #endif @@ -330,10 +316,19 @@ Status MatMul::Compute(OpKernelContext* ctx) const { const size_t lda = helper.Lda(trans_a); const size_t ldb = helper.Ldb(trans_b); #if defined(__aarch64__) && defined(__linux__) - if (use_fastmath_mode_ && - !trans_a && !trans_b && - ((N * K) >= kFastMathModeKernelsizeThreshold) && - ((K % kFastMathModeKAlignment) == 0)) { + const bool can_use_fastmath_sbgemm = CanUseFastMathModeSBGemm(N, K); + if (packed_b_) { + const bool packed_b_can_use_fastmath_sbgemm = CanPackBForFastMathModeSBGemm(b_shape); + if (packed_b_can_use_fastmath_sbgemm) { + ORT_ENFORCE(K == static_cast(b_shape[0]), + "MatMul fastmath PrePack/Compute K mismatch: packed B K=", + b_shape[0], ", Compute K=", K); + } + ORT_ENFORCE(can_use_fastmath_sbgemm == packed_b_can_use_fastmath_sbgemm, + "MatMul fastmath PrePack/Compute eligibility mismatch."); + } + + if (can_use_fastmath_sbgemm) { std::vector data(max_len); for (size_t i = 0; i < max_len; i++) { data[i].BIsfp32 = !(bool(packed_b_)); diff --git a/onnxruntime/core/providers/cpu/math/matmul.h b/onnxruntime/core/providers/cpu/math/matmul.h index c3e93846f670f..7f7c06e8eedb0 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.h +++ b/onnxruntime/core/providers/cpu/math/matmul.h @@ -102,9 +102,19 @@ class MatMul final : public OpKernel { bool use_fastmath_mode_; // sbgemm kernel is implemented as 8x8 blocks with weights pre-packed to 4 blocks of 4x2 // so a minimum of 32 elements is defined to outweigh the additional prepacking overhead - // The NEON SBGemm kernel consumes A in 4-float groups. Keep K tails on SGEMM - const size_t kFastMathModeKAlignment = 4; const size_t kFastMathModeKernelsizeThreshold = 32; + + bool CanUseFastMathModeSBGemm(size_t n, size_t k) const { + return use_fastmath_mode_ && + (trans_a_attr_ == 0) && + (trans_b_attr_ == 0) && + ((n * k) >= kFastMathModeKernelsizeThreshold); + } + + bool CanPackBForFastMathModeSBGemm(const TensorShape& b_shape) const { + return b_shape.NumDimensions() == 2 && + CanUseFastMathModeSBGemm(static_cast(b_shape[1]), static_cast(b_shape[0])); + } #endif }; diff --git a/onnxruntime/test/mlas/unittest/test_sbgemm.cpp b/onnxruntime/test/mlas/unittest/test_sbgemm.cpp index 2a8e01b9dda3a..1a402ac72456a 100644 --- a/onnxruntime/test/mlas/unittest/test_sbgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sbgemm.cpp @@ -78,6 +78,7 @@ class SBGemmShortExecuteTest : public MlasTestFixture + template void SmallFloatFill(T* start, size_t size) { constexpr float MinimumFillValue = -11.0f; @@ -169,6 +171,7 @@ class MlasSBGemmTest : public MlasTestBase { void Test(size_t M, size_t N, size_t K, size_t BatchSize, bool withBias) { AType* A = BufferA.GetFilledBuffer(K * M * BatchSize + 16, SmallFloatFill); + std::fill_n(A + K * M * BatchSize, 16, std::numeric_limits::quiet_NaN()); AType Atail[16]; std::memcpy(Atail, A + K * M * BatchSize, 16 * sizeof(AType)); @@ -219,6 +222,7 @@ class MlasSBGemmTest : public MlasTestBase { void TestAccumulate(size_t M, size_t N, size_t K, size_t BatchSize) { AType* A = BufferA.GetFilledBuffer(K * M * BatchSize + 16, SmallFloatFill); + std::fill_n(A + K * M * BatchSize, 16, std::numeric_limits::quiet_NaN()); AType Atail[16]; std::memcpy(Atail, A + K * M * BatchSize, 16 * sizeof(AType)); diff --git a/onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc b/onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc index dc5b42a517ec8..18fb9dcdd23b8 100644 --- a/onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc +++ b/onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc @@ -12,6 +12,11 @@ #include "test/util/include/test_environment.h" #include "default_providers.h" +#include +#include +#include +#include + #if defined(__aarch64__) && defined(__linux__) namespace onnxruntime { @@ -170,62 +175,67 @@ void RunMatMulTest(int32_t opset_version) { RunMatMulTest(opset_version, false, false, false); } -TEST(MathOpTest, MatMulFloatTypeFastMathKTailFallsBackToSgemm) { +TEST(MathOpTest, MatMulFloatTypeFastMathKTailDoesNotReadPaddedA) { constexpr int64_t M = 1; constexpr int64_t N = 8; - constexpr int64_t K = 13; - OpTester test("MatMul", 7); + for (const int64_t K : std::array{13, 14, 15, 16, 17}) { + SCOPED_TRACE("K=" + std::to_string(K)); - std::vector input0_vals(K); - std::iota(input0_vals.begin(), input0_vals.end(), 1.0f); - test.AddInput("A", {M, K}, input0_vals); + OpTester test("MatMul", 7); - std::vector input1_vals(K * N, 1.0f); - test.AddInput("B", {K, N}, input1_vals, true); + std::vector input0_vals(K); + std::iota(input0_vals.begin(), input0_vals.end(), 1.0f); + test.AddInput("A", {M, K}, input0_vals); - std::vector expected_vals(N, 91.0f); - test.AddOutput("Y", {M, N}, expected_vals); + std::vector input1_vals(K * N, 1.0f); + test.AddInput("B", {K, N}, input1_vals, true); - Model& model = test.BuildModel(); - ASSERT_STATUS_OK(model.MainGraph().Resolve()); + std::vector expected_vals(N, std::accumulate(input0_vals.begin(), input0_vals.end(), 0.0f)); + test.AddOutput("Y", {M, N}, expected_vals); - std::string serialized_model; - ASSERT_TRUE(model.ToProto().SerializeToString(&serialized_model)); + Model& model = test.BuildModel(); + ASSERT_STATUS_OK(model.MainGraph().Resolve()); - SessionOptions so; - ASSERT_STATUS_OK(so.config_options.AddConfigEntry( - kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "1")); + std::string serialized_model; + ASSERT_TRUE(model.ToProto().SerializeToString(&serialized_model)); - InferenceSession session_object{so, GetEnvironment()}; - ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultCpuExecutionProvider())); + SessionOptions so; + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "1")); - std::stringstream model_stream(serialized_model); - ASSERT_STATUS_OK(session_object.Load(model_stream)); - ASSERT_STATUS_OK(session_object.Initialize()); + InferenceSession session_object{so, GetEnvironment()}; + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultCpuExecutionProvider())); - std::vector input0_backing(input0_vals.begin(), input0_vals.end()); - input0_backing.resize(K + 3, std::numeric_limits::quiet_NaN()); + std::stringstream model_stream(serialized_model); + ASSERT_STATUS_OK(session_object.Load(model_stream)); + ASSERT_STATUS_OK(session_object.Initialize()); - OrtValue input0; - Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape({M, K}), input0_backing.data(), - OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator), input0); + std::vector input0_backing(input0_vals.begin(), input0_vals.end()); + // Poison a full 4-float NEON load past the logical A row to verify that + // SBGemm K-tail handling does not consume overread values. + input0_backing.resize(K + 4, std::numeric_limits::quiet_NaN()); - NameMLValMap feeds; - feeds.insert(std::make_pair(std::string("A"), input0)); + OrtValue input0; + Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape({M, K}), input0_backing.data(), + OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator), input0); - std::vector fetches; - ASSERT_STATUS_OK(session_object.Run(RunOptions{}, feeds, AsSpan({std::string("Y")}), &fetches)); - ASSERT_EQ(fetches.size(), 1u); - ASSERT_TRUE(fetches[0].IsTensor()); + NameMLValMap feeds; + feeds.insert(std::make_pair(std::string("A"), input0)); - const auto& output_tensor = fetches[0].Get(); - ASSERT_EQ(output_tensor.Shape(), TensorShape({M, N})); + std::vector fetches; + ASSERT_STATUS_OK(session_object.Run(RunOptions{}, feeds, AsSpan({std::string("Y")}), &fetches)); + ASSERT_EQ(fetches.size(), 1u); + ASSERT_TRUE(fetches[0].IsTensor()); - const auto* output_data = output_tensor.Data(); - for (int64_t i = 0; i < N; ++i) { - ASSERT_TRUE(std::isfinite(output_data[i])) << "Output " << i << " should not include padded A tail values."; - ASSERT_EQ(output_data[i], expected_vals[i]) << "Output " << i; + const auto& output_tensor = fetches[0].Get(); + ASSERT_EQ(output_tensor.Shape(), TensorShape({M, N})); + + const auto* output_data = output_tensor.Data(); + for (int64_t i = 0; i < N; ++i) { + ASSERT_TRUE(std::isfinite(output_data[i])) << "Output " << i << " should not include padded A tail values."; + ASSERT_EQ(output_data[i], expected_vals[i]) << "Output " << i; + } } } From 610f377b89a35e42b409b985b8e89321e50d8778 Mon Sep 17 00:00:00 2001 From: Jonathan Clohessy Date: Tue, 2 Jun 2026 22:42:08 +0100 Subject: [PATCH 5/6] Rework test to not use optester Signed-off-by: Jonathan Clohessy --- .../cpu/math/matmul_fastmath_test.cc | 40 +++++++++++++++++-- 1 file changed, 37 insertions(+), 3 deletions(-) diff --git a/onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc b/onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc index 18fb9dcdd23b8..144569f35235b 100644 --- a/onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc +++ b/onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc @@ -194,11 +194,45 @@ TEST(MathOpTest, MatMulFloatTypeFastMathKTailDoesNotReadPaddedA) { std::vector expected_vals(N, std::accumulate(input0_vals.begin(), input0_vals.end(), 0.0f)); test.AddOutput("Y", {M, N}, expected_vals); - Model& model = test.BuildModel(); - ASSERT_STATUS_OK(model.MainGraph().Resolve()); + ONNX_NAMESPACE::ModelProto model; + model.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); + model.add_opset_import()->set_version(7); + + auto* graph = model.mutable_graph(); + graph->set_name("MatMulKTailGraph"); + + auto* input = graph->add_input(); + input->set_name("A"); + auto* input_tensor_type = input->mutable_type()->mutable_tensor_type(); + input_tensor_type->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + input_tensor_type->mutable_shape()->add_dim()->set_dim_value(M); + input_tensor_type->mutable_shape()->add_dim()->set_dim_value(K); + + auto* output = graph->add_output(); + output->set_name("Y"); + auto* output_tensor_type = output->mutable_type()->mutable_tensor_type(); + output_tensor_type->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + output_tensor_type->mutable_shape()->add_dim()->set_dim_value(M); + output_tensor_type->mutable_shape()->add_dim()->set_dim_value(N); + + auto* b_initializer = graph->add_initializer(); + b_initializer->set_name("B"); + b_initializer->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + b_initializer->add_dims(K); + b_initializer->add_dims(N); + for (float value : input1_vals) { + b_initializer->add_float_data(value); + } + + auto* matmul_node = graph->add_node(); + matmul_node->set_name("MatMulKTail"); + matmul_node->set_op_type("MatMul"); + matmul_node->add_input("A"); + matmul_node->add_input("B"); + matmul_node->add_output("Y"); std::string serialized_model; - ASSERT_TRUE(model.ToProto().SerializeToString(&serialized_model)); + ASSERT_TRUE(model.SerializeToString(&serialized_model)); SessionOptions so; ASSERT_STATUS_OK(so.config_options.AddConfigEntry( From 4c4811a1ecf27fe3c1037a7df20a8f96fab7192b Mon Sep 17 00:00:00 2001 From: Jonathan Clohessy Date: Wed, 3 Jun 2026 11:45:58 +0100 Subject: [PATCH 6/6] Remove Dead code from test Signed-off-by: Jonathan Clohessy --- onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc | 5 ----- 1 file changed, 5 deletions(-) diff --git a/onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc b/onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc index 144569f35235b..70ea9bb0579b9 100644 --- a/onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc +++ b/onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc @@ -182,17 +182,12 @@ TEST(MathOpTest, MatMulFloatTypeFastMathKTailDoesNotReadPaddedA) { for (const int64_t K : std::array{13, 14, 15, 16, 17}) { SCOPED_TRACE("K=" + std::to_string(K)); - OpTester test("MatMul", 7); - std::vector input0_vals(K); std::iota(input0_vals.begin(), input0_vals.end(), 1.0f); - test.AddInput("A", {M, K}, input0_vals); std::vector input1_vals(K * N, 1.0f); - test.AddInput("B", {K, N}, input1_vals, true); std::vector expected_vals(N, std::accumulate(input0_vals.begin(), input0_vals.end(), 0.0f)); - test.AddOutput("Y", {M, N}, expected_vals); ONNX_NAMESPACE::ModelProto model; model.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);