Skip to content

Commit de5958e

Browse files
committed
improves serialization
1 parent 8989eb7 commit de5958e

5 files changed

Lines changed: 38 additions & 14 deletions

File tree

_unittests/ut_tasks/test_tasks_image_text_to_text.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ def test_image_text_to_text_tiny_gemma3(self):
6161
def test_image_text_to_text_gemma3_4b_it(self):
6262
make_hybrid_cache = get_make_hybrid_cache()
6363
if make_hybrid_cache is None:
64-
raise unittest.SkipTest("not implemented yet for transformers>=5")
64+
raise unittest.SkipTest(
65+
"not implemented yet for transformers>=5 (make_hybrid_cache is None)"
66+
)
6567
mid = "google/gemma-3-4b-it"
6668
data = get_untrained_model_with_inputs(
6769
mid,

_unittests/ut_torch_export_patches/test_patch_transformers.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,7 @@ def forward(
586586
for exporter in ("custom", "onnx-dynamo"):
587587
# onnx-dynamo needs OpOverload(op='aten.sym_storage_offset' (transformers>=5.0?)
588588
if exporter == "onnx-dynamo" and not has_onnxscript("0.5.7"):
589-
raise unittest.SkipTest("needs onnxscript>=0.5.7")
589+
self.skipTest("needs onnxscript>=0.5.7")
590590
filename = self.get_dump_file(
591591
f"test_patched_qwen2_5_vl_vision_attention_forward.{exporter}.onnx"
592592
)
@@ -640,7 +640,7 @@ def test_qwen2_5_vl_vision_attention_iteration(self):
640640
)
641641
for exporter in ("custom", "onnx-dynamo"):
642642
if exporter == "onnx-dynamo" and aten_sym_storage_offset is None:
643-
raise unittest.SkipTest("update onnxscript to make this test run")
643+
self.skipTest("update onnxscript to make this test run")
644644
# onnx-dynamo needs OpOverload(op='aten.sym_storage_offset' (transformers>=5.0?)
645645
filename = self.get_dump_file(
646646
f"test_qwen2_5_vl_vision_attention_iteration.{exporter}.onnx"
@@ -909,7 +909,7 @@ def test_cache_dependant_input_preparation_exporting(self):
909909
torch.testing.assert_close(eager2, export2)
910910

911911
with self.subTest(case="case2"):
912-
raise unittest.SkipTest("torch 2.10+ has probably a bug here.")
912+
self.skipTest("torch 2.10+ has probably a bug here.")
913913
input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)
914914
inputs_embeds = torch.rand((2, 8), dtype=torch.float32)
915915
cache_position = torch.arange(0, 8, dtype=torch.int64)
@@ -995,15 +995,17 @@ def test_prepare_inputs_for_generation_decoder_llm(self):
995995

996996
with self.subTest(case="case5"):
997997
if not has_transformers("4.57"):
998-
raise unittest.SkipTest("transformers 4.57+.")
998+
self.skipTest("This test only works with transformers>=4.57, <5.3.")
999999
if has_transformers("5.2.99"):
1000-
raise unittest.SkipTest("transformers 5.2+.")
1000+
self.skipTest("This test is no longer valid with transformers>=5.3.")
10011001
with self.assertRaises((AttributeError, TypeError)):
10021002
model_inputs = model.prepare_inputs_for_generation(
10031003
input_ids, past_key_values=dynamic_cache
10041004
)
10051005

10061006
with self.subTest(case="case6"):
1007+
if has_transformers("5.2.99"):
1008+
self.skipTest("This test is no longer valid with transformers>=5.3.")
10071009
cache_position = torch.arange(input_ids.shape[-1], dtype=torch.long).to(
10081010
torch_device
10091011
)
@@ -1025,6 +1027,8 @@ def test_prepare_inputs_for_generation_decoder_llm(self):
10251027
) # we still need the full attention mask!
10261028

