diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index eb763b45d4..51e2a7cc66 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -99,9 +99,22 @@ def infer_task( except KeyError as e: try: config = AutoConfig.from_pretrained(model_name_or_path) - with_past_arch_list = ["MistralForCausalLM", "Zamba2ForCausalLM"] + with_past_arch_list = [ + "MistralForCausalLM", + "Zamba2ForCausalLM", + "LlamaForCausalLMEagle3", + "Eagle3LlamaForCausalLM", + ] if any(arch in config.architectures for arch in with_past_arch_list): - task = "text-generation-with-past" + # VLM Eagle3 models (targeting VLM architectures like Qwen3-VL) + # should use image-text-to-text task for proper inputs_embeds/3D position_ids export. + if "Eagle3LlamaForCausalLM" in config.architectures and ( + getattr(config, "modal_type", "") == "VLM" + or getattr(config, "target_model_type", "") in {"qwen2_vl", "qwen3_vl"} + ): + task = "image-text-to-text" + else: + task = "text-generation-with-past" except Exception: raise KeyError( f"The task could not be automatically inferred. Please provide the argument --task with the relevant task from {', '.join(TasksManager.get_all_tasks())}. Detailed error: {e}" @@ -144,9 +157,20 @@ def update_config_for_eagle3(config): spec = importlib.util.find_spec(moduler_name) if spec and spec.origin: moduler_path = os.path.dirname(spec.origin) + # Use VLM-aware Eagle3 classes for models targeting VLM architectures + # (e.g. AngelSlim/Qwen3-VL-4B-Instruct_eagle3 with Eagle3LlamaForCausalLM). + is_vlm_eagle3 = getattr(config, "modal_type", "") == "VLM" or getattr( + config, "target_model_type", "" + ) in {"qwen2_vl", "qwen3_vl"} + if is_vlm_eagle3: + model_cls = "QwenVLEagle3Model" + causal_lm_cls = "QwenVLEagle3ForCausalLM" + else: + model_cls = "LlamaEagle3Model" + causal_lm_cls = "LlamaEagle3ForCausalLM" config.auto_map = { - "AutoModel": moduler_path + "--model_patcher.LlamaEagle3Model", - "AutoModelForCausalLM": moduler_path + "--model_patcher.LlamaEagle3ForCausalLM", + "AutoModel": moduler_path + f"--model_patcher.{model_cls}", + "AutoModelForCausalLM": moduler_path + f"--model_patcher.{causal_lm_cls}", } config.tie_word_embeddings = False return config @@ -316,9 +340,10 @@ def main_export( quantization_config = getattr(config, "quantization_config", None) quant_method = quantization_config.get("quant_method", None) if quantization_config else None - # update config to load eagle3 models + # update config to load eagle3 models (both text-only and VLM variants) archs = getattr(config, "architectures", None) - if isinstance(archs, list) and len(archs) > 0 and archs[0] == "LlamaForCausalLMEagle3": + _eagle3_archs = {"LlamaForCausalLMEagle3", "Eagle3LlamaForCausalLM"} + if isinstance(archs, list) and len(archs) > 0 and archs[0] in _eagle3_archs: loading_kwargs["config"] = update_config_for_eagle3(config) # mxfp4 quantized model will be dequantized to bf16 diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index 60d90f53e0..3cd51114ff 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -1006,6 +1006,12 @@ def _get_submodels_and_export_configs( exporter, ) stateful_per_model = [stateful] * len(models_for_export) + + # VLM Eagle3 models need stateful KV cache despite model_type being "llama" + # (not in MULTI_MODAL_TEXT_GENERATION_MODELS) and task being "image-text-to-text". + if not stateful and getattr(export_config, "eagle3_vlm", False): + stateful_per_model = [True] * len(models_for_export) + return export_config, models_for_export, stateful_per_model diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index 3a5d3d08fa..952f6d5806 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -759,6 +759,37 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int return self.random_float_tensor(shape, framework=framework, dtype=float_dtype) +class Eagle3VLMDummyGenerator(DummyInputGenerator): + """ + Dummy input generator for VLM Eagle-3 speculative decoding. + + Produces `inputs_embeds` (float) and 3D `position_ids` (MRoPE) + required by VLM Eagle-3 draft models targeting Qwen3-VL. + """ + + SUPPORTED_INPUT_NAMES = ("inputs_embeds", "position_ids") + + def __init__( + self, + task: str, + normalized_config: NormalizedTextConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], + **kwargs, + ): + self.batch_size = batch_size + self.sequence_length = sequence_length + self.hidden_size = normalized_config.hidden_size + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + if input_name == "inputs_embeds": + shape = (self.batch_size, self.sequence_length, self.hidden_size) + return self.random_float_tensor(shape, framework=framework, dtype=float_dtype) + if input_name == "position_ids": + shape = (3, self.batch_size, self.sequence_length) + return self.random_int_tensor(shape, max_value=self.sequence_length, framework=framework, dtype=int_dtype) + + @register_in_tasks_manager( "llama", *[ @@ -767,6 +798,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int "text-generation", "text-generation-with-past", "text-classification", + "image-text-to-text", ], library_name="transformers", ) @@ -794,10 +826,26 @@ def __init__( ) archs = getattr(config, "architectures", None) self.eagle3 = False + self.eagle3_vlm = False if isinstance(archs, list) and len(archs) > 0 and "eagle3" in archs[0].lower(): - self.DUMMY_INPUT_GENERATOR_CLASSES += (Eagle3DummyGenerator,) self.MIN_TRANSFORMERS_VERSION = "4.54.0" self.eagle3 = True + # VLM Eagle3 targets a VLM model (e.g. Qwen3-VL) and requires + # inputs_embeds instead of input_ids and 3D MRoPE position_ids. + target_model_type = getattr(config, "target_model_type", "") + modal_type = getattr(config, "modal_type", "") + if modal_type == "VLM" or target_model_type in {"qwen2_vl", "qwen3_vl"}: + self.eagle3_vlm = True + # VLM Eagle3 always needs KV cache for speculative decoding, + # regardless of whether the task includes "-with-past". + self.use_past = True + self.use_past_in_inputs = True + # Eagle3VLMDummyGenerator must precede DummyTextInputGenerator + # so it wins for inputs_embeds and position_ids generation. + self.DUMMY_INPUT_GENERATOR_CLASSES = (Eagle3VLMDummyGenerator,) + self.DUMMY_INPUT_GENERATOR_CLASSES + (Eagle3DummyGenerator,) + self.MIN_TRANSFORMERS_VERSION = "4.57.0" + else: + self.DUMMY_INPUT_GENERATOR_CLASSES += (Eagle3DummyGenerator,) @property def inputs(self) -> Dict[str, Dict[int, str]]: @@ -805,8 +853,55 @@ def inputs(self) -> Dict[str, Dict[int, str]]: # Eagle3 model has additional conditional input if self.eagle3: common_inputs["hidden_states"] = {0: "batch_size", 1: "sequence_length", 2: "hidden_size"} + # VLM Eagle3 uses inputs_embeds (not input_ids) and 3D MRoPE position_ids + if self.eagle3_vlm: + common_inputs.pop("input_ids", None) + common_inputs["inputs_embeds"] = {0: "batch_size", 1: "sequence_length", 2: "hidden_size"} + common_inputs["position_ids"] = {0: "num_dims", 1: "batch_size", 2: "sequence_length"} return common_inputs + def generate_dummy_inputs(self, framework: str = "pt", **kwargs): + dummy_inputs_generators = self._create_dummy_input_generator_classes(**kwargs) + + dummy_inputs = {} + input_names = [key for key in self.inputs.keys() if not key.startswith("past_key_values")] + if self.use_past_in_inputs and self.use_cache_branch is not False: + input_names.append("past_key_values") + + for input_name in input_names: + input_was_inserted = False + for dummy_input_gen in dummy_inputs_generators: + if dummy_input_gen.supports_input(input_name): + dummy_inputs[input_name] = self.overwrite_shape_and_generate_input( + dummy_input_gen, + input_name, + framework, + input_shapes=kwargs, + ) + input_was_inserted = True + break + if not input_was_inserted: + raise RuntimeError( + f'Could not generate dummy input for "{input_name}". Try adding a proper dummy input generator to the model ONNX config.' + ) + + if ( + self.use_past_in_inputs + and self.PAD_ATTENTION_MASK_TO_PAST + and self.use_cache_branch is not False + and "attention_mask" in dummy_inputs + and self.task in ("text-generation", "image-text-to-text") + ): + # VLM Eagle3 uses inputs_embeds instead of input_ids + main_input = dummy_inputs.get("input_ids", dummy_inputs.get("inputs_embeds")) + seq_len = main_input.shape[1] + past_seq_len = dummy_inputs["past_key_values"][0][1].shape[-2] + dummy_inputs["attention_mask"] = DummyInputGenerator.pad_input_on_dim( + dummy_inputs["attention_mask"], desired_length=past_seq_len + seq_len, dim=1 + ) + + return dummy_inputs + @property def outputs(self) -> Dict[str, Dict[int, str]]: common_outputs = super().outputs diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index df0d350e52..89b2956ce6 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -7901,6 +7901,148 @@ def forward( ) +if is_transformers_version(">=", "4.57"): + from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextRotaryEmbedding + + class QwenVLEagle3Model(LlamaEagle3Model): + """ + Eagle-3 draft model with Qwen3-VL MRoPE for VLM speculative decoding. + + Extends LlamaEagle3Model by replacing the standard rotary embedding with + Qwen3VLTextRotaryEmbedding, which supports interleaved multimodal RoPE + (MRoPE). This allows the draft model to handle position IDs compatible + with Qwen3-VL target models. + + The forward signature is redefined to accept ``inputs_embeds`` as the + primary input (instead of ``input_ids``). This is critical because the + TorchScript tracer uses ``inspect.signature(model.forward)`` to determine + input parameter names and ordering. By removing ``input_ids`` from the + signature entirely, the traced/converted OpenVINO model will have + ``inputs_embeds`` as an explicit input (float32 3D tensor). + + ``position_ids`` is kept as 3D ``[3, batch, seq]`` for MRoPE without + flattening. + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + # Replace standard rotary embedding with VLM-aware MRoPE embedding. + self.rotary_emb = Qwen3VLTextRotaryEmbedding(config=config) + + def forward( + self, + inputs_embeds: Optional[torch.FloatTensor] = None, + hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> BaseModelOutputWithPast: + batch_size, seq_length, _ = hidden_states.shape + use_cache = use_cache if use_cache is not None else self.config.use_cache + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + if hidden_states is None: + hidden_states = torch.zeros( + [batch_size, seq_length, self.hidden_size], + dtype=inputs_embeds.dtype, + device=inputs_embeds.device, + ) + + inputs_embeds = inputs_embeds.to(hidden_states.dtype) + if hidden_states.shape[-1] != inputs_embeds.shape[-1]: + hidden_states = self.fc(hidden_states) + + position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) + + hidden_states = self.midlayer( + input_emb=inputs_embeds, + hidden_states=hidden_states, + attention_mask=causal_mask, + position_embeddings=position_embeddings, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=True, + ) + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + class QwenVLEagle3ForCausalLM(LlamaEagle3ForCausalLM): + """ + Eagle-3 causal LM with Qwen3-VL MRoPE for VLM speculative decoding. + + Uses QwenVLEagle3Model as the underlying model. The forward signature + is redefined to accept ``inputs_embeds`` instead of ``input_ids``, + ensuring the TorchScript tracer produces an OpenVINO model with + ``inputs_embeds`` as a named input parameter. + """ + + def __init__(self, config): + super().__init__(config) + self.model = QwenVLEagle3Model(config) + + def forward( + self, + inputs_embeds: Optional[torch.FloatTensor] = None, + hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, + ) -> Eagle3Output: + outputs: BaseModelOutputWithPast = self.model( + inputs_embeds=inputs_embeds, + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.model.lm_head(hidden_states[:, slice_indices, :]) + + d2t_out = self.identity(self.model.d2t) + return Eagle3Output( + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + d2t=d2t_out, + ) + + # Patched implementation of the gated delta rule in recurrent form. # Adapted from: # https://github.com/huggingface/transformers/blob/v4.57-release/src/transformers/models/qwen3_next/modeling_qwen3_next.py#L522 diff --git a/tests/openvino/test_decoder.py b/tests/openvino/test_decoder.py index a318874d3c..3fadd7b487 100644 --- a/tests/openvino/test_decoder.py +++ b/tests/openvino/test_decoder.py @@ -12,6 +12,7 @@ from transformers.testing_utils import slow from utils_tests import ( EAGLE3_MODELS, + EAGLE3_VLM_MODELS, F32_CONFIG, MODEL_NAMES, OPENVINO_DEVICE, @@ -872,3 +873,27 @@ def test_load_and_infer_with_eagle3_model(self, model_arch, model_pair): del ov_model gc.collect() + + @parameterized.expand(EAGLE3_VLM_MODELS.items()) + @pytest.mark.skipif(is_transformers_version("<", "4.57"), reason="Qwen3-VL Eagle3 requires transformers >= 4.57") + def test_load_and_infer_with_vlm_eagle3_model(self, model_arch, model_pair): + draft_model_id, target_model_id = model_pair + + ov_model = OVModelForCausalLM.from_pretrained(draft_model_id, export=True, trust_remote_code=True) + self.assertIsInstance(ov_model.config, PretrainedConfig) + self.assertTrue(ov_model.use_cache) + + tokenizer = AutoTokenizer.from_pretrained(target_model_id) + tokens = tokenizer("This is a sample output", return_tensors="pt") + + ov_outputs = ov_model(**tokens) + self.assertTrue("logits" in ov_outputs) + self.assertIsInstance(ov_outputs.logits, torch.Tensor) + + self.assertTrue("past_key_values" in ov_outputs) + self.assertIsInstance(ov_outputs.past_key_values, tuple) + self.assertEqual(ov_model.stateful, True) + self.assertTrue(len(ov_outputs.past_key_values) == 1 and len(ov_outputs.past_key_values[0]) == 0) + + del ov_model + gc.collect() diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index fe6d584d2f..055d21f5ad 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -232,6 +232,13 @@ EAGLE3_MODELS = {"qwen3_eagle3": ("AngelSlim/Qwen3-1.7B_eagle3", "Qwen/Qwen3-1.7B")} +# VLM-based Eagle3 draft models (AngelSlim Eagle3LlamaForCausalLM architecture). +# These use Qwen3-VL MRoPE and target VLM models for speculative decoding. +# Only used in the decoder test (not genai, since the VLM target needs image-text-to-text export). +EAGLE3_VLM_MODELS = { + "qwen3_vl_eagle3": ("AngelSlim/Qwen3-VL-4B-Instruct_eagle3", "Qwen/Qwen3-VL-4B-Instruct"), +} + _ARCHITECTURES_TO_EXPECTED_INT8 = { "afmoe": {"model": 16}, "bert": {"model": 68},