|
| 1 | +// Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +// Licensed under the MIT License. |
| 3 | + |
| 4 | +#include "../generators.h" |
| 5 | +#include "model.h" |
| 6 | +#include "mistral3_image_processor.h" |
| 7 | + |
| 8 | +namespace Generators { |
| 9 | + |
| 10 | +namespace { |
| 11 | + |
| 12 | +// Pixtral special tokens — resolved at runtime via tokenizer lookup. |
| 13 | +constexpr char kImgToken[] = "[IMG]"; |
| 14 | +constexpr char kImgBreakToken[] = "[IMG_BREAK]"; |
| 15 | +constexpr char kImgEndToken[] = "[IMG_END]"; |
| 16 | +constexpr char kInstToken[] = "[INST]"; |
| 17 | + |
| 18 | +// Build input_ids for the image portion of the prompt. |
| 19 | +// Returns the token IDs including [IMG], [IMG_BREAK], and [IMG_END]. |
| 20 | +std::vector<int32_t> BuildImageTokenSequence(int patch_rows, int patch_cols, |
| 21 | + int32_t img_id, int32_t break_id, int32_t end_id) { |
| 22 | + std::vector<int32_t> tokens; |
| 23 | + tokens.reserve(patch_rows * patch_cols + patch_rows); |
| 24 | + |
| 25 | + for (int r = 0; r < patch_rows; ++r) { |
| 26 | + for (int c = 0; c < patch_cols; ++c) { |
| 27 | + tokens.push_back(img_id); |
| 28 | + } |
| 29 | + if (r < patch_rows - 1) { |
| 30 | + tokens.push_back(break_id); |
| 31 | + } else { |
| 32 | + tokens.push_back(end_id); |
| 33 | + } |
| 34 | + } |
| 35 | + return tokens; |
| 36 | +} |
| 37 | + |
| 38 | +// Per-image dimensions: each image may have a different resolution after |
| 39 | +// smart_resize. When image_sizes is available (from PixtralImageSizes), |
| 40 | +// use per-image H/W. Otherwise fall back to the (padded) pixel_values shape. |
| 41 | +struct PerImageInfo { |
| 42 | + int patch_rows; |
| 43 | + int patch_cols; |
| 44 | + int64_t num_img_tokens; // [IMG] count only (excludes [IMG_BREAK]/[IMG_END]) |
| 45 | + std::vector<int32_t> token_sequence; |
| 46 | +}; |
| 47 | + |
| 48 | +std::tuple<std::unique_ptr<OrtValue>, int64_t> |
| 49 | +ProcessPixtralPrompt(const Tokenizer& tokenizer, const std::string& prompt, |
| 50 | + OrtxTensor* pixel_values, OrtxTensor* image_sizes_tensor, |
| 51 | + int patch_size, int spatial_merge_size, |
| 52 | + Ort::Allocator& allocator) { |
| 53 | + const int32_t img_token_id = tokenizer.TokenToTokenId(kImgToken); |
| 54 | + const int32_t img_break_id = tokenizer.TokenToTokenId(kImgBreakToken); |
| 55 | + const int32_t img_end_id = tokenizer.TokenToTokenId(kImgEndToken); |
| 56 | + const int32_t inst_token_id = tokenizer.TokenToTokenId(kInstToken); |
| 57 | + |
| 58 | + int64_t num_images = 0; |
| 59 | + std::vector<PerImageInfo> image_infos; |
| 60 | + |
| 61 | + if (pixel_values) { |
| 62 | + const float* data{}; |
| 63 | + const int64_t* shape{}; |
| 64 | + size_t num_dims{}; |
| 65 | + CheckResult(OrtxGetTensorData(pixel_values, reinterpret_cast<const void**>(&data), &shape, &num_dims)); |
| 66 | + if (num_dims != 4) { |
| 67 | + throw std::runtime_error( |
| 68 | + "Mistral3ImageProcessor: expected 4D pixel_values [N,C,H,W], " |
| 69 | + "got " + |
| 70 | + std::to_string(num_dims) + "D tensor."); |
| 71 | + } |
| 72 | + num_images = shape[0]; |
| 73 | + int64_t padded_h = shape[2]; |
| 74 | + int64_t padded_w = shape[3]; |
| 75 | + |
| 76 | + // Read per-image sizes if available, otherwise use padded dimensions |
| 77 | + const int64_t* sizes_data = nullptr; |
| 78 | + if (image_sizes_tensor) { |
| 79 | + const void* raw{}; |
| 80 | + const int64_t* sizes_shape{}; |
| 81 | + size_t sizes_dims{}; |
| 82 | + CheckResult(OrtxGetTensorData(image_sizes_tensor, &raw, &sizes_shape, &sizes_dims)); |
| 83 | + |
| 84 | + if (sizes_dims != 2) { |
| 85 | + throw std::runtime_error( |
| 86 | + "Mistral3ImageProcessor: expected 2D image_sizes tensor [N,2], " |
| 87 | + "got " + |
| 88 | + std::to_string(sizes_dims) + "D tensor."); |
| 89 | + } |
| 90 | + if (sizes_shape[1] != 2) { |
| 91 | + throw std::runtime_error( |
| 92 | + "Mistral3ImageProcessor: expected image_sizes tensor shape [N,2], " |
| 93 | + "got second dimension " + |
| 94 | + std::to_string(sizes_shape[1]) + "."); |
| 95 | + } |
| 96 | + if (sizes_shape[0] != num_images) { |
| 97 | + throw std::runtime_error( |
| 98 | + "Mistral3ImageProcessor: image_sizes tensor first dimension (" + |
| 99 | + std::to_string(sizes_shape[0]) + ") must match pixel_values batch size (" + |
| 100 | + std::to_string(num_images) + ")."); |
| 101 | + } |
| 102 | + sizes_data = static_cast<const int64_t*>(raw); |
| 103 | + } |
| 104 | + |
| 105 | + int64_t effective_patch = static_cast<int64_t>(patch_size) * spatial_merge_size; |
| 106 | + for (int64_t i = 0; i < num_images; ++i) { |
| 107 | + int64_t h = sizes_data ? sizes_data[i * 2] : padded_h; |
| 108 | + int64_t w = sizes_data ? sizes_data[i * 2 + 1] : padded_w; |
| 109 | + |
| 110 | + if (h % effective_patch != 0 || w % effective_patch != 0) { |
| 111 | + throw std::runtime_error( |
| 112 | + "Mistral3ImageProcessor: image " + std::to_string(i) + " dimensions (" + |
| 113 | + std::to_string(h) + "x" + std::to_string(w) + |
| 114 | + ") must be divisible by patch_size*merge_size (" + |
| 115 | + std::to_string(effective_patch) + "). Check smart_resize configuration."); |
| 116 | + } |
| 117 | + |
| 118 | + PerImageInfo info; |
| 119 | + info.patch_rows = static_cast<int>(h / effective_patch); |
| 120 | + info.patch_cols = static_cast<int>(w / effective_patch); |
| 121 | + info.token_sequence = BuildImageTokenSequence(info.patch_rows, info.patch_cols, |
| 122 | + img_token_id, img_break_id, img_end_id); |
| 123 | + // Count only [IMG] tokens — this equals the vision model's feature output count |
| 124 | + // (patch_rows * patch_cols), excluding structural [IMG_BREAK]/[IMG_END] tokens. |
| 125 | + info.num_img_tokens = static_cast<int64_t>( |
| 126 | + std::count(info.token_sequence.begin(), info.token_sequence.end(), img_token_id)); |
| 127 | + image_infos.push_back(std::move(info)); |
| 128 | + } |
| 129 | + } |
| 130 | + |
| 131 | + int64_t total_img_tokens = 0; |
| 132 | + for (const auto& info : image_infos) { |
| 133 | + total_img_tokens += info.num_img_tokens; |
| 134 | + } |
| 135 | + |
| 136 | + // Tokenize the text prompt |
| 137 | + std::vector<int32_t> input_ids; |
| 138 | + if (!prompt.empty()) { |
| 139 | + input_ids = tokenizer.Encode(prompt.c_str()); |
| 140 | + } |
| 141 | + |
| 142 | + // Expand [IMG] placeholders for each image. |
| 143 | + // Each [IMG] (or group of consecutive [IMG] tokens) in the prompt corresponds |
| 144 | + // to one image, expanded with its per-image token sequence. |
| 145 | + if (!image_infos.empty()) { |
| 146 | + std::vector<int32_t> expanded_ids; |
| 147 | + size_t total_expansion = input_ids.size(); |
| 148 | + for (const auto& info : image_infos) { |
| 149 | + total_expansion += info.token_sequence.size(); |
| 150 | + } |
| 151 | + expanded_ids.reserve(total_expansion); |
| 152 | + |
| 153 | + size_t next_image = 0; |
| 154 | + for (size_t i = 0; i < input_ids.size(); ++i) { |
| 155 | + if (input_ids[i] == img_token_id && next_image < image_infos.size()) { |
| 156 | + // Replace this [IMG] (and consecutive [IMG] tokens) with the image's token sequence |
| 157 | + expanded_ids.insert(expanded_ids.end(), |
| 158 | + image_infos[next_image].token_sequence.begin(), |
| 159 | + image_infos[next_image].token_sequence.end()); |
| 160 | + ++next_image; |
| 161 | + // Skip consecutive [IMG] tokens from the original prompt |
| 162 | + while (i + 1 < input_ids.size() && input_ids[i + 1] == img_token_id) { |
| 163 | + ++i; |
| 164 | + } |
| 165 | + } else { |
| 166 | + expanded_ids.push_back(input_ids[i]); |
| 167 | + } |
| 168 | + } |
| 169 | + |
| 170 | + // If not all images had placeholders, insert remaining after [INST] |
| 171 | + if (next_image < image_infos.size()) { |
| 172 | + std::vector<int32_t> remaining_tokens; |
| 173 | + for (size_t img = next_image; img < image_infos.size(); ++img) { |
| 174 | + remaining_tokens.insert(remaining_tokens.end(), |
| 175 | + image_infos[img].token_sequence.begin(), |
| 176 | + image_infos[img].token_sequence.end()); |
| 177 | + } |
| 178 | + |
| 179 | + std::vector<int32_t> final_ids; |
| 180 | + final_ids.reserve(expanded_ids.size() + remaining_tokens.size()); |
| 181 | + bool inserted = false; |
| 182 | + for (size_t i = 0; i < expanded_ids.size(); ++i) { |
| 183 | + final_ids.push_back(expanded_ids[i]); |
| 184 | + if (expanded_ids[i] == inst_token_id && !inserted) { |
| 185 | + final_ids.insert(final_ids.end(), remaining_tokens.begin(), remaining_tokens.end()); |
| 186 | + inserted = true; |
| 187 | + } |
| 188 | + } |
| 189 | + if (!inserted) { |
| 190 | + // No [INST] found — prepend remaining image tokens |
| 191 | + final_ids.clear(); |
| 192 | + final_ids.insert(final_ids.end(), remaining_tokens.begin(), remaining_tokens.end()); |
| 193 | + final_ids.insert(final_ids.end(), expanded_ids.begin(), expanded_ids.end()); |
| 194 | + } |
| 195 | + expanded_ids = std::move(final_ids); |
| 196 | + } |
| 197 | + |
| 198 | + input_ids = std::move(expanded_ids); |
| 199 | + } |
| 200 | + |
| 201 | + auto input_ids_value = OrtValue::CreateTensor<int32_t>( |
| 202 | + allocator, std::vector<int64_t>{1, static_cast<int64_t>(input_ids.size())}); |
| 203 | + std::copy(input_ids.begin(), input_ids.end(), |
| 204 | + input_ids_value->GetTensorMutableData<int32_t>()); |
| 205 | + |
| 206 | + return {std::move(input_ids_value), total_img_tokens}; |
| 207 | +} |
| 208 | +} // namespace |
| 209 | + |
| 210 | +Mistral3ImageProcessor::Mistral3ImageProcessor(Config& config, const SessionInfo& session_info) |
| 211 | + : pixel_values_type_{session_info.GetInputDataType(config.model.vision.inputs.pixel_values)}, |
| 212 | + patch_size_{config.model.vision.patch_size}, |
| 213 | + spatial_merge_size_{config.model.vision.spatial_merge_size} { |
| 214 | + const auto processor_config = |
| 215 | + (config.config_path / fs::path(config.model.vision.config_filename)).string(); |
| 216 | + CheckResult(OrtxCreateProcessor(processor_.ToBeAssigned(), processor_config.c_str())); |
| 217 | + |
| 218 | + config.AddMapping(std::string(Config::Defaults::InputIdsName), |
| 219 | + config.model.embedding.inputs.input_ids); |
| 220 | + config.AddMapping(std::string(Config::Defaults::PixelValuesName), |
| 221 | + config.model.vision.inputs.pixel_values); |
| 222 | +} |
| 223 | + |
| 224 | +std::unique_ptr<NamedTensors> Mistral3ImageProcessor::Process( |
| 225 | + const Tokenizer& tokenizer, const Payload& payload) const { |
| 226 | + std::string prompt = std::string(payload.prompt); |
| 227 | + const Images* images = payload.images; |
| 228 | + Ort::Allocator& allocator{Ort::Allocator::GetWithDefaultOptions()}; |
| 229 | + auto named_tensors = std::make_unique<NamedTensors>(); |
| 230 | + |
| 231 | + if (!images) { |
| 232 | + // Text-only: tokenize prompt without image processing |
| 233 | + auto [input_ids, num_img_tokens] = |
| 234 | + ProcessPixtralPrompt(tokenizer, prompt, nullptr, nullptr, patch_size_, |
| 235 | + spatial_merge_size_, allocator); |
| 236 | + named_tensors->emplace(Config::Defaults::InputIdsName, |
| 237 | + std::make_shared<Tensor>(std::move(input_ids))); |
| 238 | + |
| 239 | + // Explicitly set num_image_tokens=0 for text-only inputs so downstream |
| 240 | + // pipeline components know there are no vision features to process. |
| 241 | + auto zero_tokens = OrtValue::CreateTensor<int64_t>(allocator, std::vector<int64_t>{1}); |
| 242 | + zero_tokens->GetTensorMutableData<int64_t>()[0] = 0; |
| 243 | + named_tensors->emplace(std::string(Config::Defaults::NumImageTokens), |
| 244 | + std::make_shared<Tensor>(std::move(zero_tokens))); |
| 245 | + return named_tensors; |
| 246 | + } |
| 247 | + |
| 248 | + // Process images through the ort-extensions processor (normalization, resizing) |
| 249 | + ort_extensions::OrtxObjectPtr<OrtxTensorResult> result; |
| 250 | + CheckResult(OrtxImagePreProcess(processor_.get(), images->images_.get(), |
| 251 | + result.ToBeAssigned())); |
| 252 | + |
| 253 | + OrtxTensor* pixel_values = nullptr; |
| 254 | + CheckResult(OrtxTensorResultGetAt(result.get(), 0, &pixel_values)); |
| 255 | + |
| 256 | + // Tensor 1: image_sizes[N, 2] from PixtralImageSizes (post-resize, pre-padding). |
| 257 | + // Models must be exported with PixtralImageSizes in processor_config.json. |
| 258 | + OrtxTensor* image_sizes = nullptr; |
| 259 | + CheckResult(OrtxTensorResultGetAt(result.get(), 1, &image_sizes)); |
| 260 | + |
| 261 | + auto [input_ids, num_img_tokens] = |
| 262 | + ProcessPixtralPrompt(tokenizer, prompt, pixel_values, image_sizes, patch_size_, |
| 263 | + spatial_merge_size_, allocator); |
| 264 | + |
| 265 | + named_tensors->emplace(std::string(Config::Defaults::InputIdsName), |
| 266 | + std::make_shared<Tensor>(std::move(input_ids))); |
| 267 | + |
| 268 | + // Convert pixel_values to the vision model's expected dtype (NCHW layout |
| 269 | + // is already handled by the Permute3D step in processor_config.json). |
| 270 | + { |
| 271 | + std::unique_ptr<OrtValue> pv_ortvalue; |
| 272 | + if (pixel_values_type_ == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) { |
| 273 | + pv_ortvalue = ProcessTensor<float>(pixel_values, allocator); |
| 274 | + } else if (pixel_values_type_ == ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16) { |
| 275 | + pv_ortvalue = ProcessTensor<Ort::BFloat16_t>(pixel_values, allocator); |
| 276 | + } else { |
| 277 | + pv_ortvalue = ProcessTensor<Ort::Float16_t>(pixel_values, allocator); |
| 278 | + } |
| 279 | + named_tensors->emplace(std::string(Config::Defaults::PixelValuesName), |
| 280 | + std::make_shared<Tensor>(std::move(pv_ortvalue))); |
| 281 | + } |
| 282 | + |
| 283 | + // Add image_sizes[N, 2] for PixtralVisionState to slice per-image dimensions |
| 284 | + if (image_sizes) { |
| 285 | + named_tensors->emplace(std::string(Config::Defaults::ImageSizesName), |
| 286 | + std::make_shared<Tensor>(ProcessTensor<int64_t>(image_sizes, allocator))); |
| 287 | + } |
| 288 | + |
| 289 | + // Add num_image_tokens (total across all images) for the embedding model |
| 290 | + auto num_img_tokens_value = OrtValue::CreateTensor<int64_t>( |
| 291 | + allocator, std::vector<int64_t>{1}); |
| 292 | + num_img_tokens_value->GetTensorMutableData<int64_t>()[0] = num_img_tokens; |
| 293 | + named_tensors->emplace(std::string(Config::Defaults::NumImageTokens), |
| 294 | + std::make_shared<Tensor>(std::move(num_img_tokens_value))); |
| 295 | + |
| 296 | + return named_tensors; |
| 297 | +} |
| 298 | + |
| 299 | +} // namespace Generators |
0 commit comments