10271029
with self.subTest(case="case6.2"):
1030+
if has_transformers("5.2.99"):
1031+
self.skipTest("This test is no longer valid with transformers>=5.3.")
10281032
max_cache_len = 10
10291033
batch_size = 2
10301034
query_length = input_ids.shape[-1] - init_input_ids.shape[-1]
@@ -1048,7 +1052,7 @@ def test_prepare_inputs_for_generation_decoder_llm(self):
10481052

10491053
with self.subTest(case="case7"):
10501054
if not has_transformers("4.57"):
1051-
raise unittest.SkipTest("transformers 4.57+.")
1055+
self.skipTest("This test only works with transformers>=4.57.")
10521056
init_inputs_embeds = model.get_input_embeddings()(init_input_ids)
10531057
model_inputs = model.prepare_inputs_for_generation(
10541058
input_ids,

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -742,16 +742,22 @@ def make_hybrid_cache(
742742
not max_batch_size and not max_cache_len
743743
), "key_value_pairs is not empty, do not specify max_cache_len and max_batch_size"
744744
max_batch_size = key_value_pairs[0][0].shape[0]
745+
assert max_cache_len is not None or all(
746+
isinstance(kv[0].shape[2], int) for kv in key_value_pairs
747+
), (
748+
f"Cannot determine max_cache_len with "
749+
f"shapes={[kv[0].shape for kv in key_value_pairs]}"
750+
)
745751
sets_of_dim = set(kv[0].shape[2] for kv in key_value_pairs)
746752
if len(sets_of_dim) == 1:
747-
max_cache_len = sets_of_dim.pop()
748-
sliding_window = max_cache_len
753+
if max_cache_len is None:
754+
max_cache_len = sets_of_dim.pop()
749755
else:
750756
assert (
751757
len(sets_of_dim) == 2
752758
), f"Not implemented for more than 2 dimensions {sets_of_dim}"
753-
max_cache_len = max(sets_of_dim)
754-
sliding_window = min(sets_of_dim)
759+
if max_cache_len is None:
760+
max_cache_len = max(sets_of_dim)
755761
layer_types = [
756762
"full_attention" if i == max_cache_len else "sliding_attention"
757763
for i in [kv[0].shape[2] for kv in key_value_pairs]
@@ -760,8 +766,8 @@ def make_hybrid_cache(
760766
assert (
761767
max_batch_size and max_cache_len
762768
), "key_value_pairs is empty, max_batch_size and max_cache_len are required"
763-
if sliding_window is None:
764-
sliding_window = max_cache_len
769+
if sliding_window is None:
770+
sliding_window = max_cache_len
765771
_max_cache_len = max_cache_len
766772
_sliding_window = sliding_window
767773

onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,18 @@ def patched_sdpa_attention_forward(
139139
if is_causal is None and attention_mask is not None:
140140
is_causal = False
141141
if is_causal is not None:
142+
torch._check(query.shape[0] > 0)
143+
torch._check(query.shape[1] > 0)
144+
torch._check(query.shape[2] > 0)
145+
torch._check(query.shape[3] > 0)
146+
torch._check(key.shape[0] > 0)
147+
torch._check(key.shape[1] > 0)
148+
torch._check(key.shape[2] > 0)
149+
torch._check(key.shape[3] > 0)
150+
torch._check(value.shape[0] > 0)
151+
torch._check(value.shape[1] > 0)
152+
torch._check(value.shape[2] > 0)
153+
torch._check(value.shape[3] > 0)
142154
return (
143155
torch.nn.functional.scaled_dot_product_attention(
144156
query,

onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def _flatten_key_value_cache(cache: Cache) -> Tuple[List[Any], torch.utils._pytr
6161
flat = list(itertools.chain.from_iterable(zip(ca.key_cache, ca.value_cache)))
6262
unique = set(ca.cls_layers) if ca.cls_layers else None
6363
if (
64-
cache.__class__.__name__ != "DynamicCache"
64+
cache.__class__.__name__ not in ("DynamicCache", "HybridCache")
6565
or unique is None
6666
or (len(unique) == 1 and unique.pop().__name__ == "DynamicLayer")
6767
):

0 commit comments

Comments
 (0)