Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions _unittests/ut_ci_models/test_ci_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def test_main_qwen25_tiny_llm(self):
pretrained=False,
part="",
output_folder=self.get_dump_folder("test_main_qwen25_tiny_llm"),
opset=24,
)
self.clean_dump()

Expand Down
2 changes: 2 additions & 0 deletions _unittests/ut_torch_export_patches/test_patch_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,6 +996,8 @@ def test_prepare_inputs_for_generation_decoder_llm(self):
with self.subTest(case="case5"):
if not has_transformers("4.57"):
raise unittest.SkipTest("transformers 4.57+.")
if has_transformers("5.2.99"):
raise unittest.SkipTest("transformers 5.2+.")
Comment thread
xadupre marked this conversation as resolved.
Outdated
with self.assertRaises((AttributeError, TypeError)):
model_inputs = model.prepare_inputs_for_generation(
input_ids, past_key_values=dynamic_cache
Expand Down
10 changes: 7 additions & 3 deletions onnx_diagnostic/ci_models/export_qwen25_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
import sys
import time
import warnings
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple
from .ci_helpers import (
check_for_discrepancies_and_log_everything_into_a_json_file,
compute_expected_outputs,
Expand Down Expand Up @@ -199,6 +199,7 @@ def main(
atol: float = 0.01,
mismatch01: float = 0.1,
profile_exporter: bool = False,
opset: Optional[int] = None,
):
Comment thread
xadupre marked this conversation as resolved.
"""
Exports model Qwen/Qwen2.5-VL-7B-Instruct or pieces of it.
Expand All @@ -221,6 +222,8 @@ def main(
:param atol: raises an exception if tolerance is above that threshold
:param mismatch01: raises an exception if the ratio of mismatches
is above that threshold
:param opset: opset, if not specified, a value is chosen based on the
proposed rewriting
:param profile_exporter: profiles the exporter
"""
prefix = simplify_model_id_for_a_filename(model_id)
Expand All @@ -243,6 +246,7 @@ def main(
print(f"-- make_zip={make_zip}")
print(f"-- output_folder={output_folder}")
print(f"-- atol={atol}")
print(f"-- opset={opset}")
print(f"-- mismatch01={mismatch01}")
print(f"-- profile_exporter={profile_exporter}")
print("------------------------------------------------------------------")
Expand Down Expand Up @@ -473,15 +477,15 @@ def process_image(inputs_embeds, image_features):

begin = time.perf_counter()

target_opset = 22
target_opset = opset or 22
if (
exporter == "onnx-dynamo"
and device == "cuda"
and "QWEN25ATTENTION" not in os.environ
):
os.environ["QWEN25ATTENTION"] = "PACKED"
elif "QWEN25ATTENTION" in os.environ and os.environ["QWEN25ATTENTION"] == "LOOPA23":
target_opset = 23
target_opset = opset or 23

with torch_export_patches(
patch_torch=False,
Expand Down
4 changes: 2 additions & 2 deletions onnx_diagnostic/tasks/image_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def _get_inputs_gemma3(
},
"position_ids": {0: batch, 1: seq_length},
"cache_position": {0: seq_length},
"past_key_values": [{0: batch} for _ in range(num_hidden_layers * 2)],
"past_key_values": [{0: batch, 2: seq_length} for _ in range(num_hidden_layers * 2)],
"pixel_values": {0: batch},
Comment thread
sdpython marked this conversation as resolved.
"use_cache": None,
}
Expand Down Expand Up @@ -280,7 +280,7 @@ def get_inputs_default(
"past_key_values": list(
itertools.chain.from_iterable(
zip(
[{0: batch} for _ in range(num_hidden_layers)],
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class patched_GenerationMixin:
(
None
if pv.Version(transformers.__version__) >= pv.Version("4.56")
and pv.Version(transformers.__version__) < pv.Version("5.2.99")
Comment thread
xadupre marked this conversation as resolved.
else "prepare_inputs_for_generation"
),
# (
Expand Down Expand Up @@ -297,192 +298,3 @@ def prepare_inputs_for_generation( # pragma: no cover
# 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
model_inputs.pop("labels", None)
return model_inputs

'''
# drops a patch since it is for a very specific version.
def _sample(
self,
input_ids: torch.LongTensor,
logits_processor: "LogitsProcessorList", # noqa: F821
stopping_criteria: "StoppingCriteriaList", # noqa: F821
generation_config: "GenerationConfig", # noqa: F821
synced_gpus: bool = False,
streamer: Optional["BaseStreamer"] = None, # noqa: F821
**model_kwargs,
) -> Union["GenerateNonBeamOutput", torch.LongTensor]: # noqa: F821
"""
2025/09/29: updates for Gemma3 models, fix for eager mode as well as the export.
"""
# init values
pad_token_id = generation_config._pad_token_tensor
output_attentions = generation_config.output_attentions
output_hidden_states = generation_config.output_hidden_states
output_scores = generation_config.output_scores
output_logits = generation_config.output_logits
return_dict_in_generate = generation_config.return_dict_in_generate
has_eos_stopping_criteria = any(
hasattr(criteria, "eos_token_id") for criteria in stopping_criteria
)
do_sample = generation_config.do_sample

# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
raw_logits = () if (return_dict_in_generate and output_logits) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = (
() if (return_dict_in_generate and output_hidden_states) else None
)

# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = (
model_kwargs["encoder_outputs"].get("attentions")
if output_attentions
else None
)
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states")
if output_hidden_states
else None
)

# keep track of which sequences are already finished
batch_size, cur_len = input_ids.shape[:2]
this_peer_finished = False
unfinished_sequences = torch.ones(
batch_size, dtype=torch.long, device=input_ids.device
)
model_kwargs = self._get_initial_cache_position(
cur_len, input_ids.device, model_kwargs
)

model_forward = self.__call__
compile_forward = self._valid_auto_compile_criteria(model_kwargs, generation_config)
if compile_forward:
os.environ["TOKENIZERS_PARALLELISM"] = "0"
# If we use FA2 and a static cache, we cannot compile with fullgraph
if self.config._attn_implementation == "flash_attention_2":
# only raise warning if the user passed an explicit compile-config
if (
generation_config.compile_config is not None
and generation_config.compile_config.fullgraph
):
generation_config.compile_config.fullgraph = False
model_forward = self.get_compiled_call(generation_config.compile_config)

if generation_config.prefill_chunk_size is not None:
model_kwargs = self._prefill_chunking(input_ids, generation_config, **model_kwargs)
is_prefill = False
else:
is_prefill = True

while self._has_unfinished_sequences(
this_peer_finished, synced_gpus, device=input_ids.device
):
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)

if is_prefill:
outputs = self(**model_inputs, return_dict=True)
is_prefill = False
else:
outputs = model_forward(**model_inputs, return_dict=True)

model_kwargs = self._update_model_kwargs_for_generation(
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
if synced_gpus and this_peer_finished:
continue

next_token_logits = outputs.logits[:, -1, :].to(
copy=True, dtype=torch.float32, device=input_ids.device
)

# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)

# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (next_token_scores,)
if output_logits:
raw_logits += (next_token_logits,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,)
if self.config.is_encoder_decoder
else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)

if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)

# token selection
if do_sample:
probs = torch.nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(next_token_scores, dim=-1)

# finished sentences should have their next token be a padding token
if has_eos_stopping_criteria:
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
1 - unfinished_sequences
)

# update generated ids, model inputs, and length for next step
# PATCHED: the two following lines, next_tokens can 2D already for this model
next_tokens_2d = (
next_tokens if len(next_tokens.shape) == 2 else next_tokens[:, None]
)
input_ids = torch.cat([input_ids, next_tokens_2d], dim=-1)
if streamer is not None:
streamer.put(next_tokens.cpu())

unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
this_peer_finished = unfinished_sequences.max() == 0
cur_len += 1

# This is needed to properly delete outputs.logits which may be very large
# for first iteration
# Otherwise a reference to outputs is kept which keeps
# the logits alive in the next iteration
del outputs

if streamer is not None:
streamer.end()

if return_dict_in_generate:
if self.config.is_encoder_decoder:
return transformers.generation.utils.GenerateEncoderDecoderOutput(
sequences=input_ids,
scores=scores,
logits=raw_logits,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return transformers.generation.utils.GenerateDecoderOnlyOutput(
sequences=input_ids,
scores=scores,
logits=raw_logits,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
past_key_values=model_kwargs.get("past_key_values"),
)
else:
return input_ids
'''
Loading