Skip to content

Commit c368515

Browse files
add decoder model
1 parent 5c35ff0 commit c368515

7 files changed

Lines changed: 390 additions & 149 deletions

File tree

onnxscript/rewriter/ort_fusions/_core.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,12 +123,13 @@ def optimize_for_ort(
123123
rewrite(model, ORT_PATTERN_REWRITE_RULES)
124124
return model, fusion_count
125125

126-
'''
126+
127+
"""
127128
from onnxscript import ir, rewriter
128129
import onnxscript.rewriter.ort_fusions as ort_fusions
129130
model_ir = ir.serde.deserialize_model(model)
130131
model_ir, count = ort_fusions.optimize_for_ort(model_ir)
131132
print("Applied fusions", count)
132133
print("\n\n\n\n\n\n\n\n\n\n\n")
133134
model = ir.serde.serialize_model(model_ir)
134-
'''
135+
"""

onnxscript/rewriter/ort_fusions/_whisper_tiny.py

Lines changed: 0 additions & 122 deletions
This file was deleted.

onnxscript/rewriter/ort_fusions/attention_test.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,7 @@ def create_model(self, with_past=False):
5858

5959
@script()
6060
def model_with_mha(input, weight, bias):
61-
qkv_no_bias = op.MatMul(input, weight)
62-
qkv = op.Add(qkv_no_bias, bias)
61+
qkv = op.MatMul(input, weight)
6362

6463
query_BSDh = op.Slice(qkv, [0], [160], [2])
6564
key_BSDh = op.Slice(qkv, [160], [320], [2])
@@ -69,14 +68,18 @@ def model_with_mha(input, weight, bias):
6968
query_BSDh,
7069
key_BSDh,
7170
value_BSDh,
71+
bias,
72+
None,
73+
None,
74+
None,
75+
None,
7276
num_heads=self.num_heads,
7377
)
7478
return mha
7579

7680
@script()
7781
def model_with_mha_past(input, weight, bias, past):
78-
qkv_no_bias = op.MatMul(input, weight)
79-
qkv = op.Add(qkv_no_bias, bias)
82+
qkv = op.MatMul(input, weight)
8083

8184
query_BSDh = op.Slice(qkv, [0], [160], [2])
8285
key_BSDh = op.Slice(qkv, [160], [320], [2])
@@ -91,7 +94,7 @@ def model_with_mha_past(input, weight, bias, past):
9194
query_BSDh,
9295
key_BSDh,
9396
value_BSDh,
94-
None,
97+
bias,
9598
None,
9699
None,
97100
past_key,

onnxscript/rewriter/ort_fusions/fuse_xformers_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def test_fuse_xformers(self):
2727
self.assertEqual(fusion_count["partial_rotary_embedding"], 0)
2828
self.assertEqual(fusion_count["cos_sin_cache"], 2)
2929
self.assertEqual(fusion_count["sdpa"], 1)
30-
self.assertEqual(fusion_count["mha"], 0)
30+
self.assertEqual(fusion_count["mha"], 1)
3131
self.assertEqual(fusion_count["attention"], 0)
3232
self.assertEqual(fusion_count["gqa"], 0)
3333
self.assertEqual(fusion_count["gelu"], 0)

onnxscript/rewriter/ort_fusions/mha_test.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,11 @@
88

99
import onnxscript.optimizer
1010
import onnxscript.rewriter.ort_fusions._core as xformers
11+
from onnxscript.ir.passes.common import shape_inference
1112
from onnxscript.rewriter.ort_fusions._test_utils import ORT_VERSION, assert_allclose, ort_run
1213
from onnxscript.rewriter.ort_fusions.models._smollm_2 import smollm_test_2
14+
from onnxscript.rewriter.ort_fusions.models._whisper_decoder import whisper_decoder_test
15+
from onnxscript.rewriter.ort_fusions.models._whisper_encoder import whisper_encoder_test
1316

1417

1518
class TestMultiHeadAttention(unittest.TestCase):
@@ -40,6 +43,32 @@ def test_smollm(self):
4043
new_outputs = ort_run("optimized", model, inputs)
4144
assert_allclose(new_outputs, original_outputs)
4245

46+
def test_whisper_encoder(self):
47+
# Generate model
48+
whisper_encoder = whisper_encoder_test()
49+
model = whisper_encoder.get_onnx_model()
50+
onnxscript.optimizer.optimize(model)
51+
52+
# Fuse SDPA and MHA
53+
sdpa_count = xformers.fuse_sdpa(model)
54+
self.assertGreater(sdpa_count, 0)
55+
model = shape_inference.infer_shapes(model)
56+
mha_count = xformers.fuse_mha(model)
57+
self.assertGreater(mha_count, 0)
58+
59+
def test_whisper_decoder(self):
60+
# Generate model
61+
whisper_decoder = whisper_decoder_test()
62+
model = whisper_decoder.get_onnx_model()
63+
onnxscript.optimizer.optimize(model)
64+
65+
# Fuse SDPA and MHA
66+
sdpa_count = xformers.fuse_sdpa(model)
67+
self.assertGreater(sdpa_count, 0)
68+
model = shape_inference.infer_shapes(model)
69+
mha_count = xformers.fuse_mha(model)
70+
self.assertGreater(mha_count, 0)
71+
4372

4473
if __name__ == "__main__":
4574
unittest.main()

0 commit comments

Comments
 (0)