-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtest_tasks.py
More file actions
262 lines (248 loc) · 11.8 KB
/
test_tasks.py
File metadata and controls
262 lines (248 loc) · 11.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
import os
import unittest
import torch
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
from onnx_diagnostic.helpers.torch_helper import to_any, torch_deepcopy
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
class TestTasks(ExtTestCase):
def test_unittest_going(self):
assert (
os.environ.get("UNITTEST_GOING", "0") == "1"
), "UNITTEST_GOING=1 must be defined for these tests"
@hide_stdout()
def test_text2text_generation(self):
mid = "sshleifer/tiny-marian-en-de"
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
self.assertEqual(data["task"], "text2text-generation")
self.assertIn((data["size"], data["n_weights"]), [(473928, 118482)])
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
raise unittest.SkipTest(f"not working for {mid!r}")
model(**inputs)
model(**data["inputs2"])
with torch_export_patches(patch_transformers=True, verbose=10):
torch.export.export(
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
)
@hide_stdout()
def test_text_generation(self):
mid = "arnir0/Tiny-LLM"
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
self.assertEqual(data["task"], "text-generation")
self.assertIn("inputs", data)
self.assertIn("inputs2", data)
self.assertIn("inputs_batch1", data)
self.assertIn("inputs_empty_cache", data)
self.assertIn((data["size"], data["n_weights"]), [(51955968, 12988992)])
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
model(**inputs)
model(**data["inputs2"])
with torch_export_patches(patch_transformers=True, verbose=10):
torch.export.export(
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
)
@hide_stdout()
def test_submodule(self):
mid = "arnir0/Tiny-LLM::model"
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
self.assertEqual(data["task"], "text-generation")
self.assertIn("inputs", data)
self.assertIn("inputs2", data)
self.assertIn("inputs_batch1", data)
self.assertIn("inputs_empty_cache", data)
self.assertIn((data["size"], data["n_weights"]), [(27379968, 6844992)])
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
model(**inputs)
model(**data["inputs2"])
with torch_export_patches(patch_transformers=True, verbose=10):
torch.export.export(
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
)
@hide_stdout()
def test_text_generation_empty_cache(self):
mid = "arnir0/Tiny-LLM"
data = get_untrained_model_with_inputs(mid, add_second_input=True)
model, inputs = data["model"], data["inputs"]
self.assertIn("inputs_empty_cache", data)
empty_inputs = torch_deepcopy(data["inputs_empty_cache"])
model(**torch_deepcopy(empty_inputs))
expected = model(**torch_deepcopy(inputs))
self.assertEqual(
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
)
with torch_export_patches(patch_transformers=True, verbose=1):
ep = torch.export.export(
model,
(),
kwargs=torch_deepcopy(inputs),
dynamic_shapes=use_dyn_not_str(data["dynamic_shapes"]),
)
got = ep.module()(**torch_deepcopy(inputs))
self.assertEqualArrayAny(expected, got)
@hide_stdout()
def test_text_generation_batch1(self):
mid = "arnir0/Tiny-LLM"
data = get_untrained_model_with_inputs(mid, add_second_input=True)
model, inputs = data["model"], data["inputs"]
self.assertIn("inputs_batch1", data)
empty_inputs = torch_deepcopy(data["inputs_batch1"])
model(**torch_deepcopy(empty_inputs))
expected = model(**torch_deepcopy(inputs))
self.assertEqual(
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
)
with torch_export_patches(patch_transformers=True, verbose=1):
ep = torch.export.export(
model,
(),
kwargs=torch_deepcopy(inputs),
dynamic_shapes=use_dyn_not_str(data["dynamic_shapes"]),
)
got = ep.module()(**torch_deepcopy(inputs))
self.assertEqualArrayAny(expected, got)
@hide_stdout()
def test_automatic_speech_recognition_float32(self):
mid = "openai/whisper-tiny"
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
self.assertEqual(data["task"], "automatic-speech-recognition")
self.assertIn((data["size"], data["n_weights"]), [(132115968, 33028992)])
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
model(**data["inputs"])
model(**data["inputs2"])
self.maxDiff = None
self.assertIn("{0:DYN(batch),1:DYN(seq_length)}", self.string_type(ds))
self.assertEqualAny(
{
"decoder_input_ids": {0: "batch", 1: "seq_length"},
"cache_position": {0: "seq_length"},
"encoder_outputs": [{0: "batch"}],
"past_key_values": [
[{0: "batch"}, {0: "batch"}, {0: "batch"}, {0: "batch"}],
[{0: "batch"}, {0: "batch"}, {0: "batch"}, {0: "batch"}],
],
},
ds,
)
model(**inputs)
self.assertEqual(
"#1[T1r3]",
self.string_type(torch.utils._pytree.tree_flatten(inputs["encoder_outputs"])[0]),
)
with torch_export_patches(patch_transformers=True, verbose=10):
flat = torch.utils._pytree.tree_flatten(inputs["past_key_values"])[0]
self.assertIsInstance(flat, list)
self.assertIsInstance(flat[0], torch.Tensor)
self.assertEqual(
"#8[T1r4,T1r4,T1r4,T1r4,T1r4,T1r4,T1r4,T1r4]",
self.string_type(flat),
)
torch.export.export(
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
)
with torch_export_patches(patch_transformers=True, verbose=10):
flat = torch.utils._pytree.tree_flatten(inputs["past_key_values"])[0]
self.assertIsInstance(flat, list)
self.assertIsInstance(flat[0], torch.Tensor)
self.assertEqual(
"#8[T1r4,T1r4,T1r4,T1r4,T1r4,T1r4,T1r4,T1r4]",
self.string_type(flat),
)
torch.export.export(
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
)
@hide_stdout()
def test_automatic_speech_recognition_float16(self):
mid = "openai/whisper-tiny"
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
self.assertEqual(data["task"], "automatic-speech-recognition")
self.assertIn((data["size"], data["n_weights"]), [(132115968, 33028992)])
self.assertIn("encoder_outputs:BaseModelOutput", self.string_type(data["inputs"]))
data["inputs"] = to_any(data["inputs"], torch.float16)
self.assertIn("encoder_outputs:BaseModelOutput", self.string_type(data["inputs"]))
data["inputs2"] = to_any(data["inputs2"], torch.float16)
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
model = to_any(model, torch.float16)
model(**data["inputs2"])
self.maxDiff = None
self.assertIn("{0:DYN(batch),1:DYN(seq_length)}", self.string_type(ds))
self.assertEqualAny(
{
"decoder_input_ids": {0: "batch", 1: "seq_length"},
"cache_position": {0: "seq_length"},
"encoder_outputs": [{0: "batch"}],
"past_key_values": [
[{0: "batch"}, {0: "batch"}, {0: "batch"}, {0: "batch"}],
[{0: "batch"}, {0: "batch"}, {0: "batch"}, {0: "batch"}],
],
},
ds,
)
self.assertEqual(
"#1[T10r3]",
self.string_type(torch.utils._pytree.tree_flatten(inputs["encoder_outputs"])[0]),
)
with torch_export_patches(patch_transformers=True, verbose=10):
model(**inputs)
flat = torch.utils._pytree.tree_flatten(inputs["past_key_values"])[0]
self.assertIsInstance(flat, list)
self.assertIsInstance(flat[0], torch.Tensor)
self.assertEqual(
"#8[T10r4,T10r4,T10r4,T10r4,T10r4,T10r4,T10r4,T10r4]",
self.string_type(flat),
)
torch.export.export(
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
)
with torch_export_patches(patch_transformers=True, verbose=10):
flat = torch.utils._pytree.tree_flatten(inputs["past_key_values"])[0]
self.assertIsInstance(flat, list)
self.assertIsInstance(flat[0], torch.Tensor)
self.assertEqual(
"#8[T10r4,T10r4,T10r4,T10r4,T10r4,T10r4,T10r4,T10r4]",
self.string_type(flat),
)
torch.export.export(
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
)
@hide_stdout()
def test_fill_mask(self):
mid = "google-bert/bert-base-multilingual-cased"
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
self.assertEqual(data["task"], "fill-mask")
self.assertIn((data["size"], data["n_weights"]), [(428383212, 107095803)])
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
model(**inputs)
model(**data["inputs2"])
with torch_export_patches(patch_transformers=True, verbose=10):
torch.export.export(
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
)
@hide_stdout()
def test_text_classification(self):
mid = "Intel/bert-base-uncased-mrpc"
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
self.assertEqual(data["task"], "text-classification")
self.assertIn((data["size"], data["n_weights"]), [(154420232, 38605058)])
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
model(**inputs)
model(**data["inputs2"])
with torch_export_patches(patch_transformers=True, verbose=10):
torch.export.export(
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
)
@hide_stdout()
def test_sentence_similary(self):
mid = "sentence-transformers/all-MiniLM-L6-v1"
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
self.assertEqual(data["task"], "sentence-similarity")
self.assertIn((data["size"], data["n_weights"]), [(62461440, 15615360)])
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
model(**inputs)
model(**data["inputs2"])
with torch_export_patches(patch_transformers=True, verbose=10):
torch.export.export(
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
)
if __name__ == "__main__":
unittest.main(verbosity=2)