Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions onnxscript/rewriter/_fusion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Callable, Sequence, Union

import onnxscript.ir as ir
from onnxscript.ir.passes.common import shape_inference
from onnxscript.rewriter import pattern

Dim = Union[int, ir.SymbolicDim]
Expand All @@ -26,11 +27,18 @@ def _check_shape(bindings: dict[str, Dim], val: ir.Value, shape: Sequence[str])
def apply_fusion_rules(rules: pattern.RewriteRule | pattern.RewriteRuleSet) -> Callable:
"""
Apply the given fusion rules to the model and return the number of fusions applied.
If debug is True, enable pattern matching tracer for debugging.

model: The input ONNX model represented as an `ir.Model`.
debug: If debug is True, enable pattern matching tracer for debugging.
apply_shape_inference: If True, apply shape inference after fusions.
"""

def apply_to(model: ir.Model, debug: bool = False) -> int:
def apply_to(
model: ir.Model, debug: bool = False, apply_shape_inference: bool = False
) -> int:
count = rules.apply_to_model(model)
if apply_shape_inference:
shape_inference.infer_shapes(model)
if count == 0 and debug:
tracer = pattern.MatchingTracer()
rules.apply_to_model(model, tracer=tracer)
Expand Down
49 changes: 31 additions & 18 deletions onnxscript/rewriter/ort_fusions/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,18 @@ def _pre_optimize(model: ir.Model) -> ir.Model:
# TODO: Do we need this dependence on ONNX's partial-data-propagation? There are some
# extra shape-propagation and partial-data-propagation rules in ONNX that are not yet
# incorporated in our optimizer.
model = shape_inference.infer_shapes(model)
shape_inference.infer_shapes(model)
optimize(model)
return model


def fuse_xformers(model: ir.Model) -> tuple[ir.Model, dict[str, int]]:
def fuse_xformers(model: ir.Model, debug: bool = False) -> tuple[ir.Model, dict[str, int]]:
"""
Apply transformer-specific fusions to the given model.

Args:
model: The input ONNX model represented as an `ir.Model`.
debug: If debug is True, enable pattern matching tracer for debugging.

Returns:
A tuple containing:
Expand All @@ -67,35 +68,42 @@ def fuse_xformers(model: ir.Model) -> tuple[ir.Model, dict[str, int]]:
fusion_count = dict()

model = _pre_optimize(model)
fusion_count["erf_gelu"] = fuse_erfgelu(model)
fusion_count["rms_normalization"] = fuse_rms_normalization(model)
fusion_count["skip_layer_normalization"] = fuse_skip_layer_normalization(model)
fusion_count["skip_rms_normalization"] = fuse_skip_rms_normalization(model)
fusion_count["rotary_embedding"] = fuse_rotary_embedding(model)
fusion_count["partial_rotary_embedding"] = fuse_partial_rotary_embedding(model)
fusion_count["cos_sin_cache"] = fuse_cos_sin_cache(model)
fusion_count["sdpa"] = fuse_sdpa(model)

def fuse(func, apply_shape_inference: bool = False):
return func(model, debug=debug, apply_shape_inference=apply_shape_inference)

fusion_count["erf_gelu"] = fuse(fuse_erfgelu)
fusion_count["rms_normalization"] = fuse(fuse_rms_normalization)
fusion_count["skip_layer_normalization"] = fuse(fuse_skip_layer_normalization)
fusion_count["skip_rms_normalization"] = fuse(fuse_skip_rms_normalization)
fusion_count["rotary_embedding"] = fuse(fuse_rotary_embedding)
fusion_count["partial_rotary_embedding"] = fuse(fuse_partial_rotary_embedding)
fusion_count["cos_sin_cache"] = fuse(fuse_cos_sin_cache)
fusion_count["sdpa"] = fuse(fuse_sdpa, apply_shape_inference=True)
# Optimize to avoid trying multiple attention-based fusions
fusion_count["mha"] = fuse_mha(model)
fusion_count["mha"] = fuse(fuse_mha)
if fusion_count["mha"] == 0:
# If no MHA fusion was applied, we can try the GQA fusion.
# and avoid trying the attention fusion.
fusion_count["gqa"] = fuse_gqa(model)
fusion_count["packed_qkv_for_gqa"] = fuse_qkv_gqa(model)
fusion_count["gqa"] = fuse(fuse_gqa)
fusion_count["packed_qkv_for_gqa"] = fuse(fuse_qkv_gqa)
fusion_count["attention"] = 0
else:
fusion_count["attention"] = fuse_attention(model)
fusion_count["attention"] = fuse(fuse_attention)
fusion_count["gqa"] = 0
fusion_count["gelu"] = fuse_gelu(model)
fusion_count["bias_gelu"] = fuse_bias_gelu(model)
fusion_count["gelu"] = fuse(fuse_gelu)
fusion_count["bias_gelu"] = fuse(fuse_bias_gelu)
# Finally: inline any intermediate fusion functions introduced that were not
# consumed by other fusions, and eliminate any remaining unused nodes.
optimize(model)
return model, fusion_count


def optimize_for_ort(
model: ir.Model, config_name: str | None = None
model: ir.Model,
config_name: str | None = None,
Comment thread
shubhambhokare1 marked this conversation as resolved.
*,
debug: bool = False,
) -> tuple[ir.Model, dict[str, int]]:
"""
Optimize the model for ORT backend.
Expand All @@ -108,13 +116,18 @@ def optimize_for_ort(
config_name: The name of the configuration to use for optimization.
Typically it identifies the Execution Provider (EP) to optimize for.
If None, the default configuration will be used.
debug: If debug is True, enable pattern matching tracer for debugging.

Returns:
A tuple containing:
- The optimized `ir.Model` after applying transformer-specific fusions.
- A dictionary with a count of each of the fusions applied.
"""

model, fusion_count = fuse_xformers(model)
model, fusion_count = fuse_xformers(
model,
debug=debug,
)
# Apply the ORT pattern rewrite rules.
rewrite(model, ORT_PATTERN_REWRITE_RULES)
return model, fusion_count
Loading