Skip to content

Commit f2805e3

Browse files
committed
More tests about patches
1 parent c03eb5e commit f2805e3

7 files changed

Lines changed: 319 additions & 8 deletions

File tree

_unittests/ut_helpers/test_args_helper.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import unittest
22
from onnx_diagnostic.ext_test_case import ExtTestCase
3-
from onnx_diagnostic.helpers.args_helper import get_parsed_args, check_cuda_availability
3+
from onnx_diagnostic.helpers.args_helper import (
4+
get_parsed_args,
5+
check_cuda_availability,
6+
process_outputname,
7+
)
48

59

610
class TestHelpers(ExtTestCase):
@@ -52,6 +56,10 @@ def test_args_expose(self):
5256
self.assertEqual(args.repeat, 10)
5357
self.assertEqual(args.warmup, 5)
5458

59+
def test_process_outputname(self):
60+
self.assertEqual("ggg.g", process_outputname("ggg.g", "hhh.h"))
61+
self.assertEqual("hhh.ggg.h", process_outputname("+.ggg", "hhh.h"))
62+
5563

5664
if __name__ == "__main__":
5765
unittest.main(verbosity=2)

_unittests/ut_helpers/test_onnx_helper.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -860,6 +860,66 @@ def test_make_model_with_local_functions_2(self):
860860

861861
check_model(new_model)
862862

863+
@hide_stdout()
864+
def test_make_model_with_local_functions_3(self):
865+
model = oh.make_model(
866+
oh.make_graph(
867+
[
868+
oh.make_node("Unsqueeze", ["X", "zero"], ["xu1"]),
869+
oh.make_node("Unsqueeze", ["xu1", "un"], ["xu2"]),
870+
oh.make_node("Reshape", ["xu2", "shape1"], ["xm1"]),
871+
oh.make_node("Reshape", ["Y", "shape2"], ["xm2c"]),
872+
oh.make_node("Cast", ["xm2c"], ["xm2"], to=1),
873+
oh.make_node("MatMul", ["xm1", "xm2"], ["xm"]),
874+
oh.make_node("Reshape", ["xm", "shape3"], ["Z"]),
875+
],
876+
"dummy",
877+
[oh.make_tensor_value_info("X", TFLOAT, [320, 1280])],
878+
[oh.make_tensor_value_info("Z", TFLOAT, [3, 5, 320, 640])],
879+
[
880+
onh.from_array(
881+
np.random.rand(3, 5, 1280, 640).astype(np.float32), name="Y"
882+
),
883+
onh.from_array(np.array([0], dtype=np.int64), name="zero"),
884+
onh.from_array(np.array([1], dtype=np.int64), name="un"),
885+
onh.from_array(np.array([1, 320, 1280], dtype=np.int64), name="shape1"),
886+
onh.from_array(np.array([15, 1280, 640], dtype=np.int64), name="shape2"),
887+
onh.from_array(np.array([3, 5, 320, 640], dtype=np.int64), name="shape3"),
888+
],
889+
),
890+
opset_imports=[oh.make_opsetid("", 18)],
891+
ir_version=9,
892+
)
893+
for i_node in range(len(model.graph.node) - 1):
894+
if i_node == 2:
895+
continue
896+
node = model.graph.node[i_node]
897+
meta = node.metadata_props.add()
898+
meta.key = f"source[{i_node}]"
899+
meta.value = "LLL"
900+
new_model = make_model_with_local_functions(
901+
model, "^LLL$", metadata_key_prefix="source[", verbose=1
902+
)
903+
check_model(model)
904+
self.assertEqual(len(new_model.functions), 1)
905+
p = pretty_onnx(new_model)
906+
self.assertIn("LLL0[local_function]", p)
907+
self.assertIn("LLL1[local_function]", p)
908+
909+
self.assertEqual(["X", "shape1", "un", "zero"], new_model.functions[0].input)
910+
self.assertEqual(["xm1"], new_model.functions[0].output)
911+
self.assertEqual("LLL0", new_model.functions[0].name)
912+
self.assertEqual("local_function", new_model.functions[0].domain)
913+
self.assertEqual(len(new_model.functions[0].node), 3)
914+
915+
self.assertEqual(["Y", "shape2"], new_model.functions[1].input)
916+
self.assertEqual(["xm2c"], new_model.functions[1].output)
917+
self.assertEqual("LLL1", new_model.functions[1].name)
918+
self.assertEqual("local_function", new_model.functions[1].domain)
919+
self.assertEqual(len(new_model.functions[1].node), 1)
920+
921+
check_model(new_model)
922+
863923

