Skip to content

Commit d58af8b

Browse files
authored
More tests about patches (#415)
* More tests about patches * remove one patch * fix * fix * verbose * improve algo * fix partition * fix exceptions * documentation
1 parent c03eb5e commit d58af8b

9 files changed

Lines changed: 590 additions & 92 deletions

File tree

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.9.2
55
+++++
66

7+
* :pr:`415`: improves function make_model_with_local_functions to support ill-defined partitions
78
* :pr:`413`: fix InputObserver in the generic case
89
* :pr:`412`: patches for ViTModel (through rewriting)
910

_unittests/ut_helpers/test_args_helper.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import unittest
22
from onnx_diagnostic.ext_test_case import ExtTestCase
3-
from onnx_diagnostic.helpers.args_helper import get_parsed_args, check_cuda_availability
3+
from onnx_diagnostic.helpers.args_helper import (
4+
get_parsed_args,
5+
check_cuda_availability,
6+
process_outputname,
7+
)
48

59

610
class TestHelpers(ExtTestCase):
@@ -52,6 +56,10 @@ def test_args_expose(self):
5256
self.assertEqual(args.repeat, 10)
5357
self.assertEqual(args.warmup, 5)
5458

59+
def test_process_outputname(self):
60+
self.assertEqual("ggg.g", process_outputname("ggg.g", "hhh.h"))
61+
self.assertEqual("hhh.ggg.h", process_outputname("+.ggg", "hhh.h"))
62+
5563

5664
if __name__ == "__main__":
5765
unittest.main(verbosity=2)

_unittests/ut_helpers/test_onnx_helper.py

Lines changed: 67 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -794,11 +794,8 @@ def test_make_model_with_local_functions_bug(self):
794794
meta.key = "namespace"
795795
meta.value = "LLL"
796796
self.assertRaise(
797-
lambda: make_model_with_local_functions(model, "^LLL$"),
797+
lambda: make_model_with_local_functions(model, "^LLL$", allow_extensions=False),
798798
ValueError,
799-
"Results {'xu1'} are needed for inputs ['X', 'Y', 'shape1', "
800-
"'shape2', 'xu2', 'zero'] but also requires ['xm1', 'xm2', 'xu1'] "
801-
"which is not allowed.",
802799
)
803800
check_model(model)
804801

@@ -860,6 +857,72 @@ def test_make_model_with_local_functions_2(self):
860857

861858
check_model(new_model)
862859

860+
@hide_stdout()
861+
def test_make_model_with_local_functions_3(self):
862+
model = oh.make_model(
863+
oh.make_graph(
864+
[
865+
oh.make_node("Unsqueeze", ["X", "zero"], ["xu1"]),
866+
oh.make_node("Unsqueeze", ["xu1", "un"], ["xu2"]),
867+
oh.make_node("Reshape", ["xu2", "shape1"], ["xm1"]),
868+
oh.make_node("Reshape", ["Y", "shape2"], ["xm2c"]),
869+
oh.make_node("Cast", ["xm2c"], ["xm2"], to=1),
870+
oh.make_node("MatMul", ["xm1", "xm2"], ["xm"]),
871+
oh.make_node("Reshape", ["xm", "shape3"], ["Z"]),
872+
],
873+
"dummy",
874+
[oh.make_tensor_value_info("X", TFLOAT, [320, 1280])],
875+
[oh.make_tensor_value_info("Z", TFLOAT, [3, 5, 320, 640])],
876+
[
877+
onh.from_array(
878+
np.random.rand(3, 5, 1280, 640).astype(np.float32), name="Y"
879+
),
880+
onh.from_array(np.array([0], dtype=np.int64), name="zero"),
881+
onh.from_array(np.array([1], dtype=np.int64), name="un"),
882+
onh.from_array(np.array([1, 320, 1280], dtype=np.int64), name="shape1"),
883+
onh.from_array(np.array([15, 1280, 640], dtype=np.int64), name="shape2"),
884+
onh.from_array(np.array([3, 5, 320, 640], dtype=np.int64), name="shape3"),
885+
],
886+
),
887+
opset_imports=[oh.make_opsetid("", 18)],
888+
ir_version=9,
889+
)
890+
check_model(model)
891+
for i_node in range(len(model.graph.node) - 1):
892+
if i_node == 2:
893+
continue
894+
node = model.graph.node[i_node]
895+
meta = node.metadata_props.add()
896+
meta.key = f"source[{i_node}]"
897+
meta.value = "LLL"
898+
self.assertRaise(
899+
lambda: make_model_with_local_functions(
900+
model,
901+
"^LLL$",
902+
metadata_key_prefix="source[",
903+
verbose=1,
904+
allow_extensions=False,
905+
),
906+
ValueError,
907+
)
908+
new_model = make_model_with_local_functions(
909+
model, "^LLL$", metadata_key_prefix="source[", verbose=1
910+
)
911+
check_model(new_model)
912+
self.assertEqual(len(new_model.functions), 1)
913+
p = pretty_onnx(new_model)
914+
self.assertIn("LLL[local_function]", p)
915+
916+
self.assertEqual(
917+
["X", "Y", "shape1", "shape2", "un", "zero"], new_model.functions[0].input
918+
)
919+
self.assertEqual(["xm"], new_model.functions[0].output)
920+
self.assertEqual("LLL", new_model.functions[0].name)
921+
self.assertEqual("local_function", new_model.functions[0].domain)
922+
self.assertEqual(len(new_model.functions[0].node), 6)
923+
924+
check_model(new_model)
925+
863926

864927
if __name__ == "__main__":
865928
unittest.main(verbosity=2)

_unittests/ut_torch_export_patches/test_patch_transformers.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,15 @@
1212
requires_torch,
1313
ignore_warnings,
1414
has_onnxscript,
15+
has_transformers,
1516
requires_onnxscript,
1617
)
1718
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy, fake_torchdynamo_exporting
1819
from onnx_diagnostic.export.shape_helper import make_fake_with_dynamic_dimensions
1920
from onnx_diagnostic.torch_models.hghub.hub_api import get_cached_configuration
2021
from onnx_diagnostic.torch_export_patches import torch_export_patches
2122
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
23+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
2224
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
2325
patch_qwen2_5,
2426
patch_funnel,
@@ -392,6 +394,20 @@ def forward(self, q, k, cos, sin):
392394
rtol=1,
393395
)
394396

