Skip to content

Commit 5c35ff0

Browse files
add one layer encoder
1 parent 975c52b commit 5c35ff0

2 files changed

Lines changed: 188 additions & 0 deletions

File tree

onnxscript/rewriter/ort_fusions/_whisper_tiny.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import onnxscript.optimizer
1515
import onnxscript.rewriter.ort_fusions._core as xformers
1616

17+
from onnxscript.rewriter.ort_fusions._whisper_tiny_encoder import whisper_encoder_test
18+
1719

1820
def make_encoder_model():
1921
pass
@@ -25,6 +27,14 @@ def make_decoder_model():
2527

2628
class TestMultiHeadAttention(unittest.TestCase):
2729
def test_whisper_tiny(self):
30+
31+
test = whisper_encoder_test()
32+
model = test.get_onnx_model()
33+
onnxscript.optimizer.optimize(model)
34+
model, fusion_count_m = xformers.fuse_xformers(model)
35+
print(f"Fused {fusion_count_m} ops")
36+
37+
2838
# Generate encoder model
2939
whisper_encoder_model = onnx.load(
3040
"/workspace/testing/whisper-opt/whisper-tiny-4.48/whisper-tiny_encoder.onnx"
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
"""
5+
A one-layer Whisper encoder model test case, with inputs: audio_features.
6+
This is an onnxscript version of the model.
7+
"""
8+
9+
import numpy as np
10+
11+
import onnxscript.ir as ir
12+
from onnxscript import script
13+
from onnxscript.onnx_opset import opset18
14+
from onnxscript.onnx_types import FLOAT
15+
16+
17+
def make_model(
18+
encoder_encoder_embed_positions_weight,
19+
encoder_encoder_conv1_weight,
20+
encoder_encoder_conv1_bias,
21+
encoder_encoder_conv2_weight,
22+
encoder_encoder_conv2_bias,
23+
encoder_encoder_layers_0_self_attn_layer_norm_weight,
24+
encoder_encoder_layers_0_self_attn_layer_norm_bias,
25+
encoder_encoder_layers_0_self_attn_q_proj_weight,
26+
encoder_encoder_layers_0_self_attn_q_proj_bias,
27+
encoder_encoder_layers_0_self_attn_k_proj_weight,
28+
encoder_encoder_layers_0_self_attn_v_proj_weight,
29+
encoder_encoder_layers_0_self_attn_v_proj_bias,
30+
encoder_encoder_layers_0_self_attn_out_proj_weight,
31+
encoder_encoder_layers_0_self_attn_out_proj_bias,
32+
encoder_encoder_layers_0_final_layer_norm_weight,
33+
encoder_encoder_layers_0_final_layer_norm_bias,
34+
encoder_encoder_layers_0_fc1_weight,
35+
encoder_encoder_layers_0_fc1_bias,
36+
encoder_encoder_layers_0_fc2_weight,
37+
encoder_encoder_layers_0_fc2_bias,
38+
encoder_encoder_layer_norm_weight,
39+
encoder_encoder_layer_norm_bias,
40+
41+
):
42+
@script()
43+
def main_graph(
44+
audio_features: FLOAT[1,80,3000]
45+
) -> (FLOAT[1,1500,384]):
46+
val_0 = opset18.Shape(audio_features, end=1, start=0)
47+
conv1d = opset18.Conv(audio_features, encoder_encoder_conv1_weight, encoder_encoder_conv1_bias, group=1, pads=[1, 1], auto_pad='NOTSET', strides=[1], dilations=[1])
48+
val_2 = opset18.Div(conv1d, 1.4142135)
49+
val_3 = opset18.Erf(val_2)
50+
val_5 = opset18.Add(val_3, 1.0)
51+
val_7 = opset18.Mul(0.5, val_5)
52+
gelu = opset18.Mul(conv1d, val_7)
53+
conv1d_1 = opset18.Conv(gelu, encoder_encoder_conv2_weight, encoder_encoder_conv2_bias, group=1, pads=[1, 1], auto_pad='NOTSET', strides=[2], dilations=[1])
54+
val_9 = opset18.Div(conv1d_1, 1.4142135)
55+
val_10 = opset18.Erf(val_9)
56+
val_12 = opset18.Add(val_10, 1.0)
57+
val_14 = opset18.Mul(0.5, val_12)
58+
gelu_1 = opset18.Mul(conv1d_1, val_14)
59+
permute = opset18.Transpose(gelu_1, perm=[0, 2, 1])
60+
add_20 = opset18.Add(permute, encoder_encoder_embed_positions_weight)
61+
layer_norm = opset18.LayerNormalization(add_20, encoder_encoder_layers_0_self_attn_layer_norm_weight, encoder_encoder_layers_0_self_attn_layer_norm_bias, stash_type=1, epsilon=9.999999747378752e-06, axis=-1)
62+
val_17 = opset18.Transpose(encoder_encoder_layers_0_self_attn_q_proj_weight, perm=[1, 0])
63+
val_18 = opset18.MatMul(layer_norm, val_17)
64+
linear = opset18.Add(val_18, encoder_encoder_layers_0_self_attn_q_proj_bias)
65+
mul_18 = opset18.Mul(linear, 0.125)
66+
val_25 = opset18.Concat(val_0, [1500], [6], [64], axis=0)
67+
view = opset18.Reshape(mul_18, val_25, allowzero=0)
68+
transpose = opset18.Transpose(view, perm=[0, 2, 1, 3])
69+
val_27 = opset18.Transpose(encoder_encoder_layers_0_self_attn_k_proj_weight, perm=[1, 0])
70+
linear_1 = opset18.MatMul(layer_norm, val_27)
71+
val_31 = opset18.Concat(val_0, [-1], [6], [64], axis=0)
72+
view_1 = opset18.Reshape(linear_1, val_31, allowzero=0)
73+
val_33 = opset18.Transpose(encoder_encoder_layers_0_self_attn_v_proj_weight, perm=[1, 0])
74+
val_34 = opset18.MatMul(layer_norm, val_33)
75+
linear_2 = opset18.Add(val_34, encoder_encoder_layers_0_self_attn_v_proj_bias)
76+
val_37 = opset18.Concat(val_0, [-1], [6], [64], axis=0)
77+
view_2 = opset18.Reshape(linear_2, val_37, allowzero=0)
78+
transpose_2 = opset18.Transpose(view_2, perm=[0, 2, 1, 3])
79+
transpose_3 = opset18.Transpose(view_1, perm=[0, 2, 3, 1])
80+
matmul = opset18.MatMul(transpose, transpose_3)
81+
softmax = opset18.Softmax(matmul, axis=-1)
82+
matmul_1 = opset18.MatMul(softmax, transpose_2)
83+
transpose_4 = opset18.Transpose(matmul_1, perm=[0, 2, 1, 3])
84+
val_42 = opset18.Concat(val_0, [1500], [384], axis=0)
85+
_unsafe_view = opset18.Reshape(transpose_4, val_42, allowzero=0)
86+
val_44 = opset18.Transpose(encoder_encoder_layers_0_self_attn_out_proj_weight, perm=[1, 0])
87+
val_45 = opset18.MatMul(_unsafe_view, val_44)
88+
linear_3 = opset18.Add(val_45, encoder_encoder_layers_0_self_attn_out_proj_bias)
89+
add_141 = opset18.Add(add_20, linear_3)
90+
layer_norm_1 = opset18.LayerNormalization(add_141, encoder_encoder_layers_0_final_layer_norm_weight, encoder_encoder_layers_0_final_layer_norm_bias, stash_type=1, epsilon=9.999999747378752e-06, axis=-1)
91+
val_48 = opset18.Transpose(encoder_encoder_layers_0_fc1_weight, perm=[1, 0])
92+
val_49 = opset18.MatMul(layer_norm_1, val_48)
93+
linear_4 = opset18.Add(val_49, encoder_encoder_layers_0_fc1_bias)
94+
val_51 = opset18.Div(linear_4, 1.4142135)
95+
val_52 = opset18.Erf(val_51)
96+
val_54 = opset18.Add(val_52, 1.0)
97+
val_56 = opset18.Mul(0.5, val_54)
98+
gelu_2 = opset18.Mul(linear_4, val_56)
99+
val_57 = opset18.Transpose(encoder_encoder_layers_0_fc2_weight, perm=[1, 0])
100+
val_58 = opset18.MatMul(gelu_2, val_57)
101+
linear_5 = opset18.Add(val_58, encoder_encoder_layers_0_fc2_bias)
102+
add_170 = opset18.Add(add_141, linear_5)
103+
layer_norm_2 = opset18.LayerNormalization(add_170, encoder_encoder_layer_norm_weight, encoder_encoder_layer_norm_bias, stash_type=1, epsilon=9.999999747378752e-06, axis=-1)
104+
return layer_norm_2
105+
106+
model = main_graph.to_model_proto()
107+
return model
108+
109+
110+
def make_model_with_random_weights():
111+
encoder_encoder_embed_positions_weight = np.random.rand(1500, 384).astype(np.float32)
112+
encoder_encoder_conv1_weight = np.random.rand(384, 80, 3).astype(np.float32)
113+
encoder_encoder_conv1_bias = np.random.rand(384).astype(np.float32)
114+
encoder_encoder_conv2_weight = np.random.rand(384, 384, 3).astype(np.float32)
115+
encoder_encoder_conv2_bias = np.random.rand(384).astype(np.float32)
116+
encoder_encoder_layers_0_self_attn_layer_norm_weight = np.random.rand(384).astype(np.float32)
117+
encoder_encoder_layers_0_self_attn_layer_norm_bias = np.random.rand(384).astype(np.float32)
118+
encoder_encoder_layers_0_self_attn_q_proj_weight = np.random.rand(384, 384).astype(np.float32)
119+
encoder_encoder_layers_0_self_attn_q_proj_bias = np.random.rand(384).astype(np.float32)
120+
encoder_encoder_layers_0_self_attn_k_proj_weight = np.random.rand(384, 384).astype(np.float32)
121+
encoder_encoder_layers_0_self_attn_v_proj_weight = np.random.rand(384, 384).astype(np.float32)
122+
encoder_encoder_layers_0_self_attn_v_proj_bias = np.random.rand(384).astype(np.float32)
123+
encoder_encoder_layers_0_self_attn_out_proj_weight = np.random.rand(384, 384).astype(np.float32)
124+
encoder_encoder_layers_0_self_attn_out_proj_bias = np.random.rand(384).astype(np.float32)
125+
encoder_encoder_layers_0_final_layer_norm_weight = np.random.rand(384).astype(np.float32)
126+
encoder_encoder_layers_0_final_layer_norm_bias = np.random.rand(384).astype(np.float32)
127+
encoder_encoder_layers_0_fc1_weight = np.random.rand(1536, 384).astype(np.float32)
128+
encoder_encoder_layers_0_fc1_bias = np.random.rand(1536).astype(np.float32)
129+
encoder_encoder_layers_0_fc2_weight = np.random.rand(384, 1536).astype(np.float32)
130+
encoder_encoder_layers_0_fc2_bias = np.random.rand(384).astype(np.float32)
131+
encoder_encoder_layer_norm_weight = np.random.rand(384).astype(np.float32)
132+
encoder_encoder_layer_norm_bias = np.random.rand(384).astype(np.float32)
133+
model = make_model(
134+
encoder_encoder_embed_positions_weight,
135+
encoder_encoder_conv1_weight,
136+
encoder_encoder_conv1_bias,
137+
encoder_encoder_conv2_weight,
138+
encoder_encoder_conv2_bias,
139+
encoder_encoder_layers_0_self_attn_layer_norm_weight,
140+
encoder_encoder_layers_0_self_attn_layer_norm_bias,
141+
encoder_encoder_layers_0_self_attn_q_proj_weight,
142+
encoder_encoder_layers_0_self_attn_q_proj_bias,
143+
encoder_encoder_layers_0_self_attn_k_proj_weight,
144+
encoder_encoder_layers_0_self_attn_v_proj_weight,
145+
encoder_encoder_layers_0_self_attn_v_proj_bias,
146+
encoder_encoder_layers_0_self_attn_out_proj_weight,
147+
encoder_encoder_layers_0_self_attn_out_proj_bias,
148+
encoder_encoder_layers_0_final_layer_norm_weight,
149+
encoder_encoder_layers_0_final_layer_norm_bias,
150+
encoder_encoder_layers_0_fc1_weight,
151+
encoder_encoder_layers_0_fc1_bias,
152+
encoder_encoder_layers_0_fc2_weight,
153+
encoder_encoder_layers_0_fc2_bias,
154+
encoder_encoder_layer_norm_weight,
155+
encoder_encoder_layer_norm_bias
156+
)
157+
return model
158+
159+
160+
class _WhisperEncoderTest:
161+
def get_onnx_model(self):
162+
if not hasattr(self, "_onnx_model"):
163+
model_proto = make_model_with_random_weights()
164+
model = ir.serde.deserialize_model(model_proto)
165+
self._onnx_model = model
166+
return self._onnx_model
167+
168+
def get_ort_inputs(self):
169+
if not hasattr(self, "_ort_inputs"):
170+
inputs = {
171+
"audio_features": np.random.rand((1, 80, 3000)).astype(np.float32),
172+
}
173+
self._ort_inputs = inputs
174+
return self._ort_inputs
175+
176+
177+
def whisper_encoder_test():
178+
return _WhisperEncoderTest()

0 commit comments

Comments
 (0)