Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions _unittests/ut_ci_models/test_ci_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def test_main_qwen25_tiny_llm(self):
pretrained=False,
part="",
output_folder=self.get_dump_folder("test_main_qwen25_tiny_llm"),
opset=24,
)
self.clean_dump()

Expand Down
4 changes: 3 additions & 1 deletion _unittests/ut_tasks/test_tasks_image_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def test_image_text_to_text_tiny_gemma3(self):
def test_image_text_to_text_gemma3_4b_it(self):
make_hybrid_cache = get_make_hybrid_cache()
if make_hybrid_cache is None:
raise unittest.SkipTest("not implemented yet for transformers>=5")
raise unittest.SkipTest(
"not implemented yet for transformers>=5 (make_hybrid_cache is None)"
)
mid = "google/gemma-3-4b-it"
data = get_untrained_model_with_inputs(
mid,
Expand Down
16 changes: 11 additions & 5 deletions _unittests/ut_torch_export_patches/test_patch_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ def forward(
for exporter in ("custom", "onnx-dynamo"):
# onnx-dynamo needs OpOverload(op='aten.sym_storage_offset' (transformers>=5.0?)
if exporter == "onnx-dynamo" and not has_onnxscript("0.5.7"):
raise unittest.SkipTest("needs onnxscript>=0.5.7")
self.skipTest("needs onnxscript>=0.5.7")
filename = self.get_dump_file(
f"test_patched_qwen2_5_vl_vision_attention_forward.{exporter}.onnx"
)
Expand Down Expand Up @@ -640,7 +640,7 @@ def test_qwen2_5_vl_vision_attention_iteration(self):
)
for exporter in ("custom", "onnx-dynamo"):
if exporter == "onnx-dynamo" and aten_sym_storage_offset is None:
raise unittest.SkipTest("update onnxscript to make this test run")
self.skipTest("update onnxscript to make this test run")
# onnx-dynamo needs OpOverload(op='aten.sym_storage_offset' (transformers>=5.0?)
filename = self.get_dump_file(
f"test_qwen2_5_vl_vision_attention_iteration.{exporter}.onnx"
Expand Down Expand Up @@ -909,7 +909,7 @@ def test_cache_dependant_input_preparation_exporting(self):
torch.testing.assert_close(eager2, export2)

with self.subTest(case="case2"):
raise unittest.SkipTest("torch 2.10+ has probably a bug here.")
self.skipTest("torch 2.10+ has probably a bug here.")
input_ids = torch.randint(0, 16, (2, 8), dtype=torch.int64)
inputs_embeds = torch.rand((2, 8), dtype=torch.float32)
cache_position = torch.arange(0, 8, dtype=torch.int64)
Expand Down Expand Up @@ -995,13 +995,17 @@ def test_prepare_inputs_for_generation_decoder_llm(self):

with self.subTest(case="case5"):
if not has_transformers("4.57"):
raise unittest.SkipTest("transformers 4.57+.")
self.skipTest("This test only works with transformers>=4.57, <5.3.")
if has_transformers("5.2.99"):
self.skipTest("This test is no longer valid with transformers>=5.3.")
with self.assertRaises((AttributeError, TypeError)):
model_inputs = model.prepare_inputs_for_generation(
input_ids, past_key_values=dynamic_cache
)

with self.subTest(case="case6"):
if has_transformers("5.2.99"):
self.skipTest("This test is no longer valid with transformers>=5.3.")
cache_position = torch.arange(input_ids.shape[-1], dtype=torch.long).to(
torch_device
)
Expand All @@ -1023,6 +1027,8 @@ def test_prepare_inputs_for_generation_decoder_llm(self):
) # we still need the full attention mask!

with self.subTest(case="case6.2"):
if has_transformers("5.2.99"):
self.skipTest("This test is no longer valid with transformers>=5.3.")
max_cache_len = 10
batch_size = 2
query_length = input_ids.shape[-1] - init_input_ids.shape[-1]
Expand All @@ -1046,7 +1052,7 @@ def test_prepare_inputs_for_generation_decoder_llm(self):

with self.subTest(case="case7"):
if not has_transformers("4.57"):
raise unittest.SkipTest("transformers 4.57+.")
self.skipTest("This test only works with transformers>=4.57.")
init_inputs_embeds = model.get_input_embeddings()(init_input_ids)
model_inputs = model.prepare_inputs_for_generation(
input_ids,
Expand Down
6 changes: 6 additions & 0 deletions onnx_diagnostic/ci_models/ci_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@ def get_parser(name: str, epilog: str = "") -> ArgumentParser:
help="Profiles the exporter and outputs an html document from pyinstrument",
action=BooleanOptionalAction,
)
parser.add_argument(
"--opset",
type=int,
default=0,
help="default opsets, 0 to let the exporter choose",
)
return parser


Expand Down
5 changes: 4 additions & 1 deletion onnx_diagnostic/ci_models/export_phi4_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,7 @@ def main(
atol: float = 2,
mismatch01: float = 0.01,
profile_exporter: bool = False,
opset: Optional[int] = None,
):
"""
Exports model Qwen/Qwen2.5-VL-7B-Instruct or pieces of it.
Expand All @@ -733,6 +734,7 @@ def main(
:param atol: raises an exception if tolerance is above that threshold
:param mismatch01: raises an exception if the ratio of mismatches
is above that threshold
:param opset: opset to choose
:param profile_exporter: profiles the exporter
"""
prefix = simplify_model_id_for_a_filename(model_id)
Expand Down Expand Up @@ -947,7 +949,7 @@ def forward(

begin = time.perf_counter()

target_opset = 22
target_opset = opset or 22

details = PatchDetails()
with torch_export_patches(
Expand Down Expand Up @@ -1062,4 +1064,5 @@ def forward(
atol=args.atol,
mismatch01=args.mismatch01,
profile_exporter=args.profile_exporter,
opset=args.opset if args.opset > 0 else None,
)
11 changes: 8 additions & 3 deletions onnx_diagnostic/ci_models/export_qwen25_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
import sys
import time
import warnings
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple
from .ci_helpers import (
check_for_discrepancies_and_log_everything_into_a_json_file,
compute_expected_outputs,
Expand Down Expand Up @@ -199,6 +199,7 @@ def main(
atol: float = 0.01,
mismatch01: float = 0.1,
profile_exporter: bool = False,
opset: Optional[int] = None,
):
Comment thread
xadupre marked this conversation as resolved.
"""
Exports model Qwen/Qwen2.5-VL-7B-Instruct or pieces of it.
Expand All @@ -221,6 +222,8 @@ def main(
:param atol: raises an exception if tolerance is above that threshold
:param mismatch01: raises an exception if the ratio of mismatches
is above that threshold
:param opset: opset, if not specified, a value is chosen based on the
proposed rewriting
:param profile_exporter: profiles the exporter
"""
prefix = simplify_model_id_for_a_filename(model_id)
Expand All @@ -243,6 +246,7 @@ def main(
print(f"-- make_zip={make_zip}")
print(f"-- output_folder={output_folder}")
print(f"-- atol={atol}")
print(f"-- opset={opset}")
print(f"-- mismatch01={mismatch01}")
print(f"-- profile_exporter={profile_exporter}")
print("------------------------------------------------------------------")
Expand Down Expand Up @@ -473,15 +477,15 @@ def process_image(inputs_embeds, image_features):

begin = time.perf_counter()

target_opset = 22
target_opset = opset or 22
if (
exporter == "onnx-dynamo"
and device == "cuda"
and "QWEN25ATTENTION" not in os.environ
):
os.environ["QWEN25ATTENTION"] = "PACKED"
elif "QWEN25ATTENTION" in os.environ and os.environ["QWEN25ATTENTION"] == "LOOPA23":
target_opset = 23
target_opset = opset or 23

with torch_export_patches(
patch_torch=False,
Expand Down Expand Up @@ -565,4 +569,5 @@ def process_image(inputs_embeds, image_features):
atol=args.atol,
mismatch01=args.mismatch01,
profile_exporter=args.profile_exporter,
opset=args.opset if args.opset > 0 else None,
)
21 changes: 15 additions & 6 deletions onnx_diagnostic/helpers/cache_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,7 @@ def make_hybrid_cache(
max_cache_len: Optional[int] = None,
max_batch_size: Optional[int] = None,
sliding_window: Optional[int] = None,
cls_layers: Optional[List[type]] = None,
) -> transformers.cache_utils.HybridCache:
"""
Creates an instance of :class:`transformers.cache_utils.HybridCache`.
Expand All @@ -662,6 +663,8 @@ def make_hybrid_cache(
:param key_value_pairs: list of pairs of (key, values)
:return: :class:`transformers.cache_utils.HybridCache`

`cls_layers` is unused.

Example:

.. runpython::
Expand Down Expand Up @@ -742,16 +745,22 @@ def make_hybrid_cache(
not max_batch_size and not max_cache_len
), "key_value_pairs is not empty, do not specify max_cache_len and max_batch_size"
max_batch_size = key_value_pairs[0][0].shape[0]
assert max_cache_len is not None or all(
isinstance(kv[0].shape[2], int) for kv in key_value_pairs
), (
f"Cannot determine max_cache_len with "
f"shapes={[kv[0].shape for kv in key_value_pairs]}"
)
sets_of_dim = set(kv[0].shape[2] for kv in key_value_pairs)
if len(sets_of_dim) == 1:
max_cache_len = sets_of_dim.pop()
sliding_window = max_cache_len
if max_cache_len is None:
max_cache_len = sets_of_dim.pop()
else:
assert (
len(sets_of_dim) == 2
), f"Not implemented for more than 2 dimensions {sets_of_dim}"
max_cache_len = max(sets_of_dim)
sliding_window = min(sets_of_dim)
if max_cache_len is None:
max_cache_len = max(sets_of_dim)
layer_types = [
"full_attention" if i == max_cache_len else "sliding_attention"
for i in [kv[0].shape[2] for kv in key_value_pairs]
Expand All @@ -760,8 +769,8 @@ def make_hybrid_cache(
assert (
max_batch_size and max_cache_len
), "key_value_pairs is empty, max_batch_size and max_cache_len are required"
if sliding_window is None:
sliding_window = max_cache_len
if sliding_window is None:
sliding_window = max_cache_len
_max_cache_len = max_cache_len
_sliding_window = sliding_window

Expand Down
2 changes: 1 addition & 1 deletion onnx_diagnostic/tasks/image_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def get_inputs_default(
"past_key_values": list(
itertools.chain.from_iterable(
zip(
[{0: batch} for _ in range(num_hidden_layers)],
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,18 @@ def patched_sdpa_attention_forward(
if is_causal is None and attention_mask is not None:
is_causal = False
if is_causal is not None:
torch._check(query.shape[0] > 0)
torch._check(query.shape[1] > 0)
torch._check(query.shape[2] > 0)
torch._check(query.shape[3] > 0)
torch._check(key.shape[0] > 0)
torch._check(key.shape[1] > 0)
torch._check(key.shape[2] > 0)
torch._check(key.shape[3] > 0)
torch._check(value.shape[0] > 0)
torch._check(value.shape[1] > 0)
torch._check(value.shape[2] > 0)
torch._check(value.shape[3] > 0)
Comment thread
sdpython marked this conversation as resolved.
return (
torch.nn.functional.scaled_dot_product_attention(
query,
Expand Down
Loading
Loading