|
12 | 12 | requires_torch, |
13 | 13 | ignore_warnings, |
14 | 14 | has_onnxscript, |
| 15 | + has_transformers, |
15 | 16 | requires_onnxscript, |
16 | 17 | ) |
17 | 18 | from onnx_diagnostic.helpers.torch_helper import torch_deepcopy, fake_torchdynamo_exporting |
18 | 19 | from onnx_diagnostic.export.shape_helper import make_fake_with_dynamic_dimensions |
19 | 20 | from onnx_diagnostic.torch_models.hghub.hub_api import get_cached_configuration |
20 | 21 | from onnx_diagnostic.torch_export_patches import torch_export_patches |
21 | 22 | from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str |
| 23 | +from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs |
22 | 24 | from onnx_diagnostic.torch_export_patches.patches.patch_transformers import ( |
23 | 25 | patch_qwen2_5, |
24 | 26 | patch_funnel, |
@@ -392,6 +394,20 @@ def forward(self, q, k, cos, sin): |
392 | 394 | rtol=1, |
393 | 395 | ) |
394 | 396 |
|
| 397 | + @requires_transformers("4.55") |
| 398 | + @requires_onnxscript("0.6.2") |
| 399 | + @unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers") |
| 400 | + def test_qwen_function_proto(self): |
| 401 | + from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import ( |
| 402 | + LoopAttention23, |
| 403 | + LoopMHAAttention, |
| 404 | + PackedAttention, |
| 405 | + ) |
| 406 | + |
| 407 | + LoopMHAAttention.to_function_proto() |
| 408 | + LoopAttention23.to_function_proto() |
| 409 | + PackedAttention.to_function_proto() |
| 410 | + |
395 | 411 | @requires_transformers("4.55") |
396 | 412 | @unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers") |
397 | 413 | def test_patched_qwen2_5_vl_rot_pos_emb(self): |
@@ -874,6 +890,173 @@ def test_model_funnel(self): |
874 | 890 | got = patched.relative_positional_attention(**inputs) |
875 | 891 | self.assertEqualArray(expected, got) |
876 | 892 |
|
| 893 | + def test_cache_dependant_input_preparation_exporting(self): |
| 894 | + from onnx_diagnostic.torch_export_patches.patches._patch_transformers_generation_mixin import ( # noqa: E501 |
| 895 | + patched_GenerationMixin as GenerationMixin, |
| 896 | + ) |
| 897 | + |
| 898 | + with self.subTest(case="case1"): |
| 899 | + input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)[:, :0] |
| 900 | + inputs_embeds = torch.rand((2, 8), dtype=torch.float32) |
| 901 | + cache_position = torch.arange(0, 8, dtype=torch.int64) |
| 902 | + eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation( |
| 903 | + input_ids, inputs_embeds, cache_position |
| 904 | + ) |
| 905 | + export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting( |
| 906 | + input_ids, inputs_embeds, cache_position |
| 907 | + ) |
| 908 | + torch.testing.assert_close(eager1, export1) |
| 909 | + torch.testing.assert_close(eager2, export2) |
| 910 | + |
| 911 | + with self.subTest(case="case2"): |
| 912 | + raise unittest.SkipTest("torch 2.10+ has probably a bug here.") |
| 913 | + input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64) |
| 914 | + inputs_embeds = torch.rand((2, 8), dtype=torch.float32) |
| 915 | + cache_position = torch.arange(0, 8, dtype=torch.int64) |
| 916 | + eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation( |
| 917 | + input_ids, inputs_embeds, cache_position |
| 918 | + ) |
| 919 | + export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting( |
| 920 | + input_ids, inputs_embeds, cache_position |
| 921 | + ) |
| 922 | + torch.testing.assert_close(eager1, export1) |
| 923 | + torch.testing.assert_close(eager2, export2) |
| 924 | + |
| 925 | + with self.subTest(case="case3"): |
| 926 | + input_ids = torch.randint(0, 16, (2, 12), dtype=torch.int64) |
| 927 | + inputs_embeds = None |
| 928 | + cache_position = torch.arange(0, 8, dtype=torch.int64) |
| 929 | + eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation( |
| 930 | + input_ids, inputs_embeds, cache_position |
| 931 | + ) |
| 932 | + export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting( |
| 933 | + input_ids, inputs_embeds, cache_position |
| 934 | + ) |
| 935 | + torch.testing.assert_close(eager1, export1) |
| 936 | + torch.testing.assert_close(eager2, export2) |
| 937 | + |
| 938 | + with self.subTest(case="case4"): |
| 939 | + input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64) |
| 940 | + inputs_embeds = None |
| 941 | + cache_position = torch.arange(0, 8, dtype=torch.int64) |
| 942 | + eager1, eager2 = GenerationMixin()._cache_dependant_input_preparation( |
| 943 | + input_ids, inputs_embeds, cache_position |
| 944 | + ) |
| 945 | + export1, export2 = GenerationMixin()._cache_dependant_input_preparation_exporting( |
| 946 | + input_ids, inputs_embeds, cache_position |
| 947 | + ) |
| 948 | + torch.testing.assert_close(eager1, export1) |
| 949 | + torch.testing.assert_close(eager2, export2) |
| 950 | + |
| 951 | + @requires_transformers("4.57") |
| 952 | + def test_prepare_inputs_for_generation_decoder_llm(self): |
| 953 | + data = get_untrained_model_with_inputs( |
| 954 | + "hf-internal-testing/tiny-random-LlamaForCausalLM" |
| 955 | + ) |
| 956 | + model = data["model"] |
| 957 | + config = model.config |
| 958 | + torch_device = "cpu" |
| 959 | + |
| 960 | + with torch_export_patches(patch_transformers=True): |
| 961 | + with self.subTest(case="case1"): |
| 962 | + self.assertTrue("GenerationMixin" in str(model.prepare_inputs_for_generation)) |
| 963 | + |
| 964 | + input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]]).to(torch_device) |
| 965 | + cache_position = torch.arange(input_ids.shape[1], device=input_ids.device) |
| 966 | + |
| 967 | + with self.subTest(case="case2"): |
| 968 | + input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]]).to(torch_device) |
| 969 | + model_inputs = model.prepare_inputs_for_generation( |
| 970 | + input_ids, cache_position=cache_position |
| 971 | + ) |
| 972 | + self.assertTrue(torch.all(model_inputs["input_ids"] == input_ids)) |
| 973 | + |
| 974 | + with self.subTest(case="case3"): |
| 975 | + attention_mask = torch.tensor([[1, 1, 1], [1, 1, 1]]).to(torch_device) |
| 976 | + model_inputs = model.prepare_inputs_for_generation( |
| 977 | + input_ids, attention_mask=attention_mask, cache_position=cache_position |
| 978 | + ) |
| 979 | + self.assertTrue(torch.all(model_inputs["attention_mask"] == attention_mask)) |
| 980 | + self.assertTrue(model_inputs["position_ids"].shape == input_ids.shape) |
| 981 | + |
| 982 | + with self.subTest(case="case4"): |
| 983 | + self.assertFalse("use_cache" in model_inputs) |
| 984 | + model_inputs = model.prepare_inputs_for_generation( |
| 985 | + input_ids, use_cache=True, foo="bar", cache_position=cache_position |
| 986 | + ) |
| 987 | + self.assertTrue(model_inputs["use_cache"] is True) |
| 988 | + self.assertTrue(model_inputs["foo"] == "bar") |
| 989 | + |
| 990 | + init_input_ids = input_ids[:, :2] |
| 991 | + dynamic_cache = transformers.cache_utils.DynamicCache(config=config) |
| 992 | + dynamic_cache = model( |
| 993 | + init_input_ids, past_key_values=dynamic_cache |
| 994 | + ).past_key_values |
| 995 | + |
| 996 | + with self.subTest(case="case5"): |
| 997 | + if not has_transformers("4.57"): |
| 998 | + raise unittest.SkipTest("transformers 4.57+.") |
| 999 | + with self.assertRaises((AttributeError, TypeError)): |
| 1000 | + model_inputs = model.prepare_inputs_for_generation( |
| 1001 | + input_ids, past_key_values=dynamic_cache |
| 1002 | + ) |
| 1003 | + |
| 1004 | + with self.subTest(case="case6"): |
| 1005 | + cache_position = torch.arange(input_ids.shape[-1], dtype=torch.long).to( |
| 1006 | + torch_device |
| 1007 | + ) |
| 1008 | + cache_position = cache_position[dynamic_cache.get_seq_length() :] |
| 1009 | + model_inputs = model.prepare_inputs_for_generation( |
| 1010 | + input_ids, |
| 1011 | + past_key_values=dynamic_cache, |
| 1012 | + cache_position=cache_position, |
| 1013 | + attention_mask=attention_mask, |
| 1014 | + ) |
| 1015 | + self.assertTrue("past_key_values" in model_inputs) |
| 1016 | + self.assertTrue(torch.all(model_inputs["cache_position"] == cache_position)) |
| 1017 | + self.assertTrue( |
| 1018 | + model_inputs["input_ids"].shape[-1] == 1 |
| 1019 | + ) # 1 = 3 fed tokens - 2 tokens in the cache |
| 1020 | + self.assertTrue(model_inputs["position_ids"].shape[-1] == 1) |
| 1021 | + self.assertTrue( |
| 1022 | + model_inputs["attention_mask"].shape[-1] == 3 |
| 1023 | + ) # we still need the full attention mask! |
| 1024 | + |
| 1025 | + with self.subTest(case="case6.2"): |
| 1026 | + max_cache_len = 10 |
| 1027 | + batch_size = 2 |
| 1028 | + query_length = input_ids.shape[-1] - init_input_ids.shape[-1] |
| 1029 | + static_cache = transformers.cache_utils.StaticCache( |
| 1030 | + config=config, max_cache_len=max_cache_len |
| 1031 | + ) |
| 1032 | + static_cache = model( |
| 1033 | + init_input_ids, past_key_values=static_cache |
| 1034 | + ).past_key_values |
| 1035 | + model_inputs = model.prepare_inputs_for_generation( |
| 1036 | + input_ids, |
| 1037 | + past_key_values=static_cache, |
| 1038 | + cache_position=cache_position, |
| 1039 | + attention_mask=attention_mask, |
| 1040 | + ) |
| 1041 | + self.assertTrue("past_key_values" in model_inputs) |
| 1042 | + self.assertTrue( |
| 1043 | + list(model_inputs["attention_mask"].shape) |
| 1044 | + == [batch_size, 1, query_length, max_cache_len] |
| 1045 | + ) |
| 1046 | + |
| 1047 | + with self.subTest(case="case7"): |
| 1048 | + if not has_transformers("4.57"): |
| 1049 | + raise unittest.SkipTest("transformers 4.57+.") |
| 1050 | + init_inputs_embeds = model.get_input_embeddings()(init_input_ids) |
| 1051 | + model_inputs = model.prepare_inputs_for_generation( |
| 1052 | + input_ids, |
| 1053 | + past_key_values=dynamic_cache, |
| 1054 | + inputs_embeds=init_inputs_embeds, |
| 1055 | + cache_position=cache_position, |
| 1056 | + ) |
| 1057 | + self.assertTrue(model_inputs["input_ids"] is not None) |
| 1058 | + self.assertTrue(model_inputs["inputs_embeds"] is None) |
| 1059 | + |
877 | 1060 |
|
878 | 1061 | if __name__ == "__main__": |
879 | 1062 | unittest.main(verbosity=2) |
0 commit comments