Skip to content

Commit 06ef3db

Browse files
committed
Extract IsFullShapeNode helper into graph_utils
1 parent 7941263 commit 06ef3db

5 files changed

Lines changed: 35 additions & 55 deletions

File tree

onnxruntime/core/graph/graph_utils.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "core/common/logging/logging.h"
99

1010
#include <algorithm>
11+
#include <limits>
1112
#include <queue>
1213
#include <string>
1314
#include <vector>
@@ -411,6 +412,14 @@ const ONNX_NAMESPACE::AttributeProto* GetNodeAttribute(const Node& node, const s
411412
return iter == attrs.end() ? nullptr : &iter->second;
412413
}
413414

415+
bool IsFullShapeNode(const Node& node) {
416+
const auto* start_attr = GetNodeAttribute(node, "start");
417+
const auto* end_attr = GetNodeAttribute(node, "end");
418+
// end=INT64_MAX is the runtime default meaning "all dimensions" (full shape).
419+
return (!start_attr || start_attr->i() == 0) &&
420+
(!end_attr || end_attr->i() == std::numeric_limits<int64_t>::max());
421+
}
422+
414423
static NodeArg& GetOrCreateNodeArg(Graph& graph, const ONNX_NAMESPACE::TensorProto& new_initializer) {
415424
ONNX_NAMESPACE::TypeProto new_type;
416425
auto* typeproto_tensor = new_type.mutable_tensor_type();

onnxruntime/core/graph/graph_utils.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "core/graph/onnx_protobuf.h"
99
#include "core/graph/graph.h"
1010

11+
#include <limits>
1112
#include <string>
1213
#include <vector>
1314

@@ -31,6 +32,10 @@ bool IsSupportedOptypeVersionAndDomain(const Node& node,
3132
/** Returns the attribute of a Node with a given name. */
3233
const ONNX_NAMESPACE::AttributeProto* GetNodeAttribute(const Node& node, const std::string& attr_name);
3334

35+
/** Checks whether a Shape node returns the full tensor shape (all dimensions).
36+
* Returns false if start/end attributes restrict the output to a subset of dimensions. */
37+
bool IsFullShapeNode(const Node& node);
38+
3439
/** Add a new initializer to 'graph'.
3540
Checks that new_initializer does not already exist in 'graph' before adding it.
3641
@returns The NodeArg for the new initializer.

onnxruntime/core/optimizer/attention_fusion_helper.h

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -98,19 +98,12 @@ bool MatchGemmSubgraph(Graph& graph,
9898
const Node& slice = edges[6]->GetNode();
9999
const Node& shape_before_slice = edges[7]->GetNode();
100100

101-
// Opset 15+ added optional start/end attributes to Shape, allowing it to return only a
102-
// subset of dimensions ("partial shape"). The downstream Slice/Squeeze/Gather nodes assume
103-
// Shape returns the full tensor shape. If start != 0 or end is not the default (INT64_MAX),
104-
// the fusion would produce incorrect index mapping. Reject such cases.
105-
if (shape_before_slice.SinceVersion() >= 15) {
106-
const ONNX_NAMESPACE::AttributeProto* start_attr = graph_utils::GetNodeAttribute(shape_before_slice, "start");
107-
const ONNX_NAMESPACE::AttributeProto* end_attr = graph_utils::GetNodeAttribute(shape_before_slice, "end");
108-
// end=INT64_MAX is the runtime default, meaning "all dimensions" (i.e. full shape).
109-
if (!((!start_attr || start_attr->i() == 0) &&
110-
(!end_attr || end_attr->i() == std::numeric_limits<int64_t>::max()))) {
111-
DEBUG_LOG("Shape node has non-default start/end attributes");
112-
return false;
113-
}
101+
// The downstream Slice/Squeeze/Gather nodes assume Shape returns the full tensor shape so
102+
// that indices map directly to tensor dimensions. A partial shape (opset 15+ start/end
103+
// attributes) would produce incorrect index mapping.
104+
if (!graph_utils::IsFullShapeNode(shape_before_slice)) {
105+
DEBUG_LOG("Shape node has non-default start/end attributes");
106+
return false;
114107
}
115108

116109
const auto& subgraph_input = shape_before_slice.InputDefs()[0];

onnxruntime/core/optimizer/embed_layer_norm_fusion.cc

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
// Licensed under the MIT License.
33
#include "core/optimizer/embed_layer_norm_fusion.h"
44

5-
#include <limits>
6-
75
#include "core/common/span_utils.h"
86
#include "core/optimizer/initializer.h"
97
#include "core/graph/contrib_ops/contrib_defs.h"
@@ -140,20 +138,12 @@ static bool MatchInputToConcatSubgraph(
140138
}
141139
}
142140

143-
// Opset 15+ added optional start/end attributes to Shape, allowing it to return only a
144-
// subset of dimensions ("partial shape"). The Gather(index=0) below assumes Shape returns
145-
// the full tensor shape. If start != 0 or end is not the default (INT64_MAX), the Gather
146-
// would pick the wrong dimension and fusion would be incorrect. Reject such cases.
141+
// The Gather(index=0) below assumes Shape returns the full tensor shape. A partial shape
142+
// (opset 15+ start/end attributes) would cause Gather to pick the wrong dimension.
147143
const Node& shape_node_path1 = edges[shape_index]->GetNode();
148-
if (shape_node_path1.SinceVersion() >= 15) {
149-
const ONNX_NAMESPACE::AttributeProto* start_attr = graph_utils::GetNodeAttribute(shape_node_path1, "start");
150-
const ONNX_NAMESPACE::AttributeProto* end_attr = graph_utils::GetNodeAttribute(shape_node_path1, "end");
151-
// end=INT64_MAX is the runtime default, meaning "all dimensions" (i.e. full shape).
152-
if (!((!start_attr || start_attr->i() == 0) &&
153-
(!end_attr || end_attr->i() == std::numeric_limits<int64_t>::max()))) {
154-
DEBUG_LOG("Shape node in path 1 has non-default start/end attributes.");
155-
return false;
156-
}
144+
if (!graph_utils::IsFullShapeNode(shape_node_path1)) {
145+
DEBUG_LOG("Shape node in path 1 has non-default start/end attributes.");
146+
return false;
157147
}
158148

159149
Node& concat_node = *graph.GetNode(edges[0]->GetNode().Index());
@@ -185,19 +175,11 @@ static bool MatchInputToConcatSubgraph(
185175
Node& gather_node_1 = *graph.GetNode(edges[1]->GetNode().Index());
186176
Node& shape_node_1 = *graph.GetNode(edges[2]->GetNode().Index());
187177

188-
// Opset 15+ added optional start/end attributes to Shape, allowing it to return only a
189-
// subset of dimensions ("partial shape"). The Gather(index=1) below assumes Shape returns
190-
// the full tensor shape. If start != 0 or end is not the default (INT64_MAX), the Gather
191-
// would pick the wrong dimension and fusion would be incorrect. Reject such cases.
192-
if (shape_node_1.SinceVersion() >= 15) {
193-
const ONNX_NAMESPACE::AttributeProto* start_attr = graph_utils::GetNodeAttribute(shape_node_1, "start");
194-
const ONNX_NAMESPACE::AttributeProto* end_attr = graph_utils::GetNodeAttribute(shape_node_1, "end");
195-
// end=INT64_MAX is the runtime default, meaning "all dimensions" (i.e. full shape).
196-
if (!((!start_attr || start_attr->i() == 0) &&
197-
(!end_attr || end_attr->i() == std::numeric_limits<int64_t>::max()))) {
198-
DEBUG_LOG("Shape node in path 2 has non-default start/end attributes.");
199-
return false;
200-
}
178+
// The Gather(index=1) below assumes Shape returns the full tensor shape. A partial shape
179+
// (opset 15+ start/end attributes) would cause Gather to pick the wrong dimension.
180+
if (!graph_utils::IsFullShapeNode(shape_node_1)) {
181+
DEBUG_LOG("Shape node in path 2 has non-default start/end attributes.");
182+
return false;
201183
}
202184

203185
// The gather node (with second input indices==1) is also shared by other subgraph

onnxruntime/core/optimizer/reshape_fusion.cc

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
// Licensed under the MIT License.
33

44
#include <algorithm>
5-
#include <limits>
65

76
#include "core/graph/graph_utils.h"
87
#include "core/optimizer/initializer.h"
@@ -173,19 +172,11 @@ bool ReshapeFusion::Match_One_Element_Output_Subgraph_1(Graph& graph, const Node
173172
const Node& gather = edges[1]->GetNode();
174173
const Node& shape = edges[2]->GetNode();
175174

176-
// Opset 15+ added optional start/end attributes to Shape, allowing it to return only a
177-
// subset of dimensions ("partial shape"). The fusion below assumes Shape returns the full
178-
// tensor shape so that Gather indices correspond directly to tensor dimensions. If start != 0
179-
// or end is set to something other than the default (INT64_MAX = all dims), the Gather index
180-
// semantics change and the fusion would produce incorrect results. Reject such cases.
181-
if (shape.SinceVersion() >= 15) {
182-
const ONNX_NAMESPACE::AttributeProto* start_attr = graph_utils::GetNodeAttribute(shape, "start");
183-
const ONNX_NAMESPACE::AttributeProto* end_attr = graph_utils::GetNodeAttribute(shape, "end");
184-
// end=INT64_MAX is the runtime default, meaning "all dimensions" (i.e. full shape).
185-
if (!((!start_attr || start_attr->i() == 0) &&
186-
(!end_attr || end_attr->i() == std::numeric_limits<int64_t>::max()))) {
187-
return false;
188-
}
175+
// The fusion assumes Shape returns the full tensor shape so that Gather indices correspond
176+
// directly to tensor dimensions. A partial shape (opset 15+ start/end attributes) would shift
177+
// the index mapping and produce incorrect results.
178+
if (!graph_utils::IsFullShapeNode(shape)) {
179+
return false;
189180
}
190181

191182
InlinedVector<int64_t> axes;

0 commit comments

Comments
 (0)