Skip to content

Commit 9f81b29

Browse files
add fuse helper
1 parent 1d7aea3 commit 9f81b29

2 files changed

Lines changed: 43 additions & 19 deletions

File tree

onnxscript/rewriter/_fusion_utils.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Callable, Sequence, Union
66

77
import onnxscript.ir as ir
8+
from onnxscript.ir.passes.common import shape_inference
89
from onnxscript.rewriter import pattern
910

1011
Dim = Union[int, ir.SymbolicDim]
@@ -26,11 +27,18 @@ def _check_shape(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str])
2627
def apply_fusion_rules(rules: pattern.RewriteRule | pattern.RewriteRuleSet) -> Callable:
2728
"""
2829
Apply the given fusion rules to the model and return the number of fusions applied.
29-
If debug is True, enable pattern matching tracer for debugging.
30+
31+
model: The input ONNX model represented as an `ir.Model`.
32+
debug: If debug is True, enable pattern matching tracer for debugging.
33+
apply_shape_inference: If True, apply shape inference after fusions.
3034
"""
3135

32-
def apply_to(model: ir.Model, debug: bool = False) -> int:
36+
def apply_to(
37+
model: ir.Model, debug: bool = False, apply_shape_inference: bool = False
38+
) -> int:
3339
count = rules.apply_to_model(model)
40+
if apply_shape_inference:
41+
model = shape_inference.infer_shapes(model)
3442
if count == 0 and debug:
3543
tracer = pattern.MatchingTracer()
3644
rules.apply_to_model(model, tracer=tracer)

onnxscript/rewriter/ort_fusions/_core.py

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,16 @@ def _pre_optimize(model: ir.Model) -> ir.Model:
5252
return model
5353

5454

55-
def fuse_xformers(model: ir.Model) -> tuple[ir.Model, dict[str, int]]:
55+
def fuse_xformers(
56+
model: ir.Model, debug: bool = False, apply_shape_inference: bool = False
57+
) -> tuple[ir.Model, dict[str, int]]:
5658
"""
5759
Apply transformer-specific fusions to the given model.
5860
5961
Args:
6062
model: The input ONNX model represented as an `ir.Model`.
63+
debug: If debug is True, enable pattern matching tracer for debugging.
64+
apply_shape_inference: If True, apply shape inference after fusions.
6165
6266
Returns:
6367
A tuple containing:
@@ -67,35 +71,42 @@ def fuse_xformers(model: ir.Model) -> tuple[ir.Model, dict[str, int]]:
6771
fusion_count = dict()
6872

6973
model = _pre_optimize(model)
70-
fusion_count["erf_gelu"] = fuse_erfgelu(model)
71-
fusion_count["rms_normalization"] = fuse_rms_normalization(model)
72-
fusion_count["skip_layer_normalization"] = fuse_skip_layer_normalization(model)
73-
fusion_count["skip_rms_normalization"] = fuse_skip_rms_normalization(model)
74-
fusion_count["rotary_embedding"] = fuse_rotary_embedding(model)
75-
fusion_count["partial_rotary_embedding"] = fuse_partial_rotary_embedding(model)
76-
fusion_count["cos_sin_cache"] = fuse_cos_sin_cache(model)
77-
fusion_count["sdpa"] = fuse_sdpa(model)
74+
75+
def fuse(func):
76+
return func(model, debug=debug, apply_shape_inference=apply_shape_inference)
77+
78+
fusion_count["erf_gelu"] = fuse(fuse_erfgelu)
79+
fusion_count["rms_normalization"] = fuse(fuse_rms_normalization)
80+
fusion_count["skip_layer_normalization"] = fuse(fuse_skip_layer_normalization)
81+
fusion_count["skip_rms_normalization"] = fuse(fuse_skip_rms_normalization)
82+
fusion_count["rotary_embedding"] = fuse(fuse_rotary_embedding)
83+
fusion_count["partial_rotary_embedding"] = fuse(fuse_partial_rotary_embedding)
84+
fusion_count["cos_sin_cache"] = fuse(fuse_cos_sin_cache)
85+
fusion_count["sdpa"] = fuse(fuse_sdpa)
7886
# Optimize to avoid trying multiple attention-based fusions
79-
fusion_count["mha"] = fuse_mha(model)
87+
fusion_count["mha"] = fuse(fuse_mha)
8088
if fusion_count["mha"] == 0:
8189
# If no MHA fusion was applied, we can try the GQA fusion.
8290
# and avoid trying the attention fusion.
83-
fusion_count["gqa"] = fuse_gqa(model)
84-
fusion_count["packed_qkv_for_gqa"] = fuse_qkv_gqa(model)
91+
fusion_count["gqa"] = fuse(fuse_gqa)
92+
fusion_count["packed_qkv_for_gqa"] = fuse(fuse_qkv_gqa)
8593
fusion_count["attention"] = 0
8694
else:
87-
fusion_count["attention"] = fuse_attention(model)
95+
fusion_count["attention"] = fuse(fuse_attention)
8896
fusion_count["gqa"] = 0
89-
fusion_count["gelu"] = fuse_gelu(model)
90-
fusion_count["bias_gelu"] = fuse_bias_gelu(model)
97+
fusion_count["gelu"] = fuse(fuse_gelu)
98+
fusion_count["bias_gelu"] = fuse(fuse_bias_gelu)
9199
# Finally: inline any intermediate fusion functions introduced that were not
92100
# consumed by other fusions, and eliminate any remaining unused nodes.
93101
optimize(model)
94102
return model, fusion_count
95103

96104

97105
def optimize_for_ort(
98-
model: ir.Model, config_name: str | None = None
106+
model: ir.Model,
107+
config_name: str | None = None,
108+
debug: bool = False,
109+
apply_shape_inference: bool = False,
99110
) -> tuple[ir.Model, dict[str, int]]:
100111
"""
101112
Optimize the model for ORT backend.
@@ -108,13 +119,18 @@ def optimize_for_ort(
108119
config_name: The name of the configuration to use for optimization.
109120
Typically it identifies the Execution Provider (EP) to optimize for.
110121
If None, the default configuration will be used.
122+
debug: If debug is True, enable pattern matching tracer for debugging.
123+
apply_shape_inference: If True, apply shape inference after fusions.
111124
112125
Returns:
113126
A tuple containing:
114127
- The optimized `ir.Model` after applying transformer-specific fusions.
115128
- A dictionary with a count of each of the fusions applied.
116129
"""
117130

118-
model, fusion_count = fuse_xformers(model)
131+
model, fusion_count = fuse_xformers(
132+
model, debug=debug, apply_shape_inference=apply_shape_inference
133+
)
134+
# Apply the ORT pattern rewrite rules.
119135
rewrite(model, ORT_PATTERN_REWRITE_RULES)
120136
return model, fusion_count

0 commit comments

Comments
 (0)