From 803d565a54c424d54ad7eedaf277dae1b10fa239 Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Wed, 23 Apr 2025 21:03:29 +0000 Subject: [PATCH 1/5] add additional gelu fusions --- onnxscript/rewriter/__init__.py | 2 + onnxscript/rewriter/erfgelu.py | 11 +++- onnxscript/rewriter/ort_fusions/_core.py | 2 + onnxscript/rewriter/ort_fusions/bias_gelu.py | 22 +++++++ .../rewriter/ort_fusions/bias_gelu_test.py | 62 +++++++++++++++++++ 5 files changed, 97 insertions(+), 2 deletions(-) create mode 100644 onnxscript/rewriter/ort_fusions/bias_gelu.py create mode 100644 onnxscript/rewriter/ort_fusions/bias_gelu_test.py diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index 5efaf784b0..013b889f3f 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -18,6 +18,7 @@ broadcast_to_matmul, cast_constant_of_shape, collapse_slices, + erfgelu, gemm_to_matmul_add, llama_rule_sets, no_op, @@ -31,6 +32,7 @@ gemm_to_matmul_add.rule, # type: ignore[has-type] *cast_constant_of_shape.rules.rules, *collapse_slices.rules.rules, + *erfgelu.rules.rules, *llama_rule_sets.llama_p0_rule_set().rules, ) diff --git a/onnxscript/rewriter/erfgelu.py b/onnxscript/rewriter/erfgelu.py index c821a79b3b..72adf9928b 100644 --- a/onnxscript/rewriter/erfgelu.py +++ b/onnxscript/rewriter/erfgelu.py @@ -6,7 +6,7 @@ # Pattern to match against -def erf_gelu_pattern(op, x): +def erf_gelu_pattern_1(op, x): # erf_gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2))) # half = pattern.Constant(0.5) # sqrt2 = pattern.Constant(1.4142) @@ -19,9 +19,16 @@ def erf_gelu_pattern(op, x): return 0.5 * (x * (op.Erf(x / math.sqrt(2)) + 1.0)) +def erf_gelu_pattern_2(op, x): + return x * (0.5 * (op.Erf(x / math.sqrt(2)) + 1.0)) + + # Replacement def gelu(op, x): return op.Gelu(x, _domain="com.microsoft") -rule = pattern.RewriteRule(erf_gelu_pattern, gelu) +rule1 = pattern.RewriteRule(erf_gelu_pattern_1, gelu) +rule2 = pattern.RewriteRule(erf_gelu_pattern_2, gelu) + +rules = pattern.RewriteRuleSet([rule1, rule2]) diff --git a/onnxscript/rewriter/ort_fusions/_core.py b/onnxscript/rewriter/ort_fusions/_core.py index a72b107eea..cdb3489e07 100644 --- a/onnxscript/rewriter/ort_fusions/_core.py +++ b/onnxscript/rewriter/ort_fusions/_core.py @@ -13,6 +13,7 @@ softmax, ) from onnxscript.rewriter.ort_fusions.attention import fuse_attention +from onnxscript.rewriter.ort_fusions.bias_gelu import fuse_bias_gelu from onnxscript.rewriter.ort_fusions.cos_sin_cache import fuse_cos_sin_cache from onnxscript.rewriter.ort_fusions.fuse_packed_qkv_gqa import fuse_qkv_gqa from onnxscript.rewriter.ort_fusions.gelu import fuse_gelu @@ -84,6 +85,7 @@ def fuse_xformers(model: ir.Model) -> tuple[ir.Model, dict[str, int]]: fusion_count["attention"] = fuse_attention(model) fusion_count["gqa"] = 0 fusion_count["gelu"] = fuse_gelu(model) + fusion_count["bias_gelu"] = fuse_bias_gelu(model) # Finally: inline any intermediate fusion functions introduced that were not # consumed by other fusions, and eliminate any remaining unused nodes. optimize(model) diff --git a/onnxscript/rewriter/ort_fusions/bias_gelu.py b/onnxscript/rewriter/ort_fusions/bias_gelu.py new file mode 100644 index 0000000000..472e3be167 --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/bias_gelu.py @@ -0,0 +1,22 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +from onnxscript.rewriter import _fusion_utils, pattern + + +class BiasGeluFusion(pattern.RewriteRuleClassBase): + def pattern(self, op, x, y): + gelu_add = op.Add(x, y) + return op.Gelu(gelu_add, _domain="com.microsoft") + + def rewrite(self, op, x, y): + return op.BiasGelu(x, y, _domain="com.microsoft") + + +_rule = BiasGeluFusion.rule() + +bias_gelu_rules = pattern.RewriteRuleSet([_rule]) + + +fuse_bias_gelu = _fusion_utils.apply_fusion_rules(bias_gelu_rules) diff --git a/onnxscript/rewriter/ort_fusions/bias_gelu_test.py b/onnxscript/rewriter/ort_fusions/bias_gelu_test.py new file mode 100644 index 0000000000..47598b5e4e --- /dev/null +++ b/onnxscript/rewriter/ort_fusions/bias_gelu_test.py @@ -0,0 +1,62 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import unittest + +import numpy as np + +import onnx +import onnxscript.ir as ir +import onnxscript.rewriter.ort_fusions._test_utils as test_utils + +from onnxscript import opset18 as op +from onnxscript.optimizer import optimize, remove_unused_nodes +from onnxscript.rewriter.ort_fusions.bias_gelu import fuse_bias_gelu + +FLOAT = onnx.TensorProto.FLOAT + + +class BiasGeluFusionTest(unittest.TestCase): + def test_gelu_fusion(self): + + model_proto = onnx.helper.make_model( + onnx.helper.make_graph( + [ + onnx.helper.make_node("Add", ["X", "Y"], ["gelu_add"]), + onnx.helper.make_node( + "Gelu", + ["gelu_add"], + ["Z"], + domain="com.microsoft", + ), + ], + "name", + [ + onnx.helper.make_tensor_value_info("X", FLOAT, [10]), + onnx.helper.make_tensor_value_info("Y", FLOAT, [10]), + ], + [onnx.helper.make_tensor_value_info("Z", FLOAT, [10])], + ), + opset_imports=[ + onnx.helper.make_opsetid("", 18), + onnx.helper.make_opsetid("com.microsoft", 1), + ], + ) + model = ir.serde.deserialize_model(model_proto) + optimize(model) + + input = {"X": np.random.randn(10).astype(np.float32), "Y": np.random.randn(10).astype(np.float32)} + original_output = test_utils.ort_run("Original", model, input) + + fuse_bias_gelu(model) + remove_unused_nodes(model) + + self.assertEqual(len(model.graph), 1) + self.assertEqual(model.graph.node(0).op_type, "BiasGelu") + + optimized_output = test_utils.ort_run("Optimized", model, input) + test_utils.assert_allclose(original_output, optimized_output) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From 5d4a01978ea05a918c8a04b4b63efa86bd88cf56 Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Wed, 23 Apr 2025 21:09:16 +0000 Subject: [PATCH 2/5] add skip_normalization with bias --- .../rewriter/ort_fusions/bias_gelu_test.py | 54 +++++++++---------- .../ort_fusions/skip_normalization.py | 51 +++++++++++++++--- 2 files changed, 72 insertions(+), 33 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/bias_gelu_test.py b/onnxscript/rewriter/ort_fusions/bias_gelu_test.py index 47598b5e4e..918feae12f 100644 --- a/onnxscript/rewriter/ort_fusions/bias_gelu_test.py +++ b/onnxscript/rewriter/ort_fusions/bias_gelu_test.py @@ -4,12 +4,10 @@ import unittest import numpy as np - import onnx + import onnxscript.ir as ir import onnxscript.rewriter.ort_fusions._test_utils as test_utils - -from onnxscript import opset18 as op from onnxscript.optimizer import optimize, remove_unused_nodes from onnxscript.rewriter.ort_fusions.bias_gelu import fuse_bias_gelu @@ -18,34 +16,36 @@ class BiasGeluFusionTest(unittest.TestCase): def test_gelu_fusion(self): - model_proto = onnx.helper.make_model( - onnx.helper.make_graph( - [ - onnx.helper.make_node("Add", ["X", "Y"], ["gelu_add"]), - onnx.helper.make_node( - "Gelu", - ["gelu_add"], - ["Z"], - domain="com.microsoft", - ), - ], - "name", - [ - onnx.helper.make_tensor_value_info("X", FLOAT, [10]), - onnx.helper.make_tensor_value_info("Y", FLOAT, [10]), - ], - [onnx.helper.make_tensor_value_info("Z", FLOAT, [10])], - ), - opset_imports=[ - onnx.helper.make_opsetid("", 18), - onnx.helper.make_opsetid("com.microsoft", 1), + onnx.helper.make_graph( + [ + onnx.helper.make_node("Add", ["X", "Y"], ["gelu_add"]), + onnx.helper.make_node( + "Gelu", + ["gelu_add"], + ["Z"], + domain="com.microsoft", + ), + ], + "name", + [ + onnx.helper.make_tensor_value_info("X", FLOAT, [10]), + onnx.helper.make_tensor_value_info("Y", FLOAT, [10]), ], - ) + [onnx.helper.make_tensor_value_info("Z", FLOAT, [10])], + ), + opset_imports=[ + onnx.helper.make_opsetid("", 18), + onnx.helper.make_opsetid("com.microsoft", 1), + ], + ) model = ir.serde.deserialize_model(model_proto) optimize(model) - input = {"X": np.random.randn(10).astype(np.float32), "Y": np.random.randn(10).astype(np.float32)} + input = { + "X": np.random.randn(10).astype(np.float32), + "Y": np.random.randn(10).astype(np.float32), + } original_output = test_utils.ort_run("Original", model, input) fuse_bias_gelu(model) @@ -59,4 +59,4 @@ def test_gelu_fusion(self): if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/onnxscript/rewriter/ort_fusions/skip_normalization.py b/onnxscript/rewriter/ort_fusions/skip_normalization.py index 9ae731d3d0..d4eca4c45d 100644 --- a/onnxscript/rewriter/ort_fusions/skip_normalization.py +++ b/onnxscript/rewriter/ort_fusions/skip_normalization.py @@ -47,27 +47,66 @@ def _skip_layer_norm_pattern(op, input, skip, gamma, beta, epsilon, stash_type): epsilon=epsilon, stash_type=stash_type, ) - return normalized + return normalized, skip_sum def _skip_layer_normalization(op, input, skip, gamma, beta, epsilon, stash_type): if stash_type.value != 1: # FLOAT type return None - normalized, _mean, _inv_std_var = op.SkipLayerNormalization( + normalized, _mean, _inv_std_var, skip_sum = op.SkipLayerNormalization( input, skip, gamma, beta, epsilon=epsilon, - _outputs=3, + _outputs=4, + _domain="com.microsoft", + ) + return normalized, skip_sum + + +# Fusion rule for Add + SkipLayerNormalization +def _skip_layer_norm_add_bias_pattern(op, input, skip, gamma, beta, bias, epsilon, stash_type): + bias_sum = op.Add(input, bias) + normalized, _mean, _inv_std_var, skip_sum = op.SkipLayerNormalization( + bias_sum, + skip, + gamma, + beta, + epsilon=epsilon, + _outputs=4, _domain="com.microsoft", ) - return normalized + return normalized, skip_sum + +def _skip_layer_normalization_add_bias( + op, input, skip, gamma, beta, bias, epsilon, stash_type +): + normalized, _mean, _inv_std_var, skip_sum = op.SkipLayerNormalization( + input, + skip, + gamma, + beta, + bias, + epsilon=epsilon, + _outputs=4, + _domain="com.microsoft", + ) + return normalized, skip_sum + + +_skip_layer_rule = pattern.RewriteRule( + _skip_layer_norm_pattern, _skip_layer_normalization, name="SkipLayerNorm" +) +_skip_layer_add_bias_rule = pattern.RewriteRule( + _skip_layer_norm_add_bias_pattern, + _skip_layer_normalization_add_bias, + name="SkipLayerNormAddBias", +) -_skip_layer_rule = pattern.RewriteRule(_skip_layer_norm_pattern, _skip_layer_normalization) -skip_layer_normalization_rules = [_skip_layer_rule] +skip_layer_normalization_rules = [_skip_layer_rule, _skip_layer_add_bias_rule] skip_layer_normalization_ruleset = pattern.RewriteRuleSet(skip_layer_normalization_rules) From 4e8243eb96f8ac1144c3234548ba378cccf128da Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Thu, 24 Apr 2025 04:57:03 +0000 Subject: [PATCH 3/5] reorder file --- onnxscript/rewriter/__init__.py | 2 -- onnxscript/rewriter/ort_fusions/_core.py | 2 ++ onnxscript/rewriter/{ => ort_fusions}/erfgelu.py | 4 +++- 3 files changed, 5 insertions(+), 3 deletions(-) rename onnxscript/rewriter/{ => ort_fusions}/erfgelu.py (88%) diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index 013b889f3f..5efaf784b0 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -18,7 +18,6 @@ broadcast_to_matmul, cast_constant_of_shape, collapse_slices, - erfgelu, gemm_to_matmul_add, llama_rule_sets, no_op, @@ -32,7 +31,6 @@ gemm_to_matmul_add.rule, # type: ignore[has-type] *cast_constant_of_shape.rules.rules, *collapse_slices.rules.rules, - *erfgelu.rules.rules, *llama_rule_sets.llama_p0_rule_set().rules, ) diff --git a/onnxscript/rewriter/ort_fusions/_core.py b/onnxscript/rewriter/ort_fusions/_core.py index cdb3489e07..52deb6c1b0 100644 --- a/onnxscript/rewriter/ort_fusions/_core.py +++ b/onnxscript/rewriter/ort_fusions/_core.py @@ -15,6 +15,7 @@ from onnxscript.rewriter.ort_fusions.attention import fuse_attention from onnxscript.rewriter.ort_fusions.bias_gelu import fuse_bias_gelu from onnxscript.rewriter.ort_fusions.cos_sin_cache import fuse_cos_sin_cache +from onnxscript.rewriter.ort_fusions.erfgelu import fuse_erfgelu from onnxscript.rewriter.ort_fusions.fuse_packed_qkv_gqa import fuse_qkv_gqa from onnxscript.rewriter.ort_fusions.gelu import fuse_gelu from onnxscript.rewriter.ort_fusions.gqa import fuse_gqa @@ -66,6 +67,7 @@ def fuse_xformers(model: ir.Model) -> tuple[ir.Model, dict[str, int]]: fusion_count = dict() model = _pre_optimize(model) + fusion_count["erf_gelu"] = fuse_erfgelu(model) fusion_count["rms_normalization"] = fuse_rms_normalization(model) fusion_count["skip_layer_normalization"] = fuse_skip_layer_normalization(model) fusion_count["skip_rms_normalization"] = fuse_skip_rms_normalization(model) diff --git a/onnxscript/rewriter/erfgelu.py b/onnxscript/rewriter/ort_fusions/erfgelu.py similarity index 88% rename from onnxscript/rewriter/erfgelu.py rename to onnxscript/rewriter/ort_fusions/erfgelu.py index 72adf9928b..ba515a5572 100644 --- a/onnxscript/rewriter/erfgelu.py +++ b/onnxscript/rewriter/ort_fusions/erfgelu.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. import math -from onnxscript.rewriter import pattern +from onnxscript.rewriter import _fusion_utils, pattern # Pattern to match against @@ -32,3 +32,5 @@ def gelu(op, x): rule2 = pattern.RewriteRule(erf_gelu_pattern_2, gelu) rules = pattern.RewriteRuleSet([rule1, rule2]) + +fuse_erfgelu = _fusion_utils.apply_fusion_rules(rules) From c3a2762bb4fb0d63eb674fa7614413564c9b633a Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Thu, 24 Apr 2025 20:00:43 +0000 Subject: [PATCH 4/5] fix test failures --- onnxscript/rewriter/ort_fusions/bias_gelu_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxscript/rewriter/ort_fusions/bias_gelu_test.py b/onnxscript/rewriter/ort_fusions/bias_gelu_test.py index 918feae12f..f52f6a8c48 100644 --- a/onnxscript/rewriter/ort_fusions/bias_gelu_test.py +++ b/onnxscript/rewriter/ort_fusions/bias_gelu_test.py @@ -38,6 +38,7 @@ def test_gelu_fusion(self): onnx.helper.make_opsetid("", 18), onnx.helper.make_opsetid("com.microsoft", 1), ], + ir_version=10, ) model = ir.serde.deserialize_model(model_proto) optimize(model) From d65da68e1d4dbb9e68d2a09e2c7def72b8a5d452 Mon Sep 17 00:00:00 2001 From: shubhambhokare1 Date: Thu, 24 Apr 2025 21:03:54 +0000 Subject: [PATCH 5/5] update test to onnxscript style --- .../rewriter/ort_fusions/_test_utils.py | 2 +- .../rewriter/ort_fusions/bias_gelu_test.py | 43 +++++++------------ 2 files changed, 17 insertions(+), 28 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/_test_utils.py b/onnxscript/rewriter/ort_fusions/_test_utils.py index e1a6be338d..4181fffbf4 100644 --- a/onnxscript/rewriter/ort_fusions/_test_utils.py +++ b/onnxscript/rewriter/ort_fusions/_test_utils.py @@ -33,7 +33,7 @@ def ort_run(model_name: str, model, inputs): return session.run(None, inputs) -def assert_allclose(outputs, expected_outputs, rtol=1e-4, atol=1e-4): +def assert_allclose(outputs, expected_outputs, rtol=1e-3, atol=1e-3): for i, (baseline_output, optimized_output) in enumerate(zip(expected_outputs, outputs)): try: np.testing.assert_equal(baseline_output.shape, optimized_output.shape) diff --git a/onnxscript/rewriter/ort_fusions/bias_gelu_test.py b/onnxscript/rewriter/ort_fusions/bias_gelu_test.py index f52f6a8c48..ce8c08cf4f 100644 --- a/onnxscript/rewriter/ort_fusions/bias_gelu_test.py +++ b/onnxscript/rewriter/ort_fusions/bias_gelu_test.py @@ -4,48 +4,37 @@ import unittest import numpy as np -import onnx +import onnxscript import onnxscript.ir as ir import onnxscript.rewriter.ort_fusions._test_utils as test_utils +from onnxscript import FLOAT, script +from onnxscript import opset18 as op from onnxscript.optimizer import optimize, remove_unused_nodes from onnxscript.rewriter.ort_fusions.bias_gelu import fuse_bias_gelu -FLOAT = onnx.TensorProto.FLOAT +msft_op = onnxscript.values.Opset("com.microsoft", 1) class BiasGeluFusionTest(unittest.TestCase): - def test_gelu_fusion(self): - model_proto = onnx.helper.make_model( - onnx.helper.make_graph( - [ - onnx.helper.make_node("Add", ["X", "Y"], ["gelu_add"]), - onnx.helper.make_node( - "Gelu", - ["gelu_add"], - ["Z"], - domain="com.microsoft", - ), - ], - "name", - [ - onnx.helper.make_tensor_value_info("X", FLOAT, [10]), - onnx.helper.make_tensor_value_info("Y", FLOAT, [10]), - ], - [onnx.helper.make_tensor_value_info("Z", FLOAT, [10])], - ), - opset_imports=[ - onnx.helper.make_opsetid("", 18), - onnx.helper.make_opsetid("com.microsoft", 1), - ], + def test_bias_gelu_fusion(self): + @script() + def bias_gelu_model(x, y): + gelu_add = op.Add(x, y) + gelu = msft_op.Gelu(gelu_add) + return gelu + + model_proto = bias_gelu_model.to_model_proto( + input_types=[FLOAT[10], FLOAT[10]], + output_types=[FLOAT[10]], ir_version=10, ) model = ir.serde.deserialize_model(model_proto) optimize(model) input = { - "X": np.random.randn(10).astype(np.float32), - "Y": np.random.randn(10).astype(np.float32), + "x": np.random.randn(10).astype(np.float32), + "y": np.random.randn(10).astype(np.float32), } original_output = test_utils.ort_run("Original", model, input)