Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,22 @@ def infer_task(
except KeyError as e:
try:
config = AutoConfig.from_pretrained(model_name_or_path)
with_past_arch_list = ["MistralForCausalLM", "Zamba2ForCausalLM"]
with_past_arch_list = [
"MistralForCausalLM",
"Zamba2ForCausalLM",
"LlamaForCausalLMEagle3",
"Eagle3LlamaForCausalLM",
]
if any(arch in config.architectures for arch in with_past_arch_list):
task = "text-generation-with-past"
# VLM Eagle3 models (targeting VLM architectures like Qwen3-VL)
# should use image-text-to-text task for proper inputs_embeds/3D position_ids export.
if "Eagle3LlamaForCausalLM" in config.architectures and (
getattr(config, "modal_type", "") == "VLM"
or getattr(config, "target_model_type", "") in {"qwen2_vl", "qwen3_vl"}
):
task = "image-text-to-text"
else:
task = "text-generation-with-past"
except Exception:
raise KeyError(
f"The task could not be automatically inferred. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}"
Expand Down Expand Up @@ -144,9 +157,20 @@ def update_config_for_eagle3(config):
spec = importlib.util.find_spec(moduler_name)
if spec and spec.origin:
moduler_path = os.path.dirname(spec.origin)
# Use VLM-aware Eagle3 classes for models targeting VLM architectures
# (e.g. AngelSlim/Qwen3-VL-4B-Instruct_eagle3 with Eagle3LlamaForCausalLM).
is_vlm_eagle3 = getattr(config, "modal_type", "") == "VLM" or getattr(
config, "target_model_type", ""
) in {"qwen2_vl", "qwen3_vl"}
if is_vlm_eagle3:
model_cls = "QwenVLEagle3Model"
causal_lm_cls = "QwenVLEagle3ForCausalLM"
else:
model_cls = "LlamaEagle3Model"
causal_lm_cls = "LlamaEagle3ForCausalLM"
config.auto_map = {
"AutoModel": moduler_path + "--model_patcher.LlamaEagle3Model",
"AutoModelForCausalLM": moduler_path + "--model_patcher.LlamaEagle3ForCausalLM",
"AutoModel": moduler_path + f"--model_patcher.{model_cls}",
"AutoModelForCausalLM": moduler_path + f"--model_patcher.{causal_lm_cls}",
}
config.tie_word_embeddings = False
return config
Expand Down Expand Up @@ -316,9 +340,10 @@ def main_export(
quantization_config = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", None) if quantization_config else None

# update config to load eagle3 models
# update config to load eagle3 models (both text-only and VLM variants)
archs = getattr(config, "architectures", None)
if isinstance(archs, list) and len(archs) > 0 and archs[0] == "LlamaForCausalLMEagle3":
_eagle3_archs = {"LlamaForCausalLMEagle3", "Eagle3LlamaForCausalLM"}
if isinstance(archs, list) and len(archs) > 0 and archs[0] in _eagle3_archs:
loading_kwargs["config"] = update_config_for_eagle3(config)

# mxfp4 quantized model will be dequantized to bf16
Expand Down
6 changes: 6 additions & 0 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,6 +1006,12 @@ def _get_submodels_and_export_configs(
exporter,
)
stateful_per_model = [stateful] * len(models_for_export)

# VLM Eagle3 models need stateful KV cache despite model_type being "llama"
# (not in MULTI_MODAL_TEXT_GENERATION_MODELS) and task being "image-text-to-text".
if not stateful and getattr(export_config, "eagle3_vlm", False):
stateful_per_model = [True] * len(models_for_export)

return export_config, models_for_export, stateful_per_model


Expand Down
97 changes: 96 additions & 1 deletion optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,37 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
return self.random_float_tensor(shape, framework=framework, dtype=float_dtype)


class Eagle3VLMDummyGenerator(DummyInputGenerator):
"""
Dummy input generator for VLM Eagle-3 speculative decoding.

Produces `inputs_embeds` (float) and 3D `position_ids` (MRoPE)
required by VLM Eagle-3 draft models targeting Qwen3-VL.
"""

SUPPORTED_INPUT_NAMES = ("inputs_embeds", "position_ids")

def __init__(
self,
task: str,
normalized_config: NormalizedTextConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
**kwargs,
):
self.batch_size = batch_size
self.sequence_length = sequence_length
self.hidden_size = normalized_config.hidden_size

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
if input_name == "inputs_embeds":
shape = (self.batch_size, self.sequence_length, self.hidden_size)
return self.random_float_tensor(shape, framework=framework, dtype=float_dtype)
if input_name == "position_ids":
shape = (3, self.batch_size, self.sequence_length)
return self.random_int_tensor(shape, max_value=self.sequence_length, framework=framework, dtype=int_dtype)


@register_in_tasks_manager(
"llama",
*[
Expand All @@ -767,6 +798,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
"text-generation",
"text-generation-with-past",
"text-classification",
"image-text-to-text",
],
library_name="transformers",
)
Expand Down Expand Up @@ -794,19 +826,82 @@ def __init__(
)
archs = getattr(config, "architectures", None)
self.eagle3 = False
self.eagle3_vlm = False
if isinstance(archs, list) and len(archs) > 0 and "eagle3" in archs[0].lower():
self.DUMMY_INPUT_GENERATOR_CLASSES += (Eagle3DummyGenerator,)
self.MIN_TRANSFORMERS_VERSION = "4.54.0"
self.eagle3 = True
# VLM Eagle3 targets a VLM model (e.g. Qwen3-VL) and requires
# inputs_embeds instead of input_ids and 3D MRoPE position_ids.
target_model_type = getattr(config, "target_model_type", "")
modal_type = getattr(config, "modal_type", "")
if modal_type == "VLM" or target_model_type in {"qwen2_vl", "qwen3_vl"}:
self.eagle3_vlm = True
# VLM Eagle3 always needs KV cache for speculative decoding,
# regardless of whether the task includes "-with-past".
self.use_past = True
self.use_past_in_inputs = True
# Eagle3VLMDummyGenerator must precede DummyTextInputGenerator
# so it wins for inputs_embeds and position_ids generation.
self.DUMMY_INPUT_GENERATOR_CLASSES = (Eagle3VLMDummyGenerator,) + self.DUMMY_INPUT_GENERATOR_CLASSES + (Eagle3DummyGenerator,)
self.MIN_TRANSFORMERS_VERSION = "4.57.0"
else:
self.DUMMY_INPUT_GENERATOR_CLASSES += (Eagle3DummyGenerator,)

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = super().inputs
# Eagle3 model has additional conditional input
if self.eagle3:
common_inputs["hidden_states"] = {0: "batch_size", 1: "sequence_length", 2: "hidden_size"}
# VLM Eagle3 uses inputs_embeds (not input_ids) and 3D MRoPE position_ids
if self.eagle3_vlm:
common_inputs.pop("input_ids", None)
common_inputs["inputs_embeds"] = {0: "batch_size", 1: "sequence_length", 2: "hidden_size"}
common_inputs["position_ids"] = {0: "num_dims", 1: "batch_size", 2: "sequence_length"}
return common_inputs

def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
dummy_inputs_generators = self._create_dummy_input_generator_classes(**kwargs)

dummy_inputs = {}
input_names = [key for key in self.inputs.keys() if not key.startswith("past_key_values")]
if self.use_past_in_inputs and self.use_cache_branch is not False:
input_names.append("past_key_values")

for input_name in input_names:
input_was_inserted = False
for dummy_input_gen in dummy_inputs_generators:
if dummy_input_gen.supports_input(input_name):
dummy_inputs[input_name] = self.overwrite_shape_and_generate_input(
dummy_input_gen,
input_name,
framework,
input_shapes=kwargs,
)
input_was_inserted = True
break
if not input_was_inserted:
raise RuntimeError(
f'Could not generate dummy input for "{input_name}". Try adding a proper dummy input generator to the model ONNX config.'
)

if (
self.use_past_in_inputs
and self.PAD_ATTENTION_MASK_TO_PAST
and self.use_cache_branch is not False
and "attention_mask" in dummy_inputs
and self.task in ("text-generation", "image-text-to-text")
):
# VLM Eagle3 uses inputs_embeds instead of input_ids
main_input = dummy_inputs.get("input_ids", dummy_inputs.get("inputs_embeds"))
seq_len = main_input.shape[1]
past_seq_len = dummy_inputs["past_key_values"][0][1].shape[-2]
dummy_inputs["attention_mask"] = DummyInputGenerator.pad_input_on_dim(
dummy_inputs["attention_mask"], desired_length=past_seq_len + seq_len, dim=1
)

return dummy_inputs

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
common_outputs = super().outputs
Expand Down
142 changes: 142 additions & 0 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -7901,6 +7901,148 @@ def forward(
)


if is_transformers_version(">=", "4.57"):
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextRotaryEmbedding

class QwenVLEagle3Model(LlamaEagle3Model):
"""
Eagle-3 draft model with Qwen3-VL MRoPE for VLM speculative decoding.

Extends LlamaEagle3Model by replacing the standard rotary embedding with
Qwen3VLTextRotaryEmbedding, which supports interleaved multimodal RoPE
(MRoPE). This allows the draft model to handle position IDs compatible
with Qwen3-VL target models.

The forward signature is redefined to accept ``inputs_embeds`` as the
primary input (instead of ``input_ids``). This is critical because the
TorchScript tracer uses ``inspect.signature(model.forward)`` to determine
input parameter names and ordering. By removing ``input_ids`` from the
signature entirely, the traced/converted OpenVINO model will have
``inputs_embeds`` as an explicit input (float32 3D tensor).

``position_ids`` is kept as 3D ``[3, batch, seq]`` for MRoPE without
flattening.
"""

def __init__(self, config: LlamaConfig):
super().__init__(config)
# Replace standard rotary embedding with VLM-aware MRoPE embedding.
self.rotary_emb = Qwen3VLTextRotaryEmbedding(config=config)

def forward(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment linking to the original forward method and mentioning the differences please?

self,
inputs_embeds: Optional[torch.FloatTensor] = None,
hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> BaseModelOutputWithPast:
batch_size, seq_length, _ = hidden_states.shape
use_cache = use_cache if use_cache is not None else self.config.use_cache

if use_cache and past_key_values is None:
past_key_values = DynamicCache(config=self.config)

if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)

if position_ids is None:
position_ids = cache_position.unsqueeze(0)

causal_mask = create_causal_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)

if hidden_states is None:
hidden_states = torch.zeros(
[batch_size, seq_length, self.hidden_size],
dtype=inputs_embeds.dtype,
device=inputs_embeds.device,
)

inputs_embeds = inputs_embeds.to(hidden_states.dtype)
if hidden_states.shape[-1] != inputs_embeds.shape[-1]:
hidden_states = self.fc(hidden_states)

position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)

hidden_states = self.midlayer(
input_emb=inputs_embeds,
hidden_states=hidden_states,
attention_mask=causal_mask,
position_embeddings=position_embeddings,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=True,
)

hidden_states = self.norm(hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
)

class QwenVLEagle3ForCausalLM(LlamaEagle3ForCausalLM):
"""
Eagle-3 causal LM with Qwen3-VL MRoPE for VLM speculative decoding.

Uses QwenVLEagle3Model as the underlying model. The forward signature
is redefined to accept ``inputs_embeds`` instead of ``input_ids``,
ensuring the TorchScript tracer produces an OpenVINO model with
``inputs_embeds`` as a named input parameter.
"""

def __init__(self, config):
super().__init__(config)
self.model = QwenVLEagle3Model(config)

def forward(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same

self,
inputs_embeds: Optional[torch.FloatTensor] = None,
hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs,
) -> Eagle3Output:
outputs: BaseModelOutputWithPast = self.model(
inputs_embeds=inputs_embeds,
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)

hidden_states = outputs.last_hidden_state
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.model.lm_head(hidden_states[:, slice_indices, :])

d2t_out = self.identity(self.model.d2t)
return Eagle3Output(
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
d2t=d2t_out,
)


# Patched implementation of the gated delta rule in recurrent form.
# Adapted from:
# https://github.com/huggingface/transformers/blob/v4.57-release/src/transformers/models/qwen3_next/modeling_qwen3_next.py#L522
Expand Down
25 changes: 25 additions & 0 deletions tests/openvino/test_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from transformers.testing_utils import slow
from utils_tests import (
EAGLE3_MODELS,
EAGLE3_VLM_MODELS,
F32_CONFIG,
MODEL_NAMES,
OPENVINO_DEVICE,
Expand Down Expand Up @@ -872,3 +873,27 @@ def test_load_and_infer_with_eagle3_model(self, model_arch, model_pair):

del ov_model
gc.collect()

@parameterized.expand(EAGLE3_VLM_MODELS.items())
@pytest.mark.skipif(is_transformers_version("<", "4.57"), reason="Qwen3-VL Eagle3 requires transformers >= 4.57")
def test_load_and_infer_with_vlm_eagle3_model(self, model_arch, model_pair):
draft_model_id, target_model_id = model_pair

ov_model = OVModelForCausalLM.from_pretrained(draft_model_id, export=True, trust_remote_code=True)
self.assertIsInstance(ov_model.config, PretrainedConfig)
self.assertTrue(ov_model.use_cache)

tokenizer = AutoTokenizer.from_pretrained(target_model_id)
tokens = tokenizer("This is a sample output", return_tensors="pt")

ov_outputs = ov_model(**tokens)
self.assertTrue("logits" in ov_outputs)
self.assertIsInstance(ov_outputs.logits, torch.Tensor)

self.assertTrue("past_key_values" in ov_outputs)
self.assertIsInstance(ov_outputs.past_key_values, tuple)
self.assertEqual(ov_model.stateful, True)
self.assertTrue(len(ov_outputs.past_key_values) == 1 and len(ov_outputs.past_key_values[0]) == 0)

del ov_model
gc.collect()
Loading