Skip to content

Commit e8c90e9

Browse files
committed
Reject QDQ Gemm→QGemm fusion when alpha != 1 with bias (#28130)
The Gemm→QGemm QDQ fusion selector only validated beta == 1, letting Gemms with alpha != 1 and a bias through. QGemm broadcasts the int32 bias into the accumulator before applying the alpha*sa*sb output scale, so the bias ends up scaled by alpha too — producing incorrect outputs when alpha != 1 (bias == 0 masks the issue). Add an alpha == 1 check alongside the existing beta == 1 check in GemmNodeGroupSelector::Check (only when bias is present — without bias the fused path is still exact). Extend QDQTransformerGemmTests and the fastmath variant with an alpha_not_one parameter so the regression is covered. Follow-up tracked in the issue: absorb alpha into the int32 bias in GemmReplaceWithQuant so alpha != 1 cases can keep the fusion.
1 parent 0e72188 commit e8c90e9

3 files changed

Lines changed: 31 additions & 4 deletions

File tree

onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -835,6 +835,13 @@ bool GemmNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& n
835835
return true;
836836
}
837837

838+
// When bias is present, QGemm folds bias into the int32 accumulator before
839+
// applying the alpha*sa*sb output scale, which would incorrectly scale the
840+
// bias by alpha. Require alpha==1 and beta==1 so the fused path is exact.
841+
if (node.GetAttributes().at("alpha").f() != 1.0) {
842+
return false;
843+
}
844+
838845
if (node.GetAttributes().at("beta").f() != 1.0) { // beta needs to be 1.0
839846
return false;
840847
}

onnxruntime/test/optimizer/qdq_transformer_fastmath_test.cc

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,8 @@ TEST(QDQTransformerTests, MatMul_S8S8U8_DisableFastMath) {
323323
}
324324