864924
if __name__ == "__main__":
865925
unittest.main(verbosity=2)

_unittests/ut_torch_export_patches/test_patch_transformers.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from onnx_diagnostic.torch_models.hghub.hub_api import get_cached_configuration
2020
from onnx_diagnostic.torch_export_patches import torch_export_patches
2121
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
22+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
2223
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
2324
patch_qwen2_5,
2425
patch_funnel,
@@ -392,6 +393,20 @@ def forward(self, q, k, cos, sin):
392393
rtol=1,
393394
)
394395

396+
@requires_transformers("4.55")
397+
@requires_onnxscript("0.6.2")
398+
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
399+
def test_qwen_function_proto(self):
400+
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import (
401+
LoopAttention23,
402+
LoopMHAAttention,
403+
PackedAttention,
404+
)
405+
406+
LoopMHAAttention.to_function_proto()
407+
LoopAttention23.to_function_proto()
408+
PackedAttention.to_function_proto()
409+
395410
@requires_transformers("4.55")
396411
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
397412
def test_patched_qwen2_5_vl_rot_pos_emb(self):
@@ -874,6 +889,166 @@ def test_model_funnel(self):
874889
got = patched.relative_positional_attention(**inputs)
875890
self.assertEqualArray(expected, got)
876891

892+
def test_cache_dependant_input_preparation_exporting(self):
893+
from onnx_diagnostic.torch_export_patches.patches._patch_transformers_generation_mixin import ( # noqa: E501
894+
patched_GenerationMixin as GenerationMixin,
895+
)
896+
897+
with self.subTest(case="case1"):
898+
input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)[:, :0]
899+
inputs_embeds = torch.rand((2, 8), dtype=torch.float32)
900+
cache_position = torch.arange(0, 8, dtype=torch.int64)
901+
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(
902+
input_ids, inputs_embeds, cache_position
903+
)
904+
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
905+
input_ids, inputs_embeds, cache_position
906+
)
907+
torch.testing.assert_close(eager1, export1)
908+
torch.testing.assert_close(eager2, export2)
909+
910+
with self.subTest(case="case2"):
911+
input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)
912+
inputs_embeds = torch.rand((2, 8), dtype=torch.float32)
913+
cache_position = torch.arange(0, 8, dtype=torch.int64)
914+
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(
915+
input_ids, inputs_embeds, cache_position
916+
)
917+
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
918+
input_ids, inputs_embeds, cache_position
919+
)
920+
torch.testing.assert_close(eager1, export1)
921+
torch.testing.assert_close(eager2, export2)
922+
923+
with self.subTest(case="case3"):
924+
input_ids = torch.randint(0, 16, (2, 12), dtype=torch.int64)
925+
inputs_embeds = None
926+
cache_position = torch.arange(0, 8, dtype=torch.int64)
927+
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(
928+
input_ids, inputs_embeds, cache_position
929+
)
930+
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
931+
input_ids, inputs_embeds, cache_position
932+
)
933+
torch.testing.assert_close(eager1, export1)
934+
torch.testing.assert_close(eager2, export2)
935+
936+
with self.subTest(case="case4"):
937+
input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)
938+
inputs_embeds = None
939+
cache_position = torch.arange(0, 8, dtype=torch.int64)
940+
eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation(
941+
input_ids, inputs_embeds, cache_position
942+
)
943+
export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting(
944+
input_ids, inputs_embeds, cache_position
945+
)
946+
torch.testing.assert_close(eager1, export1)
947+
torch.testing.assert_close(eager2, export2)
948+
949+
def test_prepare_inputs_for_generation_decoder_llm(self):
950+
data = get_untrained_model_with_inputs(
951+
"hf-internal-testing/tiny-random-LlamaForCausalLM"
952+
)
953+
model = data["model"]
954+
config = model.config
955+
torch_device = "cpu"
956+
957+
with torch_export_patches(patch_transformers=True):
958+
with self.subTest(case="case1"):
959+
self.assertTrue("GenerationMixin" in str(model.prepare_inputs_for_generation))
960+
961+
input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]]).to(torch_device)
962+
cache_position = torch.arange(input_ids.shape[1], device=input_ids.device)
963+
964+
with self.subTest(case="case2"):
965+
input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]]).to(torch_device)
966+
model_inputs = model.prepare_inputs_for_generation(
967+
input_ids, cache_position=cache_position
968+
)
969+
self.assertTrue(torch.all(model_inputs["input_ids"] == input_ids))
970+
971+
with self.subTest(case="case3"):
972+
attention_mask = torch.tensor([[1, 1, 1], [1, 1, 1]]).to(torch_device)
973+
model_inputs = model.prepare_inputs_for_generation(
974+
input_ids, attention_mask=attention_mask, cache_position=cache_position
975+
)
976+
self.assertTrue(torch.all(model_inputs["attention_mask"] == attention_mask))
977+
self.assertTrue(model_inputs["position_ids"].shape == input_ids.shape)
978+
979+
with self.subTest(case="case4"):
980+
self.assertFalse("use_cache" in model_inputs)
981+
model_inputs = model.prepare_inputs_for_generation(
982+
input_ids, use_cache=True, foo="bar", cache_position=cache_position
983+
)
984+
self.assertTrue(model_inputs["use_cache"] is True)
985+
self.assertTrue(model_inputs["foo"] == "bar")
986+
987+
with self.subTest(case="case5"):
988+
init_input_ids = input_ids[:, :2]
989+
dynamic_cache = transformers.cache_utils.DynamicCache(config=config)
990+
dynamic_cache = model(
991+
init_input_ids, past_key_values=dynamic_cache
992+
).past_key_values
993+
with self.assertRaises((AttributeError, TypeError)):
994+
model_inputs = model.prepare_inputs_for_generation(
995+
input_ids, past_key_values=dynamic_cache
996+
)
997+
998+
with self.subTest(case="case6"):
999+
cache_position = torch.arange(input_ids.shape[-1], dtype=torch.long).to(
1000+
torch_device
1001+
)
1002+
cache_position = cache_position[dynamic_cache.get_seq_length() :]
1003+
model_inputs = model.prepare_inputs_for_generation(
1004+
input_ids,
1005+
past_key_values=dynamic_cache,
1006+
cache_position=cache_position,
1007+
attention_mask=attention_mask,
1008+
)
1009+
self.assertTrue("past_key_values" in model_inputs)
1010+
self.assertTrue(torch.all(model_inputs["cache_position"] == cache_position))
1011+
self.assertTrue(
1012+
model_inputs["input_ids"].shape[-1] == 1
1013+
) # 1 = 3 fed tokens - 2 tokens in the cache
1014+
self.assertTrue(model_inputs["position_ids"].shape[-1] == 1)
1015+
self.assertTrue(
1016+
model_inputs["attention_mask"].shape[-1] == 3
1017+
) # we still need the full attention mask!
1018+
1019+
with self.subTest(case="case6.2"):
1020+
max_cache_len = 10
1021+
batch_size = 2
1022+
query_length = input_ids.shape[-1] - init_input_ids.shape[-1]
1023+
static_cache = transformers.cache_utils.StaticCache(
1024+
config=config, max_cache_len=max_cache_len
1025+
)
1026+
static_cache = model(
1027+
init_input_ids, past_key_values=static_cache
1028+
).past_key_values
1029+
model_inputs = model.prepare_inputs_for_generation(
1030+
input_ids,
1031+
past_key_values=static_cache,
1032+
cache_position=cache_position,
1033+
attention_mask=attention_mask,
1034+
)
1035+
self.assertTrue("past_key_values" in model_inputs)
1036+
self.assertTrue(
1037+
list(model_inputs["attention_mask"].shape)
1038+
== [batch_size, 1, query_length, max_cache_len]
1039+
)
1040+
1041+
with self.subTest(case="case7"):
1042+
init_inputs_embeds = model.get_input_embeddings()(init_input_ids)
1043+
model_inputs = model.prepare_inputs_for_generation(
1044+
input_ids,
1045+
past_key_values=dynamic_cache,
1046+
inputs_embeds=init_inputs_embeds,
1047+
cache_position=cache_position,
1048+
)
1049+
self.assertTrue(model_inputs["input_ids"] is not None)
1050+
self.assertTrue(model_inputs["inputs_embeds"] is None)
1051+
8771052

