|
| 1 | +# Copyright 2025 HuggingFace Inc. team. All rights reserved. |
| 2 | +# Licensed under the Apache License, Version 2.0. |
| 3 | +# |
| 4 | +# Adapted from transformers/models/mistral3/modeling_mistral3.py |
| 5 | +# |
| 6 | +# REFERENCE ONLY: This module is NOT used by optimize.py (which uses mobius |
| 7 | +# for vision/embedding export). It is kept as a reference implementation |
| 8 | +# showing how to build an ONNX-export-friendly Ministral3 vision + embedding |
| 9 | +# model for potential future Olive-based export. |
| 10 | + |
| 11 | +from typing import Optional |
| 12 | + |
| 13 | +import torch |
| 14 | +import torch.nn as nn |
| 15 | + |
| 16 | +from transformers import AutoModel |
| 17 | +from transformers.models.mistral3.configuration_mistral3 import Mistral3Config |
| 18 | + |
| 19 | + |
| 20 | +class Mistral3PatchMerger(nn.Module): |
| 21 | + """ONNX-export-friendly Mistral3PatchMerger. |
| 22 | +
|
| 23 | + Uses pure tensor operations during export instead of Python for-loops. |
| 24 | + """ |
| 25 | + |
| 26 | + def __init__(self, config): |
| 27 | + super().__init__() |
| 28 | + self.config = config |
| 29 | + hidden_size = config.vision_config.hidden_size |
| 30 | + self.spatial_merge_size = config.spatial_merge_size |
| 31 | + self.patch_size = config.vision_config.patch_size |
| 32 | + self.merging_layer = nn.Linear( |
| 33 | + hidden_size * self.spatial_merge_size**2, hidden_size, bias=False |
| 34 | + ) |
| 35 | + |
| 36 | + def forward( |
| 37 | + self, image_features: torch.Tensor, image_sizes: torch.Tensor |
| 38 | + ) -> torch.Tensor: |
| 39 | + if torch.compiler.is_exporting(): |
| 40 | + return self._forward_export(image_features, image_sizes) |
| 41 | + return self._forward_eager(image_features, image_sizes) |
| 42 | + |
| 43 | + def _forward_export(self, image_features, image_sizes): |
| 44 | + patch_h = image_sizes[0, 0] // self.patch_size |
| 45 | + patch_w = image_sizes[0, 1] // self.patch_size |
| 46 | + d = image_features.shape[-1] |
| 47 | + |
| 48 | + image_grid = ( |
| 49 | + image_features.view(patch_h, patch_w, d).permute(2, 0, 1).unsqueeze(0) |
| 50 | + ) |
| 51 | + |
| 52 | + torch._check(image_grid.shape[2] != 0) |
| 53 | + torch._check(image_grid.shape[3] != 0) |
| 54 | + torch._check(image_grid.shape[2] // self.spatial_merge_size > 0) |
| 55 | + torch._check(image_grid.shape[3] // self.spatial_merge_size > 0) |
| 56 | + |
| 57 | + grid = torch.nn.functional.unfold( |
| 58 | + image_grid, |
| 59 | + kernel_size=self.spatial_merge_size, |
| 60 | + stride=self.spatial_merge_size, |
| 61 | + ) |
| 62 | + image_features = grid.view(d * self.spatial_merge_size**2, -1).t() |
| 63 | + return self.merging_layer(image_features) |
| 64 | + |
| 65 | + def _forward_eager(self, image_features, image_sizes): |
| 66 | + image_sizes_list = [ |
| 67 | + (sz[0] // self.patch_size, sz[1] // self.patch_size) for sz in image_sizes |
| 68 | + ] |
| 69 | + tokens_per_image = [h * w for h, w in image_sizes_list] |
| 70 | + d = image_features.shape[-1] |
| 71 | + |
| 72 | + permuted = [] |
| 73 | + for idx, image_tokens in enumerate(image_features.split(tokens_per_image)): |
| 74 | + h, w = image_sizes_list[idx] |
| 75 | + image_grid = image_tokens.view(h, w, d).permute(2, 0, 1).unsqueeze(0) |
| 76 | + grid = torch.nn.functional.unfold( |
| 77 | + image_grid, |
| 78 | + kernel_size=self.spatial_merge_size, |
| 79 | + stride=self.spatial_merge_size, |
| 80 | + ) |
| 81 | + permuted.append(grid.view(d * self.spatial_merge_size**2, -1).t()) |
| 82 | + |
| 83 | + return self.merging_layer(torch.cat(permuted, dim=0)) |
| 84 | + |
| 85 | + |
| 86 | +def pixtral_vision_forward_export(self, pixel_values, **kwargs): |
| 87 | + """ONNX-export-friendly forward for PixtralVisionModel (batch=1). |
| 88 | +
|
| 89 | + Skips generate_block_attention_mask and computes position_ids inline. |
| 90 | + """ |
| 91 | + torch._check(pixel_values.shape[0] == 1) |
| 92 | + |
| 93 | + target_dtype = self.patch_conv.weight.dtype |
| 94 | + patch_embeds = self.patch_conv(pixel_values.to(dtype=target_dtype)) |
| 95 | + |
| 96 | + grid_h = patch_embeds.shape[2] |
| 97 | + grid_w = patch_embeds.shape[3] |
| 98 | + |
| 99 | + patch_embeds = patch_embeds[0].flatten(1).T.unsqueeze(0) |
| 100 | + patch_embeds = self.ln_pre(patch_embeds) |
| 101 | + |
| 102 | + max_width = self.config.image_size // self.config.patch_size |
| 103 | + h_indices = torch.arange(grid_h, device=pixel_values.device) |
| 104 | + w_indices = torch.arange(grid_w, device=pixel_values.device) |
| 105 | + mesh_h, mesh_w = torch.meshgrid(h_indices, w_indices, indexing="ij") |
| 106 | + position_ids = (mesh_h * max_width + mesh_w).reshape(-1) |
| 107 | + kwargs["position_ids"] = position_ids.unsqueeze(0) |
| 108 | + |
| 109 | + position_embeddings = self.patch_positional_embedding(patch_embeds, position_ids) |
| 110 | + |
| 111 | + return self.transformer( |
| 112 | + patch_embeds, |
| 113 | + attention_mask=None, |
| 114 | + position_embeddings=position_embeddings, |
| 115 | + **kwargs, |
| 116 | + ) |
| 117 | + |
| 118 | + |
| 119 | +def _pixtral_vision_forward_dispatch(self, pixel_values, **kwargs): |
| 120 | + if torch.compiler.is_exporting(): |
| 121 | + return pixtral_vision_forward_export(self, pixel_values, **kwargs) |
| 122 | + return self._original_forward(pixel_values, **kwargs) |
| 123 | + |
| 124 | + |
| 125 | +def patch_model_for_onnx_export(model): |
| 126 | + """Apply ONNX-export-friendly patches to a Mistral 3 model.""" |
| 127 | + import types |
| 128 | + |
| 129 | + if hasattr(model, "model") and hasattr(model.model, "multi_modal_projector"): |
| 130 | + patch_merger = model.model.multi_modal_projector.patch_merger |
| 131 | + vision_tower = model.model.vision_tower |
| 132 | + elif hasattr(model, "multi_modal_projector"): |
| 133 | + patch_merger = model.multi_modal_projector.patch_merger |
| 134 | + vision_tower = model.vision_tower |
| 135 | + else: |
| 136 | + raise ValueError("Cannot find multi_modal_projector.patch_merger on the model.") |
| 137 | + |
| 138 | + patch_merger.__class__ = Mistral3PatchMerger |
| 139 | + |
| 140 | + vision_tower._original_forward = vision_tower.forward |
| 141 | + vision_tower.forward = types.MethodType( |
| 142 | + _pixtral_vision_forward_dispatch, vision_tower |
| 143 | + ) |
| 144 | + |
| 145 | + return model |
| 146 | + |
| 147 | + |
| 148 | +class Ministral3Model(nn.Module): |
| 149 | + """Ministral3 composite model for vision + embedding ONNX export. |
| 150 | +
|
| 151 | + Wraps HF Mistral3Model and provides: |
| 152 | + - get_image_features(): vision encoder export |
| 153 | + - get_fused_input_embeddings(): embedding fusion export |
| 154 | + """ |
| 155 | + |
| 156 | + def __init__(self, config: Mistral3Config): |
| 157 | + super().__init__() |
| 158 | + self.config = config |
| 159 | + |
| 160 | + # Build the full HF model, then patch for export |
| 161 | + self.hf_model = AutoModel.from_config( |
| 162 | + config, attn_implementation="sdpa", trust_remote_code=True |
| 163 | + ) |
| 164 | + patch_model_for_onnx_export(self.hf_model) |
| 165 | + |
| 166 | + # Expose sub-components for weight loading |
| 167 | + self.vision_tower = self.hf_model.vision_tower |
| 168 | + self.multi_modal_projector = self.hf_model.multi_modal_projector |
| 169 | + self.embed_tokens = self.hf_model.language_model.embed_tokens |
| 170 | + |
| 171 | + def get_input_embeddings(self): |
| 172 | + return self.embed_tokens |
| 173 | + |
| 174 | + def get_image_features(self, pixel_values: torch.FloatTensor) -> torch.Tensor: |
| 175 | + """Vision encoder: pixel_values -> image_features.""" |
| 176 | + image_outputs = self.vision_tower(pixel_values, return_dict=True) |
| 177 | + selected_image_feature = image_outputs.last_hidden_state |
| 178 | + |
| 179 | + image_sizes = torch.tensor( |
| 180 | + [[pixel_values.shape[-2], pixel_values.shape[-1]]], |
| 181 | + dtype=torch.int64, |
| 182 | + device=pixel_values.device, |
| 183 | + ) |
| 184 | + image_features = self.multi_modal_projector( |
| 185 | + selected_image_feature.squeeze(0), image_sizes |
| 186 | + ) |
| 187 | + return image_features |
| 188 | + |
| 189 | + def get_fused_input_embeddings( |
| 190 | + self, input_ids: torch.LongTensor, image_features: Optional[torch.Tensor] = None |
| 191 | + ) -> torch.Tensor: |
| 192 | + """Embedding fusion: input_ids + image_features -> inputs_embeds.""" |
| 193 | + inputs_embeds = self.embed_tokens(input_ids) |
| 194 | + if image_features is not None: |
| 195 | + image_features = image_features.to(inputs_embeds.dtype) |
| 196 | + special_image_mask = input_ids == self.config.image_token_index |
| 197 | + expanded_mask = ( |
| 198 | + special_image_mask.unsqueeze(-1) |
| 199 | + .expand_as(inputs_embeds) |
| 200 | + .to(inputs_embeds.device) |
| 201 | + ) |
| 202 | + inputs_embeds = inputs_embeds.masked_scatter(expanded_mask, image_features) |
| 203 | + return inputs_embeds |
| 204 | + |
| 205 | + def forward(self, *args, **kwargs): |
| 206 | + raise NotImplementedError( |
| 207 | + "Use get_image_features() or get_fused_input_embeddings() via method swap." |
| 208 | + ) |
| 209 | + |
| 210 | + |
| 211 | +__all__ = ["Ministral3Model", "patch_model_for_onnx_export"] |
0 commit comments