Skip to content

[CPU] Add fix for fastmath corner case with failing tests via sbgemm_neon_kernel#28394

Open
JonathanC-ARM wants to merge 6 commits into
microsoft:mainfrom
JonathanC-ARM:jonclo01/sbgemm_neon_gating
Open

[CPU] Add fix for fastmath corner case with failing tests via sbgemm_neon_kernel#28394
JonathanC-ARM wants to merge 6 commits into
microsoft:mainfrom
JonathanC-ARM:jonclo01/sbgemm_neon_gating

Conversation

@JonathanC-ARM
Copy link
Copy Markdown
Contributor

Description

This change fixes sporadic onnxruntime_test_all failures observed when building with KleidiAI enabled and --no_sve.

In this configuration, some MatMul cases use the ARM64 BF16 fastmath path backed by the NEON SBGemm kernel (non kleidiai).
That kernel consumes A in groups of 4 floats. For K-tail cases such as K=13, the final block contains only one valid A value, but the kernel can still read up to three additional floats beyond the logical A row.
B is already packed/padded for this path, but A is not. If the overread A values contain NaN or otherwise invalid data, the result can diverge because NaN * 0 still propagates NaN.

This change avoids the unsafe NEON SBGemm fastmath path when K is not a multiple of 4. Those K-tail cases fall back to the existing SGEMM path, while aligned-K cases continue to use BF16 fastmath.
Repro

Build with KleidiAI enabled and --no_sve, then run:

./onnxruntime_test_all \
  --gtest_random_seed=2345 \
  --gtest_brief=1 \
  --gtest_filter="QDQTransformerTests*" 

Example failure:

[  FAILED  ] QDQTransformerTests.DQMatMulPerTensorWithBlockSizeOption
expected -146.433, got -146.436, diff: 0.00294495, tol=0.00147433

Validation

Validated with:

./onnxruntime_test_all \
  --gtest_random_seed=2345 \
  --gtest_filter="*QDQ*" \
  --gtest_repeat=10 \
  --gtest_break_on_failure \
  --gtest_brief=1

Also validated the previously failing QDQ MatMul tests over repeated runs.

Added a regression test to onnxruntime_provider_test which guards against this particular matmul corner case

./onnxruntime_provider_test --gtest_filter="MathOpTest.MatMulFloatTypeFastMathKTailFallsBackToSgemm" 
[ RUN      ] MathOpTest.MatMulFloatTypeFastMathKTailFallsBackToSgemm
[symbolize_elf.inc : 378] RAW: Unable to get high fd: rc=0, limit=1024
/home/jonclo01/kfi-devenv/repos/onnxruntime/onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc:230: Failure
Value of: std::isfinite(output_data[i])
  Actual: false
Expected: true
Output 0 should not include padded A tail values.
Stack trace:
  0xaaaaab2f8428: onnxruntime::test::MathOpTest_MatMulFloatTypeFastMathKTailFallsBackToSgemm_Test::TestBody()
  0xaaaaac946c30: testing::internal::HandleExceptionsInMethodIfSupported<>()
  0xaaaaac947048: testing::Test::Run()
  0xaaaaac9474e0: testing::TestInfo::Run()
... Google Test internal frames ...

[  FAILED  ] MathOpTest.MatMulFloatTypeFastMathKTailFallsBackToSgemm (13 ms)
[----------] 1 test from MathOpTest (13 ms total)

[----------] Global test environment tear-down
[==========] 1 test from 1 test suite ran. (14 ms total)
[  PASSED  ] 0 tests.
[  FAILED  ] 1 test, listed below:
[  FAILED  ] MathOpTest.MatMulFloatTypeFastMathKTailFallsBackToSgemm

…ernel

Signed-off-by: Jonathan Clohessy <Jonathan.Clohessy@arm.com>
…correctly

Signed-off-by: Jonathan Clohessy <Jonathan.Clohessy@arm.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR addresses sporadic ARM64 Linux test failures when KleidiAI is enabled without SVE by preventing the BF16 fastmath (NEON SBGemm) MatMul path from running when the shared dimension K is not a multiple of 4, avoiding unsafe tail overreads and NaN propagation.

Changes:

  • Gate the ARM64 BF16 fastmath path (SBGemm) on K % 4 == 0 in both pre-packing and compute.
  • Add an ARM64/Linux regression test that pads the A-buffer tail with NaNs to ensure K-tail cases don’t leak invalid values into outputs.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated no comments.

File Description
onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc Adds a regression test that detects NaN propagation from K-tail overreads when fastmath is enabled.
onnxruntime/core/providers/cpu/math/matmul.h Introduces a K-alignment constant (4) for the ARM64 fastmath SBGemm gating logic.
onnxruntime/core/providers/cpu/math/matmul.cc Applies K % 4 == 0 checks to avoid using SBGemm (and BF16 B prepacking) for misaligned K-tail cases.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.

feeds.insert(std::make_pair(std::string("A"), input0));

std::vector<OrtValue> fetches;
ASSERT_STATUS_OK(session_object.Run(RunOptions{}, feeds, AsSpan({std::string("Y")}), &fetches));
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 3 out of 3 changed files in this pull request and generated 1 comment.

