@@ -52,16 +52,13 @@ def _pre_optimize(model: ir.Model) -> ir.Model:
5252 return model
5353
5454
55- def fuse_xformers (
56- model : ir .Model , debug : bool = False , apply_shape_inference : bool = False
57- ) -> tuple [ir .Model , dict [str , int ]]:
55+ def fuse_xformers (model : ir .Model , debug : bool = False ) -> tuple [ir .Model , dict [str , int ]]:
5856 """
5957 Apply transformer-specific fusions to the given model.
6058
6159 Args:
6260 model: The input ONNX model represented as an `ir.Model`.
6361 debug: If debug is True, enable pattern matching tracer for debugging.
64- apply_shape_inference: If True, apply shape inference after fusions.
6562
6663 Returns:
6764 A tuple containing:
@@ -72,7 +69,7 @@ def fuse_xformers(
7269
7370 model = _pre_optimize (model )
7471
75- def fuse (func ):
72+ def fuse (func , apply_shape_inference : bool = False ):
7673 return func (model , debug = debug , apply_shape_inference = apply_shape_inference )
7774
7875 fusion_count ["erf_gelu" ] = fuse (fuse_erfgelu )
@@ -82,7 +79,7 @@ def fuse(func):
8279 fusion_count ["rotary_embedding" ] = fuse (fuse_rotary_embedding )
8380 fusion_count ["partial_rotary_embedding" ] = fuse (fuse_partial_rotary_embedding )
8481 fusion_count ["cos_sin_cache" ] = fuse (fuse_cos_sin_cache )
85- fusion_count ["sdpa" ] = fuse (fuse_sdpa )
82+ fusion_count ["sdpa" ] = fuse (fuse_sdpa , apply_shape_inference = True )
8683 # Optimize to avoid trying multiple attention-based fusions
8784 fusion_count ["mha" ] = fuse (fuse_mha )
8885 if fusion_count ["mha" ] == 0 :
@@ -106,7 +103,6 @@ def optimize_for_ort(
106103 model : ir .Model ,
107104 config_name : str | None = None ,
108105 debug : bool = False ,
109- apply_shape_inference : bool = False ,
110106) -> tuple [ir .Model , dict [str , int ]]:
111107 """
112108 Optimize the model for ORT backend.
@@ -120,7 +116,6 @@ def optimize_for_ort(
120116 Typically it identifies the Execution Provider (EP) to optimize for.
121117 If None, the default configuration will be used.
122118 debug: If debug is True, enable pattern matching tracer for debugging.
123- apply_shape_inference: If True, apply shape inference after fusions.
124119
125120 Returns:
126121 A tuple containing:
@@ -129,7 +124,8 @@ def optimize_for_ort(
129124 """
130125
131126 model , fusion_count = fuse_xformers (
132- model , debug = debug , apply_shape_inference = apply_shape_inference
127+ model ,
128+ debug = debug ,
133129 )
134130 # Apply the ORT pattern rewrite rules.
135131 rewrite (model , ORT_PATTERN_REWRITE_RULES )
0 commit comments