8781053
if __name__ == "__main__":
8791054
unittest.main(verbosity=2)

onnx_diagnostic/_command_lines_parser.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def get_parser_dot() -> ArgumentParser:
4747

4848
def _cmd_dot(argv: List[Any]):
4949
import subprocess
50+
from .helpers.args_helper import process_outputname
5051
from .helpers.dot_helper import to_dot
5152

5253
parser = get_parser_dot()
@@ -58,15 +59,17 @@ def _cmd_dot(argv: List[Any]):
5859
print("-- converts into dot")
5960
dot = to_dot(onx)
6061
if args.output:
62+
outname = process_outputname(args.output, args.input)
6163
if args.verbose:
62-
print(f"-- saves into {args.output}")
63-
with open(args.output, "w") as f:
64+
print(f"-- saves into {outname!r}")
65+
with open(outname, "w") as f:
6466
f.write(dot)
6567
else:
6668
print(dot)
6769
if args.run:
6870
assert args.output, "Cannot run dot without an output file."
69-
cmds = ["dot", f"-T{args.run}", args.output, "-o", f"{args.output}.{args.run}"]
71+
outname = process_outputname(outname, args.input)
72+
cmds = ["dot", f"-T{args.run}", outname, "-o", f"{args.output}.{args.run}"]
7073
if args.verbose:
7174
print(f"-- run {' '.join(cmds)}")
7275
p = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
@@ -1553,10 +1556,11 @@ def _cmd_optimize(argv: List[Any]):
15531556
parser = get_parser_optimize()
15541557
args = parser.parse_args(argv[1:])
15551558