Comment on lines 5 to 9
#include "core/session/onnxruntime_session_options_config_keys.h"
#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"
#include "core/session/inference_session.h"
#include "test/common/dnnl_op_test_utils.h"
@hariharans29 hariharans29 changed the title Add fix for fastmath corner case with failing tests via sbgemm_neon_kernel [CPU] Add fix for fastmath corner case with failing tests via sbgemm_neon_kernel May 26, 2026
@hariharans29
Copy link
Copy Markdown
Member

Verdict: Correct and safe targeted fix; merge-ready modulo a small comment/scope concern. The underlying NEON SBGemm kernel has a contract bug (declares BufferOverReadBytes = 32 in MlasSBGemmDispatchNeon, sbgemm_kernel_neon.cpp:360, so callers are expected to back A with safe trailing bytes — but MatMul<float>::Compute hands the raw input tensor straight in), and this PR works around it by avoiding the path entirely when K is not a multiple of the kernel's A-stride. That's the right short-term call.

Things worth saying on the PR

  1. The check belongs in MLAS, not in the op kernel. Every caller of MlasSBGemm (today: this MatMul; tomorrow: anyone else who finds the API) is going to hit the same trap. The clean long-term fix is one of:

    • Mask the tail vld1q_f32 in the SBGemm NEON kernel with vandq_u32 / a per-tail load helper so the kernel is self-contained for any K.
    • Have MlasSBGemmConvertCopyPackB / a new MlasSBGemmPackA copy A row-by-row with the K-tail zero-padded, so the kernel-declared BufferOverReadBytes contract is satisfied internally.

    Either approach removes the perf loss for K=13/15/... shapes (which today silently regress to a plain SGEMM). Could be a follow-up issue tracked from this PR; please file one if you're not picking it up immediately.

  2. PrePack and Compute checks must agree, and they do — but please add an ORT_ENFORCE or assert. PrePack gates on b_shape[ndims-2] and Compute gates on helper.K(). They are guaranteed equal at runtime, but if any future MatMul-helper change ever desynced them you'd hit the worst case: B packed to bf16 in PrePack, then Compute decides to use the SGEMM path and feeds a bf16-packed B into a float kernel → silent garbage. A one-line ORT_ENFORCE(K == k_dim_from_packed_b) (or just assert) at the top of Compute would catch any future regression cheaply.

  3. Width of the gating predicate. The PR gates on K % 4 == 0. The kernel name says "4-float groups", which is correct for the NEON path used in --no_sve builds. If the SVE path or a future BF16 micro-kernel ever requires a different alignment (e.g., 8 for SVE 256-bit, or 2 for BF8×2), this constant becomes wrong silently. Two small changes would future-proof it:

    • Source kFastMathModeKAlignment from the dispatch struct (e.g., add AAlignment next to BufferOverReadBytes) rather than from a header-local constant.
    • Rename the constant to kSBGemmNeonAAlignment to make the kernel-specificity obvious.
  4. Test scope. The regression test only covers M=1, N=8, K=13. That's sufficient to prove the path falls back, but it would be nice to parametrize over K ∈ {13, 14, 15} (and at minimum one K ∈ {16, 17} to confirm K=16 still takes the fast path and K=17 falls back). A for (int64_t K : {13, 14, 15, 17}) loop inside the test body is enough.

  5. Test correctness nits (one is a real issue):

    • Real: input0_backing.resize(K + 3, NaN) is too small to fully demonstrate the bug across all K. The kernel may overread up to 4 floats (one vld1q_f32) past column K-K%4, which for K=13 is positions 13/14/15 — K+3 is exactly the minimum. For K=14 you'd need K+2, for K=15 K+1, for K=13 K+3. Safer to pad to K + 4 (or even align to 16 bytes) and document why.
    • You compute expected_vals as 91.0f (= sum of 1..13). Slightly more robust: derive it programmatically (std::accumulate(input0_vals.begin(), input0_vals.end(), 0.0f)) so the test stays correct if anyone tweaks the input.
    • The Copilot suggestion to ASSERT_EQ(fetches.size(), 1u) before fetches[0] is fine; the include/<numeric> / <limits> / <cmath> comment is also fine. Both are no-cost.
  6. PrePack guard against non-2D B. The b_shape.NumDimensions() >= 2 ? b_shape[ndims-2] : dim1 ternary papers over the case where B is 1-D, but in that case dim1 is N (not K), so the gate's k_dim % 4 == 0 check is meaningless. ONNX MatMul doesn't allow B to be 1-D after broadcast resolution in PrePack (where shape is already at least 2D for the packing path), but please either drop the ternary and just use b_shape[ndims-2] with an assert that ndims >= 2, or remove the fallback path entirely. Reading the current code makes you wonder when the fallback fires.

  7. Cosmetics:

    • kFastMathModeKAlignment comment in matmul.h reads "consumes A in 4-float groups. Keep K tails on SGEMM" — please add the root cause ("kernel can overread A by up to 3 floats; tail values may be NaN and propagate through bf16 mul") so future readers know why the fix exists rather than just what it does.
    • In matmul.cc, the new multi-line conditions might trip clang-format — please run lintrunner.

Reviewer recommendation

LGTM for merge after addressing item 5 (test K+3 sizing fragility) and item 2 (Compute/PrePack consistency assert). Items 1, 3, 4, 6, 7 are reasonable follow-ups but not blockers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants