Skip to content

Commit 2cd75a7

Browse files
make shape infer flag internal
1 parent 9f81b29 commit 2cd75a7

1 file changed

Lines changed: 5 additions & 9 deletions

File tree

onnxscript/rewriter/ort_fusions/_core.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,16 +52,13 @@ def _pre_optimize(model: ir.Model) -> ir.Model:
5252
return model
5353

5454

55-
def fuse_xformers(
56-
model: ir.Model, debug: bool = False, apply_shape_inference: bool = False
57-
) -> tuple[ir.Model, dict[str, int]]:
55+
def fuse_xformers(model: ir.Model, debug: bool = False) -> tuple[ir.Model, dict[str, int]]:
5856
"""
5957
Apply transformer-specific fusions to the given model.
6058
6159
Args:
6260
model: The input ONNX model represented as an `ir.Model`.
6361
debug: If debug is True, enable pattern matching tracer for debugging.
64-
apply_shape_inference: If True, apply shape inference after fusions.
6562
6663
Returns:
6764
A tuple containing:
@@ -72,7 +69,7 @@ def fuse_xformers(
7269

7370
model = _pre_optimize(model)
7471

75-
def fuse(func):
72+
def fuse(func, apply_shape_inference: bool = False):
7673
return func(model, debug=debug, apply_shape_inference=apply_shape_inference)
7774

7875
fusion_count["erf_gelu"] = fuse(fuse_erfgelu)
@@ -82,7 +79,7 @@ def fuse(func):
8279
fusion_count["rotary_embedding"] = fuse(fuse_rotary_embedding)
8380
fusion_count["partial_rotary_embedding"] = fuse(fuse_partial_rotary_embedding)
8481
fusion_count["cos_sin_cache"] = fuse(fuse_cos_sin_cache)
85-
fusion_count["sdpa"] = fuse(fuse_sdpa)
82+
fusion_count["sdpa"] = fuse(fuse_sdpa, apply_shape_inference=True)
8683
# Optimize to avoid trying multiple attention-based fusions
8784
fusion_count["mha"] = fuse(fuse_mha)
8885
if fusion_count["mha"] == 0:
@@ -106,7 +103,6 @@ def optimize_for_ort(
106103
model: ir.Model,
107104
config_name: str | None = None,
108105
debug: bool = False,
109-
apply_shape_inference: bool = False,
110106
) -> tuple[ir.Model, dict[str, int]]:
111107
"""
112108
Optimize the model for ORT backend.
@@ -120,7 +116,6 @@ def optimize_for_ort(
120116
Typically it identifies the Execution Provider (EP) to optimize for.
121117
If None, the default configuration will be used.
122118
debug: If debug is True, enable pattern matching tracer for debugging.
123-
apply_shape_inference: If True, apply shape inference after fusions.
124119
125120
Returns:
126121
A tuple containing:
@@ -129,7 +124,8 @@ def optimize_for_ort(
129124
"""
130125

131126
model, fusion_count = fuse_xformers(
132-
model, debug=debug, apply_shape_inference=apply_shape_inference
127+
model,
128+
debug=debug,
133129
)
134130
# Apply the ORT pattern rewrite rules.
135131
rewrite(model, ORT_PATTERN_REWRITE_RULES)

0 commit comments

Comments
 (0)