@@ -52,12 +52,16 @@ def _pre_optimize(model: ir.Model) -> ir.Model:
5252 return model
5353
5454
55- def fuse_xformers (model : ir .Model ) -> tuple [ir .Model , dict [str , int ]]:
55+ def fuse_xformers (
56+ model : ir .Model , debug : bool = False , apply_shape_inference : bool = False
57+ ) -> tuple [ir .Model , dict [str , int ]]:
5658 """
5759 Apply transformer-specific fusions to the given model.
5860
5961 Args:
6062 model: The input ONNX model represented as an `ir.Model`.
63+ debug: If debug is True, enable pattern matching tracer for debugging.
64+ apply_shape_inference: If True, apply shape inference after fusions.
6165
6266 Returns:
6367 A tuple containing:
@@ -67,35 +71,42 @@ def fuse_xformers(model: ir.Model) -> tuple[ir.Model, dict[str, int]]:
6771 fusion_count = dict ()
6872
6973 model = _pre_optimize (model )
70- fusion_count ["erf_gelu" ] = fuse_erfgelu (model )
71- fusion_count ["rms_normalization" ] = fuse_rms_normalization (model )
72- fusion_count ["skip_layer_normalization" ] = fuse_skip_layer_normalization (model )
73- fusion_count ["skip_rms_normalization" ] = fuse_skip_rms_normalization (model )
74- fusion_count ["rotary_embedding" ] = fuse_rotary_embedding (model )
75- fusion_count ["partial_rotary_embedding" ] = fuse_partial_rotary_embedding (model )
76- fusion_count ["cos_sin_cache" ] = fuse_cos_sin_cache (model )
77- fusion_count ["sdpa" ] = fuse_sdpa (model )
74+
75+ def fuse (func ):
76+ return func (model , debug = debug , apply_shape_inference = apply_shape_inference )
77+
78+ fusion_count ["erf_gelu" ] = fuse (fuse_erfgelu )
79+ fusion_count ["rms_normalization" ] = fuse (fuse_rms_normalization )
80+ fusion_count ["skip_layer_normalization" ] = fuse (fuse_skip_layer_normalization )
81+ fusion_count ["skip_rms_normalization" ] = fuse (fuse_skip_rms_normalization )
82+ fusion_count ["rotary_embedding" ] = fuse (fuse_rotary_embedding )
83+ fusion_count ["partial_rotary_embedding" ] = fuse (fuse_partial_rotary_embedding )
84+ fusion_count ["cos_sin_cache" ] = fuse (fuse_cos_sin_cache )
85+ fusion_count ["sdpa" ] = fuse (fuse_sdpa )
7886 # Optimize to avoid trying multiple attention-based fusions
79- fusion_count ["mha" ] = fuse_mha ( model )
87+ fusion_count ["mha" ] = fuse ( fuse_mha )
8088 if fusion_count ["mha" ] == 0 :
8189 # If no MHA fusion was applied, we can try the GQA fusion.
8290 # and avoid trying the attention fusion.
83- fusion_count ["gqa" ] = fuse_gqa ( model )
84- fusion_count ["packed_qkv_for_gqa" ] = fuse_qkv_gqa ( model )
91+ fusion_count ["gqa" ] = fuse ( fuse_gqa )
92+ fusion_count ["packed_qkv_for_gqa" ] = fuse ( fuse_qkv_gqa )
8593 fusion_count ["attention" ] = 0
8694 else :
87- fusion_count ["attention" ] = fuse_attention ( model )
95+ fusion_count ["attention" ] = fuse ( fuse_attention )
8896 fusion_count ["gqa" ] = 0
89- fusion_count ["gelu" ] = fuse_gelu ( model )
90- fusion_count ["bias_gelu" ] = fuse_bias_gelu ( model )
97+ fusion_count ["gelu" ] = fuse ( fuse_gelu )
98+ fusion_count ["bias_gelu" ] = fuse ( fuse_bias_gelu )
9199 # Finally: inline any intermediate fusion functions introduced that were not
92100 # consumed by other fusions, and eliminate any remaining unused nodes.
93101 optimize (model )
94102 return model , fusion_count
95103
96104
97105def optimize_for_ort (
98- model : ir .Model , config_name : str | None = None
106+ model : ir .Model ,
107+ config_name : str | None = None ,
108+ debug : bool = False ,
109+ apply_shape_inference : bool = False ,
99110) -> tuple [ir .Model , dict [str , int ]]:
100111 """
101112 Optimize the model for ORT backend.
@@ -108,13 +119,18 @@ def optimize_for_ort(
108119 config_name: The name of the configuration to use for optimization.
109120 Typically it identifies the Execution Provider (EP) to optimize for.
110121 If None, the default configuration will be used.
122+ debug: If debug is True, enable pattern matching tracer for debugging.
123+ apply_shape_inference: If True, apply shape inference after fusions.
111124
112125 Returns:
113126 A tuple containing:
114127 - The optimized `ir.Model` after applying transformer-specific fusions.
115128 - A dictionary with a count of each of the fusions applied.
116129 """
117130
118- model , fusion_count = fuse_xformers (model )
131+ model , fusion_count = fuse_xformers (
132+ model , debug = debug , apply_shape_inference = apply_shape_inference
133+ )
134+ # Apply the ORT pattern rewrite rules.
119135 rewrite (model , ORT_PATTERN_REWRITE_RULES )
120136 return model , fusion_count
0 commit comments