325325
template <typename Input1Type, typename Input2Type, typename OutputType, typename BiasType = int32_t>
326-
void QDQTransformerGemmTests(bool has_output_q, bool has_bias, bool beta_not_one = false, bool disable_fastmath = false) {
326+
void QDQTransformerGemmTests(bool has_output_q, bool has_bias, bool beta_not_one = false,
327+
bool disable_fastmath = false, bool alpha_not_one = false) {
327328
auto test_case = [&](const std::vector<int64_t>& input1_shape, const std::vector<int64_t>& input2_shape,
328329
bool use_contrib_qdq = false) {
329330
auto build_test_case = [&](ModelTestBuilder& builder) {
@@ -396,12 +397,17 @@ void QDQTransformerGemmTests(bool has_output_q, bool has_bias, bool beta_not_one
396397
if (beta_not_one) {
397398
gemm_node->AddAttribute("beta", 2.0f);
398399
}
400+
401+
if (alpha_not_one) {
402+
gemm_node->AddAttribute("alpha", 2.0f);
403+
}
399404
};
400405

401406
auto check_binary_op_graph = [&](InferenceSessionWrapper& session) {
402407
auto op_to_count = CountOpsInGraph(session.GetGraph());
403408
const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq);
404-
if ((!has_output_q || std::is_same_v<Input1Type, OutputType>) && (!has_bias || (std::is_same_v<BiasType, int32_t> && !beta_not_one)) &&
409+
if ((!has_output_q || std::is_same_v<Input1Type, OutputType>) &&
410+
(!has_bias || (std::is_same_v<BiasType, int32_t> && !beta_not_one && !alpha_not_one)) &&
405411
(std::is_same_v<Input1Type, uint8_t> || std::is_same_v<Input2Type, int8_t>)) {
406412
EXPECT_EQ(op_to_count["com.microsoft.QGemm"], 1);
407413
EXPECT_EQ(op_to_count["Gemm"], 0);
@@ -490,6 +496,10 @@ void QDQTransformerGemmTests() {
490496
QDQTransformerGemmTests<Input1Type, Input2Type, OutputType, BiasType>(false, true, true);
491497
QDQTransformerGemmTests<Input1Type, Input2Type, OutputType, BiasType>(true, false, true);
492498
QDQTransformerGemmTests<Input1Type, Input2Type, OutputType, BiasType>(true, true, true);
499+
QDQTransformerGemmTests<Input1Type, Input2Type, OutputType, BiasType>(false, false, false, false, true);
500+
QDQTransformerGemmTests<Input1Type, Input2Type, OutputType, BiasType>(false, true, false, false, true);
501+
QDQTransformerGemmTests<Input1Type, Input2Type, OutputType, BiasType>(true, false, false, false, true);
502+
QDQTransformerGemmTests<Input1Type, Input2Type, OutputType, BiasType>(true, true, false, false, true);
493503
// dummy test to disable the fastmath session
494504
QDQTransformerGemmTests<Input1Type, Input2Type, OutputType, BiasType>(true, true, true, true);
495505
}

onnxruntime/test/optimizer/qdq_transformer_test.cc

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -718,7 +718,8 @@ TEST(QDQTransformerTests, MatMul_S8S8U8) {
718718
}
719719

720720
template <typename Input1Type, typename Input2Type, typename OutputType, typename BiasType = int32_t>
721-
void QDQTransformerGemmTests(bool has_output_q, bool has_bias, bool beta_not_one = false) {
721+
void QDQTransformerGemmTests(bool has_output_q, bool has_bias, bool beta_not_one = false,
722+
bool alpha_not_one = false) {
722723
auto test_case = [&](const std::vector<int64_t>& input1_shape, const std::vector<int64_t>& input2_shape,
723724
bool use_contrib_qdq = false) {
724725
auto build_test_case = [&](ModelTestBuilder& builder) {
@@ -791,12 +792,17 @@ void QDQTransformerGemmTests(bool has_output_q, bool has_bias, bool beta_not_one
791792
if (beta_not_one) {
792793
gemm_node->AddAttribute("beta", 2.0f);
793794
}
795+
796+
if (alpha_not_one) {
797+
gemm_node->AddAttribute("alpha", 2.0f);
798+
}
794799
};
795800

796801
auto check_binary_op_graph = [&](InferenceSessionWrapper& session) {
797802
auto op_to_count = CountOpsInGraph(session.GetGraph());
798803
const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq);
799-
if ((!has_output_q || std::is_same_v<Input1Type, OutputType>) && (!has_bias || (std::is_same_v<BiasType, int32_t> && !beta_not_one)) &&
804+
if ((!has_output_q || std::is_same_v<Input1Type, OutputType>) &&
805+
(!has_bias || (std::is_same_v<BiasType, int32_t> && !beta_not_one && !alpha_not_one)) &&
800806
(std::is_same_v<Input1Type, uint8_t> || std::is_same_v<Input2Type, int8_t>)) {
801807
EXPECT_EQ(op_to_count["com.microsoft.QGemm"], 1);
802808
EXPECT_EQ(op_to_count["Gemm"], 0);
@@ -860,6 +866,10 @@ void QDQTransformerGemmTests() {
860866
QDQTransformerGemmTests<Input1Type, Input2Type, OutputType, BiasType>(false, true, true);
861867
QDQTransformerGemmTests<Input1Type, Input2Type, OutputType, BiasType>(true, false, true);
862868
QDQTransformerGemmTests<Input1Type, Input2Type, OutputType, BiasType>(true, true, true);
869+
QDQTransformerGemmTests<Input1Type, Input2Type, OutputType, BiasType>(false, false, false, true);
870+
QDQTransformerGemmTests<Input1Type, Input2Type, OutputType, BiasType>(false, true, false, true);
871+
QDQTransformerGemmTests<Input1Type, Input2Type, OutputType, BiasType>(true, false, false, true);
872+
QDQTransformerGemmTests<Input1Type, Input2Type, OutputType, BiasType>(true, true, false, true);
863873
}
864874

865875
TEST(QDQTransformerTests, Gemm_U8U8U8) {

0 commit comments

Comments
 (0)