Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,13 @@ bool GemmNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node& n
return true;
}

// When bias is present, QGemm folds bias into the int32 accumulator before
// applying the alpha*sa*sb output scale, which would incorrectly scale the
// bias by alpha. Require alpha==1 and beta==1 so the fused path is exact.
if (node.GetAttributes().at("alpha").f() != 1.0) {
return false;
}

if (node.GetAttributes().at("beta").f() != 1.0) { // beta needs to be 1.0
return false;
}
Expand Down
14 changes: 12 additions & 2 deletions onnxruntime/test/optimizer/qdq_transformer_fastmath_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,8 @@ TEST(QDQTransformerTests, MatMul_S8S8U8_DisableFastMath) {
}

template <typename Input1Type, typename Input2Type, typename OutputType, typename BiasType = int32_t>
void QDQTransformerGemmTests(bool has_output_q, bool has_bias, bool beta_not_one = false, bool disable_fastmath = false) {
void QDQTransformerGemmTests(bool has_output_q, bool has_bias, bool beta_not_one = false,
bool disable_fastmath = false, bool alpha_not_one = false) {
auto test_case = [&](const std::vector<int64_t>& input1_shape, const std::vector<int64_t>& input2_shape,
bool use_contrib_qdq = false) {
auto build_test_case = [&](ModelTestBuilder& builder) {
Expand Down Expand Up @@ -396,12 +397,17 @@ void QDQTransformerGemmTests(bool has_output_q, bool has_bias, bool beta_not_one
if (beta_not_one) {
gemm_node->AddAttribute("beta", 2.0f);
}

if (alpha_not_one) {
gemm_node->AddAttribute("alpha", 2.0f);
}
};

auto check_binary_op_graph = [&](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq);
if ((!has_output_q || std::is_same_v<Input1Type, OutputType>) && (!has_bias || (std::is_same_v<BiasType, int32_t> && !beta_not_one)) &&
if ((!has_output_q || std::is_same_v<Input1Type, OutputType>) &&
(!has_bias || (std::is_same_v<BiasType, int32_t> && !beta_not_one && !alpha_not_one)) &&
(std::is_same_v<Input1Type, uint8_t> || std::is_same_v<Input2Type, int8_t>)) {
EXPECT_EQ(op_to_count["com.microsoft.QGemm"], 1);
EXPECT_EQ(op_to_count["Gemm"], 0);
Expand Down Expand Up @@ -490,6 +496,10 @@ void QDQTransformerGemmTests() {
QDQTransformerGemmTests<Input1Type, Input2Type, OutputType, BiasType>(false, true, true);
QDQTransformerGemmTests<Input1Type, Input2Type, OutputType, BiasType>(true, false, true);
QDQTransformerGemmTests<Input1Type, Input2Type, OutputType, BiasType>(true, true, true);
QDQTransformerGemmTests<Input1Type, Input2Type, OutputType, BiasType>(false, false, false, false, true);
QDQTransformerGemmTests<Input1Type, Input2Type, OutputType, BiasType>(false, true, false, false, true);
QDQTransformerGemmTests<Input1Type, Input2Type, OutputType, BiasType>(true, false, false, false, true);
QDQTransformerGemmTests<Input1Type, Input2Type, OutputType, BiasType>(true, true, false, false, true);
// dummy test to disable the fastmath session
QDQTransformerGemmTests<Input1Type, Input2Type, OutputType, BiasType>(true, true, true, true);
}
Expand Down
14 changes: 12 additions & 2 deletions onnxruntime/test/optimizer/qdq_transformer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,8 @@ TEST(QDQTransformerTests, MatMul_S8S8U8) {
}

template <typename Input1Type, typename Input2Type, typename OutputType, typename BiasType = int32_t>
void QDQTransformerGemmTests(bool has_output_q, bool has_bias, bool beta_not_one = false) {
void QDQTransformerGemmTests(bool has_output_q, bool has_bias, bool beta_not_one = false,
bool alpha_not_one = false) {
auto test_case = [&](const std::vector<int64_t>& input1_shape, const std::vector<int64_t>& input2_shape,
bool use_contrib_qdq = false) {
auto build_test_case = [&](ModelTestBuilder& builder) {
Expand Down Expand Up @@ -791,12 +792,17 @@ void QDQTransformerGemmTests(bool has_output_q, bool has_bias, bool beta_not_one
if (beta_not_one) {
gemm_node->AddAttribute("beta", 2.0f);
}

if (alpha_not_one) {
gemm_node->AddAttribute("alpha", 2.0f);
}
};

auto check_binary_op_graph = [&](InferenceSessionWrapper& session) {
auto op_to_count = CountOpsInGraph(session.GetGraph());
const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq);
if ((!has_output_q || std::is_same_v<Input1Type, OutputType>) && (!has_bias || (std::is_same_v<BiasType, int32_t> && !beta_not_one)) &&
if ((!has_output_q || std::is_same_v<Input1Type, OutputType>) &&
(!has_bias || (std::is_same_v<BiasType, int32_t> && !beta_not_one && !alpha_not_one)) &&
(std::is_same_v<Input1Type, uint8_t> || std::is_same_v<Input2Type, int8_t>)) {
EXPECT_EQ(op_to_count["com.microsoft.QGemm"], 1);
EXPECT_EQ(op_to_count["Gemm"], 0);
Expand Down Expand Up @@ -860,6 +866,10 @@ void QDQTransformerGemmTests() {
QDQTransformerGemmTests<Input1Type, Input2Type, OutputType, BiasType>(false, true, true);
QDQTransformerGemmTests<Input1Type, Input2Type, OutputType, BiasType>(true, false, true);
QDQTransformerGemmTests<Input1Type, Input2Type, OutputType, BiasType>(true, true, true);
QDQTransformerGemmTests<Input1Type, Input2Type, OutputType, BiasType>(false, false, false, true);
QDQTransformerGemmTests<Input1Type, Input2Type, OutputType, BiasType>(false, true, false, true);
QDQTransformerGemmTests<Input1Type, Input2Type, OutputType, BiasType>(true, false, false, true);
QDQTransformerGemmTests<Input1Type, Input2Type, OutputType, BiasType>(true, true, false, true);
}

TEST(QDQTransformerTests, Gemm_U8U8U8) {
Expand Down