diff --git a/onnxruntime/core/mlas/lib/sbgemm.h b/onnxruntime/core/mlas/lib/sbgemm.h index 35520c7a4226a..99e3912910b16 100644 --- a/onnxruntime/core/mlas/lib/sbgemm.h +++ b/onnxruntime/core/mlas/lib/sbgemm.h @@ -134,7 +134,8 @@ MlasSBGemmPackedOperation(size_t M, size_t RangeStartN, size_t RangeCountN, size bool ZeroMode = (k == 0) && InitialZeroMode; CountK = std::min(K - k, PackedStrideK); - const bfloat16_t* pb = (const bfloat16_t*)PackedB + AlignedN * k + CountK * SliceStartN; + const size_t AlignedCountK = (CountK + KernelType::PackedK - 1) & ~(KernelType::PackedK - 1); + const bfloat16_t* pb = (const bfloat16_t*)PackedB + AlignedN * k + AlignedCountK * SliceStartN; float* c = C + n; const float* pbias = ((nullptr == Bias) ? nullptr : Bias + RangeStartN + n); MlasSBGemmKernel(M, CountN, CountK, A + k, lda, pb, c, ldc, ZeroMode ? pbias : nullptr, ZeroMode); @@ -319,7 +320,7 @@ MlasSBGemmPackBSize( BIsfp32) { size_t bytes_required; bytes_required = GetMlasPlatform().MlasSBGemmPackBSizeOverride(TransA, TransB, N, K); - if (bytes_required != 0){ // If ArmKleidiAI::MlasSBGemmPackBSize ran to completion + if (bytes_required != 0) { // If ArmKleidiAI::MlasSBGemmPackBSize ran to completion return bytes_required; } } @@ -365,7 +366,7 @@ MlasSBGemmConvertPackB( TransA == CBLAS_TRANSPOSE::CblasNoTrans && TransB == CBLAS_TRANSPOSE::CblasNoTrans && BIsfp32 && - GetMlasPlatform().MlasSBGemmPackBOverride(TransA, TransB, N, K, B, ldb, PackedB)){ + GetMlasPlatform().MlasSBGemmPackBOverride(TransA, TransB, N, K, B, ldb, PackedB)) { return; } #endif @@ -400,7 +401,7 @@ MlasSBGemmBatch( TransB == CBLAS_TRANSPOSE::CblasNoTrans && Data->AIsfp32 && (Data->BIsPacked || Data->BIsfp32) && - GetMlasPlatform().MlasSBGemmBatchOverride(TransA, TransB, M, N, K, Data, BatchN, ThreadPool)){ + GetMlasPlatform().MlasSBGemmBatchOverride(TransA, TransB, M, N, K, Data, BatchN, ThreadPool)) { return; } #endif diff --git a/onnxruntime/core/mlas/lib/sbgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sbgemm_kernel_neon.cpp index a6a73996c548b..00abcb31e284f 100644 --- a/onnxruntime/core/mlas/lib/sbgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sbgemm_kernel_neon.cpp @@ -17,6 +17,10 @@ Module Name: #if defined(__aarch64__) && defined(__linux__) +#include +#include +#include + #include "arm_neon.h" #include "mlasi.h" #include "sbgemm.h" @@ -29,6 +33,25 @@ struct MLAS_SBGEMM_KERNEL_NEON { static constexpr MLAS_SBGEMM_STRIDES Strides{128, 128, 256}; // M:N:K }; +MLAS_FORCEINLINE size_t +MlasSBGemmKernelRows(size_t CountM) +{ + return CountM >= 8 ? 8 : CountM >= 4 ? 4 + : CountM >= 2 ? 2 + : 1; +} + +MLAS_FORCEINLINE void +MlasSBGemmCopyPackA(float* PackedA, const float* A, size_t lda, size_t Rows, size_t CountK, size_t AlignedCountK) +{ + for (size_t m = 0; m < Rows; ++m) { + float* dst = PackedA + m * AlignedCountK; + const float* src = A + m * lda; + std::memcpy(dst, src, CountK * sizeof(float)); + std::fill_n(dst + CountK, AlignedCountK - CountK, 0.0f); + } +} + bool MLASCALL MlasBf16AccelerationSupported() { @@ -338,12 +361,32 @@ template <> MLAS_FORCEINLINE void MlasSBGemmKernel(size_t CountM, size_t CountN, size_t CountK, const float* A, size_t lda, const bfloat16_t* B, float* C, size_t ldc, const float* Bias, const bool ZeroMode) { + constexpr size_t PackedK = MLAS_SBGEMM_KERNEL_NEON::PackedK; + const size_t AlignedCountK = (CountK + PackedK - 1) & ~(PackedK - 1); + const bool PackATail = AlignedCountK != CountK; + std::vector PackedA; + if (PackATail) { + PackedA.resize(MLAS_SBGEMM_KERNEL_NEON::KernelMaxM * AlignedCountK); + } + while (CountM > 0) { + const float* KernelA = A; + size_t KernelLda = lda; + size_t KernelCountK = CountK; + + if (PackATail) { + const size_t RowsToPack = MlasSBGemmKernelRows(CountM); + MlasSBGemmCopyPackA(PackedA.data(), A, lda, RowsToPack, CountK, AlignedCountK); + KernelA = PackedA.data(); + KernelLda = AlignedCountK; + KernelCountK = AlignedCountK; + } + size_t RowsHandled; if (ZeroMode) { - RowsHandled = MlasSbgemmKernelZero(A, B, C, CountK, CountM, CountN, lda, ldc, Bias); + RowsHandled = MlasSbgemmKernelZero(KernelA, B, C, KernelCountK, CountM, CountN, KernelLda, ldc, Bias); } else { - RowsHandled = MlasSbgemmKernelAdd(A, B, C, CountK, CountM, CountN, lda, ldc, Bias); + RowsHandled = MlasSbgemmKernelAdd(KernelA, B, C, KernelCountK, CountM, CountN, KernelLda, ldc, Bias); } C += ldc * RowsHandled; A += lda * RowsHandled; diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc index ca91f46db93da..6fa3e0d9a4827 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc @@ -244,16 +244,9 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, /*out*/ Alloc if (input_idx == 1) { size_t packed_b_size; #if defined(__aarch64__) && defined(__linux__) - size_t dim1 = 0; - size_t dim2 = 0; TensorShape b_shape = tensor.Shape(); - if (b_shape.NumDimensions() == 2) { - dim1 = static_cast(b_shape[0]); - dim2 = static_cast(b_shape[1]); - } - - if (use_fastmath_mode_ && (trans_a_attr_ == 0) && (trans_b_attr_ == 0) && ((dim1 * dim2) >= kFastMathModeKernelsizeThreshold)) { + if (CanPackBForFastMathModeSBGemm(b_shape)) { is_packed = GemmPackBBfloat16(alloc, tensor, trans_a_attr_ != 0, trans_b_attr_ != 0, packed_b_, packed_b_size, b_shape_, &mlas_backend_kernel_selector_config_); } else #endif @@ -323,7 +316,19 @@ Status MatMul::Compute(OpKernelContext* ctx) const { const size_t lda = helper.Lda(trans_a); const size_t ldb = helper.Ldb(trans_b); #if defined(__aarch64__) && defined(__linux__) - if (use_fastmath_mode_ && !trans_a && !trans_b && ((N * K) >= kFastMathModeKernelsizeThreshold)) { + const bool can_use_fastmath_sbgemm = CanUseFastMathModeSBGemm(N, K); + if (packed_b_) { + const bool packed_b_can_use_fastmath_sbgemm = CanPackBForFastMathModeSBGemm(b_shape); + if (packed_b_can_use_fastmath_sbgemm) { + ORT_ENFORCE(K == static_cast(b_shape[0]), + "MatMul fastmath PrePack/Compute K mismatch: packed B K=", + b_shape[0], ", Compute K=", K); + } + ORT_ENFORCE(can_use_fastmath_sbgemm == packed_b_can_use_fastmath_sbgemm, + "MatMul fastmath PrePack/Compute eligibility mismatch."); + } + + if (can_use_fastmath_sbgemm) { std::vector data(max_len); for (size_t i = 0; i < max_len; i++) { data[i].BIsfp32 = !(bool(packed_b_)); diff --git a/onnxruntime/core/providers/cpu/math/matmul.h b/onnxruntime/core/providers/cpu/math/matmul.h index d1c0df19f924e..7f7c06e8eedb0 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.h +++ b/onnxruntime/core/providers/cpu/math/matmul.h @@ -103,6 +103,18 @@ class MatMul final : public OpKernel { // sbgemm kernel is implemented as 8x8 blocks with weights pre-packed to 4 blocks of 4x2 // so a minimum of 32 elements is defined to outweigh the additional prepacking overhead const size_t kFastMathModeKernelsizeThreshold = 32; + + bool CanUseFastMathModeSBGemm(size_t n, size_t k) const { + return use_fastmath_mode_ && + (trans_a_attr_ == 0) && + (trans_b_attr_ == 0) && + ((n * k) >= kFastMathModeKernelsizeThreshold); + } + + bool CanPackBForFastMathModeSBGemm(const TensorShape& b_shape) const { + return b_shape.NumDimensions() == 2 && + CanUseFastMathModeSBGemm(static_cast(b_shape[1]), static_cast(b_shape[0])); + } #endif }; diff --git a/onnxruntime/test/mlas/unittest/test_sbgemm.cpp b/onnxruntime/test/mlas/unittest/test_sbgemm.cpp index 2a8e01b9dda3a..1a402ac72456a 100644 --- a/onnxruntime/test/mlas/unittest/test_sbgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sbgemm.cpp @@ -78,6 +78,7 @@ class SBGemmShortExecuteTest : public MlasTestFixture + template void SmallFloatFill(T* start, size_t size) { constexpr float MinimumFillValue = -11.0f; @@ -169,6 +171,7 @@ class MlasSBGemmTest : public MlasTestBase { void Test(size_t M, size_t N, size_t K, size_t BatchSize, bool withBias) { AType* A = BufferA.GetFilledBuffer(K * M * BatchSize + 16, SmallFloatFill); + std::fill_n(A + K * M * BatchSize, 16, std::numeric_limits::quiet_NaN()); AType Atail[16]; std::memcpy(Atail, A + K * M * BatchSize, 16 * sizeof(AType)); @@ -219,6 +222,7 @@ class MlasSBGemmTest : public MlasTestBase { void TestAccumulate(size_t M, size_t N, size_t K, size_t BatchSize) { AType* A = BufferA.GetFilledBuffer(K * M * BatchSize + 16, SmallFloatFill); + std::fill_n(A + K * M * BatchSize, 16, std::numeric_limits::quiet_NaN()); AType Atail[16]; std::memcpy(Atail, A + K * M * BatchSize, 16 * sizeof(AType)); diff --git a/onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc b/onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc index 8e6ff1f387bf1..70ea9bb0579b9 100644 --- a/onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc +++ b/onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc @@ -5,11 +5,18 @@ #include "core/session/onnxruntime_session_options_config_keys.h" #include "gtest/gtest.h" #include "test/providers/provider_test_utils.h" +#include "core/session/inference_session.h" #include "test/common/dnnl_op_test_utils.h" #include "test/common/cuda_op_test_utils.h" #include "test/common/tensor_op_test_utils.h" +#include "test/util/include/test_environment.h" #include "default_providers.h" +#include +#include +#include +#include + #if defined(__aarch64__) && defined(__linux__) namespace onnxruntime { @@ -168,6 +175,99 @@ void RunMatMulTest(int32_t opset_version) { RunMatMulTest(opset_version, false, false, false); } +TEST(MathOpTest, MatMulFloatTypeFastMathKTailDoesNotReadPaddedA) { + constexpr int64_t M = 1; + constexpr int64_t N = 8; + + for (const int64_t K : std::array{13, 14, 15, 16, 17}) { + SCOPED_TRACE("K=" + std::to_string(K)); + + std::vector input0_vals(K); + std::iota(input0_vals.begin(), input0_vals.end(), 1.0f); + + std::vector input1_vals(K * N, 1.0f); + + std::vector expected_vals(N, std::accumulate(input0_vals.begin(), input0_vals.end(), 0.0f)); + + ONNX_NAMESPACE::ModelProto model; + model.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); + model.add_opset_import()->set_version(7); + + auto* graph = model.mutable_graph(); + graph->set_name("MatMulKTailGraph"); + + auto* input = graph->add_input(); + input->set_name("A"); + auto* input_tensor_type = input->mutable_type()->mutable_tensor_type(); + input_tensor_type->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + input_tensor_type->mutable_shape()->add_dim()->set_dim_value(M); + input_tensor_type->mutable_shape()->add_dim()->set_dim_value(K); + + auto* output = graph->add_output(); + output->set_name("Y"); + auto* output_tensor_type = output->mutable_type()->mutable_tensor_type(); + output_tensor_type->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + output_tensor_type->mutable_shape()->add_dim()->set_dim_value(M); + output_tensor_type->mutable_shape()->add_dim()->set_dim_value(N); + + auto* b_initializer = graph->add_initializer(); + b_initializer->set_name("B"); + b_initializer->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + b_initializer->add_dims(K); + b_initializer->add_dims(N); + for (float value : input1_vals) { + b_initializer->add_float_data(value); + } + + auto* matmul_node = graph->add_node(); + matmul_node->set_name("MatMulKTail"); + matmul_node->set_op_type("MatMul"); + matmul_node->add_input("A"); + matmul_node->add_input("B"); + matmul_node->add_output("Y"); + + std::string serialized_model; + ASSERT_TRUE(model.SerializeToString(&serialized_model)); + + SessionOptions so; + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "1")); + + InferenceSession session_object{so, GetEnvironment()}; + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(DefaultCpuExecutionProvider())); + + std::stringstream model_stream(serialized_model); + ASSERT_STATUS_OK(session_object.Load(model_stream)); + ASSERT_STATUS_OK(session_object.Initialize()); + + std::vector input0_backing(input0_vals.begin(), input0_vals.end()); + // Poison a full 4-float NEON load past the logical A row to verify that + // SBGemm K-tail handling does not consume overread values. + input0_backing.resize(K + 4, std::numeric_limits::quiet_NaN()); + + OrtValue input0; + Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape({M, K}), input0_backing.data(), + OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator), input0); + + NameMLValMap feeds; + feeds.insert(std::make_pair(std::string("A"), input0)); + + std::vector fetches; + ASSERT_STATUS_OK(session_object.Run(RunOptions{}, feeds, AsSpan({std::string("Y")}), &fetches)); + ASSERT_EQ(fetches.size(), 1u); + ASSERT_TRUE(fetches[0].IsTensor()); + + const auto& output_tensor = fetches[0].Get(); + ASSERT_EQ(output_tensor.Shape(), TensorShape({M, N})); + + const auto* output_data = output_tensor.Data(); + for (int64_t i = 0; i < N; ++i) { + ASSERT_TRUE(std::isfinite(output_data[i])) << "Output " << i << " should not include padded A tail values."; + ASSERT_EQ(output_data[i], expected_vals[i]) << "Output " << i; + } + } +} + TEST(MathOpTest, MatMulFloatType_FastMath) { // TODO: Unskip when fixed #41968513 if (DefaultDmlExecutionProvider().get() != nullptr) {