Skip to content

Commit 19c2f7b

Browse files
titaiwangmsCopilot
andauthored
[Mistral3] Add VLM support with multi-image inference (microsoft#2077)
## Summary Adds Mistral3/Pixtral VLM support to onnxruntime-genai with multi-image inference. Includes C++ image processor, PixtralVisionState for per-image vision processing, Python export support, and comprehensive tests. ## Changes ### C++ Runtime - **Mistral3 image processor** — `[IMG]`/`[IMG_BREAK]`/`[IMG_END]` token expansion based on image resolution and patch geometry, multi-image support - **PixtralVisionState** — per-image vision processing loop with bounds checks and overflow guard; slices from padded batch tensor using `image_sizes` metadata from ort-extensions `PixtralImageSizes` op - **Virtual `SetExtraInputs`** — proper polymorphic dispatch for vision state subclasses - **`IsPixtralFamily()` model type detection** — enables Pixtral-specific codepath - **`processor_config.json`** — with `PixtralImageSizes` preprocessing step - **`context_length` / `max_length` separation** — `context_length` controls KV cache allocation while `max_length` controls generation stopping, preventing premature EOS with large image token counts - **INT32 `input_ids`** — token IDs above 32767 (Pixtral `[IMG]`=128011) require int32 ### Python Export Support - Mistral3 model classes (`Mistral3Config`, `Mistral3ForConditionalGeneration`) - FP8 dtype promotion for checkpoint loading - `get_user_content()` handler for Mistral3 prompt formatting ## Multi-Image Architecture Pixtral uses dynamic image sizes (28×28 to 1540×1540) so images can't be batched in the vision encoder. `PixtralVisionState` processes each image individually by: 1. Reading `image_sizes` tensor from ort-extensions `PixtralImageSizes` op (provides per-image H×W) 2. Slicing the padded `[N, C, max_H, max_W]` batch tensor to extract each image's actual pixels 3. Running vision encoder on each image separately 4. Concatenating vision embeddings for the decoder ## Dependencies - **onnxruntime-extensions PR microsoft#1050** — `PixtralImageSizes` custom op for image size metadata - **PR microsoft#2076** — YaRN RoPE parity fixes (merged) ✅ ## Testing - 5 multi-image token expansion tests (various image sizes and counts) - Virtual dispatch verification tests for `SetExtraInputs` - YaRN RoPE parity tests (from merged microsoft#2076) - E2E verified: multi-image (fish.jpg + challenge.jpg) correctly describes both images --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 2a2ef8c commit 19c2f7b

17 files changed

Lines changed: 1400 additions & 8 deletions

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ examples/csharp/ModelChat/models
3838
!test/test_models/qwen3-vl-vision-preprocessing/*.onnx
3939
!test/test_models/qwen35-hybrid-preprocessing/
4040
!test/test_models/qwen35-hybrid-preprocessing/*.onnx
41+
!test/test_models/mistral3-vision-preprocessing/
4142

4243
.ipynb_checkpoints/
4344
/src/java/.gradle

examples/python/common.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,11 @@ def get_user_content(model_type: str, num_images: int, num_audios: int, prompt:
278278
# Qwen-2.5 VL, Qwen-3 VL, Fara
279279
image_tags = "".join(["<|vision_start|><|image_pad|><|vision_end|>" for _ in range(num_images)])
280280
content = image_tags + prompt
281+
elif model_type == "mistral3":
282+
# Pixtral / Ministral-3 VLM: the C++ image processor expands each
283+
# [IMG] into the full token sequence based on image resolution.
284+
image_tags = "".join(["[IMG]" for _ in range(num_images)])
285+
content = image_tags + prompt
281286
else:
282287
# Gemma-3 style: structured content
283288
image_tags = [{"type": "image"} for _ in range(num_images)]
Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "../generators.h"
5+
#include "model.h"
6+
#include "mistral3_image_processor.h"
7+
8+
namespace Generators {
9+
10+
namespace {
11+
12+
// Pixtral special tokens — resolved at runtime via tokenizer lookup.
13+
constexpr char kImgToken[] = "[IMG]";
14+
constexpr char kImgBreakToken[] = "[IMG_BREAK]";
15+
constexpr char kImgEndToken[] = "[IMG_END]";
16+
constexpr char kInstToken[] = "[INST]";
17+
18+
// Build input_ids for the image portion of the prompt.
19+
// Returns the token IDs including [IMG], [IMG_BREAK], and [IMG_END].
20+
std::vector<int32_t> BuildImageTokenSequence(int patch_rows, int patch_cols,
21+
int32_t img_id, int32_t break_id, int32_t end_id) {
22+
std::vector<int32_t> tokens;
23+
tokens.reserve(patch_rows * patch_cols + patch_rows);
24+
25+
for (int r = 0; r < patch_rows; ++r) {
26+
for (int c = 0; c < patch_cols; ++c) {
27+
tokens.push_back(img_id);
28+
}
29+
if (r < patch_rows - 1) {
30+
tokens.push_back(break_id);
31+
} else {
32+
tokens.push_back(end_id);
33+
}
34+
}
35+
return tokens;
36+
}
37+
38+
// Per-image dimensions: each image may have a different resolution after
39+
// smart_resize. When image_sizes is available (from PixtralImageSizes),
40+
// use per-image H/W. Otherwise fall back to the (padded) pixel_values shape.
41+
struct PerImageInfo {
42+
int patch_rows;
43+
int patch_cols;
44+
int64_t num_img_tokens; // [IMG] count only (excludes [IMG_BREAK]/[IMG_END])
45+
std::vector<int32_t> token_sequence;
46+
};
47+
48+
std::tuple<std::unique_ptr<OrtValue>, int64_t>
49+
ProcessPixtralPrompt(const Tokenizer& tokenizer, const std::string& prompt,
50+
OrtxTensor* pixel_values, OrtxTensor* image_sizes_tensor,
51+
int patch_size, int spatial_merge_size,
52+
Ort::Allocator& allocator) {
53+
const int32_t img_token_id = tokenizer.TokenToTokenId(kImgToken);
54+
const int32_t img_break_id = tokenizer.TokenToTokenId(kImgBreakToken);
55+
const int32_t img_end_id = tokenizer.TokenToTokenId(kImgEndToken);
56+
const int32_t inst_token_id = tokenizer.TokenToTokenId(kInstToken);
57+
58+
int64_t num_images = 0;
59+
std::vector<PerImageInfo> image_infos;
60+
61+
if (pixel_values) {
62+
const float* data{};
63+
const int64_t* shape{};
64+
size_t num_dims{};
65+
CheckResult(OrtxGetTensorData(pixel_values, reinterpret_cast<const void**>(&data), &shape, &num_dims));
66+
if (num_dims != 4) {
67+
throw std::runtime_error(
68+
"Mistral3ImageProcessor: expected 4D pixel_values [N,C,H,W], "
69+
"got " +
70+
std::to_string(num_dims) + "D tensor.");
71+
}
72+
num_images = shape[0];
73+
int64_t padded_h = shape[2];
74+
int64_t padded_w = shape[3];
75+
76+
// Read per-image sizes if available, otherwise use padded dimensions
77+
const int64_t* sizes_data = nullptr;
78+
if (image_sizes_tensor) {
79+
const void* raw{};
80+
const int64_t* sizes_shape{};
81+
size_t sizes_dims{};
82+
CheckResult(OrtxGetTensorData(image_sizes_tensor, &raw, &sizes_shape, &sizes_dims));
83+
84+
if (sizes_dims != 2) {
85+
throw std::runtime_error(
86+
"Mistral3ImageProcessor: expected 2D image_sizes tensor [N,2], "
87+
"got " +
88+
std::to_string(sizes_dims) + "D tensor.");
89+
}
90+
if (sizes_shape[1] != 2) {
91+
throw std::runtime_error(
92+
"Mistral3ImageProcessor: expected image_sizes tensor shape [N,2], "
93+
"got second dimension " +
94+
std::to_string(sizes_shape[1]) + ".");
95+
}
96+
if (sizes_shape[0] != num_images) {
97+
throw std::runtime_error(
98+
"Mistral3ImageProcessor: image_sizes tensor first dimension (" +
99+
std::to_string(sizes_shape[0]) + ") must match pixel_values batch size (" +
100+
std::to_string(num_images) + ").");
101+
}
102+
sizes_data = static_cast<const int64_t*>(raw);
103+
}
104+
105+
int64_t effective_patch = static_cast<int64_t>(patch_size) * spatial_merge_size;
106+
for (int64_t i = 0; i < num_images; ++i) {
107+
int64_t h = sizes_data ? sizes_data[i * 2] : padded_h;
108+
int64_t w = sizes_data ? sizes_data[i * 2 + 1] : padded_w;
109+
110+
if (h % effective_patch != 0 || w % effective_patch != 0) {
111+
throw std::runtime_error(
112+
"Mistral3ImageProcessor: image " + std::to_string(i) + " dimensions (" +
113+
std::to_string(h) + "x" + std::to_string(w) +
114+
") must be divisible by patch_size*merge_size (" +
115+
std::to_string(effective_patch) + "). Check smart_resize configuration.");
116+
}
117+
118+
PerImageInfo info;
119+
info.patch_rows = static_cast<int>(h / effective_patch);
120+
info.patch_cols = static_cast<int>(w / effective_patch);
121+
info.token_sequence = BuildImageTokenSequence(info.patch_rows, info.patch_cols,
122+
img_token_id, img_break_id, img_end_id);
123+
// Count only [IMG] tokens — this equals the vision model's feature output count
124+
// (patch_rows * patch_cols), excluding structural [IMG_BREAK]/[IMG_END] tokens.
125+
info.num_img_tokens = static_cast<int64_t>(
126+
std::count(info.token_sequence.begin(), info.token_sequence.end(), img_token_id));
127+
image_infos.push_back(std::move(info));
128+
}
129+
}
130+
131+
int64_t total_img_tokens = 0;
132+
for (const auto& info : image_infos) {
133+
total_img_tokens += info.num_img_tokens;
134+
}
135+
136+
// Tokenize the text prompt
137+
std::vector<int32_t> input_ids;
138+
if (!prompt.empty()) {
139+
input_ids = tokenizer.Encode(prompt.c_str());
140+
}
141+
142+
// Expand [IMG] placeholders for each image.
143+
// Each [IMG] (or group of consecutive [IMG] tokens) in the prompt corresponds
144+
// to one image, expanded with its per-image token sequence.
145+
if (!image_infos.empty()) {
146+
std::vector<int32_t> expanded_ids;
147+
size_t total_expansion = input_ids.size();
148+
for (const auto& info : image_infos) {
149+
total_expansion += info.token_sequence.size();
150+
}
151+
expanded_ids.reserve(total_expansion);
152+
153+
size_t next_image = 0;
154+
for (size_t i = 0; i < input_ids.size(); ++i) {
155+
if (input_ids[i] == img_token_id && next_image < image_infos.size()) {
156+
// Replace this [IMG] (and consecutive [IMG] tokens) with the image's token sequence
157+
expanded_ids.insert(expanded_ids.end(),
158+
image_infos[next_image].token_sequence.begin(),
159+
image_infos[next_image].token_sequence.end());
160+
++next_image;
161+
// Skip consecutive [IMG] tokens from the original prompt
162+
while (i + 1 < input_ids.size() && input_ids[i + 1] == img_token_id) {
163+
++i;
164+
}
165+
} else {
166+
expanded_ids.push_back(input_ids[i]);
167+
}
168+
}
169+
170+
// If not all images had placeholders, insert remaining after [INST]
171+
if (next_image < image_infos.size()) {
172+
std::vector<int32_t> remaining_tokens;
173+
for (size_t img = next_image; img < image_infos.size(); ++img) {
174+
remaining_tokens.insert(remaining_tokens.end(),
175+
image_infos[img].token_sequence.begin(),
176+
image_infos[img].token_sequence.end());
177+
}
178+
179+
std::vector<int32_t> final_ids;
180+
final_ids.reserve(expanded_ids.size() + remaining_tokens.size());
181+
bool inserted = false;
182+
for (size_t i = 0; i < expanded_ids.size(); ++i) {
183+
final_ids.push_back(expanded_ids[i]);
184+
if (expanded_ids[i] == inst_token_id && !inserted) {
185+
final_ids.insert(final_ids.end(), remaining_tokens.begin(), remaining_tokens.end());
186+
inserted = true;
187+
}
188+
}
189+
if (!inserted) {
190+
// No [INST] found — prepend remaining image tokens
191+
final_ids.clear();
192+
final_ids.insert(final_ids.end(), remaining_tokens.begin(), remaining_tokens.end());
193+
final_ids.insert(final_ids.end(), expanded_ids.begin(), expanded_ids.end());
194+
}
195+
expanded_ids = std::move(final_ids);
196+
}
197+
198+
input_ids = std::move(expanded_ids);
199+
}
200+
201+
auto input_ids_value = OrtValue::CreateTensor<int32_t>(
202+
allocator, std::vector<int64_t>{1, static_cast<int64_t>(input_ids.size())});
203+
std::copy(input_ids.begin(), input_ids.end(),
204+
input_ids_value->GetTensorMutableData<int32_t>());
205+
206+
return {std::move(input_ids_value), total_img_tokens};
207+
}
208+
} // namespace
209+
210+
Mistral3ImageProcessor::Mistral3ImageProcessor(Config& config, const SessionInfo& session_info)
211+
: pixel_values_type_{session_info.GetInputDataType(config.model.vision.inputs.pixel_values)},
212+
patch_size_{config.model.vision.patch_size},
213+
spatial_merge_size_{config.model.vision.spatial_merge_size} {
214+
const auto processor_config =
215+
(config.config_path / fs::path(config.model.vision.config_filename)).string();
216+
CheckResult(OrtxCreateProcessor(processor_.ToBeAssigned(), processor_config.c_str()));
217+
218+
config.AddMapping(std::string(Config::Defaults::InputIdsName),
219+
config.model.embedding.inputs.input_ids);
220+
config.AddMapping(std::string(Config::Defaults::PixelValuesName),
221+
config.model.vision.inputs.pixel_values);
222+
}
223+
224+
std::unique_ptr<NamedTensors> Mistral3ImageProcessor::Process(
225+
const Tokenizer& tokenizer, const Payload& payload) const {
226+
std::string prompt = std::string(payload.prompt);
227+
const Images* images = payload.images;
228+
Ort::Allocator& allocator{Ort::Allocator::GetWithDefaultOptions()};
229+
auto named_tensors = std::make_unique<NamedTensors>();
230+
231+
if (!images) {
232+
// Text-only: tokenize prompt without image processing
233+
auto [input_ids, num_img_tokens] =
234+
ProcessPixtralPrompt(tokenizer, prompt, nullptr, nullptr, patch_size_,
235+
spatial_merge_size_, allocator);
236+
named_tensors->emplace(Config::Defaults::InputIdsName,
237+
std::make_shared<Tensor>(std::move(input_ids)));
238+
239+
// Explicitly set num_image_tokens=0 for text-only inputs so downstream
240+
// pipeline components know there are no vision features to process.
241+
auto zero_tokens = OrtValue::CreateTensor<int64_t>(allocator, std::vector<int64_t>{1});
242+
zero_tokens->GetTensorMutableData<int64_t>()[0] = 0;
243+
named_tensors->emplace(std::string(Config::Defaults::NumImageTokens),
244+
std::make_shared<Tensor>(std::move(zero_tokens)));
245+
return named_tensors;
246+
}
247+
248+
// Process images through the ort-extensions processor (normalization, resizing)
249+
ort_extensions::OrtxObjectPtr<OrtxTensorResult> result;
250+
CheckResult(OrtxImagePreProcess(processor_.get(), images->images_.get(),
251+
result.ToBeAssigned()));
252+
253+
OrtxTensor* pixel_values = nullptr;
254+
CheckResult(OrtxTensorResultGetAt(result.get(), 0, &pixel_values));
255+
256+
// Tensor 1: image_sizes[N, 2] from PixtralImageSizes (post-resize, pre-padding).
257+
// Models must be exported with PixtralImageSizes in processor_config.json.
258+
OrtxTensor* image_sizes = nullptr;
259+
CheckResult(OrtxTensorResultGetAt(result.get(), 1, &image_sizes));
260+
261+
auto [input_ids, num_img_tokens] =
262+
ProcessPixtralPrompt(tokenizer, prompt, pixel_values, image_sizes, patch_size_,
263+
spatial_merge_size_, allocator);
264+
265+
named_tensors->emplace(std::string(Config::Defaults::InputIdsName),
266+
std::make_shared<Tensor>(std::move(input_ids)));
267+
268+
// Convert pixel_values to the vision model's expected dtype (NCHW layout
269+
// is already handled by the Permute3D step in processor_config.json).
270+
{
271+
std::unique_ptr<OrtValue> pv_ortvalue;
272+
if (pixel_values_type_ == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
273+
pv_ortvalue = ProcessTensor<float>(pixel_values, allocator);
274+
} else if (pixel_values_type_ == ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16) {
275+
pv_ortvalue = ProcessTensor<Ort::BFloat16_t>(pixel_values, allocator);
276+
} else {
277+
pv_ortvalue = ProcessTensor<Ort::Float16_t>(pixel_values, allocator);
278+
}
279+
named_tensors->emplace(std::string(Config::Defaults::PixelValuesName),
280+
std::make_shared<Tensor>(std::move(pv_ortvalue)));
281+
}
282+
283+
// Add image_sizes[N, 2] for PixtralVisionState to slice per-image dimensions
284+
if (image_sizes) {
285+
named_tensors->emplace(std::string(Config::Defaults::ImageSizesName),
286+
std::make_shared<Tensor>(ProcessTensor<int64_t>(image_sizes, allocator)));
287+
}
288+
289+
// Add num_image_tokens (total across all images) for the embedding model
290+
auto num_img_tokens_value = OrtValue::CreateTensor<int64_t>(
291+
allocator, std::vector<int64_t>{1});
292+
num_img_tokens_value->GetTensorMutableData<int64_t>()[0] = num_img_tokens;
293+
named_tensors->emplace(std::string(Config::Defaults::NumImageTokens),
294+
std::make_shared<Tensor>(std::move(num_img_tokens_value)));
295+
296+
return named_tensors;
297+
}
298+
299+
} // namespace Generators
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "processor.h"
7+
8+
namespace Generators {
9+
10+
struct Mistral3ImageProcessor : Processor {
11+
Mistral3ImageProcessor(Config& config, const SessionInfo& session_info);
12+
13+
std::unique_ptr<NamedTensors> Process(const Tokenizer& tokenizer, const Payload& payload) const override;
14+
15+
private:
16+
ort_extensions::OrtxObjectPtr<OrtxProcessor> processor_;
17+
18+
ONNXTensorElementDataType pixel_values_type_;
19+
int patch_size_;
20+
int spatial_merge_size_;
21+
};
22+
23+
} // namespace Generators

src/models/model.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "decoder_only_pipeline.h"
2323
#include "qwen_vl_model.h"
2424
#include "qwen2_5_vl_image_processor.h"
25+
#include "mistral3_image_processor.h"
2526
#include "../dml/interface.h"
2627
#include "../openvino/interface.h"
2728
#include "../ryzenai/interface.h"
@@ -918,6 +919,7 @@ MultiModalProcessor::MultiModalProcessor(Config& config, const SessionInfo& sess
918919
{"whisper", Processor::Create<WhisperProcessor>},
919920
{"phi4mm", Processor::Create<PhiMultiModalProcessor>},
920921
{"gemma3", Processor::Create<GemmaImageProcessor>},
922+
{"mistral3", Processor::Create<Mistral3ImageProcessor>},
921923
{"fara", Processor::Create<QwenImageProcessor>},
922924
{"qwen2_5_vl", Processor::Create<QwenImageProcessor>},
923925
{"qwen3_vl", Processor::Create<QwenImageProcessor>},

src/models/model_type.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ struct ModelType {
2121

2222
inline static bool IsVLM(const std::string& model_type) {
2323
// Vision-language model (VLM)
24-
static constexpr std::array<std::string_view, 6> VLM = {"fara", "gemma3", "phi3v", "qwen2_5_vl", "qwen3_vl", "qwen3_5"};
24+
static constexpr std::array<std::string_view, 7> VLM = {"fara", "gemma3", "mistral3", "phi3v", "qwen2_5_vl", "qwen3_vl", "qwen3_5"};
2525
return std::find(VLM.begin(), VLM.end(), model_type) != VLM.end();
2626
}
2727

@@ -30,6 +30,11 @@ struct ModelType {
3030
return model_type == "fara" || model_type == "qwen2_5_vl" || model_type == "qwen3_vl" || model_type == "qwen3_5";
3131
}
3232

33+
inline static bool IsPixtralFamily(const std::string& model_type) {
34+
// Pixtral family: per-image vision loop with variable resolution
35+
return model_type == "mistral3";
36+
}
37+
3338
inline static bool IsALM(const std::string& model_type) {
3439
// Audio-language model (ALM)
3540
static constexpr std::array<std::string_view, 1> ALM = {"whisper"};

0 commit comments

Comments
 (0)