397+
@requires_transformers("4.55")
398+
@requires_onnxscript("0.6.2")
399+
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
400+
def test_qwen_function_proto(self):
401+
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
402+
LoopAttention23,
403+
LoopMHAAttention,
404+
PackedAttention,
405+
)
406+
407+
LoopMHAAttention.to_function_proto()
408+
LoopAttention23.to_function_proto()
409+
PackedAttention.to_function_proto()
410+
395411
@requires_transformers("4.55")
396412
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
397413
def test_patched_qwen2_5_vl_rot_pos_emb(self):
@@ -874,6 +890,173 @@ def test_model_funnel(self):
874890
got = patched.relative_positional_attention(**inputs)
875891
self.assertEqualArray(expected, got)
876892

893+
def test_cache_dependant_input_preparation_exporting(self):
894+
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_generation_mixin import ( # noqa: E501
895+
patched_GenerationMixin as GenerationMixin,
896+
)
897+
898+
with self.subTest(case="case1"):
899+
input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)[:, :0]
900+
inputs_embeds = torch.rand((2, 8), dtype=torch.float32)
901+
cache_position = torch.arange(0, 8, dtype=torch.int64)
902+
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(
903+
input_ids, inputs_embeds, cache_position
904+
)
905+
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
906+
input_ids, inputs_embeds, cache_position
907+
)
908+
torch.testing.assert_close(eager1, export1)
909+
torch.testing.assert_close(eager2, export2)
910+
911+
with self.subTest(case="case2"):
912+
raise unittest.SkipTest("torch 2.10+ has probably a bug here.")
913+
input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)
914+
inputs_embeds = torch.rand((2, 8), dtype=torch.float32)
915+
cache_position = torch.arange(0, 8, dtype=torch.int64)
916+
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(
917+
input_ids, inputs_embeds, cache_position
918+
)
919+
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
920+
input_ids, inputs_embeds, cache_position
921+
)
922+
torch.testing.assert_close(eager1, export1)
923+
torch.testing.assert_close(eager2, export2)
924+
925+
with self.subTest(case="case3"):
926+
input_ids = torch.randint(0, 16, (2, 12), dtype=torch.int64)
927+
inputs_embeds = None
928+
cache_position = torch.arange(0, 8, dtype=torch.int64)
929+
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(
930+
input_ids, inputs_embeds, cache_position
931+
)
932+
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
933+
input_ids, inputs_embeds, cache_position
934+
)
935+
torch.testing.assert_close(eager1, export1)
936+
torch.testing.assert_close(eager2, export2)
937+
938+
with self.subTest(case="case4"):
939+
input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)
940+
inputs_embeds = None
941+
cache_position = torch.arange(0, 8, dtype=torch.int64)
942+
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(
943+
input_ids, inputs_embeds, cache_position
944+
)
945+
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
946+
input_ids, inputs_embeds, cache_position
947+
)
948+
torch.testing.assert_close(eager1, export1)
949+
torch.testing.assert_close(eager2, export2)
950+
951+
@requires_transformers("4.57")
952+
def test_prepare_inputs_for_generation_decoder_llm(self):
953+
data = get_untrained_model_with_inputs(
954+
"hf-internal-testing/tiny-random-LlamaForCausalLM"
955+
)
956+
model = data["model"]
957+
config = model.config
958+
torch_device = "cpu"
959+
960+
with torch_export_patches(patch_transformers=True):
961+
with self.subTest(case="case1"):
962+
self.assertTrue("GenerationMixin" in str(model.prepare_inputs_for_generation))
963+
964+
input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]]).to(torch_device)
965+
cache_position = torch.arange(input_ids.shape[1], device=input_ids.device)
966+
967+
with self.subTest(case="case2"):
968+
input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]]).to(torch_device)
969+
model_inputs = model.prepare_inputs_for_generation(
970+
input_ids, cache_position=cache_position
971+
)
972+
self.assertTrue(torch.all(model_inputs["input_ids"] == input_ids))
973+
974+
with self.subTest(case="case3"):
975+
attention_mask = torch.tensor([[1, 1, 1], [1, 1, 1]]).to(torch_device)
976+
model_inputs = model.prepare_inputs_for_generation(
977+
input_ids, attention_mask=attention_mask, cache_position=cache_position
978+
)
979+
self.assertTrue(torch.all(model_inputs["attention_mask"] == attention_mask))
980+
self.assertTrue(model_inputs["position_ids"].shape == input_ids.shape)
981+
982+
with self.subTest(case="case4"):
983+
self.assertFalse("use_cache" in model_inputs)
984+
model_inputs = model.prepare_inputs_for_generation(
985+
input_ids, use_cache=True, foo="bar", cache_position=cache_position
986+
)
987+
self.assertTrue(model_inputs["use_cache"] is True)
988+
self.assertTrue(model_inputs["foo"] == "bar")
989+
990+
init_input_ids = input_ids[:, :2]
991+
dynamic_cache = transformers.cache_utils.DynamicCache(config=config)
992+
dynamic_cache = model(
993+
init_input_ids, past_key_values=dynamic_cache
994+
).past_key_values
995+
996+
with self.subTest(case="case5"):
997+
if not has_transformers("4.57"):
998+
raise unittest.SkipTest("transformers 4.57+.")
999+
with self.assertRaises((AttributeError, TypeError)):
1000+
model_inputs = model.prepare_inputs_for_generation(
1001+
input_ids, past_key_values=dynamic_cache
1002+
)
1003+
1004+
with self.subTest(case="case6"):
1005+
cache_position = torch.arange(input_ids.shape[-1], dtype=torch.long).to(
1006+
torch_device
1007+
)
1008+
cache_position = cache_position[dynamic_cache.get_seq_length() :]
1009+
model_inputs = model.prepare_inputs_for_generation(
1010+
input_ids,
1011+
past_key_values=dynamic_cache,
1012+
cache_position=cache_position,
1013+
attention_mask=attention_mask,
1014+
)
1015+
self.assertTrue("past_key_values" in model_inputs)
1016+
self.assertTrue(torch.all(model_inputs["cache_position"] == cache_position))
1017+
self.assertTrue(
1018+
model_inputs["input_ids"].shape[-1] == 1
1019+
) # 1 = 3 fed tokens - 2 tokens in the cache
1020+
self.assertTrue(model_inputs["position_ids"].shape[-1] == 1)
1021+
self.assertTrue(
1022+
model_inputs["attention_mask"].shape[-1] == 3
1023+
) # we still need the full attention mask!
1024+
1025+
with self.subTest(case="case6.2"):
1026+
max_cache_len = 10
1027+
batch_size = 2
1028+
query_length = input_ids.shape[-1] - init_input_ids.shape[-1]
1029+
static_cache = transformers.cache_utils.StaticCache(
1030+
config=config, max_cache_len=max_cache_len
1031+
)
1032+
static_cache = model(
1033+
init_input_ids, past_key_values=static_cache
1034+
).past_key_values
1035+
model_inputs = model.prepare_inputs_for_generation(
1036+
input_ids,
1037+
past_key_values=static_cache,
1038+
cache_position=cache_position,
1039+
attention_mask=attention_mask,
1040+
)
1041+
self.assertTrue("past_key_values" in model_inputs)
1042+
self.assertTrue(
1043+
list(model_inputs["attention_mask"].shape)
1044+
== [batch_size, 1, query_length, max_cache_len]
1045+
)
1046+
1047+
with self.subTest(case="case7"):
1048+
if not has_transformers("4.57"):
1049+
raise unittest.SkipTest("transformers 4.57+.")
1050+
init_inputs_embeds = model.get_input_embeddings()(init_input_ids)
1051+
model_inputs = model.prepare_inputs_for_generation(
1052+
input_ids,
1053+
past_key_values=dynamic_cache,
1054+
inputs_embeds=init_inputs_embeds,
1055+
cache_position=cache_position,
1056+
)
1057+
self.assertTrue(model_inputs["input_ids"] is not None)
1058+
self.assertTrue(model_inputs["inputs_embeds"] is None)
1059+
8771060

8781061
if __name__ == "__main__":
8791062
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)