diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 2e9d46656b514..dcfad53c47e4b 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -835,6 +835,13 @@ bool GemmNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& n return true; } + // When bias is present, QGemm folds bias into the int32 accumulator before + // applying the alpha*sa*sb output scale, which would incorrectly scale the + // bias by alpha. Require alpha==1 and beta==1 so the fused path is exact. + if (node.GetAttributes().at("alpha").f() != 1.0) { + return false; + } + if (node.GetAttributes().at("beta").f() != 1.0) { // beta needs to be 1.0 return false; } diff --git a/onnxruntime/test/optimizer/qdq_transformer_fastmath_test.cc b/onnxruntime/test/optimizer/qdq_transformer_fastmath_test.cc index 55f1d212a8034..bb319b785218e 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_fastmath_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_fastmath_test.cc @@ -323,7 +323,8 @@ TEST(QDQTransformerTests, MatMul_S8S8U8_DisableFastMath) { } template -void QDQTransformerGemmTests(bool has_output_q, bool has_bias, bool beta_not_one = false, bool disable_fastmath = false) { +void QDQTransformerGemmTests(bool has_output_q, bool has_bias, bool beta_not_one = false, + bool disable_fastmath = false, bool alpha_not_one = false) { auto test_case = [&](const std::vector& input1_shape, const std::vector& input2_shape, bool use_contrib_qdq = false) { auto build_test_case = [&](ModelTestBuilder& builder) { @@ -396,12 +397,17 @@ void QDQTransformerGemmTests(bool has_output_q, bool has_bias, bool beta_not_one if (beta_not_one) { gemm_node->AddAttribute("beta", 2.0f); } + + if (alpha_not_one) { + gemm_node->AddAttribute("alpha", 2.0f); + } }; auto check_binary_op_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); - if ((!has_output_q || std::is_same_v) && (!has_bias || (std::is_same_v && !beta_not_one)) && + if ((!has_output_q || std::is_same_v) && + (!has_bias || (std::is_same_v && !beta_not_one && !alpha_not_one)) && (std::is_same_v || std::is_same_v)) { EXPECT_EQ(op_to_count["com.microsoft.QGemm"], 1); EXPECT_EQ(op_to_count["Gemm"], 0); @@ -490,6 +496,10 @@ void QDQTransformerGemmTests() { QDQTransformerGemmTests(false, true, true); QDQTransformerGemmTests(true, false, true); QDQTransformerGemmTests(true, true, true); + QDQTransformerGemmTests(false, false, false, false, true); + QDQTransformerGemmTests(false, true, false, false, true); + QDQTransformerGemmTests(true, false, false, false, true); + QDQTransformerGemmTests(true, true, false, false, true); // dummy test to disable the fastmath session QDQTransformerGemmTests(true, true, true, true); } diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index bdbd2c488584d..2e5c5a8f71be9 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -718,7 +718,8 @@ TEST(QDQTransformerTests, MatMul_S8S8U8) { } template -void QDQTransformerGemmTests(bool has_output_q, bool has_bias, bool beta_not_one = false) { +void QDQTransformerGemmTests(bool has_output_q, bool has_bias, bool beta_not_one = false, + bool alpha_not_one = false) { auto test_case = [&](const std::vector& input1_shape, const std::vector& input2_shape, bool use_contrib_qdq = false) { auto build_test_case = [&](ModelTestBuilder& builder) { @@ -791,12 +792,17 @@ void QDQTransformerGemmTests(bool has_output_q, bool has_bias, bool beta_not_one if (beta_not_one) { gemm_node->AddAttribute("beta", 2.0f); } + + if (alpha_not_one) { + gemm_node->AddAttribute("alpha", 2.0f); + } }; auto check_binary_op_graph = [&](InferenceSessionWrapper& session) { auto op_to_count = CountOpsInGraph(session.GetGraph()); const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); - if ((!has_output_q || std::is_same_v) && (!has_bias || (std::is_same_v && !beta_not_one)) && + if ((!has_output_q || std::is_same_v) && + (!has_bias || (std::is_same_v && !beta_not_one && !alpha_not_one)) && (std::is_same_v || std::is_same_v)) { EXPECT_EQ(op_to_count["com.microsoft.QGemm"], 1); EXPECT_EQ(op_to_count["Gemm"], 0); @@ -860,6 +866,10 @@ void QDQTransformerGemmTests() { QDQTransformerGemmTests(false, true, true); QDQTransformerGemmTests(true, false, true); QDQTransformerGemmTests(true, true, true); + QDQTransformerGemmTests(false, false, false, true); + QDQTransformerGemmTests(false, true, false, true); + QDQTransformerGemmTests(true, false, false, true); + QDQTransformerGemmTests(true, true, false, true); } TEST(QDQTransformerTests, Gemm_U8U8U8) {