Skip to content

Commit 5969770

Browse files
titaiwangmsCopilot
andcommitted
Add Ministral-3-3B VLM recipe: hybrid Olive + Mobius export
- Olive for text decoder, Mobius for vision + embedding - Lazy config loading via __getattr__ (PEP 562) - Fail-fast validation for missing components - Transforms-based processor_config.json - image_token_id=10 from HF config - Pin mobius dependency - Add eval.py (AI2D benchmark, follows Qwen VLM pattern) - Fix: dtype string handling, model save paths, genai_config filenames - Fix: conditional position_ids, vision output squeeze, embedding zero-padding Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 88954d7 commit 5969770

12 files changed

Lines changed: 1423 additions & 0 deletions

File tree

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Generated model artifacts
2+
models/
3+
4+
# Python bytecode
5+
__pycache__/
6+
*.pyc
7+
8+
# Olive cache
9+
.olive-cache/
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Ministral-3-3B ONNX Runtime GenAI Example
2+
3+
This example demonstrates how to convert [Ministral-3-3B-Instruct-2512](https://huggingface.co/mistralai/Ministral-3-3B-Instruct-2512) vision-language model to ONNX format using Olive and run inference with ONNX Runtime GenAI.
4+
5+
Ministral-3-3B is a multimodal (VLM) model combining a Pixtral vision encoder with a Mistral text decoder using YaRN RoPE for extended context. The pipeline exports three sub-models:
6+
- **Vision encoder** and **embedding** via [mobius](https://github.com/onnxruntime/mobius) (declarative ONNX graph construction)
7+
- **Text decoder** via Olive/ModelBuilder (GQA + INT4/FP16 quantization)
8+
9+
## Prerequisites
10+
11+
```bash
12+
pip install -r requirements.txt
13+
```
14+
15+
Install ONNX Runtime GenAI:
16+
17+
| Device | Install Command |
18+
|--------|-----------------|
19+
| CPU | `pip install onnxruntime-genai --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple` |
20+
| GPU (CUDA) | `pip install onnxruntime-genai-cuda --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple` |
21+
22+
## Steps
23+
24+
### 1. Export & Optimize Models
25+
26+
**CPU (INT4 text decoder, FP16 vision/embedding):**
27+
28+
```bash
29+
python optimize.py --config-dir cpu_and_mobile --device cpu
30+
```
31+
32+
**CUDA (FP16):**
33+
34+
```bash
35+
python optimize.py --config-dir cuda --device gpu
36+
```
37+
38+
**With local dequantized checkpoint (skips FP8 dequant):**
39+
40+
```bash
41+
python optimize.py --config-dir cpu_and_mobile --device cpu --model-path /path/to/Ministral-3-3B-dequantized
42+
```
43+
44+
This runs:
45+
- **Olive/ModelBuilder** for text decoder (GQA attention, YaRN RoPE, INT4/FP16)
46+
- **Mobius** for vision encoder (Pixtral, dynamic H×W, 2D RoPE) and embedding (token + image fusion)
47+
48+
Then generates `genai_config.json` and `processor_config.json` for the ORT GenAI runtime.
49+
50+
### 2. Output Structure
51+
52+
```
53+
cpu_and_mobile/models/ # or cuda/models/
54+
├── vision.onnx # Pixtral vision encoder
55+
├── vision.onnx.data
56+
├── embedding.onnx # Embedding fusion model
57+
├── embedding.onnx.data
58+
├── text.onnx # Text decoder (Mistral + YaRN)
59+
├── text.onnx.data
60+
├── genai_config.json # Runtime configuration
61+
├── processor_config.json # Pixtral image preprocessing
62+
├── tokenizer.json
63+
└── tokenizer_config.json
64+
```
65+
66+
### 3. Run Inference
67+
68+
```bash
69+
# Text-only
70+
python inference.py --prompt "What is the capital of France?"
71+
72+
# Image + text
73+
python inference.py --image photo.jpg --prompt "Describe this image"
74+
75+
# Interactive mode
76+
python inference.py --interactive
77+
78+
# CUDA model
79+
python inference.py --model_path cuda/models --prompt "Hello"
80+
```
81+
82+
Alternatively, use the built-in GenAI multimodal demo:
83+
84+
```bash
85+
python -m onnxruntime_genai.models.model_mm -m cpu_and_mobile/models --max_length 4096
86+
```
87+
88+
## Notes
89+
90+
- The HuggingFace checkpoint uses FP8 quantized weights. The export pipeline dequantizes these automatically (`weight * weight_scale_inv`).
91+
- The tokenizer uses `TokenizersBackend` class which genai doesn't support. The optimize script fixes this to `LlamaTokenizer`.
92+
- Pixtral vision supports dynamic image sizes (multiples of 28, up to 1540×1540).
93+
- The text decoder includes `llama_4_attn_scale` for long-context attention (>16K tokens).
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Reference-only: Ministral3Model is not used by optimize.py (see modeling_ministral3.py)
2+
from .modeling_ministral3 import Ministral3Model as Ministral3Model
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
# Copyright 2025 HuggingFace Inc. team. All rights reserved.
2+
# Licensed under the Apache License, Version 2.0.
3+
#
4+
# Adapted from transformers/models/mistral3/modeling_mistral3.py
5+
#
6+
# REFERENCE ONLY: This module is NOT used by optimize.py (which uses mobius
7+
# for vision/embedding export). It is kept as a reference implementation
8+
# showing how to build an ONNX-export-friendly Ministral3 vision + embedding
9+
# model for potential future Olive-based export.
10+
11+
from typing import Optional
12+
13+
import torch
14+
import torch.nn as nn
15+
16+
from transformers import AutoModel
17+
from transformers.models.mistral3.configuration_mistral3 import Mistral3Config
18+
19+
20+
class Mistral3PatchMerger(nn.Module):
21+
"""ONNX-export-friendly Mistral3PatchMerger.
22+
23+
Uses pure tensor operations during export instead of Python for-loops.
24+
"""
25+
26+
def __init__(self, config):
27+
super().__init__()
28+
self.config = config
29+
hidden_size = config.vision_config.hidden_size
30+
self.spatial_merge_size = config.spatial_merge_size
31+
self.patch_size = config.vision_config.patch_size
32+
self.merging_layer = nn.Linear(
33+
hidden_size * self.spatial_merge_size**2, hidden_size, bias=False
34+
)
35+
36+
def forward(
37+
self, image_features: torch.Tensor, image_sizes: torch.Tensor
38+
) -> torch.Tensor:
39+
if torch.compiler.is_exporting():
40+
return self._forward_export(image_features, image_sizes)
41+
return self._forward_eager(image_features, image_sizes)
42+
43+
def _forward_export(self, image_features, image_sizes):
44+
patch_h = image_sizes[0, 0] // self.patch_size
45+
patch_w = image_sizes[0, 1] // self.patch_size
46+
d = image_features.shape[-1]
47+
48+
image_grid = (
49+
image_features.view(patch_h, patch_w, d).permute(2, 0, 1).unsqueeze(0)
50+
)
51+
52+
torch._check(image_grid.shape[2] != 0)
53+
torch._check(image_grid.shape[3] != 0)
54+
torch._check(image_grid.shape[2] // self.spatial_merge_size > 0)
55+
torch._check(image_grid.shape[3] // self.spatial_merge_size > 0)
56+
57+
grid = torch.nn.functional.unfold(
58+
image_grid,
59+
kernel_size=self.spatial_merge_size,
60+
stride=self.spatial_merge_size,
61+
)
62+
image_features = grid.view(d * self.spatial_merge_size**2, -1).t()
63+
return self.merging_layer(image_features)
64+
65+
def _forward_eager(self, image_features, image_sizes):
66+
image_sizes_list = [
67+
(sz[0] // self.patch_size, sz[1] // self.patch_size) for sz in image_sizes
68+
]
69+
tokens_per_image = [h * w for h, w in image_sizes_list]
70+
d = image_features.shape[-1]
71+
72+
permuted = []
73+
for idx, image_tokens in enumerate(image_features.split(tokens_per_image)):
74+
h, w = image_sizes_list[idx]
75+
image_grid = image_tokens.view(h, w, d).permute(2, 0, 1).unsqueeze(0)
76+
grid = torch.nn.functional.unfold(
77+
image_grid,
78+
kernel_size=self.spatial_merge_size,
79+
stride=self.spatial_merge_size,
80+
)
81+
permuted.append(grid.view(d * self.spatial_merge_size**2, -1).t())
82+
83+
return self.merging_layer(torch.cat(permuted, dim=0))
84+
85+
86+
def pixtral_vision_forward_export(self, pixel_values, **kwargs):
87+
"""ONNX-export-friendly forward for PixtralVisionModel (batch=1).
88+
89+
Skips generate_block_attention_mask and computes position_ids inline.
90+
"""
91+
torch._check(pixel_values.shape[0] == 1)
92+
93+
target_dtype = self.patch_conv.weight.dtype
94+
patch_embeds = self.patch_conv(pixel_values.to(dtype=target_dtype))
95+
96+
grid_h = patch_embeds.shape[2]
97+
grid_w = patch_embeds.shape[3]
98+
99+
patch_embeds = patch_embeds[0].flatten(1).T.unsqueeze(0)
100+
patch_embeds = self.ln_pre(patch_embeds)
101+
102+
max_width = self.config.image_size // self.config.patch_size
103+
h_indices = torch.arange(grid_h, device=pixel_values.device)
104+
w_indices = torch.arange(grid_w, device=pixel_values.device)
105+
mesh_h, mesh_w = torch.meshgrid(h_indices, w_indices, indexing="ij")
106+
position_ids = (mesh_h * max_width + mesh_w).reshape(-1)
107+
kwargs["position_ids"] = position_ids.unsqueeze(0)
108+
109+
position_embeddings = self.patch_positional_embedding(patch_embeds, position_ids)
110+
111+
return self.transformer(
112+
patch_embeds,
113+
attention_mask=None,
114+
position_embeddings=position_embeddings,
115+
**kwargs,
116+
)
117+
118+
119+
def _pixtral_vision_forward_dispatch(self, pixel_values, **kwargs):
120+
if torch.compiler.is_exporting():
121+
return pixtral_vision_forward_export(self, pixel_values, **kwargs)
122+
return self._original_forward(pixel_values, **kwargs)
123+
124+
125+
def patch_model_for_onnx_export(model):
126+
"""Apply ONNX-export-friendly patches to a Mistral 3 model."""
127+
import types
128+
129+
if hasattr(model, "model") and hasattr(model.model, "multi_modal_projector"):
130+
patch_merger = model.model.multi_modal_projector.patch_merger
131+
vision_tower = model.model.vision_tower
132+
elif hasattr(model, "multi_modal_projector"):
133+
patch_merger = model.multi_modal_projector.patch_merger
134+
vision_tower = model.vision_tower
135+
else:
136+
raise ValueError("Cannot find multi_modal_projector.patch_merger on the model.")
137+
138+
patch_merger.__class__ = Mistral3PatchMerger
139+
140+
vision_tower._original_forward = vision_tower.forward
141+
vision_tower.forward = types.MethodType(
142+
_pixtral_vision_forward_dispatch, vision_tower
143+
)
144+
145+
return model
146+
147+
148+
class Ministral3Model(nn.Module):
149+
"""Ministral3 composite model for vision + embedding ONNX export.
150+
151+
Wraps HF Mistral3Model and provides:
152+
- get_image_features(): vision encoder export
153+
- get_fused_input_embeddings(): embedding fusion export
154+
"""
155+
156+
def __init__(self, config: Mistral3Config):
157+
super().__init__()
158+
self.config = config
159+
160+
# Build the full HF model, then patch for export
161+
self.hf_model = AutoModel.from_config(
162+
config, attn_implementation="sdpa", trust_remote_code=True
163+
)
164+
patch_model_for_onnx_export(self.hf_model)
165+
166+
# Expose sub-components for weight loading
167+
self.vision_tower = self.hf_model.vision_tower
168+
self.multi_modal_projector = self.hf_model.multi_modal_projector
169+
self.embed_tokens = self.hf_model.language_model.embed_tokens
170+
171+
def get_input_embeddings(self):
172+
return self.embed_tokens
173+
174+
def get_image_features(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
175+
"""Vision encoder: pixel_values -> image_features."""
176+
image_outputs = self.vision_tower(pixel_values, return_dict=True)
177+
selected_image_feature = image_outputs.last_hidden_state
178+
179+
image_sizes = torch.tensor(
180+
[[pixel_values.shape[-2], pixel_values.shape[-1]]],
181+
dtype=torch.int64,
182+
device=pixel_values.device,
183+
)
184+
image_features = self.multi_modal_projector(
185+
selected_image_feature.squeeze(0), image_sizes
186+
)
187+
return image_features
188+
189+
def get_fused_input_embeddings(
190+
self, input_ids: torch.LongTensor, image_features: Optional[torch.Tensor] = None
191+
) -> torch.Tensor:
192+
"""Embedding fusion: input_ids + image_features -> inputs_embeds."""
193+
inputs_embeds = self.embed_tokens(input_ids)
194+
if image_features is not None:
195+
image_features = image_features.to(inputs_embeds.dtype)
196+
special_image_mask = input_ids == self.config.image_token_index
197+
expanded_mask = (
198+
special_image_mask.unsqueeze(-1)
199+
.expand_as(inputs_embeds)
200+
.to(inputs_embeds.device)
201+
)
202+
inputs_embeds = inputs_embeds.masked_scatter(expanded_mask, image_features)
203+
return inputs_embeds
204+
205+
def forward(self, *args, **kwargs):
206+
raise NotImplementedError(
207+
"Use get_image_features() or get_fused_input_embeddings() via method swap."
208+
)
209+
210+
211+
__all__ = ["Ministral3Model", "patch_model_for_onnx_export"]
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
{
2+
"input_model": {
3+
"type": "HfModel",
4+
"model_path": "mistralai/Ministral-3-3B-Instruct-2512"
5+
},
6+
"passes": {
7+
"convert": {
8+
"type": "ModelBuilder",
9+
"precision": "int4",
10+
"int4_accuracy_level": 4,
11+
"extra_options": {
12+
"filename": "text.onnx"
13+
}
14+
}
15+
},
16+
"no_artifacts": true,
17+
"output_dir": "cpu_and_mobile/models/text.onnx"
18+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
{
2+
"input_model": {
3+
"type": "HfModel",
4+
"model_path": "mistralai/Ministral-3-3B-Instruct-2512"
5+
},
6+
"passes": {
7+
"m": {
8+
"type": "ModelBuilder",
9+
"precision": "fp16",
10+
"extra_options": {
11+
"filename": "text.onnx"
12+
}
13+
}
14+
},
15+
"engine": {
16+
"target": {
17+
"type": "LocalSystem",
18+
"accelerators": [
19+
{
20+
"device": "gpu",
21+
"execution_providers": [
22+
"CUDAExecutionProvider"
23+
]
24+
}
25+
]
26+
}
27+
},
28+
"no_artifacts": true,
29+
"output_dir": "cuda/models/text.onnx"
30+
}

0 commit comments

Comments
 (0)