diff --git a/onnxscript/rewriter/ort_fusions/_core.py b/onnxscript/rewriter/ort_fusions/_core.py index a72b107eea..52deb6c1b0 100644 --- a/onnxscript/rewriter/ort_fusions/_core.py +++ b/onnxscript/rewriter/ort_fusions/_core.py @@ -13,7 +13,9 @@ softmax, ) from onnxscript.rewriter.ort_fusions.attention import fuse_attention +from onnxscript.rewriter.ort_fusions.bias_gelu import fuse_bias_gelu from onnxscript.rewriter.ort_fusions.cos_sin_cache import fuse_cos_sin_cache +from onnxscript.rewriter.ort_fusions.erfgelu import fuse_erfgelu from onnxscript.rewriter.ort_fusions.fuse_packed_qkv_gqa import fuse_qkv_gqa from onnxscript.rewriter.ort_fusions.gelu import fuse_gelu from onnxscript.rewriter.ort_fusions.gqa import fuse_gqa @@ -65,6 +67,7 @@ def fuse_xformers(model: ir.Model) -> tuple[ir.Model, dict[str, int]]: fusion_count = dict() model = _pre_optimize(model) + fusion_count["erf_gelu"] = fuse_erfgelu(model) fusion_count["rms_normalization"] = fuse_rms_normalization(model) fusion_count["skip_layer_normalization"] = fuse_skip_layer_normalization(model) fusion_count["skip_rms_normalization"] = fuse_skip_rms_normalization(model) @@ -84,6 +87,7 @@ def fuse_xformers(model: ir.Model) -> tuple[ir.Model, dict[str, int]]: fusion_count["attention"] = fuse_attention(model) fusion_count["gqa"] = 0 fusion_count["gelu"] = fuse_gelu(model) + fusion_count["bias_gelu"] = fuse_bias_gelu(model) # Finally: inline any intermediate fusion functions introduced that were not # consumed by other fusions, and eliminate any remaining unused nodes. optimize(model) diff --git a/onnxscript/rewriter/ort_fusions/_test_utils.py b/onnxscript/rewriter/ort_fusions/_test_utils.py index e1a6be338d..4181fffbf4 100644 --- a/onnxscript/rewriter/ort_fusions/_test_utils.py +++ b/onnxscript/rewriter/ort_fusions/_test_utils.py @@ -33,7 +33,7 @@ def ort_run(model_name: str, model, inputs): return session.run(None, inputs) -def assert_allclose(outputs, expected_outputs, rtol=1e-4, atol=1e-4): +def assert_allclose(outputs, expected_outputs, rtol=1e-3, atol=1e-3): for i, (baseline_output, optimized_output) in enumerate(zip(expected_outputs, outputs)): try: np.testing.assert_equal(baseline_output.shape, optimized_output.shape) diff --git a/onnxscript/rewriter/ort_fusions/bias_gelu.py b/onnxscript/rewriter/ort_fusions/bias_gelu.py new file mode 100644 index 0000000000..472e3be167 --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/bias_gelu.py @@ -0,0 +1,22 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +from onnxscript.rewriter import _fusion_utils, pattern + + +class BiasGeluFusion(pattern.RewriteRuleClassBase): + def pattern(self, op, x, y): + gelu_add = op.Add(x, y) + return op.Gelu(gelu_add, _domain="com.microsoft") + + def rewrite(self, op, x, y): + return op.BiasGelu(x, y, _domain="com.microsoft") + + +_rule = BiasGeluFusion.rule() + +bias_gelu_rules = pattern.RewriteRuleSet([_rule]) + + +fuse_bias_gelu = _fusion_utils.apply_fusion_rules(bias_gelu_rules) diff --git a/onnxscript/rewriter/ort_fusions/bias_gelu_test.py b/onnxscript/rewriter/ort_fusions/bias_gelu_test.py new file mode 100644 index 0000000000..ce8c08cf4f --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/bias_gelu_test.py @@ -0,0 +1,52 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import unittest + +import numpy as np + +import onnxscript +import onnxscript.ir as ir +import onnxscript.rewriter.ort_fusions._test_utils as test_utils +from onnxscript import FLOAT, script +from onnxscript import opset18 as op +from onnxscript.optimizer import optimize, remove_unused_nodes +from onnxscript.rewriter.ort_fusions.bias_gelu import fuse_bias_gelu + +msft_op = onnxscript.values.Opset("com.microsoft", 1) + + +class BiasGeluFusionTest(unittest.TestCase): + def test_bias_gelu_fusion(self): + @script() + def bias_gelu_model(x, y): + gelu_add = op.Add(x, y) + gelu = msft_op.Gelu(gelu_add) + return gelu + + model_proto = bias_gelu_model.to_model_proto( + input_types=[FLOAT[10], FLOAT[10]], + output_types=[FLOAT[10]], + ir_version=10, + ) + model = ir.serde.deserialize_model(model_proto) + optimize(model) + + input = { + "x": np.random.randn(10).astype(np.float32), + "y": np.random.randn(10).astype(np.float32), + } + original_output = test_utils.ort_run("Original", model, input) + + fuse_bias_gelu(model) + remove_unused_nodes(model) + + self.assertEqual(len(model.graph), 1) + self.assertEqual(model.graph.node(0).op_type, "BiasGelu") + + optimized_output = test_utils.ort_run("Optimized", model, input) + test_utils.assert_allclose(original_output, optimized_output) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/erfgelu.py b/onnxscript/rewriter/ort_fusions/erfgelu.py similarity index 61% rename from onnxscript/rewriter/erfgelu.py rename to onnxscript/rewriter/ort_fusions/erfgelu.py index c821a79b3b..ba515a5572 100644 --- a/onnxscript/rewriter/erfgelu.py +++ b/onnxscript/rewriter/ort_fusions/erfgelu.py @@ -2,11 +2,11 @@ # Licensed under the MIT License. import math -from onnxscript.rewriter import pattern +from onnxscript.rewriter import _fusion_utils, pattern # Pattern to match against -def erf_gelu_pattern(op, x): +def erf_gelu_pattern_1(op, x): # erf_gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2))) # half = pattern.Constant(0.5) # sqrt2 = pattern.Constant(1.4142) @@ -19,9 +19,18 @@ def erf_gelu_pattern(op, x): return 0.5 * (x * (op.Erf(x / math.sqrt(2)) + 1.0)) +def erf_gelu_pattern_2(op, x): + return x * (0.5 * (op.Erf(x / math.sqrt(2)) + 1.0)) + + # Replacement def gelu(op, x): return op.Gelu(x, _domain="com.microsoft") -rule = pattern.RewriteRule(erf_gelu_pattern, gelu) +rule1 = pattern.RewriteRule(erf_gelu_pattern_1, gelu) +rule2 = pattern.RewriteRule(erf_gelu_pattern_2, gelu) + +rules = pattern.RewriteRuleSet([rule1, rule2]) + +fuse_erfgelu = _fusion_utils.apply_fusion_rules(rules) diff --git a/onnxscript/rewriter/ort_fusions/skip_normalization.py b/onnxscript/rewriter/ort_fusions/skip_normalization.py index 9ae731d3d0..d4eca4c45d 100644 --- a/onnxscript/rewriter/ort_fusions/skip_normalization.py +++ b/onnxscript/rewriter/ort_fusions/skip_normalization.py @@ -47,27 +47,66 @@ def _skip_layer_norm_pattern(op, input, skip, gamma, beta, epsilon, stash_type): epsilon=epsilon, stash_type=stash_type, ) - return normalized + return normalized, skip_sum def _skip_layer_normalization(op, input, skip, gamma, beta, epsilon, stash_type): if stash_type.value != 1: # FLOAT type return None - normalized, _mean, _inv_std_var = op.SkipLayerNormalization( + normalized, _mean, _inv_std_var, skip_sum = op.SkipLayerNormalization( input, skip, gamma, beta, epsilon=epsilon, - _outputs=3, + _outputs=4, + _domain="com.microsoft", + ) + return normalized, skip_sum + + +# Fusion rule for Add + SkipLayerNormalization +def _skip_layer_norm_add_bias_pattern(op, input, skip, gamma, beta, bias, epsilon, stash_type): + bias_sum = op.Add(input, bias) + normalized, _mean, _inv_std_var, skip_sum = op.SkipLayerNormalization( + bias_sum, + skip, + gamma, + beta, + epsilon=epsilon, + _outputs=4, _domain="com.microsoft", ) - return normalized + return normalized, skip_sum + +def _skip_layer_normalization_add_bias( + op, input, skip, gamma, beta, bias, epsilon, stash_type +): + normalized, _mean, _inv_std_var, skip_sum = op.SkipLayerNormalization( + input, + skip, + gamma, + beta, + bias, + epsilon=epsilon, + _outputs=4, + _domain="com.microsoft", + ) + return normalized, skip_sum + + +_skip_layer_rule = pattern.RewriteRule( + _skip_layer_norm_pattern, _skip_layer_normalization, name="SkipLayerNorm" +) +_skip_layer_add_bias_rule = pattern.RewriteRule( + _skip_layer_norm_add_bias_pattern, + _skip_layer_normalization_add_bias, + name="SkipLayerNormAddBias", +) -_skip_layer_rule = pattern.RewriteRule(_skip_layer_norm_pattern, _skip_layer_normalization) -skip_layer_normalization_rules = [_skip_layer_rule] +skip_layer_normalization_rules = [_skip_layer_rule, _skip_layer_add_bias_rule] skip_layer_normalization_ruleset = pattern.RewriteRuleSet(skip_layer_normalization_rules)