[CPU] Add fix for fastmath corner case with failing tests via sbgemm_neon_kernel#28394
[CPU] Add fix for fastmath corner case with failing tests via sbgemm_neon_kernel#28394JonathanC-ARM wants to merge 6 commits into
Conversation
…ernel Signed-off-by: Jonathan Clohessy <Jonathan.Clohessy@arm.com>
…correctly Signed-off-by: Jonathan Clohessy <Jonathan.Clohessy@arm.com>
There was a problem hiding this comment.
Pull request overview
This PR addresses sporadic ARM64 Linux test failures when KleidiAI is enabled without SVE by preventing the BF16 fastmath (NEON SBGemm) MatMul path from running when the shared dimension K is not a multiple of 4, avoiding unsafe tail overreads and NaN propagation.
Changes:
- Gate the ARM64 BF16 fastmath path (SBGemm) on
K % 4 == 0in both pre-packing and compute. - Add an ARM64/Linux regression test that pads the A-buffer tail with NaNs to ensure K-tail cases don’t leak invalid values into outputs.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc |
Adds a regression test that detects NaN propagation from K-tail overreads when fastmath is enabled. |
onnxruntime/core/providers/cpu/math/matmul.h |
Introduces a K-alignment constant (4) for the ARM64 fastmath SBGemm gating logic. |
onnxruntime/core/providers/cpu/math/matmul.cc |
Applies K % 4 == 0 checks to avoid using SBGemm (and BF16 B prepacking) for misaligned K-tail cases. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| feeds.insert(std::make_pair(std::string("A"), input0)); | ||
|
|
||
| std::vector<OrtValue> fetches; | ||
| ASSERT_STATUS_OK(session_object.Run(RunOptions{}, feeds, AsSpan({std::string("Y")}), &fetches)); |
| #include "core/session/onnxruntime_session_options_config_keys.h" | ||
| #include "gtest/gtest.h" | ||
| #include "test/providers/provider_test_utils.h" | ||
| #include "core/session/inference_session.h" | ||
| #include "test/common/dnnl_op_test_utils.h" |
|
Verdict: Correct and safe targeted fix; merge-ready modulo a small comment/scope concern. The underlying NEON SBGemm kernel has a contract bug (declares Things worth saying on the PR
Reviewer recommendation LGTM for merge after addressing item 5 (test K+3 sizing fragility) and item 2 (Compute/PrePack consistency assert). Items 1, 3, 4, 6, 7 are reasonable follow-ups but not blockers. |
Description
This change fixes sporadic
onnxruntime_test_allfailures observed when building with KleidiAI enabled and --no_sve.In this configuration, some MatMul cases use the ARM64 BF16 fastmath path backed by the NEON SBGemm kernel (non kleidiai).
That kernel consumes A in groups of 4 floats. For K-tail cases such as K=13, the final block contains only one valid A value, but the kernel can still read up to three additional floats beyond the logical A row.
B is already packed/padded for this path, but A is not. If the overread A values contain NaN or otherwise invalid data, the result can diverge because NaN * 0 still propagates NaN.
This change avoids the unsafe NEON SBGemm fastmath path when K is not a multiple of 4. Those K-tail cases fall back to the existing SGEMM path, while aligned-K cases continue to use BF16 fastmath.
Repro
Build with KleidiAI enabled and --no_sve, then run:
Example failure:
Validation
Validated with:
Also validated the previously failing QDQ MatMul tests over repeated runs.
Added a regression test to onnxruntime_provider_test which guards against this particular matmul corner case