diff --git a/.github/workflows/test_openvino_preview_models.yml b/.github/workflows/test_openvino_preview_models.yml new file mode 100644 index 0000000000..0e162e1f42 --- /dev/null +++ b/.github/workflows/test_openvino_preview_models.yml @@ -0,0 +1,66 @@ +name: Preview Models Support Validation + +on: + workflow_dispatch: + push: + branches: + - main + - v*-release + pull_request: + branches: + - main + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +permissions: + contents: read + +env: + UV_TORCH_BACKEND: cpu + UV_SYSTEM_PYTHON: true + TRANSFORMERS_IS_CI: true + HF_TOKEN: ${{ secrets.HF_HUB_READ_TOKEN }} + +jobs: + build: + strategy: + fail-fast: false + matrix: + test-pattern: + [ + "*export*", + "*seq2seq*", + "*quantization*", + ] + + runs-on: ubuntu-22.04 + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: "3.10" + + - name: Install dependencies + run: | + pip install --upgrade pip uv + uv pip install .[tests] librosa diffusers + + - name: Login with fork PRs CI token + if: ${{ env.HF_TOKEN == '' }} + run: | + python tests/scripts/login_with_ci_token.py + + - name: Install latest openvino nightly + run: | + uv pip install --pre -U openvino openvino-tokenizers openvino-genai --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly + + - name: Gemma 4 Validation + run: | + uv pip install transformers==5.5.0 + pytest tests/openvino/${{ matrix.test-pattern }} -m gemma4 --durations=0 diff --git a/docs/source/openvino/models.mdx b/docs/source/openvino/models.mdx index acd58cade3..2c79b83398 100644 --- a/docs/source/openvino/models.mdx +++ b/docs/source/openvino/models.mdx @@ -73,6 +73,7 @@ Here is the list of the supported architectures : - Gemma - Gemma 2 - Gemma 3 +- Gemma 4 - GOT-OCR 2.0 - Granite - Granite 4.0 diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index 8592257cb3..36a83adfa7 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -14,9 +14,11 @@ import enum import logging +import math from copy import deepcopy from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +import torch from transformers import AutoConfig, PretrainedConfig, PreTrainedModel from optimum.exporters.onnx.config import OnnxConfig, TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig @@ -153,6 +155,8 @@ FluxTransfromerModelPatcher, Gemma2ModelPatcher, Gemma3LMModelPatcher, + Gemma4ImageEmbeddingsModelPatcher, + Gemma4LMModelPatcher, GptJModelPatcher, GptNeoModelPatcher, GptNeoxModelPatcher, @@ -1511,6 +1515,106 @@ class Gemma3TextOpenVINOConfig(Gemma2OpenVINOConfig): MIN_TRANSFORMERS_VERSION = "4.50.0" +class Gemma4DummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): + def __init__( + self, + task: str, + normalized_config: NormalizedTextConfig, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], + random_batch_size_range: Optional[Tuple[int, int]] = None, + random_sequence_length_range: Optional[Tuple[int, int]] = None, + **kwargs, + ): + super().__init__( + task=task, + normalized_config=normalized_config, + batch_size=batch_size, + sequence_length=sequence_length, + random_batch_size_range=random_batch_size_range, + random_sequence_length_range=random_sequence_length_range, + ) + self.num_key_value_heads = normalized_config.num_key_value_heads + self.head_dim = normalized_config.head_dim + self.global_head_dim = getattr(normalized_config.config, "global_head_dim", self.head_dim) + self.layer_types = normalized_config.config.layer_types + self.num_kv_shared_layers = normalized_config.config.num_kv_shared_layers + self.sliding_window = normalized_config.config.sliding_window + # Full-attention layers use fewer KV heads than sliding-attention layers (e.g. 2 vs 8 for 26B-A4B) + self.num_global_key_value_heads = ( + getattr(normalized_config.config, "num_global_key_value_heads", None) or self.num_key_value_heads + ) + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + # some layers do not produce their own KV-cache, they use the shared KV-cache + if self.num_kv_shared_layers > 0: + layer_types = self.layer_types[: -self.num_kv_shared_layers] + else: + layer_types = self.layer_types + past_kv_values = [] + for layer_type in layer_types: + if layer_type == "sliding_attention": + shape = ( + self.batch_size, + self.num_key_value_heads, + self.sliding_window, + self.head_dim, + ) + else: + shape = ( + self.batch_size, + self.num_global_key_value_heads, + self.sequence_length, + self.global_head_dim, + ) + past_kv_value = ( + self.random_float_tensor(shape, framework=framework, dtype=float_dtype), + self.random_float_tensor(shape, framework=framework, dtype=float_dtype), + ) + past_kv_values.append(past_kv_value) + + return past_kv_values + + +@register_in_tasks_manager( + "gemma4_text", + *[ + "feature-extraction", + "feature-extraction-with-past", + "text-generation", + "text-generation-with-past", + "text-classification", + ], + library_name="transformers", +) +class Gemma4TextOpenVINOConfig(Gemma3TextOpenVINOConfig): + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, Gemma4DummyPastKeyValuesGenerator) + DUMMY_PKV_GENERATOR_CLASS = Gemma4DummyPastKeyValuesGenerator + MIN_TRANSFORMERS_VERSION = "5.5" + + def add_past_key_values(self, inputs_or_outputs: dict[str, dict[int, str]], direction: str): + if direction not in ["inputs", "outputs"]: + raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') + + if direction == "inputs": + decoder_sequence_name = "past_sequence_length" + name = "past_key_values" + else: + decoder_sequence_name = "past_sequence_length + sequence_length" + name = "present" + + num_kv_shared_layers = self._normalized_config.config.num_kv_shared_layers + if num_kv_shared_layers > 0: + layer_types = self._normalized_config.config.layer_types[:-num_kv_shared_layers] + else: + layer_types = self._normalized_config.config.layer_types + + for i, layer_type in enumerate(layer_types): + inputs_or_outputs[f"{name}.{i}.key"] = {0: "batch_size", 2: decoder_sequence_name} + inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch_size", 2: decoder_sequence_name} + + class DeciDummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): def __init__( self, @@ -1753,6 +1857,16 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs): dummy_inputs["token_type_ids"] = self.orig_export_config.DUMMY_INPUT_GENERATOR_CLASSES[ 0 ].random_int_tensor(token_type_ids_shape, min_value=0, max_value=2) + if "per_layer_inputs" in self.inputs: + per_layer_inputs_shape = ( + input_ids.shape[0], + input_ids.shape[1], + self._normalized_config.config.num_hidden_layers, + self._normalized_config.config.hidden_size_per_layer_input, + ) + dummy_inputs["per_layer_inputs"] = self.orig_export_config.DUMMY_INPUT_GENERATOR_CLASSES[ + 0 + ].random_float_tensor(per_layer_inputs_shape) return dummy_inputs @@ -4227,6 +4341,248 @@ def with_behavior( return super().with_behavior(behavior) +class Gemma4ConfigBehavior(str, enum.Enum): + VISION_EMBEDDINGS = "vision_embeddings" + TEXT_EMBEDDINGS = "text_embeddings" + LANGUAGE = "language" + TEXT_EMBEDDINGS_PER_LAYER = "text_embeddings_per_layer" + + +class DummyGemma4VisionInputGenerator(DummyVisionInputGenerator): + SUPPORTED_INPUT_NAMES = ("pixel_values", "image_position_ids") + + def __init__(self, task, normalized_config, batch_size=DEFAULT_DUMMY_SHAPES["batch_size"], **kwargs): + super().__init__(task, normalized_config, batch_size, **kwargs) + self.patch_size = getattr(normalized_config, "patch_size", 16) + self.pooling_kernel_size = getattr(normalized_config, "pooling_kernel_size", 3) + # Gemma4 processor always pads pixel_values to max_soft_tokens * pooling_kernel_size^2 patches. + # The vision model's pooling uses shape-dependent Python operations that get baked in during tracing, + # so the dummy input must match the actual inference shapes. + max_soft_tokens = getattr(normalized_config, "image_seq_length", None) + if max_soft_tokens is None: + max_soft_tokens = getattr(normalized_config, "max_soft_tokens", 280) + self.num_patches = max_soft_tokens * self.pooling_kernel_size**2 + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + if input_name == "pixel_values": + # Gemma4 expects pre-patchified pixel_values: [batch, num_patches, 3 * patch_size^2] + return self.random_float_tensor( + shape=[self.batch_size, self.num_patches, 3 * self.patch_size**2], + framework=framework, + dtype=float_dtype, + ) + if input_name == "image_position_ids": + # Create position ids as a grid. The patch count = h_patches * w_patches + # where both are divisible by pooling_kernel_size for correct pooling. + k = self.pooling_kernel_size + total_pooled = self.num_patches // (k * k) + # Find roughly square grid for pooled side + pooled_side = int(math.sqrt(total_pooled)) + if pooled_side * pooled_side < total_pooled: + pooled_h = pooled_side + pooled_w = total_pooled // pooled_h + else: + pooled_h = pooled_w = pooled_side + h_patches = pooled_h * k + w_patches = pooled_w * k + pos_ids = torch.stack( + torch.meshgrid(torch.arange(h_patches), torch.arange(w_patches), indexing="ij"), dim=-1 + ).reshape(1, -1, 2) + # Pad to num_patches with -1 (padding position) + if pos_ids.shape[1] < self.num_patches: + pad = torch.full((1, self.num_patches - pos_ids.shape[1], 2), -1, dtype=pos_ids.dtype) + pos_ids = torch.cat([pos_ids, pad], dim=1) + return pos_ids.expand(self.batch_size, -1, -1).clone() + return super().generate(input_name, framework, int_dtype, float_dtype) + + +@register_in_tasks_manager("gemma4", *["image-text-to-text"], library_name="transformers") +class Gemma4OpenVINOConfig(Gemma3OpenVINOConfig): + SUPPORTED_BEHAVIORS = [model_type.value for model_type in Gemma4ConfigBehavior] + DUMMY_INPUT_GENERATOR_CLASSES = (DummyVisionInputGenerator, DummyTextInputGenerator) + + def __init__( + self, + config: "PretrainedConfig", + task: str = "feature-extraction", + int_dtype: str = "int64", + float_dtype: str = "fp32", + behavior: Gemma4ConfigBehavior = Gemma4ConfigBehavior.VISION_EMBEDDINGS, + preprocessors: Optional[List[Any]] = None, + ): + super().__init__( + config=config, + task=task, + int_dtype=int_dtype, + float_dtype=float_dtype, + preprocessors=preprocessors, + behavior=behavior, + ) + self._behavior = behavior + if self._behavior == Gemma4ConfigBehavior.VISION_EMBEDDINGS: + self.DUMMY_INPUT_GENERATOR_CLASSES = (DummyGemma4VisionInputGenerator,) + # Attach image_seq_length from preprocessor to normalized config so + # the dummy input generator can compute the correct number of patches. + # The vision model's pooling uses shape-dependent Python operations baked in + # during tracing, so the dummy input must match actual inference shapes. + image_seq_length = None + if preprocessors is not None: + for p in preprocessors: + if hasattr(p, "image_processor") and hasattr(p.image_processor, "image_seq_length"): + image_seq_length = p.image_processor.image_seq_length + break + if hasattr(p, "image_processor") and hasattr(p.image_processor, "max_soft_tokens"): + image_seq_length = p.image_processor.max_soft_tokens + break + if image_seq_length is None: + for p in preprocessors: + if hasattr(p, "max_soft_tokens"): + image_seq_length = p.max_soft_tokens + break + if hasattr(p, "image_seq_length"): + image_seq_length = p.image_seq_length + break + if image_seq_length is not None: + self._normalized_config.image_seq_length = image_seq_length + elif self._behavior in ( + Gemma4ConfigBehavior.TEXT_EMBEDDINGS, + Gemma4ConfigBehavior.TEXT_EMBEDDINGS_PER_LAYER, + ): + self.DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator,) + self._config = config.text_config + self._normalized_config = NormalizedTextConfig(self._config) + + def with_behavior(self, behavior: Union[str, Gemma4ConfigBehavior]): + if isinstance(behavior, str) and not isinstance(behavior, Gemma4ConfigBehavior): + behavior = Gemma4ConfigBehavior(behavior) + + if behavior == Gemma4ConfigBehavior.LANGUAGE: + model_type = "gemma4_text" + inputs_update = { + "per_layer_inputs": {0: "batch_size", 1: "sequence_length", 2: "num_hidden_layers"}, + } + if getattr(self._orig_config.get_text_config(), "use_bidirectional_attention", None) == "vision": + inputs_update["token_type_ids"] = {0: "batch_size", 1: "sequence_length"} + return get_vlm_text_generation_config( + model_type, + self._orig_config.text_config, + self.int_dtype, + self.float_dtype, + model_patcher=Gemma4LMModelPatcher, + inputs_update=inputs_update, + ) + if behavior == Gemma4ConfigBehavior.TEXT_EMBEDDINGS_PER_LAYER: + config = self.__class__( + self._orig_config, + task=self.task, + int_dtype=self.int_dtype, + float_dtype=self.float_dtype, + behavior=behavior, + preprocessors=self._preprocessors, + ) + return config + return super().with_behavior(behavior) + + def get_model_for_behavior(self, model, behavior: Union[str, VLMConfigBehavior]): + if behavior == Gemma4ConfigBehavior.TEXT_EMBEDDINGS_PER_LAYER: + import torch + + class PerLayerInputsModule(torch.nn.Module): + def __init__(self, language_model, vocab_size_per_layer_input: int, config): + super().__init__() + self.language_model = language_model + self.vocab_size_per_layer_input = vocab_size_per_layer_input + self.config = config + + def forward(self, input_ids: torch.Tensor): + # 26B-A4B has hidden_size_per_layer_input=0 (PLE disabled) + if self.language_model.config.hidden_size_per_layer_input <= 0: + return torch.zeros( + input_ids.shape[0], + input_ids.shape[1], + self.language_model.config.num_hidden_layers, + 0, + dtype=torch.float32, + ) + # Replace multimodal token IDs with pad_token_id to match + # HF Gemma4Model.forward which uses llm_input_ids where + # image/video/audio positions are set to pad_token_id + pad_token_id = self.config.text_config.pad_token_id + per_layer_inputs_tokens = input_ids.clone() + for token_id_attr in ("image_token_id", "video_token_id", "audio_token_id"): + token_id = getattr(self.config, token_id_attr, None) + if token_id is not None: + per_layer_inputs_tokens = torch.where( + per_layer_inputs_tokens == token_id, + torch.full_like(per_layer_inputs_tokens, pad_token_id), + per_layer_inputs_tokens, + ) + per_layer_inputs_mask = torch.logical_and( + per_layer_inputs_tokens >= 0, + per_layer_inputs_tokens < self.vocab_size_per_layer_input, + ) + per_layer_inputs_tokens = torch.where( + per_layer_inputs_mask, + per_layer_inputs_tokens, + torch.zeros_like(per_layer_inputs_tokens), + ) + per_layer_inputs = self.language_model.get_per_layer_inputs(per_layer_inputs_tokens, None) + return per_layer_inputs + + model = PerLayerInputsModule( + model.model.language_model, model.config.text_config.vocab_size_per_layer_input, model.config + ) + return model + if behavior == VLMConfigBehavior.VISION_EMBEDDINGS: + return model + if behavior == VLMConfigBehavior.TEXT_EMBEDDINGS: + import torch + + class TextEmbeddingsModule(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, input_ids: torch.Tensor): + inputs_embeds = self.model.get_input_embeddings()(input_ids) + return inputs_embeds + + text_embedding = TextEmbeddingsModule(model) + text_embedding.config = model.model.language_model.config + return text_embedding + + return super().get_model_for_behavior(model, behavior) + + def patch_model_for_export(self, model, model_kwargs=None): + model_kwargs = model_kwargs or {} + if self._behavior == Gemma4ConfigBehavior.TEXT_EMBEDDINGS_PER_LAYER: + return ModelPatcher(self, model, model_kwargs) + if self._behavior == VLMConfigBehavior.VISION_EMBEDDINGS: + return Gemma4ImageEmbeddingsModelPatcher(self, model, model_kwargs) + return super().patch_model_for_export(model, model_kwargs) + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + if self._behavior == Gemma4ConfigBehavior.LANGUAGE: + return super().inputs + if self._behavior == Gemma4ConfigBehavior.TEXT_EMBEDDINGS_PER_LAYER: + return { + "input_ids": {0: "batch_size", 1: "sequence_length"}, + } + if self._behavior == Gemma4ConfigBehavior.VISION_EMBEDDINGS: + return { + "pixel_values": {0: "batch_size", 1: "num_patches"}, + "image_position_ids": {0: "batch_size", 1: "num_patches"}, + } + return super().inputs + + @property + def outputs(self) -> Dict[str, Dict[int, str]]: + if self._behavior == Gemma4ConfigBehavior.TEXT_EMBEDDINGS_PER_LAYER: + return {"text_embeds_per_layer": {}} + return super().outputs + + class DummyVisionPositionIdsInputGenerator(DummyVisionInputGenerator): SUPPORTED_INPUT_NAMES = ("patch_attention_mask", "patch_position_ids") diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index 63d8ffc9db..59e40f0e8a 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -52,7 +52,9 @@ ModelPatcher, gpt_oss_forward, override_arguments, - sdpa_mask_without_vmap, +) +from optimum.exporters.onnx.model_patcher import ( + sdpa_mask_without_vmap as _orig_sdpa_mask_without_vmap, ) from optimum.intel.utils.import_utils import ( is_diffusers_version, @@ -111,6 +113,28 @@ def _get_model_attribute(model, name): return getattr(target, name) +# Compatibility wrapper for sdpa_mask_without_vmap from optimum. +# The installed optimum version expects (batch_size, cache_position: Tensor, kv_length, ...), +# but transformers >= 5.5 passes (batch_size, q_length: int, kv_length: int, q_offset: int, ...). +def sdpa_mask_without_vmap(batch_size, q_length=None, kv_length=None, q_offset=0, kv_offset=0, **kwargs): + import inspect + + sig = inspect.signature(_orig_sdpa_mask_without_vmap) + if is_transformers_version(">=", "5.5") and "cache_position" in sig.parameters and q_length is not None: + # Old optimum signature: (batch_size, cache_position, kv_length, kv_offset, ...) + cache_position = torch.arange(q_length, dtype=torch.long) + q_offset + kwargs.pop("q_offset", None) + kwargs.pop("allow_is_bidirectional_skip", None) + kwargs.pop("allow_torch_fix", None) + kwargs.pop("use_vmap", None) + kwargs.pop("device", None) + return _orig_sdpa_mask_without_vmap(batch_size, cache_position, kv_length, kv_offset=kv_offset, **kwargs) + else: + return _orig_sdpa_mask_without_vmap( + batch_size, q_length=q_length, kv_length=kv_length, q_offset=q_offset, kv_offset=kv_offset, **kwargs + ) + + for idx, spec in enumerate(UNSUPPORTED_OPS_PATCHING_SPEC): if spec.name in { # onnx-exporter-specific fixes @@ -4900,6 +4924,360 @@ def __exit__(self, exc_type, exc_value, traceback): del self._model.model._orig_update_causual_mask +# Creates a dict of causal masks with bidirectional attention for vision tokens +# on sliding_attention layers, matching the behavior of transformers +# create_causal_mask_mapping when use_bidirectional_attention == "vision". +# Needs to be patched to pass proper 'sliding_mask' for prefill stage. +# Original code: https://github.com/huggingface/transformers/blob/v5.5.0/src/transformers/models/gemma4/modeling_gemma4.py#L1986 +def _create_gemma4_bidirectional_mask_dict(attention_mask_2d, mm_token_type_ids, inputs_embeds, sliding_window): + dtype = inputs_embeds.dtype + device = inputs_embeds.device + min_dtype = torch.finfo(dtype).min + + batch_size = inputs_embeds.shape[0] + seq_len = inputs_embeds.shape[1] + target_len = attention_mask_2d.shape[-1] + past_len = target_len - seq_len + + # Standard causal mask [seq_len, target_len] + causal_mask = torch.full((seq_len, target_len), min_dtype, dtype=dtype, device=device) + causal_mask = torch.triu(causal_mask, diagonal=past_len + 1) + + # Apply padding from attention_mask_2d + padding_mask = (1.0 - attention_mask_2d[:, None, None, :].to(dtype=dtype, device=device)) * min_dtype + full_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + padding_mask + mm_token_type_ids = torch.nn.functional.pad( + mm_token_type_ids, (0, target_len - mm_token_type_ids.shape[-1]), value=0 + ) + + # Sliding window causal mask + sliding_mask = full_mask.clone() + row_pos = torch.arange(seq_len, device=device).unsqueeze(1) + past_len + col_pos = torch.arange(target_len, device=device).unsqueeze(0) + beyond_window = (row_pos - col_pos) >= sliding_window + sliding_mask = sliding_mask.masked_fill(beyond_window[None, None, :, :], min_dtype) + + # Apply bidirectional masking for vision tokens (only on sliding_attention mask) + # mm_token_type_ids: [batch, total_len] - 0=text, 1=image, 2=video/audio + is_vision = (mm_token_type_ids == 1) | (mm_token_type_ids == 2) + + # Group contiguous vision tokens (trace-friendly, no in-place ops) + # Shift is_vision right by 1 position, padding with False on the left + is_prev_vision = torch.nn.functional.pad(is_vision[:, :-1].to(dtype=torch.int32), (1, 0), value=0).bool() + new_vision_starts = is_vision & ~is_prev_vision + vision_group_ids = torch.cumsum(new_vision_starts.to(dtype=torch.int32), dim=1) - 1 + vision_group_ids = torch.where(is_vision, vision_group_ids, torch.tensor(-1, dtype=torch.int32, device=device)) + + # Query group IDs correspond to positions [past_len : past_len + seq_len] + query_groups = vision_group_ids[:, past_len : past_len + seq_len] # [batch, seq_len] + key_groups = vision_group_ids # [batch, total_len] + + # same_group[b, q, k] = True iff query and key are in the same non-text vision group + same_group = (query_groups.unsqueeze(2) == key_groups.unsqueeze(1)) & (key_groups.unsqueeze(1) >= 0) + same_group = same_group.unsqueeze(1) # [batch, 1, seq_len, total_len] + + # Undo masking for same-group vision tokens in sliding mask + sliding_mask = sliding_mask.masked_fill(same_group, 0.0) + + return { + "full_attention": full_mask, + "sliding_attention": sliding_mask, + } + + +# Forward method of the language model of Gemma4, needs to be patched to pass 'per_layer_inputs', +# as original code fails to create per_layer_inputs without the providing of input_ids, +# while OV language model expects only inputs_embeds without input_ids. +# Original code: https://github.com/huggingface/transformers/blob/v5.5.0/src/transformers/models/gemma4/modeling_gemma4.py#L2152 +def gemma4_language_model_forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + input_features: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + input_features_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + mm_token_type_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + per_layer_inputs=None, + **lm_kwargs, +): + from transformers.models.gemma4.modeling_gemma4 import Gemma4ModelOutputWithPast + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # Merge text and images + if pixel_values is not None: + image_features = self.get_image_features(pixel_values) + if hasattr(image_features, "pooler_output"): + image_features = image_features.pooler_output + image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + _, special_image_mask, _, _ = self.model.get_placeholder_mask(mm_token_type_ids, input_ids, inputs_embeds) + special_image_mask_expanded = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask_expanded, image_features) + + # Create bidirectional causal mask mapping when use_bidirectional_attention == "vision" + use_bidirectional = getattr(self.config.get_text_config(), "use_bidirectional_attention", None) == "vision" + if use_bidirectional and mm_token_type_ids is not None: + attention_mask = _create_gemma4_bidirectional_mask_dict( + attention_mask, + mm_token_type_ids, + inputs_embeds, + self.model.language_model.config.sliding_window, + ) + + outputs = self.model.language_model( + input_ids=None, + per_layer_inputs=per_layer_inputs, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **lm_kwargs, + ) + + return Gemma4ModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values if use_cache else None, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features if pixel_values is not None else None, + ) + + +# Gemma4 model forward, needs to be patched to pass 'per_layer_inputs', +# Original code: https://github.com/huggingface/transformers/blob/v5.5.0/src/transformers/models/gemma4/modeling_gemma4.py#L2396 +def gemma4_lm_forward( + self, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + per_layer_inputs=None, + token_type_ids: Optional[torch.LongTensor] = None, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + input_features: Optional[torch.FloatTensor] = None, + input_features_mask: Optional[torch.Tensor] = None, + cache_position: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **lm_kwargs, +): + from optimum.exporters.onnx.model_patcher import preprocess_past_key_values + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = False + + if past_key_values is not None: + use_cache = True + past_key_values = preprocess_past_key_values(past_key_values) + + outputs = self.model( + input_ids=input_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + input_features=input_features, + attention_mask=attention_mask, + input_features_mask=input_features_mask, + position_ids=position_ids, + past_key_values=past_key_values, + mm_token_type_ids=token_type_ids, + cache_position=cache_position, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + per_layer_inputs=per_layer_inputs, + **lm_kwargs, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + tmp_logits = self.lm_head(hidden_states[:, slice_indices, :]) + if (final_logit_softcapping := self.config.get_text_config().final_logit_softcapping) is not None: + tmp_logits = tmp_logits / final_logit_softcapping + tmp_logits = torch.tanh(tmp_logits) + tmp_logits = tmp_logits * final_logit_softcapping + + outputs_dict = { + "logits": tmp_logits, + } + + if use_cache: + key_values = outputs.past_key_values + present_key_values = postprocess_past_key_values(key_values) + outputs_dict["past_key_values"] = present_key_values + return tuple([value if not isinstance(value, list) else tuple(value) for value in outputs_dict.values()]) + + +# Needs to be patched to reshape 'attention_mask' to match attention weights +# Original code: https://github.com/huggingface/transformers/blob/v5.5.0/src/transformers/models/gemma4/modeling_gemma4.py#L768 +def gemma4_eager_attention_forward_patched( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + dropout: float = 0.0, + scaling: Optional[float] = None, + softcap: Optional[float] = None, + **kwargs, +) -> tuple: + if scaling is None: + scaling = module.head_dim**-0.5 + + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + + if softcap is not None: + attn_weights = attn_weights / softcap + attn_weights = torch.tanh(attn_weights) + attn_weights = attn_weights * softcap + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +# Needs to be patched to run methods 'gemma4_eager_attention_forward_patched' instead of original one +# Original code: https://github.com/huggingface/transformers/blob/v5.5.0/src/transformers/models/gemma4/modeling_gemma4.py#L1179 +def gemma4_text_attention_forward( + self, + hidden_states: torch.Tensor, + position_embeddings: torch.Tensor, + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, +) -> tuple: + from transformers.models.gemma4.modeling_gemma4 import apply_rotary_pos_emb as apply_rotary_pos_emb_gemma4 + + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + cos, sin = position_embeddings + + query_states = self.q_proj(hidden_states).view(hidden_shape) + query_states = self.q_norm(query_states) + query_states = apply_rotary_pos_emb_gemma4(query_states, cos, sin, unsqueeze_dim=2) + query_states = query_states.transpose(1, 2) + + if self.is_kv_shared_layer and past_key_values is not None: + key_states, value_states = past_key_values.shared_layers[self.kv_shared_layer_index] + key_states = key_states.to(query_states.device) + value_states = value_states.to(query_states.device) + else: + key_states = self.k_proj(hidden_states).view(hidden_shape) + value_states = self.v_proj(hidden_states).view(hidden_shape) if self.v_proj is not None else key_states + + key_states = self.k_norm(key_states) + key_states = apply_rotary_pos_emb_gemma4(key_states, cos, sin, unsqueeze_dim=2) + key_states = key_states.transpose(1, 2) + + value_states = self.v_norm(value_states) + value_states = value_states.transpose(1, 2) + + if past_key_values is not None: + cache_kwargs = { + "sin": sin, + "cos": cos, + "cache_position": cache_position, + "sliding_window": self.sliding_window, + } + if not self.is_kv_shared_layer: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + if self.store_full_length_kv: + if not hasattr(past_key_values, "shared_layers"): + past_key_values.shared_layers = {} + past_key_values.shared_layers[self.layer_idx] = key_states, value_states + + attention_interface = gemma4_eager_attention_forward_patched + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=self.attention_dropout if self.training else 0.0, + scaling=self.scaling, + sliding_window=self.sliding_window, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class Gemma4LMModelPatcher(Gemma3LMModelPatcher): + def __init__(self, config, model, model_kwargs): + super().__init__(config, model, model_kwargs) + + self.patched_forward = gemma4_lm_forward + self.model_orig_forward = self.orig_forward + self.orig_forward = gemma4_lm_forward + + self.model_orig_language_model_forward = self._model.model.forward + + def __enter__(self): + super().__enter__() + + setattr(self._model, self.orig_forward_name, types.MethodType(gemma4_lm_forward, self._model)) + setattr(self._model.model, "forward", types.MethodType(gemma4_language_model_forward, self._model)) + for decoder_layer in self._model.model.language_model.layers: + decoder_layer.self_attn.orig_forward = decoder_layer.self_attn.forward + decoder_layer.self_attn.forward = types.MethodType(gemma4_text_attention_forward, decoder_layer.self_attn) + if hasattr(decoder_layer, "experts"): + decoder_layer.experts._orig_forward = decoder_layer.experts.forward + decoder_layer.experts.forward = types.MethodType(lfm2_moe_experts_forward, decoder_layer.experts) + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + + for decoder_layer in self._model.model.language_model.layers: + decoder_layer.self_attn.forward = decoder_layer.self_attn.orig_forward + if hasattr(decoder_layer, "experts") and hasattr(decoder_layer.experts, "_orig_forward"): + decoder_layer.experts.forward = decoder_layer.experts._orig_forward + + setattr(self._model, self.orig_forward_name, self.model_orig_forward) + setattr(self._model.model, "forward", self.model_orig_language_model_forward) + + class Idefics3ImageEmbeddingsModelPatcher(ModelPatcher): def __init__( self, @@ -6610,7 +6988,10 @@ def __init__( model: "PreTrainedModel", model_kwargs: Optional[Dict[str, Any]] = None, ): - from transformers.models.mamba.modeling_mamba import MambaCache + try: + from transformers.models.mamba.modeling_mamba import MambaCache + except ImportError: + MambaCache = object super().__init__(config, model, model_kwargs) @@ -8498,6 +8879,117 @@ def __exit__(self, exc_type, exc_value, traceback): del sparse_moe_block.down_projs, sparse_moe_block.gate_projs, sparse_moe_block.up_projs +class Gemma4PerLayerInputsGetterModelPatcher(ModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel"], + model_kwargs: Dict[str, Any] = None, + ): + model.__orig_forward = model.forward + + def per_layer_inputs_forward(self, input_ids: torch.Tensor) -> torch.Tensor: + per_layer_inputs_mask = torch.logical_and(input_ids >= 0, input_ids < self.vocab_size_per_layer_input) + per_layer_inputs_tokens = torch.where(per_layer_inputs_mask, input_ids, torch.zeros_like(input_ids)) + per_layer_inputs = self.language_model.get_per_layer_inputs(per_layer_inputs_tokens, None) + return per_layer_inputs + + model.forward = types.MethodType(per_layer_inputs_forward, model) + super().__init__(config, model, model_kwargs) + + def __enter__(self): + super().__enter__() + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + self._model.forward = self._model.__orig_forward + + +# OpenVINO has a bug due to which Clamp(-inf, inf) doesn't work correctly: CVS-185473. +# When min == -inf and max == inf, Clamp is equivalent to an identity operation and +# can be removed from the model, which serves as a workaround for the issue. +def patched_gemma4_clippable_linear_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.linear(hidden_states) + return hidden_states + + +class Gemma4ImageEmbeddingsModelPatcher(CommonImageEmbeddingsModelPatcher): + def __init__(self, config, model, model_kwargs): + super().__init__(config, model, model_kwargs) + from transformers.models.gemma4.modeling_gemma4 import Gemma4ClippableLinear + + # Get the vision encoder - it's at model.model.vision_tower.encoder + vision_model = model.model.vision_tower if is_transformers_version(">=", "5") else model.vision_tower + self._vision_encoder = vision_model.encoder + + # Patch the vision encoder forward to bypass create_bidirectional_mask, + # which is not compatible with torch.jit.trace due to dynamic masking logic. + # Instead, we construct a simple 4D bidirectional attention mask from the + # 2D padding mask to properly mask out padding patches. + orig_encoder_forward = self._vision_encoder.forward + + def patched_encoder_forward(inputs_embeds, attention_mask=None, pixel_position_ids=None, **kwargs): + hidden_states = inputs_embeds + position_embeddings = self._vision_encoder.rotary_emb(hidden_states, pixel_position_ids) + + # Build a 4D bidirectional attention mask from the 2D boolean mask. + # attention_mask is [batch, seq_len] with True=valid, False=padding. + # Decoder layers expect a 4D mask [batch, 1, seq_len, seq_len] where + # 0 = attend and large negative = masked. + attn_mask_4d = None + if attention_mask is not None: + min_dtype = torch.finfo(hidden_states.dtype).min + # [batch, 1, 1, seq_len] key mask + key_mask = attention_mask[:, None, None, :].to(hidden_states.dtype) + # Convert: 1.0 for valid tokens, min_dtype for padding + attn_mask_4d = (1.0 - key_mask) * min_dtype + + for decoder_layer in self._vision_encoder.layers[: self._vision_encoder.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=attn_mask_4d, + position_embeddings=position_embeddings, + position_ids=pixel_position_ids, + **kwargs, + ) + + from transformers.modeling_outputs import BaseModelOutputWithPast + + return BaseModelOutputWithPast(last_hidden_state=hidden_states) + + self._orig_encoder_forward = orig_encoder_forward + self._vision_encoder.forward = patched_encoder_forward + + for layer in self._vision_encoder.layers: + for module in layer.modules(): + if isinstance(module, Gemma4ClippableLinear) and module.use_clipped_linears: + if ( + module.input_min == -float("inf") + and module.input_max == float("inf") + and module.output_min == -float("inf") + and module.output_max == float("inf") + ): + module.orig_forward = module.forward + module.forward = types.MethodType(patched_gemma4_clippable_linear_forward, module) + + def __exit__(self, exc_type, exc_value, traceback): + from transformers.models.gemma4.modeling_gemma4 import Gemma4ClippableLinear + + self._vision_encoder.forward = self._orig_encoder_forward + super().__exit__(exc_type, exc_value, traceback) + + for layer in self._vision_encoder.layers: + for module in layer.modules(): + if isinstance(module, Gemma4ClippableLinear) and module.use_clipped_linears: + if ( + module.input_min == -float("inf") + and module.input_max == float("inf") + and module.output_min == -float("inf") + and module.output_max == float("inf") + ): + module.forward = module.orig_forward + + # Patches the MoE block with a vectorized implementation. # The vectorized form is required to ensure correct torch.jit tracing for this component. # Original implementation: https://github.com/huggingface/transformers/blob/v5.0.0/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py#L167 diff --git a/optimum/exporters/openvino/utils.py b/optimum/exporters/openvino/utils.py index 272384fb48..97de7176e2 100644 --- a/optimum/exporters/openvino/utils.py +++ b/optimum/exporters/openvino/utils.py @@ -297,6 +297,7 @@ def get_submodels(model): "qwen3_vl", "got_ocr2", "gemma3", + "gemma4", "idefics3", "smolvlm", "phi4mm", diff --git a/optimum/intel/openvino/configuration.py b/optimum/intel/openvino/configuration.py index f1a1044ebf..1531d6e4aa 100644 --- a/optimum/intel/openvino/configuration.py +++ b/optimum/intel/openvino/configuration.py @@ -436,6 +436,18 @@ class OVQuantizationMethod(str, Enum): "dataset": "contextual", "scale_estimation": True, }, + "google/gemma-4-26B-A4B-it": { + "bits": 4, + "sym": False, + "group_size": 64, + "group_size_fallback": "adjust", + }, + "google/gemma-4-26B-A4B": { + "bits": 4, + "sym": False, + "group_size": 64, + "group_size_fallback": "adjust", + }, } _DEFAULT_8BIT_WQ_CONFIGS = { @@ -567,6 +579,16 @@ class OVQuantizationMethod(str, Enum): ], }, }, + "google/gemma-4-26B-A4B-it": { + "lm_model": { + "patterns": [".*router.*"], + }, + }, + "google/gemma-4-26B-A4B": { + "lm_model": { + "patterns": [".*router.*"], + }, + }, } diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 5530e135fe..a74582b2c1 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -31,15 +31,21 @@ from transformers.generation.stopping_criteria import StoppingCriteriaList from transformers.generation.utils import GenerateOutput, GenerationMode from transformers.modeling_outputs import CausalLMOutputWithPast, ModelOutput -from transformers.models.mamba.modeling_mamba import MambaCache from transformers.utils.hub import PushToHubMixin +from ..utils.import_utils import compare_versions, is_transformers_version + + +if is_transformers_version("<", "5.5"): + from transformers.models.mamba.modeling_mamba import MambaCache +else: + MambaCache = object + from optimum.utils.normalized_config import NormalizedConfigManager from ...exporters.openvino import ensure_stateful_is_available, main_export, patch_stateful from ...exporters.openvino.stateful import model_has_state from ...exporters.openvino.utils import SSM_MODELS -from ..utils.import_utils import compare_versions from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS from .configuration import ( OVConfig, diff --git a/optimum/intel/openvino/modeling_visual_language.py b/optimum/intel/openvino/modeling_visual_language.py index beb7b974eb..b376a58735 100644 --- a/optimum/intel/openvino/modeling_visual_language.py +++ b/optimum/intel/openvino/modeling_visual_language.py @@ -218,6 +218,13 @@ def prepare_inputs( inputs["beam_idx"] = ( self.next_beam_idx if self.next_beam_idx is not None else np.arange(batch_size, dtype=int) ) + + if "per_layer_inputs" in self.input_names: + per_layer_inputs = kwargs.pop("per_layer_inputs", None) + if per_layer_inputs is None: + raise ValueError("Expected 'per_layer_inputs', but it was not passed") + inputs["per_layer_inputs"] = torch.Tensor(per_layer_inputs) + return inputs def forward( @@ -347,6 +354,7 @@ def forward(self, audio_feature, audio_mask): MODEL_PARTS_CLS_MAPPING = { "resampler": OVResampler, "language_model": OVModelWithEmbedForCausalLM, + "text_embeddings_per_layer": OVVisionProjection, "vision_embeddings": OVVisionEmbedding, "vision_projection": OVVisionProjection, "vision_resampler": OVVisionResampler, @@ -785,6 +793,9 @@ def forward( additional_kwargs["visual_pos_masks"] = extra_outputs[0] additional_kwargs["deepstack_visual_embeds"] = extra_outputs[1] + if self.config.model_type in ("gemma4",) and extra_outputs: + additional_kwargs["per_layer_inputs"] = extra_outputs[0] + return self.language_model.forward( input_ids=None, inputs_embeds=inputs_embeds, @@ -3937,6 +3948,110 @@ def _update_model_kwargs_for_generation( return model_kwargs +class _OVGemma4ForCausalLM(_OVGemma3ForCausalLM): + additional_parts = ["text_embeddings_per_layer"] + + def get_vision_embeddings(self, pixel_values, input_ids=None, **kwargs): + if input_ids is not None and input_ids.shape[1] == 1: + return None + return self.vision_embeddings(pixel_values, **kwargs).last_hidden_state + + def get_multimodal_embeddings( + self, input_ids, pixel_values=None, attention_mask=None, position_ids=None, **kwargs + ): + embeds_from_args = kwargs.pop("inputs_embeds", None) + inputs_embeds = ( + embeds_from_args if embeds_from_args is not None else self.get_text_embeddings(input_ids, **kwargs) + ) + per_layer_inputs = self.text_embeddings_per_layer(input_ids) + if pixel_values is not None: + vision_embeds = self.get_vision_embeddings(pixel_values, input_ids=input_ids, **kwargs) + + if vision_embeds is not None: + inputs_embeds, attention_mask, position_ids = self.merge_vision_text_embeddings( + vision_embeds, + inputs_embeds, + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + **kwargs, + ) + return inputs_embeds, attention_mask, position_ids, per_layer_inputs + + def merge_vision_text_embeddings( + self, vision_embeds, inputs_embeds, input_ids=None, attention_mask=None, position_ids=None, **kwargs + ): + image_features = torch.from_numpy(vision_embeds) if isinstance(vision_embeds, np.ndarray) else vision_embeds + inputs_embeds = torch.from_numpy(inputs_embeds) if isinstance(inputs_embeds, np.ndarray) else inputs_embeds + if input_ids is None: + special_image_mask = inputs_embeds == torch.from_numpy( + self.get_text_embeddings(torch.tensor([[self.config.image_token_id]], dtype=torch.long))[0] + ) + else: + special_image_mask = (input_ids == self.config.image_token_id).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(inputs_embeds) + + image_features = image_features.to(inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) + + return inputs_embeds, attention_mask, position_ids + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + inputs_embeds=None, + pixel_values=None, + image_sizes=None, + attention_mask=None, + mm_token_type_ids=None, + image_position_ids=None, + **kwargs, + ): + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + pixel_values=pixel_values, + image_sizes=image_sizes, + attention_mask=attention_mask, + **kwargs, + ) + # Map mm_token_type_ids to token_type_ids for the OV language model input + model_inputs["token_type_ids"] = mm_token_type_ids + model_inputs["image_position_ids"] = image_position_ids + return model_inputs + + def forward(self, input_ids, pixel_values=None, token_type_ids=None, **kwargs): + # Map mm_token_type_ids (from Gemma4 processor) to token_type_ids (OV language model input) + mm_token_type_ids = kwargs.pop("mm_token_type_ids", None) + if token_type_ids is None and mm_token_type_ids is not None: + token_type_ids = mm_token_type_ids + return super().forward( + input_ids=input_ids, + pixel_values=pixel_values, + token_type_ids=token_type_ids, + **kwargs, + ) + + def _update_model_kwargs_for_generation( + self, + outputs, + model_kwargs, + is_encoder_decoder=False, + num_new_tokens=1, + ): + model_kwargs = super()._update_model_kwargs_for_generation( + outputs=outputs, + model_kwargs=model_kwargs, + is_encoder_decoder=is_encoder_decoder, + num_new_tokens=num_new_tokens, + ) + model_kwargs.pop("mm_token_type_ids", None) + model_kwargs.pop("image_position_ids", None) + return model_kwargs + + class _OVGotOCR2ForCausalLM(OVModelForVisualCausalLM): def get_vision_embeddings(self, pixel_values, input_ids, **kwargs): if input_ids is not None and input_ids.shape[1] == 1 and kwargs.get("past_key_values") is not None: @@ -4817,6 +4932,7 @@ def preprocess_inputs( "qwen2_5_vl_text": _OVQwen2_5_VLForCausalLM, "got_ocr2": _OVGotOCR2ForCausalLM, "gemma3": _OVGemma3ForCausalLM, + "gemma4": _OVGemma4ForCausalLM, "idefics3": _OVIdefics3ForCausalLM, "smolvlm": _OVSmolVLForCasualLM, "phi4mm": _OVPhi4MMForCausalLM, diff --git a/pyproject.toml b/pyproject.toml index bc066641fd..a201ccf730 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,4 +36,5 @@ known-first-party = ["optimum"] [tool.pytest.ini_options] markers = [ "run_slow", + "gemma4: tests for gemma4 and gemma4_moe architectures (require transformers>=5.5)", ] \ No newline at end of file diff --git a/tests/openvino/conftest.py b/tests/openvino/conftest.py new file mode 100644 index 0000000000..bf929db0c5 --- /dev/null +++ b/tests/openvino/conftest.py @@ -0,0 +1,11 @@ +import pytest + + +@pytest.hookimpl(tryfirst=True) +def pytest_collection_modifyitems(config, items): + """Dynamically add the 'gemma4' marker to every parameterized test whose + name contains 'gemma4' (this also covers 'gemma4_moe').""" + gemma4_marker = pytest.mark.gemma4 + for item in items: + if "gemma4" in item.nodeid: + item.add_marker(gemma4_marker) diff --git a/tests/openvino/test_export.py b/tests/openvino/test_export.py index 72785e5a14..3f871d57d9 100644 --- a/tests/openvino/test_export.py +++ b/tests/openvino/test_export.py @@ -110,6 +110,10 @@ class ExportModelTest(unittest.TestCase): if is_transformers_version(">=", "4.55.0") and is_transformers_version("<", "4.58.0"): SUPPORTED_ARCHITECTURES.update({"afmoe": OVModelForCausalLM}) + if is_transformers_version(">=", "5.5.0"): + SUPPORTED_ARCHITECTURES.update({"gemma4": OVModelForVisualCausalLM}) + SUPPORTED_ARCHITECTURES.update({"gemma4_moe": OVModelForVisualCausalLM}) + if is_transformers_version(">=", "4.57.0"): SUPPORTED_ARCHITECTURES.update({"hunyuan_v1_dense": OVModelForCausalLM}) diff --git a/tests/openvino/test_quantization.py b/tests/openvino/test_quantization.py index 63c13de956..cdda7f00f0 100644 --- a/tests/openvino/test_quantization.py +++ b/tests/openvino/test_quantization.py @@ -1089,6 +1089,10 @@ class OVWeightCompressionTest(unittest.TestCase): ] ) + if is_transformers_version(">=", "5.5.0"): + SUPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION.append((OVModelForVisualCausalLM, "gemma4", True)) + SUPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION.append((OVModelForVisualCausalLM, "gemma4_moe", True)) + SUPPORTED_ARCHITECTURES_WITH_HYBRID_QUANTIZATION = [ (OVStableDiffusionPipeline, "stable-diffusion", 72, 195), (OVStableDiffusionXLPipeline, "stable-diffusion-xl", 84, 331), @@ -1324,7 +1328,10 @@ def test_ovmodel_8bit_weight_compression_stateful(self, model_cls, model_name, e self.assertEqual(OVWeightQuantizationConfig().to_dict(), loaded_config.quantization_config.to_dict()) self.assertFalse(model.model.has_rt_info(["runtime_options", "KV_CACHE_PRECISION"])) - @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION) + @parameterized.expand( + SUPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION, + name_func=lambda testcase_func, param_num, params: f"{testcase_func.__name__}_{parameterized.to_safe_name(params.args[1])}", + ) def test_ovmodel_load_with_compressed_weights(self, model_cls, model_type, trust_remote_code): model = model_cls.from_pretrained( MODEL_NAMES[model_type], @@ -1546,7 +1553,10 @@ def test_ovmodel_stateful_load_with_compressed_weights(self, model_cls, model_ty expected_int8 = {k: {"int8": v} for k, v in expected_int8.items()} check_compression_state_per_model(self, model.ov_models, expected_int8) - @parameterized.expand(SUPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION) + @parameterized.expand( + SUPPORTED_ARCHITECTURES_WITH_AUTO_COMPRESSION, + name_func=lambda testcase_func, param_num, params: f"{testcase_func.__name__}_{parameterized.to_safe_name(params.args[1])}", + ) def test_ovmodel_load_with_uncompressed_weights(self, model_cls, model_type, trust_remote_code): model = model_cls.from_pretrained( MODEL_NAMES[model_type], export=True, load_in_8bit=False, trust_remote_code=trust_remote_code diff --git a/tests/openvino/test_seq2seq.py b/tests/openvino/test_seq2seq.py index 4a4affeec7..37efa86b2d 100644 --- a/tests/openvino/test_seq2seq.py +++ b/tests/openvino/test_seq2seq.py @@ -610,6 +610,9 @@ class OVModelForVisualCausalLMIntegrationTest(OVSeq2SeqTestMixin): # remote code models incompatible after transformers v5 SUPPORTED_ARCHITECTURES += ["internvl_chat", "minicpmv"] + if is_transformers_version(">=", "5.5"): + SUPPORTED_ARCHITECTURES += ["gemma4", "gemma4_moe"] + # TODO: add fix for v5 and update MAX_TRANSFORMERS_VERSION accordingly if is_transformers_version("<", "5"): SUPPORTED_ARCHITECTURES += ("llava_next_video",) @@ -796,6 +799,7 @@ def compare_outputs(inputs, ov_model, transformers_model, generation_config): ov_model.generation_config.do_sample = False # minicpmo diverges after 20 tokens tokens_to_generate = 20 if model_arch == "minicpmo" else 30 + gen_config = GenerationConfig( max_new_tokens=tokens_to_generate, min_new_tokens=tokens_to_generate, diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index c8c1fed661..3ee3a2035b 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -87,6 +87,8 @@ "got_ocr2": "optimum-intel-internal-testing/tiny-random-got-ocr2-hf", "gemma3_text": "optimum-intel-internal-testing/tiny-random-gemma3-text", "gemma3": "optimum-intel-internal-testing/tiny-random-gemma3", + "gemma4": "optimum-intel-internal-testing/tiny-random-gemma4", + "gemma4_moe": "optimum-intel-internal-testing/tiny-random-gemma4-moe", "falcon": "optimum-intel-internal-testing/really-tiny-falcon-testing", "falcon-40b": "optimum-intel-internal-testing/tiny-random-falcon-40b", "falcon_mamba": "optimum-intel-internal-testing/tiny-falcon-mamba", @@ -374,6 +376,18 @@ "hunyuan_v1_dense": {"model": 32}, "qwen3_eagle3": {"model": 20}, "qwen3_next": {"model": 100}, + "gemma4": { + "lm_model": 54, + "text_embeddings_model": 1, + "vision_embeddings_model": 10, + "text_embeddings_per_layer_model": 1, + }, + "gemma4_moe": { + "lm_model": 48, + "text_embeddings_model": 1, + "vision_embeddings_model": 10, + "text_embeddings_per_layer_model": 0, + }, } TEST_IMAGE_URL = "http://images.cocodataset.org/val2017/000000039769.jpg"