|
2 | 2 | // Licensed under the MIT License. |
3 | 3 | #include "core/optimizer/embed_layer_norm_fusion.h" |
4 | 4 |
|
5 | | -#include <limits> |
6 | | - |
7 | 5 | #include "core/common/span_utils.h" |
8 | 6 | #include "core/optimizer/initializer.h" |
9 | 7 | #include "core/graph/contrib_ops/contrib_defs.h" |
@@ -140,20 +138,12 @@ static bool MatchInputToConcatSubgraph( |
140 | 138 | } |
141 | 139 | } |
142 | 140 |
|
143 | | - // Opset 15+ added optional start/end attributes to Shape, allowing it to return only a |
144 | | - // subset of dimensions ("partial shape"). The Gather(index=0) below assumes Shape returns |
145 | | - // the full tensor shape. If start != 0 or end is not the default (INT64_MAX), the Gather |
146 | | - // would pick the wrong dimension and fusion would be incorrect. Reject such cases. |
| 141 | + // The Gather(index=0) below assumes Shape returns the full tensor shape. A partial shape |
| 142 | + // (opset 15+ start/end attributes) would cause Gather to pick the wrong dimension. |
147 | 143 | const Node& shape_node_path1 = edges[shape_index]->GetNode(); |
148 | | - if (shape_node_path1.SinceVersion() >= 15) { |
149 | | - const ONNX_NAMESPACE::AttributeProto* start_attr = graph_utils::GetNodeAttribute(shape_node_path1, "start"); |
150 | | - const ONNX_NAMESPACE::AttributeProto* end_attr = graph_utils::GetNodeAttribute(shape_node_path1, "end"); |
151 | | - // end=INT64_MAX is the runtime default, meaning "all dimensions" (i.e. full shape). |
152 | | - if (!((!start_attr || start_attr->i() == 0) && |
153 | | - (!end_attr || end_attr->i() == std::numeric_limits<int64_t>::max()))) { |
154 | | - DEBUG_LOG("Shape node in path 1 has non-default start/end attributes."); |
155 | | - return false; |
156 | | - } |
| 144 | + if (!graph_utils::IsFullShapeNode(shape_node_path1)) { |
| 145 | + DEBUG_LOG("Shape node in path 1 has non-default start/end attributes."); |
| 146 | + return false; |
157 | 147 | } |
158 | 148 |
|
159 | 149 | Node& concat_node = *graph.GetNode(edges[0]->GetNode().Index()); |
@@ -185,19 +175,11 @@ static bool MatchInputToConcatSubgraph( |
185 | 175 | Node& gather_node_1 = *graph.GetNode(edges[1]->GetNode().Index()); |
186 | 176 | Node& shape_node_1 = *graph.GetNode(edges[2]->GetNode().Index()); |
187 | 177 |
|
188 | | - // Opset 15+ added optional start/end attributes to Shape, allowing it to return only a |
189 | | - // subset of dimensions ("partial shape"). The Gather(index=1) below assumes Shape returns |
190 | | - // the full tensor shape. If start != 0 or end is not the default (INT64_MAX), the Gather |
191 | | - // would pick the wrong dimension and fusion would be incorrect. Reject such cases. |
192 | | - if (shape_node_1.SinceVersion() >= 15) { |
193 | | - const ONNX_NAMESPACE::AttributeProto* start_attr = graph_utils::GetNodeAttribute(shape_node_1, "start"); |
194 | | - const ONNX_NAMESPACE::AttributeProto* end_attr = graph_utils::GetNodeAttribute(shape_node_1, "end"); |
195 | | - // end=INT64_MAX is the runtime default, meaning "all dimensions" (i.e. full shape). |
196 | | - if (!((!start_attr || start_attr->i() == 0) && |
197 | | - (!end_attr || end_attr->i() == std::numeric_limits<int64_t>::max()))) { |
198 | | - DEBUG_LOG("Shape node in path 2 has non-default start/end attributes."); |
199 | | - return false; |
200 | | - } |
| 178 | + // The Gather(index=1) below assumes Shape returns the full tensor shape. A partial shape |
| 179 | + // (opset 15+ start/end attributes) would cause Gather to pick the wrong dimension. |
| 180 | + if (!graph_utils::IsFullShapeNode(shape_node_1)) { |
| 181 | + DEBUG_LOG("Shape node in path 2 has non-default start/end attributes."); |
| 182 | + return false; |
201 | 183 | } |
202 | 184 |
|
203 | 185 | // The gather node (with second input indices==1) is also shared by other subgraph |
|
0 commit comments