Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions onnxscript/rewriter/ort_fusions/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
softmax,
)
from onnxscript.rewriter.ort_fusions.attention import fuse_attention
from onnxscript.rewriter.ort_fusions.bias_gelu import fuse_bias_gelu
from onnxscript.rewriter.ort_fusions.cos_sin_cache import fuse_cos_sin_cache
from onnxscript.rewriter.ort_fusions.erfgelu import fuse_erfgelu
from onnxscript.rewriter.ort_fusions.fuse_packed_qkv_gqa import fuse_qkv_gqa
from onnxscript.rewriter.ort_fusions.gelu import fuse_gelu
from onnxscript.rewriter.ort_fusions.gqa import fuse_gqa
Expand Down Expand Up @@ -65,6 +67,7 @@ def fuse_xformers(model: ir.Model) -> tuple[ir.Model, dict[str, int]]:
fusion_count = dict()

model = _pre_optimize(model)
fusion_count["erf_gelu"] = fuse_erfgelu(model)
fusion_count["rms_normalization"] = fuse_rms_normalization(model)
fusion_count["skip_layer_normalization"] = fuse_skip_layer_normalization(model)
fusion_count["skip_rms_normalization"] = fuse_skip_rms_normalization(model)
Expand All @@ -84,6 +87,7 @@ def fuse_xformers(model: ir.Model) -> tuple[ir.Model, dict[str, int]]:
fusion_count["attention"] = fuse_attention(model)
fusion_count["gqa"] = 0
fusion_count["gelu"] = fuse_gelu(model)
fusion_count["bias_gelu"] = fuse_bias_gelu(model)
# Finally: inline any intermediate fusion functions introduced that were not
# consumed by other fusions, and eliminate any remaining unused nodes.
optimize(model)
Expand Down
2 changes: 1 addition & 1 deletion onnxscript/rewriter/ort_fusions/_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def ort_run(model_name: str, model, inputs):
return session.run(None, inputs)


def assert_allclose(outputs, expected_outputs, rtol=1e-4, atol=1e-4):
def assert_allclose(outputs, expected_outputs, rtol=1e-3, atol=1e-3):
for i, (baseline_output, optimized_output) in enumerate(zip(expected_outputs, outputs)):
try:
np.testing.assert_equal(baseline_output.shape, optimized_output.shape)
Expand Down
22 changes: 22 additions & 0 deletions onnxscript/rewriter/ort_fusions/bias_gelu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

from onnxscript.rewriter import _fusion_utils, pattern


class BiasGeluFusion(pattern.RewriteRuleClassBase):
def pattern(self, op, x, y):
gelu_add = op.Add(x, y)
return op.Gelu(gelu_add, _domain="com.microsoft")

def rewrite(self, op, x, y):
return op.BiasGelu(x, y, _domain="com.microsoft")


_rule = BiasGeluFusion.rule()

bias_gelu_rules = pattern.RewriteRuleSet([_rule])


fuse_bias_gelu = _fusion_utils.apply_fusion_rules(bias_gelu_rules)
52 changes: 52 additions & 0 deletions onnxscript/rewriter/ort_fusions/bias_gelu_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import unittest

import numpy as np

import onnxscript
import onnxscript.ir as ir
import onnxscript.rewriter.ort_fusions._test_utils as test_utils
from onnxscript import FLOAT, script
from onnxscript import opset18 as op
from onnxscript.optimizer import optimize, remove_unused_nodes
from onnxscript.rewriter.ort_fusions.bias_gelu import fuse_bias_gelu

msft_op = onnxscript.values.Opset("com.microsoft", 1)


class BiasGeluFusionTest(unittest.TestCase):
def test_bias_gelu_fusion(self):
@script()
def bias_gelu_model(x, y):
gelu_add = op.Add(x, y)
gelu = msft_op.Gelu(gelu_add)
return gelu

Check warning on line 25 in onnxscript/rewriter/ort_fusions/bias_gelu_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/bias_gelu_test.py#L23-L25

Added lines #L23 - L25 were not covered by tests

model_proto = bias_gelu_model.to_model_proto(
input_types=[FLOAT[10], FLOAT[10]],
output_types=[FLOAT[10]],
ir_version=10,
)
model = ir.serde.deserialize_model(model_proto)
optimize(model)

input = {
"x": np.random.randn(10).astype(np.float32),
"y": np.random.randn(10).astype(np.float32),
}
original_output = test_utils.ort_run("Original", model, input)

fuse_bias_gelu(model)
remove_unused_nodes(model)

self.assertEqual(len(model.graph), 1)
self.assertEqual(model.graph.node(0).op_type, "BiasGelu")

