Skip to content

Commit 0b278bb

Browse files
authored
Update optimizer opset version checks for latest ONNX opset 26 (#28966)
This pull request expands support for additional ONNX opset versions in the attention fusion optimization code, making the optimizer compatible with newer and more diverse ONNX models. The changes primarily update the accepted opset versions for various operators such as `Transpose`, `Reshape`, `Squeeze`, `Unsqueeze`, `Shape`, and others across multiple functions. This ensures broader model compatibility and improves the robustness of the fusion logic. **Expanded opset version support for attention fusion:** * Updated accepted opset versions for key operators (`Transpose`, `Reshape`, `Squeeze`, `Unsqueeze`, `Shape`, `Add`, `Mul`, `Sub`, `Div`, `Cast`, etc.) in the main attention fusion logic (`attention_fusion.cc`), allowing matching and fusion of newer ONNX models using these operators at opsets up to 25. [[1]](diffhunk://#diff-2d859229c1824649bd6a37eaefa52306394bc6c3aa341d6deff1d4f2fb9902f3L352-R367) [[2]](diffhunk://#diff-2d859229c1824649bd6a37eaefa52306394bc6c3aa341d6deff1d4f2fb9902f3L382-R384) [[3]](diffhunk://#diff-2d859229c1824649bd6a37eaefa52306394bc6c3aa341d6deff1d4f2fb9902f3L394-R395) [[4]](diffhunk://#diff-2d859229c1824649bd6a37eaefa52306394bc6c3aa341d6deff1d4f2fb9902f3L405-R405) [[5]](diffhunk://#diff-2d859229c1824649bd6a37eaefa52306394bc6c3aa341d6deff1d4f2fb9902f3L463-R471) [[6]](diffhunk://#diff-2d859229c1824649bd6a37eaefa52306394bc6c3aa341d6deff1d4f2fb9902f3L500-R500) [[7]](diffhunk://#diff-2d859229c1824649bd6a37eaefa52306394bc6c3aa341d6deff1d4f2fb9902f3L514-R514) [[8]](diffhunk://#diff-2d859229c1824649bd6a37eaefa52306394bc6c3aa341d6deff1d4f2fb9902f3L923-R927) [[9]](diffhunk://#diff-2d859229c1824649bd6a37eaefa52306394bc6c3aa341d6deff1d4f2fb9902f3L956-R958) [[10]](diffhunk://#diff-2d859229c1824649bd6a37eaefa52306394bc6c3aa341d6deff1d4f2fb9902f3L1073-R1074) [[11]](diffhunk://#diff-2d859229c1824649bd6a37eaefa52306394bc6c3aa341d6deff1d4f2fb9902f3L1166-R1166) [[12]](diffhunk://#diff-2d859229c1824649bd6a37eaefa52306394bc6c3aa341d6deff1d4f2fb9902f3L1268-R1275) **Helper and mask subgraph matching improvements:** * Broadened opset version checks for subgraph matching in helper functions, including those for Gemm subgraphs, unidirectional mask subgraphs, input mask subgraphs, and past subgraph matching, to support additional opset versions and operator variants. [[1]](diffhunk://#diff-97696a1ea660259af1c02da793abf7a807de115421a0ec32f1e36f39371e4e16L77-R84) [[2]](diffhunk://#diff-97696a1ea660259af1c02da793abf7a807de115421a0ec32f1e36f39371e4e16L169-R171) [[3]](diffhunk://#diff-97696a1ea660259af1c02da793abf7a807de115421a0ec32f1e36f39371e4e16L378-R379) [[4]](diffhunk://#diff-97696a1ea660259af1c02da793abf7a807de115421a0ec32f1e36f39371e4e16L395-R402) [[5]](diffhunk://#diff-97696a1ea660259af1c02da793abf7a807de115421a0ec32f1e36f39371e4e16L457-R458) [[6]](diffhunk://#diff-97696a1ea660259af1c02da793abf7a807de115421a0ec32f1e36f39371e4e16L485-R487) [[7]](diffhunk://#diff-97696a1ea660259af1c02da793abf7a807de115421a0ec32f1e36f39371e4e16L635-R637) [[8]](diffhunk://#diff-97696a1ea660259af1c02da793abf7a807de115421a0ec32f1e36f39371e4e16L769-R769) [[9]](diffhunk://#diff-97696a1ea660259af1c02da793abf7a807de115421a0ec32f1e36f39371e4e16L794-R796) [[10]](diffhunk://#diff-97696a1ea660259af1c02da793abf7a807de115421a0ec32f1e36f39371e4e16L812-R814) [[11]](diffhunk://#diff-97696a1ea660259af1c02da793abf7a807de115421a0ec32f1e36f39371e4e16L890-R890) These changes collectively future-proof the attention fusion optimizer for a wider range of ONNX models and operator versions, reducing the likelihood of unsupported patterns during optimization.
1 parent 7db7893 commit 0b278bb

10 files changed

Lines changed: 743 additions & 127 deletions

onnxruntime/core/graph/graph_utils.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "core/common/logging/logging.h"
99

1010
#include <algorithm>
11+
#include <limits>
1112
#include <queue>
1213
#include <string>
1314
#include <vector>
@@ -411,6 +412,14 @@ const ONNX_NAMESPACE::AttributeProto* GetNodeAttribute(const Node& node, const s
411412
return iter == attrs.end() ? nullptr : &iter->second;
412413
}
413414

415+
bool IsFullShapeNode(const Node& node) {
416+
const auto* start_attr = GetNodeAttribute(node, "start");
417+
const auto* end_attr = GetNodeAttribute(node, "end");
418+
// end=INT64_MAX is the runtime default meaning "all dimensions" (full shape).
419+
return (!start_attr || start_attr->i() == 0) &&
420+
(!end_attr || end_attr->i() == std::numeric_limits<int64_t>::max());
421+
}
422+
414423
static NodeArg& GetOrCreateNodeArg(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer) {
415424
ONNX_NAMESPACE::TypeProto new_type;
416425
auto* typeproto_tensor = new_type.mutable_tensor_type();

onnxruntime/core/graph/graph_utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "core/graph/onnx_protobuf.h"
99
#include "core/graph/graph.h"
1010

11+
#include <limits>
1112
#include <string>
1213
#include <vector>
1314

@@ -31,6 +32,10 @@ bool IsSupportedOptypeVersionAndDomain(const Node& node,
3132
/** Returns the attribute of a Node with a given name. */
3233
const ONNX_NAMESPACE::AttributeProto* GetNodeAttribute(const Node& node, const std::string& attr_name);
3334

35+
/** Checks whether a Shape node returns the full tensor shape (all dimensions).
36+
* Returns false if start/end attributes restrict the output to a subset of dimensions. */
37+
bool IsFullShapeNode(const Node& node);
38+
3439
/** Add a new initializer to 'graph'.
3540
Checks that new_initializer does not already exist in 'graph' before adding it.
3641
@returns The NodeArg for the new initializer.

onnxruntime/core/optimizer/attention_fusion.cc

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -349,22 +349,22 @@ static bool TryFuseMobileClipMHA(Node& qkv_matmul,
349349

350350
const Node* sequence_transpose = graph_utils::GetInputNode(qkv_matmul, 0);
351351
if (sequence_transpose == nullptr ||
352-
!graph_utils::IsSupportedOptypeVersionAndDomain(*sequence_transpose, "Transpose", {1, 13}, kOnnxDomain) ||
352+
!graph_utils::IsSupportedOptypeVersionAndDomain(*sequence_transpose, "Transpose", {1, 13, 21, 23, 24, 25}, kOnnxDomain) ||
353353
!HasExpectedPerm(*sequence_transpose, {0, 2, 1}) ||
354354
!optimizer_utils::CheckOutputEdges(graph, *sequence_transpose, 1)) {
355355
return false;
356356
}
357357

358358
const Node* input_reshape = graph_utils::GetInputNode(*sequence_transpose, 0);
359359
if (input_reshape == nullptr ||
360-
!graph_utils::IsSupportedOptypeVersionAndDomain(*input_reshape, "Reshape", {5, 13, 14}, kOnnxDomain) ||
360+
!graph_utils::IsSupportedOptypeVersionAndDomain(*input_reshape, "Reshape", {5, 13, 14, 19, 21, 23, 24, 25}, kOnnxDomain) ||
361361
!optimizer_utils::CheckOutputEdges(graph, *input_reshape, 1)) {
362362
return fail("missing input Reshape before sequence transpose");
363363
}
364364

365365
Node* qkv_reshape = GetOnlyChildByOutputIndex(graph, qkv_matmul, 0, "Reshape");
366366
if (qkv_reshape == nullptr ||
367-
!graph_utils::IsSupportedOptypeVersionAndDomain(*qkv_reshape, "Reshape", {5, 13, 14}, kOnnxDomain) ||
367+
!graph_utils::IsSupportedOptypeVersionAndDomain(*qkv_reshape, "Reshape", {5, 13, 14, 19, 21, 23, 24, 25}, kOnnxDomain) ||
368368
!optimizer_utils::CheckOutputEdges(graph, *qkv_reshape, 1)) {
369369
return fail("qkv Reshape after MatMul not matched");
370370
}
@@ -379,9 +379,9 @@ static bool TryFuseMobileClipMHA(Node& qkv_matmul,
379379
Node* k_squeeze = GetOnlyChildByOutputIndex(graph, *split, 1, "Squeeze");
380380
Node* v_transpose = GetOnlyChildByOutputIndex(graph, *split, 2, "Transpose");
381381
if (q_transpose == nullptr || k_squeeze == nullptr || v_transpose == nullptr ||
382-
!graph_utils::IsSupportedOptypeVersionAndDomain(*q_transpose, "Transpose", {1, 13}, kOnnxDomain) ||
383-
!graph_utils::IsSupportedOptypeVersionAndDomain(*k_squeeze, "Squeeze", {13}, kOnnxDomain) ||
384-
!graph_utils::IsSupportedOptypeVersionAndDomain(*v_transpose, "Transpose", {1, 13}, kOnnxDomain) ||
382+
!graph_utils::IsSupportedOptypeVersionAndDomain(*q_transpose, "Transpose", {1, 13, 21, 23, 24, 25}, kOnnxDomain) ||
383+
!graph_utils::IsSupportedOptypeVersionAndDomain(*k_squeeze, "Squeeze", {13, 21, 23, 24, 25}, kOnnxDomain) ||
384+
!graph_utils::IsSupportedOptypeVersionAndDomain(*v_transpose, "Transpose", {1, 13, 21, 23, 24, 25}, kOnnxDomain) ||
385385
!HasExpectedPerm(*q_transpose, {2, 0, 3, 1, 4}) ||
386386
!HasExpectedPerm(*v_transpose, {2, 0, 3, 1, 4}) ||
387387
!HasExpectedAxesInput(graph, *k_squeeze, {2})) {
@@ -391,8 +391,8 @@ static bool TryFuseMobileClipMHA(Node& qkv_matmul,
391391
Node* q_squeeze = GetOnlyChildByOutputIndex(graph, *q_transpose, 0, "Squeeze");
392392
Node* v_squeeze = GetOnlyChildByOutputIndex(graph, *v_transpose, 0, "Squeeze");
393393
if (q_squeeze == nullptr || v_squeeze == nullptr ||
394-
!graph_utils::IsSupportedOptypeVersionAndDomain(*q_squeeze, "Squeeze", {13}, kOnnxDomain) ||
395-
!graph_utils::IsSupportedOptypeVersionAndDomain(*v_squeeze, "Squeeze", {13}, kOnnxDomain) ||
394+
!graph_utils::IsSupportedOptypeVersionAndDomain(*q_squeeze, "Squeeze", {13, 21, 23, 24, 25}, kOnnxDomain) ||
395+
!graph_utils::IsSupportedOptypeVersionAndDomain(*v_squeeze, "Squeeze", {13, 21, 23, 24, 25}, kOnnxDomain) ||
396396
!HasExpectedAxesInput(graph, *q_squeeze, {0}) ||
397397
!HasExpectedAxesInput(graph, *v_squeeze, {0})) {
398398
return fail("q/v squeeze pattern not matched");
@@ -402,7 +402,7 @@ static bool TryFuseMobileClipMHA(Node& qkv_matmul,
402402
Node* k_transpose = GetOnlyChildByOutputIndex(graph, *k_squeeze, 0, "Transpose");
403403
if (q_scale_mul == nullptr || k_transpose == nullptr ||
404404
!graph_utils::IsSupportedOptypeVersionAndDomain(*q_scale_mul, "Mul", {7, 13, 14}, kOnnxDomain) ||
405-
!graph_utils::IsSupportedOptypeVersionAndDomain(*k_transpose, "Transpose", {1, 13}, kOnnxDomain) ||
405+
!graph_utils::IsSupportedOptypeVersionAndDomain(*k_transpose, "Transpose", {1, 13, 21, 23, 24, 25}, kOnnxDomain) ||
406406
!HasExpectedPerm(*k_transpose, {0, 2, 3, 1})) {
407407
return fail("q scale Mul or k Transpose(0,2,3,1) not matched");
408408
}
@@ -460,15 +460,15 @@ static bool TryFuseMobileClipMHA(Node& qkv_matmul,
460460

461461
Node* transpose_3 = GetOnlyChildByOutputIndex(graph, *qkv_matmul_1, 0, "Transpose");
462462
if (transpose_3 == nullptr ||
463-
!graph_utils::IsSupportedOptypeVersionAndDomain(*transpose_3, "Transpose", {1, 13}, kOnnxDomain) ||
463+
!graph_utils::IsSupportedOptypeVersionAndDomain(*transpose_3, "Transpose", {1, 13, 21, 23, 24, 25}, kOnnxDomain) ||
464464
!HasExpectedPerm(*transpose_3, {0, 2, 1, 3}) ||
465465
!optimizer_utils::CheckOutputEdges(graph, *transpose_3, 1)) {
466466
return fail("output Transpose(0,2,1,3) not matched");
467467
}
468468

469469
Node* reshape_2 = GetOnlyChildByOutputIndex(graph, *transpose_3, 0, "Reshape");
470470
if (reshape_2 == nullptr ||
471-
!graph_utils::IsSupportedOptypeVersionAndDomain(*reshape_2, "Reshape", {5, 13, 14}, kOnnxDomain) ||
471+
!graph_utils::IsSupportedOptypeVersionAndDomain(*reshape_2, "Reshape", {5, 13, 14, 19, 21, 23, 24, 25}, kOnnxDomain) ||
472472
!optimizer_utils::CheckOutputEdges(graph, *reshape_2, 1)) {
473473
return fail("output Reshape not matched");
474474
}
@@ -497,7 +497,7 @@ static bool TryFuseMobileClipMHA(Node& qkv_matmul,
497497
if (proj_gemm == nullptr) {
498498
proj_gemm_input_reshape = GetOnlyChildByOutputIndex(graph, *reshape_2, 0, "Reshape");
499499
if (proj_gemm_input_reshape == nullptr ||
500-
!graph_utils::IsSupportedOptypeVersionAndDomain(*proj_gemm_input_reshape, "Reshape", {5, 13, 14}, kOnnxDomain) ||
500+
!graph_utils::IsSupportedOptypeVersionAndDomain(*proj_gemm_input_reshape, "Reshape", {5, 13, 14, 19, 21, 23, 24, 25}, kOnnxDomain) ||
501501
!optimizer_utils::CheckOutputEdges(graph, *proj_gemm_input_reshape, 1)) {
502502
return fail("projection MatMul/Gemm not matched");
503503
}
@@ -511,7 +511,7 @@ static bool TryFuseMobileClipMHA(Node& qkv_matmul,
511511

512512
proj_gemm_output_reshape = GetOnlyChildByOutputIndex(graph, *proj_gemm, 0, "Reshape");
513513
if (proj_gemm_output_reshape == nullptr ||
514-
!graph_utils::IsSupportedOptypeVersionAndDomain(*proj_gemm_output_reshape, "Reshape", {5, 13, 14}, kOnnxDomain) ||
514+
!graph_utils::IsSupportedOptypeVersionAndDomain(*proj_gemm_output_reshape, "Reshape", {5, 13, 14, 19, 21, 23, 24, 25}, kOnnxDomain) ||
515515
!optimizer_utils::CheckOutputEdges(graph, *proj_gemm_output_reshape, 1)) {
516516
return fail("normalized projection Gemm output Reshape not matched");
517517
}
@@ -920,11 +920,11 @@ static bool FuseSubGraphQKImpl(Node& layer_norm,
920920
}
921921

922922
std::vector<graph_utils::EdgeEndToMatch> q_path{
923-
{0, 0, "Transpose", {1, 13}, kOnnxDomain},
924-
{0, 0, "Reshape", {5, 13}, kOnnxDomain},
925-
{0, 0, "Add", {7, 13}, kOnnxDomain},
923+
{0, 0, "Transpose", {1, 13, 21, 23, 24, 25}, kOnnxDomain},
924+
{0, 0, "Reshape", {5, 13, 14, 19, 21, 23, 24, 25}, kOnnxDomain},
925+
{0, 0, "Add", {7, 13, 14}, kOnnxDomain},
926926
{0, 0, "MatMul", {1, 9, 13}, kOnnxDomain},
927-
{0, 0, "LayerNormalization", {1}, kOnnxDomain}};
927+
{0, 0, "LayerNormalization", {1, 17}, kOnnxDomain}};
928928
if (!graph_utils::FindPath(edges[edges.size() - 1]->GetNode(), true, q_path, edges, logger)) {
929929
DEBUG_LOG("Failed to find path for q");
930930
return false;
@@ -953,9 +953,9 @@ static bool FuseSubGraphQKImpl(Node& layer_norm,
953953
}
954954

955955
std::vector<graph_utils::EdgeEndToMatch> k_path{
956-
{0, 1, "Transpose", {1, 13}, kOnnxDomain},
957-
{0, 0, "Reshape", {5, 13}, kOnnxDomain},
958-
{0, 0, "Add", {7, 13}, kOnnxDomain},
956+
{0, 1, "Transpose", {1, 13, 21, 23, 24, 25}, kOnnxDomain},
957+
{0, 0, "Reshape", {5, 13, 14, 19, 21, 23, 24, 25}, kOnnxDomain},
958+
{0, 0, "Add", {7, 13, 14}, kOnnxDomain},
959959
{0, 0, "MatMul", {1, 9, 13}, kOnnxDomain},
960960
{0, 0, "LayerNormalization", {1, 17}, kOnnxDomain}};
961961

@@ -1070,8 +1070,8 @@ static bool FuseSubGraphQK(Node& layer_norm,
10701070
const logging::Logger& logger) {
10711071
// path to q
10721072
std::vector<graph_utils::EdgeEndToMatch> q_varience_path{
1073-
{0, 0, "Div", {7, 13}, kOnnxDomain},
1074-
{0, 0, "MatMul", {1, 9}, kOnnxDomain}};
1073+
{0, 0, "Div", {7, 13, 14}, kOnnxDomain},
1074+
{0, 0, "MatMul", {1, 9, 13}, kOnnxDomain}};
10751075
std::vector<const Node::EdgeEnd*> edges;
10761076
if (!graph_utils::FindPath(*(mask_nodes.add), true, q_varience_path, edges, logger)) {
10771077
DEBUG_LOG("Failed to find path for q");
@@ -1163,7 +1163,7 @@ static bool FuseSubGraphQKDistilBert(Node& layer_norm,
11631163
// path to q
11641164
std::vector<graph_utils::EdgeEndToMatch> q_varience_path{
11651165
{0, 2, "MatMul", {1, 9, 13}, kOnnxDomain},
1166-
{0, 0, "Div", {7, 13}, kOnnxDomain}};
1166+
{0, 0, "Div", {7, 13, 14}, kOnnxDomain}};
11671167
std::vector<const Node::EdgeEnd*> edges;
11681168
if (!graph_utils::FindPath(*(mask_nodes.where), true, q_varience_path, edges, logger)) {
11691169
DEBUG_LOG("Failed to find path for q");
@@ -1265,14 +1265,14 @@ bool AttentionFusion::FuseSubGraph(Node& layer_norm,
12651265
std::map<std::string, NodeArg*>& mask_int32_map,
12661266
const logging::Logger& logger) {
12671267
std::vector<graph_utils::EdgeEndToMatch> parent_path{
1268-
{0, 0, "Add", {7, 13}, kOnnxDomain},
1268+
{0, 0, "Add", {7, 13, 14}, kOnnxDomain},
12691269
{0, 0, "MatMul", {1, 9, 13}, kOnnxDomain},
1270-
{0, 0, "Reshape", {5, 13}, kOnnxDomain},
1271-
{0, 0, "Transpose", {1, 13}, kOnnxDomain},
1270+
{0, 0, "Reshape", {5, 13, 14, 19, 21, 23, 24, 25}, kOnnxDomain},
1271+
{0, 0, "Transpose", {1, 13, 21, 23, 24, 25}, kOnnxDomain},
12721272
{0, 0, "MatMul", {1, 9, 13}, kOnnxDomain},
1273-
{0, 1, "Transpose", {1, 13}, kOnnxDomain},
1274-
{0, 0, "Reshape", {5, 13}, kOnnxDomain},
1275-
{0, 0, "Add", {7, 13}, kOnnxDomain},
1273+
{0, 1, "Transpose", {1, 13, 21, 23, 24, 25}, kOnnxDomain},
1274+
{0, 0, "Reshape", {5, 13, 14, 19, 21, 23, 24, 25}, kOnnxDomain},
1275+
{0, 0, "Add", {7, 13, 14}, kOnnxDomain},
12761276
{0, 0, "MatMul", {1, 9, 13}, kOnnxDomain},
12771277
{0, 0, "LayerNormalization", {1, 17}, kOnnxDomain}};
12781278

0 commit comments

Comments
 (0)