Skip to content

Commit 53e8b87

Browse files
authored
fix cache dynamic dimension in image-text-to-text (#421)
* fix cache dynamic dimension in image-text-to-txt * disable patches * improves serialization * fix ache * a few fixes * cache
1 parent 9af6f4f commit 53e8b87

12 files changed

Lines changed: 83 additions & 212 deletions

File tree

_unittests/ut_ci_models/test_ci_export.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def test_main_qwen25_tiny_llm(self):
2020
pretrained=False,
2121
part="",
2222
output_folder=self.get_dump_folder("test_main_qwen25_tiny_llm"),
23+
opset=24,
2324
)
2425
self.clean_dump()
2526

_unittests/ut_tasks/test_tasks_image_text_to_text.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ def test_image_text_to_text_tiny_gemma3(self):
6161
def test_image_text_to_text_gemma3_4b_it(self):
6262
make_hybrid_cache = get_make_hybrid_cache()
6363
if make_hybrid_cache is None:
64-
raise unittest.SkipTest("not implemented yet for transformers>=5")
64+
raise unittest.SkipTest(
65+
"not implemented yet for transformers>=5 (make_hybrid_cache is None)"
66+
)
6567
mid = "google/gemma-3-4b-it"
6668
data = get_untrained_model_with_inputs(
6769
mid,

_unittests/ut_torch_export_patches/test_patch_transformers.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,7 @@ def forward(
586586
for exporter in ("custom", "onnx-dynamo"):
587587
# onnx-dynamo needs OpOverload(op='aten.sym_storage_offset' (transformers>=5.0?)
588588
if exporter == "onnx-dynamo" and not has_onnxscript("0.5.7"):
589-
raise unittest.SkipTest("needs onnxscript>=0.5.7")
589+
self.skipTest("needs onnxscript>=0.5.7")
590590
filename = self.get_dump_file(
591591
f"test_patched_qwen2_5_vl_vision_attention_forward.{exporter}.onnx"
592592
)
@@ -640,7 +640,7 @@ def test_qwen2_5_vl_vision_attention_iteration(self):
640640
)
641641
for exporter in ("custom", "onnx-dynamo"):
642642
if exporter == "onnx-dynamo" and aten_sym_storage_offset is None:
643-
raise unittest.SkipTest("update onnxscript to make this test run")
643+
self.skipTest("update onnxscript to make this test run")
644644
# onnx-dynamo needs OpOverload(op='aten.sym_storage_offset' (transformers>=5.0?)
645645
filename = self.get_dump_file(
646646
f"test_qwen2_5_vl_vision_attention_iteration.{exporter}.onnx"
@@ -909,7 +909,7 @@ def test_cache_dependant_input_preparation_exporting(self):
909909
torch.testing.assert_close(eager2, export2)
910910

911911
with self.subTest(case="case2"):
912-
raise unittest.SkipTest("torch 2.10+ has probably a bug here.")
912+
self.skipTest("torch 2.10+ has probably a bug here.")
913913
input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)
914914
inputs_embeds = torch.rand((2, 8), dtype=torch.float32)
915915
cache_position = torch.arange(0, 8, dtype=torch.int64)
@@ -995,13 +995,17 @@ def test_prepare_inputs_for_generation_decoder_llm(self):
995995

996996
with self.subTest(case="case5"):
997997
if not has_transformers("4.57"):
998-
raise unittest.SkipTest("transformers 4.57+.")
998+
self.skipTest("This test only works with transformers>=4.57, <5.3.")
999+
if has_transformers("5.2.99"):
1000+
self.skipTest("This test is no longer valid with transformers>=5.3.")
9991001
with self.assertRaises((AttributeError, TypeError)):
10001002
model_inputs = model.prepare_inputs_for_generation(
10011003
input_ids, past_key_values=dynamic_cache
10021004
)
10031005

10041006
with self.subTest(case="case6"):
1007+
if has_transformers("5.2.99"):
1008+
self.skipTest("This test is no longer valid with transformers>=5.3.")
10051009
cache_position = torch.arange(input_ids.shape[-1], dtype=torch.long).to(
10061010
torch_device
10071011
)
@@ -1023,6 +1027,8 @@ def test_prepare_inputs_for_generation_decoder_llm(self):
10231027
) # we still need the full attention mask!
10241028

10251029
with self.subTest(case="case6.2"):
1030+
if has_transformers("5.2.99"):
1031+
self.skipTest("This test is no longer valid with transformers>=5.3.")
10261032
max_cache_len = 10
10271033
batch_size = 2
10281034
query_length = input_ids.shape[-1] - init_input_ids.shape[-1]
@@ -1046,7 +1052,7 @@ def test_prepare_inputs_for_generation_decoder_llm(self):
10461052