optimized_output = test_utils.ort_run("Optimized", model, input)
test_utils.assert_allclose(original_output, optimized_output)


if __name__ == "__main__":
unittest.main()

Check warning on line 52 in onnxscript/rewriter/ort_fusions/bias_gelu_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/bias_gelu_test.py#L52

Added line #L52 was not covered by tests
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
# Licensed under the MIT License.
import math

from onnxscript.rewriter import pattern
from onnxscript.rewriter import _fusion_utils, pattern


# Pattern to match against
def erf_gelu_pattern(op, x):
def erf_gelu_pattern_1(op, x):
# erf_gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2)))
# half = pattern.Constant(0.5)
# sqrt2 = pattern.Constant(1.4142)
Expand All @@ -19,9 +19,18 @@ def erf_gelu_pattern(op, x):
return 0.5 * (x * (op.Erf(x / math.sqrt(2)) + 1.0))


def erf_gelu_pattern_2(op, x):
return x * (0.5 * (op.Erf(x / math.sqrt(2)) + 1.0))


# Replacement
def gelu(op, x):
return op.Gelu(x, _domain="com.microsoft")


rule = pattern.RewriteRule(erf_gelu_pattern, gelu)
rule1 = pattern.RewriteRule(erf_gelu_pattern_1, gelu)
rule2 = pattern.RewriteRule(erf_gelu_pattern_2, gelu)

rules = pattern.RewriteRuleSet([rule1, rule2])

fuse_erfgelu = _fusion_utils.apply_fusion_rules(rules)
51 changes: 45 additions & 6 deletions onnxscript/rewriter/ort_fusions/skip_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,27 +47,66 @@
epsilon=epsilon,
stash_type=stash_type,
)
return normalized
return normalized, skip_sum


def _skip_layer_normalization(op, input, skip, gamma, beta, epsilon, stash_type):
if stash_type.value != 1: # FLOAT type
return None
normalized, _mean, _inv_std_var = op.SkipLayerNormalization(
normalized, _mean, _inv_std_var, skip_sum = op.SkipLayerNormalization(

Check warning on line 56 in onnxscript/rewriter/ort_fusions/skip_normalization.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/skip_normalization.py#L56

Added line #L56 was not covered by tests
input,
skip,
gamma,
beta,
epsilon=epsilon,
_outputs=3,
_outputs=4,
_domain="com.microsoft",
)
return normalized, skip_sum

Check warning on line 65 in onnxscript/rewriter/ort_fusions/skip_normalization.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/skip_normalization.py#L65

Added line #L65 was not covered by tests


# Fusion rule for Add + SkipLayerNormalization
def _skip_layer_norm_add_bias_pattern(op, input, skip, gamma, beta, bias, epsilon, stash_type):
bias_sum = op.Add(input, bias)
Comment thread
shubhambhokare1 marked this conversation as resolved.
normalized, _mean, _inv_std_var, skip_sum = op.SkipLayerNormalization(
bias_sum,
skip,
gamma,
beta,
epsilon=epsilon,
_outputs=4,
_domain="com.microsoft",
)
return normalized
return normalized, skip_sum


def _skip_layer_normalization_add_bias(
op, input, skip, gamma, beta, bias, epsilon, stash_type
):
normalized, _mean, _inv_std_var, skip_sum = op.SkipLayerNormalization(

Check warning on line 86 in onnxscript/rewriter/ort_fusions/skip_normalization.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/skip_normalization.py#L86

Added line #L86 was not covered by tests
input,
skip,
gamma,
beta,
bias,
epsilon=epsilon,
_outputs=4,
_domain="com.microsoft",
)
return normalized, skip_sum

Check warning on line 96 in onnxscript/rewriter/ort_fusions/skip_normalization.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/ort_fusions/skip_normalization.py#L96

Added line #L96 was not covered by tests


_skip_layer_rule = pattern.RewriteRule(
_skip_layer_norm_pattern, _skip_layer_normalization, name="SkipLayerNorm"
)
_skip_layer_add_bias_rule = pattern.RewriteRule(
_skip_layer_norm_add_bias_pattern,
_skip_layer_normalization_add_bias,
name="SkipLayerNormAddBias",
)

_skip_layer_rule = pattern.RewriteRule(_skip_layer_norm_pattern, _skip_layer_normalization)

skip_layer_normalization_rules = [_skip_layer_rule]
skip_layer_normalization_rules = [_skip_layer_rule, _skip_layer_add_bias_rule]
skip_layer_normalization_ruleset = pattern.RewriteRuleSet(skip_layer_normalization_rules)


Expand Down
Loading