1559+
from .helpers.args_helper import process_outputname
15561560
from .helpers.optim_helper import optimize_model
15571561

15581562
output = (
1559-
args.output
1563+
process_outputname(args.output, args.input)
15601564
if args.output
15611565
else f"{os.path.splitext(args.input)[0]}.o-{args.algorithm}.onnx"
15621566
)
@@ -1586,10 +1590,21 @@ def get_parser_partition() -> ArgumentParser:
15861590
The regular may match the following values,
15871591
'model.layers.0.forward', 'model.layers.1.forward', ...
15881592
A local function will be created for each distinct layer.
1593+
1594+
Example:
1595+
1596+
python -m onnx_diagnostic partition \\
1597+
model.onnx +.part -v 1 -r "model.layers.0.s.*"
15891598
"""),
15901599
)
15911600
parser.add_argument("input", help="input model")
1592-
parser.add_argument("output", help="output model")
1601+
parser.add_argument(
1602+
"output",
1603+
help=textwrap.dedent("""
1604+
output model, an expression like '+.part'
1605+
inserts '.part' just before the extension"
1606+
""").strip("\n"),
1607+
)
15931608
parser.add_argument(
15941609
"-r",
15951610
"--regex",
@@ -1619,6 +1634,7 @@ def get_parser_partition() -> ArgumentParser:
16191634

16201635

16211636
def _cmd_partition(argv: List[Any]):
1637+
from .helpers.args_helper import process_outputname
16221638
from .helpers.onnx_helper import make_model_with_local_functions
16231639

16241640
parser = get_parser_partition()
@@ -1635,9 +1651,10 @@ def _cmd_partition(argv: List[Any]):
16351651
metadata_key_prefix=tuple(args.meta_prefix.split(",")),
16361652
verbose=args.verbose,
16371653
)
1654+
outname = process_outputname(args.output, args.input)
16381655
if args.verbose:
1639-
print(f"-- save into {args.output!r}")
1640-
onnx.save(onx2, args.output)
1656+
print(f"-- save into {outname!r}")
1657+
onnx.save(onx2, outname)
16411658
if args.verbose:
16421659
print("-- done")
16431660

onnx_diagnostic/helpers/args_helper.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import subprocess
23
from argparse import ArgumentParser, Namespace
34
from typing import Dict, List, Optional, Tuple, Union
@@ -131,3 +132,14 @@ def get_parsed_args(
131132
if update:
132133
res.__dict__.update(update)
133134
return res
135+
136+
137+
def process_outputname(output_name: str, input_name: str) -> str:
138+
"""
139+
If 'output_name' starts with '+', then it is modified into
140+
``<input_name_no_extension><output_name>.extension``.
141+
"""
142+
if not output_name.startswith("+"):
143+
return output_name
144+
name, ext = os.path.splitext(input_name)
145+
return f"{name}{output_name[1:]}{ext}"

0 commit comments

Comments
 (0)