10471053
with self.subTest(case="case7"):
10481054
if not has_transformers("4.57"):
1049-
raise unittest.SkipTest("transformers 4.57+.")
1055+
self.skipTest("This test only works with transformers>=4.57.")
10501056
init_inputs_embeds = model.get_input_embeddings()(init_input_ids)
10511057
model_inputs = model.prepare_inputs_for_generation(
10521058
input_ids,

onnx_diagnostic/ci_models/ci_helpers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,12 @@ def get_parser(name: str, epilog: str = "") -> ArgumentParser:
128128
help="Profiles the exporter and outputs an html document from pyinstrument",
129129
action=BooleanOptionalAction,
130130
)
131+
parser.add_argument(
132+
"--opset",
133+
type=int,
134+
default=0,
135+
help="default opsets, 0 to let the exporter choose",
136+
)
131137
return parser
132138

133139

onnx_diagnostic/ci_models/export_phi4_mm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,7 @@ def main(
711711
atol: float = 2,
712712
mismatch01: float = 0.01,
713713
profile_exporter: bool = False,
714+
opset: Optional[int] = None,
714715
):
715716
"""
716717
Exports model Qwen/Qwen2.5-VL-7B-Instruct or pieces of it.
@@ -733,6 +734,7 @@ def main(
733734
:param atol: raises an exception if tolerance is above that threshold
734735
:param mismatch01: raises an exception if the ratio of mismatches
735736
is above that threshold
737+
:param opset: opset to choose
736738
:param profile_exporter: profiles the exporter
737739
"""
738740
prefix = simplify_model_id_for_a_filename(model_id)
@@ -947,7 +949,7 @@ def forward(
947949

948950
begin = time.perf_counter()
949951

950-
target_opset = 22
952+
target_opset = opset or 22
951953

952954
details = PatchDetails()
953955
with torch_export_patches(
@@ -1062,4 +1064,5 @@ def forward(
10621064
atol=args.atol,
10631065
mismatch01=args.mismatch01,
10641066
profile_exporter=args.profile_exporter,
1067+
opset=args.opset if args.opset > 0 else None,
10651068
)

onnx_diagnostic/ci_models/export_qwen25_vl.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
import sys
6161
import time
6262
import warnings
63-
from typing import Any, Dict, List, Tuple
63+
from typing import Any, Dict, List, Optional, Tuple
6464
from .ci_helpers import (
6565
check_for_discrepancies_and_log_everything_into_a_json_file,
6666
compute_expected_outputs,
@@ -199,6 +199,7 @@ def main(
199199
atol: float = 0.01,
200200
mismatch01: float = 0.1,
201201
profile_exporter: bool = False,
202+
opset: Optional[int] = None,
202203
):
203204
"""
204205
Exports model Qwen/Qwen2.5-VL-7B-Instruct or pieces of it.
@@ -221,6 +222,8 @@ def main(
221222
:param atol: raises an exception if tolerance is above that threshold
222223
:param mismatch01: raises an exception if the ratio of mismatches
223224
is above that threshold
225+
:param opset: opset, if not specified, a value is chosen based on the
226+
proposed rewriting
224227
:param profile_exporter: profiles the exporter
225228
"""
226229
prefix = simplify_model_id_for_a_filename(model_id)
@@ -243,6 +246,7 @@ def main(
243246
print(f"-- make_zip={make_zip}")
244247
print(f"-- output_folder={output_folder}")
245248
print(f"-- atol={atol}")
249+
print(f"-- opset={opset}")
246250
print(f"-- mismatch01={mismatch01}")
247251
print(f"-- profile_exporter={profile_exporter}")
248252
print("------------------------------------------------------------------")
@@ -473,15 +477,15 @@ def process_image(inputs_embeds, image_features):
473477

474478
begin = time.perf_counter()
475479

476-
target_opset = 22
480+
target_opset = opset or 22
477481
if (
478482
exporter == "onnx-dynamo"
479483
and device == "cuda"
480484
and "QWEN25ATTENTION" not in os.environ
481485
):
482486
os.environ["QWEN25ATTENTION"] = "PACKED"
483487
elif "QWEN25ATTENTION" in os.environ and os.environ["QWEN25ATTENTION"] == "LOOPA23":
484-
target_opset = 23
488+
target_opset = opset or 23
485489

486490
with torch_export_patches(
487491
patch_torch=False,
@@ -565,4 +569,5 @@ def process_image(inputs_embeds, image_features):
565569
atol=args.atol,
566570
mismatch01=args.mismatch01,
567571
profile_exporter=args.profile_exporter,
572+
opset=args.opset if args.opset > 0 else None,
568573
)

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -539,8 +539,10 @@ def make_encoder_decoder_cache(
539539

540540
def make_mamba_cache(
541541
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
542+
cls_layers: Optional[Union[str, List[type]]] = None,
543+
cls_kwargs: Optional[Union[Dict[str, int], List[Dict[str, int]]]] = None,
542544
) -> "MambaCache": # noqa: F821
543-
"Creates a ``MambaCache``."
545+
"""Creates a ``MambaCache``. `cls_layers`, `cls_kwargs` are unused."""
544546
# import is moved here because this part is slow.
545547
try:
546548
from transformers.models.mamba.modeling_mamba import MambaCache
@@ -591,8 +593,13 @@ def get_text_config(self, *args, **kwargs):
591593

592594
def make_sliding_window_cache(
593595
key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
596+
cls_layers: Optional[Union[str, List[type]]] = None,
597+
cls_kwargs: Optional[Union[Dict[str, int], List[Dict[str, int]]]] = None,
594598
) -> transformers.cache_utils.SlidingWindowCache:
595-
"Creates a :class:`transformers.cache_utils.SlidingWindowCache`."
599+
"""
600+
Creates a :class:`transformers.cache_utils.SlidingWindowCache`.
601+
`cls_layers`, `cls_kwargs` are unused.
602+
"""
596603
key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
597604

598605
class _config:
@@ -654,6 +661,8 @@ def make_hybrid_cache(
654661
max_cache_len: Optional[int] = None,
655662
max_batch_size: Optional[int] = None,
656663
sliding_window: Optional[int] = None,
664+
cls_layers: Optional[Union[str, List[type]]] = None,
665+
cls_kwargs: Optional[Union[Dict[str, int], List[Dict[str, int]]]] = None,
657666
) -> transformers.cache_utils.HybridCache:
658667
"""
659668
Creates an instance of :class:`transformers.cache_utils.HybridCache`.
@@ -662,6 +671,8 @@ def make_hybrid_cache(
662671
:param key_value_pairs: list of pairs of (key, values)
663672
:return: :class:`transformers.cache_utils.HybridCache`
664673
674+
`cls_layers`, `cls_kwargs` are unused.
675+
665676
Example:
666677
667678
.. runpython::
@@ -742,16 +753,22 @@ def make_hybrid_cache(
742753
not max_batch_size and not max_cache_len
743754
), "key_value_pairs is not empty, do not specify max_cache_len and max_batch_size"
744755
max_batch_size = key_value_pairs[0][0].shape[0]
756+
assert max_cache_len is not None or all(
757+
isinstance(kv[0].shape[2], int) for kv in key_value_pairs
758+
), (
759+
f"Cannot determine max_cache_len with "
760+
f"shapes={[kv[0].shape for kv in key_value_pairs]}"
761+
)
745762
sets_of_dim = set(kv[0].shape[2] for kv in key_value_pairs)
746763
if len(sets_of_dim) == 1:
747-
max_cache_len = sets_of_dim.pop()
748-
sliding_window = max_cache_len
764+
if max_cache_len is None:
765+
max_cache_len = sets_of_dim.pop()
749766
else:
750767
assert (
751768
len(sets_of_dim) == 2
752769
), f"Not implemented for more than 2 dimensions {sets_of_dim}"
753-
max_cache_len = max(sets_of_dim)
754-
sliding_window = min(sets_of_dim)
770+
if max_cache_len is None:
771+
max_cache_len = max(sets_of_dim)
755772
layer_types = [
756773
"full_attention" if i == max_cache_len else "sliding_attention"
757774
for i in [kv[0].shape[2] for kv in key_value_pairs]
@@ -760,8 +777,8 @@ def make_hybrid_cache(
760777
assert (
761778
max_batch_size and max_cache_len
762779
), "key_value_pairs is empty, max_batch_size and max_cache_len are required"
763-
if sliding_window is None:
764-
sliding_window = max_cache_len
780+
if sliding_window is None:
781+
sliding_window = max_cache_len
765782
_max_cache_len = max_cache_len
766783
_sliding_window = sliding_window
767784

onnx_diagnostic/tasks/image_text_to_text.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ def get_inputs_default(
280280
"past_key_values": list(
281281
itertools.chain.from_iterable(
282282
zip(
283-
[{0: batch} for _ in range(num_hidden_layers)],
283+
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
284284
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
285285
)
286286
)

onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,18 @@ def patched_sdpa_attention_forward(
139139
if is_causal is None and attention_mask is not None:
140140
is_causal = False
141141
if is_causal is not None:
142+
torch._check(query.shape[0] > 0)
143+
torch._check(query.shape[1] > 0)
144+
torch._check(query.shape[2] > 0)
145+
torch._check(query.shape[3] > 0)
146+
torch._check(key.shape[0] > 0)
147+
torch._check(key.shape[1] > 0)
148+
torch._check(key.shape[2] > 0)
149+
torch._check(key.shape[3] > 0)
150+
torch._check(value.shape[0] > 0)
151+
torch._check(value.shape[1] > 0)
152+
torch._check(value.shape[2] > 0)
153+
torch._check(value.shape[3] > 0)
142154
return (
143155
torch.nn.functional.scaled_dot_product_attention(
144156
query,

0 commit comments

Comments
 (0)