Skip to content

Commit 20919d4

Browse files
committed
fix
1 parent 194e078 commit 20919d4

2 files changed

Lines changed: 13 additions & 10 deletions

File tree

_unittests/ut_tasks/test_tasks_image_text_to_text.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
class TestTasksImageTextToText(ExtTestCase):
1717
@hide_stdout()
18-
@requires_transformers("4.56")
18+
@requires_transformers("5.0.99")
1919
@requires_torch("2.7.99")
2020
def test_image_text_to_text_idefics(self):
2121
mid = "HuggingFaceM4/tiny-random-idefics"

onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,11 @@ def _flatten_key_value_cache(cache: Cache) -> Tuple[List[Any], torch.utils._pytr
4444
ca = CacheKeyValue(cache)
4545
flat = list(itertools.chain.from_iterable(zip(ca.key_cache, ca.value_cache)))
4646
unique = set(ca.cls_layers) if ca.cls_layers else None
47-
if unique is None or (len(unique) == 1 and unique.pop().__name__ == "DynamicLayer"):
47+
if (
48+
cache.__class__.__name__ != "DynamicCache"
49+
or unique is None
50+
or (len(unique) == 1 and unique.pop().__name__ == "DynamicLayer")
51+
):
4852
keys = list(
4953
itertools.chain.from_iterable(
5054
(f"key_{i}", f"value_{i}") for i in range(len(ca.key_cache))
@@ -80,14 +84,13 @@ def _unflatten_cache(
8084
)
8185
if expected == context:
8286
res = make_cache(list(zip(values[::2], values[1::2])))
83-
assert output_type is None or isinstance(
84-
res, output_type
85-
), f"Type mismatch between {output_type} (expected) and {type(res)}"
86-
return res
87-
88-
cls_layer_names = [SHORTEN_LAYER_NAMES[name.split("_")[1][0]] for name in context][::2]
89-
cls_layers = [getattr(transformers.cache_utils, cls_name) for cls_name in cls_layer_names]
90-
res = make_cache(list(zip(values[::2], values[1::2])), cls_layers=cls_layers)
87+
else:
88+
cls_layer_names = [SHORTEN_LAYER_NAMES[name.split("_")[1][0]] for name in context][::2]
89+
cls_layers = [
90+
getattr(transformers.cache_utils, cls_name) for cls_name in cls_layer_names
91+
]
92+
res = make_cache(list(zip(values[::2], values[1::2])), cls_layers=cls_layers)
93+
9194
assert output_type is None or isinstance(
9295
res, output_type
9396
), f"Type mismatch between {output_type} (expected) and {type(res)}"

0 commit comments

Comments
 (0)