@@ -129,38 +129,70 @@ def pattern(self, op, input, skip, gamma, beta, bias, epsilon, stash_type):
129129 )
130130 return normalized , skip_sum
131131
132+ def check (
133+ self , op , input , skip , gamma , beta , bias , epsilon , stash_type
134+ ) -> pattern .MatchResult : # type: ignore[name-defined]
135+ """Check if the pattern matches conditions for use of SimplifiedLayerNormalization op."""
136+ check_result = pattern .MatchResult ()
137+ bindings : dict [str , Dim ] = {}
132138
133- def _skip_layer_normalization_add_bias (
134- op , input , skip , gamma , beta , bias , epsilon , stash_type
135- ):
136- normalized , _mean , _inv_std_var , skip_sum = op .SkipLayerNormalization (
137- input ,
138- skip ,
139- gamma ,
140- beta ,
141- bias ,
142- epsilon = epsilon ,
143- _outputs = 4 ,
144- _domain = "com.microsoft" ,
145- )
146- return normalized , skip_sum
139+ def no_match (val : ir .Value , dims : Sequence [str ]) -> bool :
140+ return not _fusion_utils ._check_shape (bindings , val , dims )
141+
142+ if no_match (input , ["B" , "S" , "D" ]):
143+ return check_result .fail (
144+ f"Shape mismatch: { input } does not match expected dimensions ['B', 'S', 'D']" ,
145+ input ,
146+ )
147+ if no_match (skip , ["B" , "S" , "D" ]):
148+ return check_result .fail (
149+ f"Shape mismatch: { skip } does not match expected dimensions ['B', 'S', 'D']" ,
150+ skip ,
151+ )
152+ if no_match (gamma , ["D" ]):
153+ return check_result .fail (
154+ f"Shape mismatch: { gamma } does not match expected dimensions ['D']" ,
155+ gamma ,
156+ )
157+ if no_match (beta , ["D" ]):
158+ return check_result .fail (
159+ f"Shape mismatch: { beta } does not match expected dimensions ['D']" ,
160+ beta ,
161+ )
162+ if self ._has_bias :
163+ if no_match (bias , ["D" ]):
164+ return check_result .fail (
165+ f"Shape mismatch: { bias } does not match expected dimensions ['D']" ,
166+ bias ,
167+ )
168+
169+ return check_result
170+
171+ def rewrite (self , op , input , skip , gamma , beta , bias , epsilon , stash_type ):
172+ normalized , _mean , _inv_std_var , skip_sum = op .SkipLayerNormalization (
173+ input ,
174+ skip ,
175+ gamma ,
176+ beta ,
177+ bias ,
178+ epsilon = epsilon ,
179+ _outputs = 4 ,
180+ _domain = "com.microsoft" ,
181+ )
182+ return normalized , skip_sum
147183
148184
149185_skip_layer_add_bias_rule = SkipLayerNormFusion .rule (
150186 "SkipLayerNormBias" , has_bias = True , bias_pre_add = False
151187)
152- _skip_layer_add_bias_rule = pattern .RewriteRule (
153- _skip_layer_norm_add_bias_pattern ,
154- _skip_layer_normalization_add_bias ,
155- name = "SkipLayerNormAddBias" ,
188+ _skip_layer_pre_add_bias_rule = SkipLayerNormFusion .rule (
189+ "SkipLayerNormPreBias" , has_bias = True , bias_pre_add = True
156190)
191+ _skip_layer_rule = SkipLayerNormFusion .rule ("SkipLayerNorm" , has_bias = False )
157192
158-
159- skip_layer_normalization_rules = [_skip_layer_rule , _skip_layer_add_bias_rule ]
160- skip_layer_normalization_ruleset = pattern .RewriteRuleSet (skip_layer_normalization_rules )
161-
162-
163- fuse_skip_rms_normalization = _fusion_utils .apply_fusion_rules (skip_rms_normalization_ruleset )
193+ skip_layer_normalization_ruleset = pattern .RewriteRuleSet (
194+ [_skip_layer_pre_add_bias_rule , _skip_layer_add_bias_rule , _skip_layer_rule ]
195+ )
164196
165197
166198fuse_skip_layer_normalization = _fusion_utils .apply_fusion_rules (
0 commit comments