Skip to content

Commit b71f18b

Browse files
authored
Add Gemma4 multimodal support (vision + audio) (microsoft#2103)
## Summary Adds end-to-end support for Google Gemma 4 multimodal models in ORT GenAI, covering text-only (`gemma4_text`), vision-language, and any-to-any (vision + audio + text) variants. ## Changes ### Model registration - Register `gemma4_text` as LLM and `gemma4` as MMM (multi-modal model) - MMM auto-detects speech support from `speech.filename` in genai_config — no separate `gemma4_any_to_any` type needed - Register `Gemma4MultiModalProcessor` in the processor factory ### Gemma4 multimodal processor (`gemma4_multimodal_processor.cpp/h`) - **Vision**: Preprocesses images via `Gemma4ImageTransform` (onnxruntime-extensions), trims padded patches to actual count using `num_soft_tokens` from preprocessor, produces `pixel_values` + `pixel_position_ids` - **Audio**: Extracts mel features via `Gemma4LogMel`, computes `audio_sizes` for the pipeline, generates `input_features_mask` (all-True for single-clip inference), and expands `<|audio|>` placeholder tokens in the prompt - **Prompt handling**: Expands both `<|image|>` and `<|audio|>` tokens from the chat template into the correct number of soft tokens before encoding. Handles template-inserted tokens and auto-insertion when no template is available ### KV cache — per-layer head_dim (`kv_cache.cpp/h`) - Auto-detects varying `head_dim` across layers from ONNX session input shapes (Gemma4 uses 256 for sliding-window layers, 512 for global attention layers) - Creates per-layer `empty_pasts_` with correct head dimensions - Handles `layer_shapes_[i][2] == 0` (unconstrained) in Update to avoid zero-size allocation - Updates `layer_shapes_` sequence dimension for `past_present_share_buffer` mode ### Position inputs — int64 support (`position_inputs.cpp`) - `WindowedPositionInputs` now supports both `int32_t` and `int64_t` for `position_ids` and `attention_mask` - Type-dispatching lambdas for all data access points (first window, subsequent windows, token generation) ### Multi-modal pipeline (`multi_modal.cpp/h`) - **DecoderState**: Optional `decoder_input_ids_` for models requiring `input_ids` alongside `inputs_embeds` - **EmbeddingState**: Handles empty `audio_features` tensor when embedding model requires it but no speech session exists (`AllocateEmptyFeatures`) - **SpeechState**: Manages 3D→2D reshape of speech output (`ReshapeFeatures`) before passing to embedding model - **Pipeline**: Conditional audio feature reshape and empty audio fallback based on `num_audio_tokens_` ### MultiModalFeatures (`multi_modal_features.cpp/h`) - `AllocateEmptyFeatures()` — pre-allocates empty tensor for optional inputs - `ReshapeFeatures()` — in-place reshape with data copy and state pointer update - `batch_size <= 0` support — skip batch dimension for 3D model outputs ### Config (`config.h/cpp`) - Added `pixel_position_ids` to vision inputs - Added `audio_token_id` and `boa_token_id` to model config - Added `PixelPositionIdsName` default constant ### Example script (`common.py`) - Added `{"type": "audio"}` entries for Gemma-style structured content in `get_user_content` ## Testing Tested with Gemma4 E2B model exported via mobius: - ✅ Text-only generation - ✅ Image description (detailed landscape analysis) - ✅ Audio transcription (Windows SAPI TTS → model correctly identifies speech content) - ✅ Image-only with any-to-any config (empty audio_features handled) - ✅ Mixed GQA + standard Attention with `past_present_share_buffer=false` - ✅ Per-layer head_dim KV cache (256/512) - ✅ int64 position_ids with `WindowedPositionInputs`
1 parent e7301d0 commit b71f18b

19 files changed

Lines changed: 805 additions & 96 deletions

cmake/deps.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ pybind11;https://github.com/pybind/pybind11/archive/refs/tags/v2.13.6.zip;f78029
1414
googletest;https://github.com/google/googletest/archive/530d5c8c84abd2a46f38583ee817743c9b3a42b4.zip;5e3a61db2aa975cfd0f97ba92c818744e7fa7034
1515
microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5
1616
directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e
17-
onnxruntime_extensions;https://github.com/microsoft/onnxruntime-extensions.git;539d380ce9c2fcdfc9fd9f151ef5604425215aa9
17+
onnxruntime_extensions;https://github.com/microsoft/onnxruntime-extensions.git;e094cc816679d0b2b5fe2b4fd7f73e5b1844b425
1818

1919
# These two dependencies are for the optional constrained decoding feature (USE_GUIDANCE)
2020
llguidance;https://github.com/microsoft/llguidance.git;94fa39128ef184ffeda33845f6d333f332a34b4d

examples/python/common.py

Lines changed: 112 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
import argparse
55
import json
66
import os
7+
from dataclasses import asdict, dataclass
8+
from typing import Any
9+
710
import onnxruntime_genai as og
811

9-
from dataclasses import dataclass, asdict
10-
from typing import Any
1112

1213
def set_logger(inputs: bool = True, outputs: bool = True) -> None:
1314
"""
@@ -21,6 +22,7 @@ def set_logger(inputs: bool = True, outputs: bool = True) -> None:
2122
"""
2223
og.set_log_options(enabled=True, model_input_values=inputs, model_output_values=outputs)
2324

25+
2426
def register_ep(ep: str, ep_path: str, use_winml: bool) -> None:
2527
"""
2628
Register execution provider if path is provided or via Windows ML
@@ -42,6 +44,7 @@ def register_ep(ep: str, ep_path: str, use_winml: bool) -> None:
4244
# Modified from here: https://learn.microsoft.com/en-us/windows/ai/new-windows-ml/tutorial?tabs=python#acquiring-the-model-and-preprocessing
4345
try:
4446
import winml
47+
4548
print(winml.register_execution_providers(ort=False, ort_genai=True))
4649
except ImportError:
4750
print("WinML not available, using default execution providers")
@@ -53,11 +56,14 @@ def register_ep(ep: str, ep_path: str, use_winml: bool) -> None:
5356
og.register_execution_provider_library("NvTensorRTRTXExecutionProvider", ep_path)
5457
else:
5558
print(f"Warning: EP registration not supported for {ep}")
56-
print("Only 'cuda' and 'NvTensorRtRtx' support plug-in libraries. Use Windows ML via '--use_winml' to register EPs.")
59+
print(
60+
"Only 'cuda' and 'NvTensorRtRtx' support plug-in libraries. Use Windows ML via '--use_winml' to register EPs."
61+
)
5762
return
5863

5964
print(f"Registered {ep} successfully!")
6065

66+
6167
def get_config(path: str, ep: str, ep_options: dict[str, str] = {}, search_options: dict[str, int] = {}) -> og.Config:
6268
"""
6369
Get og.Config object and set EP-specific and search-specific options inside it
@@ -98,6 +104,7 @@ def get_config(path: str, ep: str, ep_options: dict[str, str] = {}, search_optio
98104
config.overlay(json.dumps({"search": search_options}))
99105
return config
100106

107+
101108
def get_search_options(args: argparse.Namespace):
102109
"""
103110
Get search options for a generator's params during decoding
@@ -128,7 +135,10 @@ def get_search_options(args: argparse.Namespace):
128135
search_options["batch_size"] = search_options.get("batch_size", 1)
129136
return search_options
130137

131-
def apply_chat_template(model_path: str, tokenizer: og.Tokenizer, messages: str, add_generation_prompt: bool, tools: str = "") -> str:
138+
139+
def apply_chat_template(
140+
model_path: str, tokenizer: og.Tokenizer, messages: str, add_generation_prompt: bool, tools: str = ""
141+
) -> str:
132142
"""
133143
Apply the chat template with various fallback options
134144
@@ -151,6 +161,7 @@ def apply_chat_template(model_path: str, tokenizer: og.Tokenizer, messages: str,
151161
)
152162
return prompt
153163

164+
154165
def get_user_prompt(prompt: str, non_interactive: bool) -> str:
155166
"""
156167
Get prompt for 'user' role in chat template
@@ -179,6 +190,7 @@ def get_user_prompt(prompt: str, non_interactive: bool) -> str:
179190

180191
return text
181192

193+
182194
def get_user_media_paths(media_paths: list[str], non_interactive: bool, media_type: str) -> list[str]:
183195
"""
184196
Get paths to media for user
@@ -202,7 +214,9 @@ def get_user_media_paths(media_paths: list[str], non_interactive: bool, media_ty
202214
# If interactive mode is on
203215
paths = [
204216
path.strip()
205-
for path in input(f"{media_type.capitalize()} Path (comma separated; leave empty if no {media_type}): ").split(",")
217+
for path in input(
218+
f"{media_type.capitalize()} Path (comma separated; leave empty if no {media_type}): "
219+
).split(",")
206220
]
207221

208222
paths = [path for path in paths if path]
@@ -213,6 +227,7 @@ def get_user_media_paths(media_paths: list[str], non_interactive: bool, media_ty
213227

214228
return paths
215229

230+
216231
def get_user_images(image_paths: list[str], non_interactive: bool) -> tuple[og.Images, int]:
217232
"""
218233
Get images for user
@@ -232,6 +247,7 @@ def get_user_images(image_paths: list[str], non_interactive: bool) -> tuple[og.I
232247
images = og.Images.open(*paths)
233248
return images, len(paths)
234249

250+
235251
def get_user_audios(audio_paths: list[str], non_interactive: bool) -> tuple[og.Audios, int]:
236252
"""
237253
Get audios for user
@@ -251,6 +267,7 @@ def get_user_audios(audio_paths: list[str], non_interactive: bool) -> tuple[og.A
251267
audios = og.Audios.open(*paths)
252268
return audios, len(paths)
253269

270+
254271
def get_user_content(model_type: str, num_images: int, num_audios: int, prompt: str) -> str | list[dict[str, str]]:
255272
"""
256273
Get content for 'user' role in chat template
@@ -284,49 +301,59 @@ def get_user_content(model_type: str, num_images: int, num_audios: int, prompt:
284301
image_tags = "".join(["[IMG]" for _ in range(num_images)])
285302
content = image_tags + prompt
286303
else:
287-
# Gemma-3 style: structured content
304+
# Gemma-3/4 style: structured content with image and audio entries
288305
image_tags = [{"type": "image"} for _ in range(num_images)]
289-
content = image_tags + [{"type": "text", "text": prompt}]
306+
audio_tags = [{"type": "audio"} for _ in range(num_audios)]
307+
content = image_tags + audio_tags + [{"type": "text", "text": prompt}]
290308
return content
291309

310+
292311
@dataclass
293312
class ToolSchema:
294313
"""
295314
A class for defining a tool in a JSON schema compatible way
296315
"""
316+
297317
description: str
298318
type: str
299319
properties: dict[str, Any]
300320
required: list[str]
301321
additionalProperties: bool
302322

323+
303324
@dataclass
304325
class JsonSchema:
305326
"""
306327
A class for defining a JSON schema for guidance
307328
"""
329+
308330
x_guidance: dict[str, Any]
309331
type: str
310332
items: dict[str, list[ToolSchema]]
311333
minItems: int
312334

335+
313336
@dataclass
314337
class FunctionDefinition:
315338
"""
316339
A class for defining a function in an OpenAI-compatible way
317340
"""
341+
318342
name: str
319343
description: str
320344
parameters: dict[str, Any]
321345

346+
322347
@dataclass
323348
class Tool:
324349
"""
325350
A class for defining a tool in an OpenAI-compatible way
326351
"""
352+
327353
type: str
328354
function: FunctionDefinition
329355

356+
330357
def tools_to_schemas(tools: list[Tool]) -> list[ToolSchema]:
331358
"""
332359
Convert a list of tools to a list of tool schemas
@@ -360,6 +387,7 @@ def tools_to_schemas(tools: list[Tool]) -> list[ToolSchema]:
360387
tool_schemas.append(tool_schema)
361388
return tool_schemas
362389

390+
363391
def get_json_schema(tools: list[Tool], tool_output: bool) -> str:
364392
"""
365393
Create a JSON schema from a list of tools
@@ -376,6 +404,7 @@ def get_json_schema(tools: list[Tool], tool_output: bool) -> str:
376404
d = {k.replace("x_guidance", "x-guidance"): v for k, v in asdict(json_schema).items()}
377405
return json.dumps(d)
378406

407+
379408
def get_lark_grammar(
380409
tools: list[Tool],
381410
text_output: bool,
@@ -423,6 +452,7 @@ def get_lark_grammar(
423452

424453
return "\n".join(rows)
425454

455+
426456
def to_tool(tool_defs: list[dict[str, Any]]) -> list[Tool]:
427457
"""
428458
Convert a JSON-deserialized object of tools to a list of Tool objects
@@ -443,6 +473,7 @@ def to_tool(tool_defs: list[dict[str, Any]]) -> list[Tool]:
443473
tools.append(tool)
444474
return tools
445475

476+
446477
def get_guidance(
447478
response_format: str = "",
448479
filepath: str = "",
@@ -474,7 +505,7 @@ def get_guidance(
474505
if tool_output:
475506
if os.path.exists(filepath):
476507
# If tools are provided as a file
477-
with open(filepath, 'r') as f:
508+
with open(filepath) as f:
478509
tool_defs = json.load(f)
479510
tools = to_tool(tool_defs)
480511
elif tools_str != "":
@@ -488,14 +519,18 @@ def get_guidance(
488519
if type(tools[0]) != Tool:
489520
tools = to_tool(tools)
490521
else:
491-
raise ValueError("Please provide the list of tools through a file, JSON-serialized string, or a list of tools")
522+
raise ValueError(
523+
"Please provide the list of tools through a file, JSON-serialized string, or a list of tools"
524+
)
492525

493526
assert len(tools) > 0, "Could not obtain a list of tools in memory"
494527

495528
# Create guidance based on user-provided response format
496529
if response_format in {"text", "lark_grammar"}:
497530
if response_format == "text":
498-
assert text_output and not tool_output, "A response format of 'text' requires text_output = True and tool_output = False"
531+
assert text_output and not tool_output, (
532+
"A response format of 'text' requires text_output = True and tool_output = False"
533+
)
499534

500535
guidance_type = "lark_grammar"
501536
guidance_data = get_lark_grammar(
@@ -506,7 +541,9 @@ def get_guidance(
506541
tool_call_end=tool_call_end,
507542
)
508543
elif response_format in {"json_schema", "json_object"}:
509-
assert tool_output and not text_output, "A response format of 'json_schema' or 'json_object' requires text_output = False and tool_output = True"
544+
assert tool_output and not text_output, (
545+
"A response format of 'json_schema' or 'json_object' requires text_output = False and tool_output = True"
546+
)
510547

511548
guidance_type = "json_schema"
512549
guidance_data = get_json_schema(tools=tools, tool_output=tool_output)
@@ -515,6 +552,7 @@ def get_guidance(
515552

516553
return guidance_type, guidance_data, json.dumps([asdict(tool) for tool in tools])
517554

555+
518556
def get_generator_params_args(parser: argparse.ArgumentParser) -> None:
519557
"""
520558
Add an argument group for the generator params
@@ -525,16 +563,34 @@ def get_generator_params_args(parser: argparse.ArgumentParser) -> None:
525563
None
526564
"""
527565
generator_params = parser.add_argument_group("Generator Params")
528-
generator_params.add_argument('-c', '--chunk_size', type=int, default=0, help="Chunk size for prefill chunking during context processing (default: 0 = disabled, >0 = enabled)")
529-
generator_params.add_argument('-s', '--do_sample', action='store_true', help='Do random sampling. When false, greedy or beam search are used to generate the output. Defaults to false')
530-
generator_params.add_argument('-i', '--min_length', type=int, help='Min number of tokens to generate including the prompt')
531-
generator_params.add_argument('-l', '--max_length', type=int, help='Max number of tokens to generate including the prompt')
532-
generator_params.add_argument('-b', '--num_beams', type=int, default=1, help='Number of beams to create')
533-
generator_params.add_argument('-rs', '--num_return_sequences', type=int, default=1, help='Number of return sequences to produce')
534-
generator_params.add_argument('-r', '--repetition_penalty', type=float, help='Repetition penalty to sample with')
535-
generator_params.add_argument('-t', '--temperature', type=float, help='Temperature to sample with')
536-
generator_params.add_argument('-k', '--top_k', type=int, help='Top k tokens to sample from')
537-
generator_params.add_argument('-p', '--top_p', type=float, help='Top p probability to sample with')
566+
generator_params.add_argument(
567+
"-c",
568+
"--chunk_size",
569+
type=int,
570+
default=0,
571+
help="Chunk size for prefill chunking during context processing (default: 0 = disabled, >0 = enabled)",
572+
)
573+
generator_params.add_argument(
574+
"-s",
575+
"--do_sample",
576+
action="store_true",
577+
help="Do random sampling. When false, greedy or beam search are used to generate the output. Defaults to false",
578+
)
579+
generator_params.add_argument(
580+
"-i", "--min_length", type=int, help="Min number of tokens to generate including the prompt"
581+
)
582+
generator_params.add_argument(
583+
"-l", "--max_length", type=int, help="Max number of tokens to generate including the prompt"
584+
)
585+
generator_params.add_argument("-b", "--num_beams", type=int, default=1, help="Number of beams to create")
586+
generator_params.add_argument(
587+
"-rs", "--num_return_sequences", type=int, default=1, help="Number of return sequences to produce"
588+
)
589+
generator_params.add_argument("-r", "--repetition_penalty", type=float, help="Repetition penalty to sample with")
590+
generator_params.add_argument("-t", "--temperature", type=float, help="Temperature to sample with")
591+
generator_params.add_argument("-k", "--top_k", type=int, help="Top k tokens to sample from")
592+
generator_params.add_argument("-p", "--top_p", type=float, help="Top p probability to sample with")
593+
538594

539595
def get_guidance_args(parser: argparse.ArgumentParser) -> None:
540596
"""
@@ -546,9 +602,38 @@ def get_guidance_args(parser: argparse.ArgumentParser) -> None:
546602
None
547603
"""
548604
guidance = parser.add_argument_group("Guidance Arguments")
549-
guidance.add_argument('-rf', '--response_format', type=str, default="", choices=["", "text", "json_object", "json_schema", "lark_grammar"], help='Provide response format for the model')
550-
guidance.add_argument('-tf', '--tools_file', type=str, default="", help='Path to file containing list of OpenAI-compatible tool definitions. Ex: test/test_models/tool-definitions/weather.json')
551-
guidance.add_argument('-text', '--text_output', action='store_true', default=False, help='Produce a text response in the output')
552-
guidance.add_argument('-tool', '--tool_output', action='store_true', default=False, help='Produce a tool call in the output')
553-
guidance.add_argument('-tcs', '--tool_call_start', type=str, default="", help='String representation of tool call start (ex: <|tool_call|>). Needs to be marked as special in tokenizer.json for guidance to work.')
554-
guidance.add_argument('-tce', '--tool_call_end', type=str, default="", help='String representation of tool call end (ex: <|/tool_call|>). Needs to be marked as special in tokenizer.json for guidance to work.')
605+
guidance.add_argument(
606+
"-rf",
607+
"--response_format",
608+
type=str,
609+
default="",
610+
choices=["", "text", "json_object", "json_schema", "lark_grammar"],
611+
help="Provide response format for the model",
612+
)
613+
guidance.add_argument(
614+
"-tf",
615+
"--tools_file",
616+
type=str,
617+
default="",
618+
help="Path to file containing list of OpenAI-compatible tool definitions. Ex: test/test_models/tool-definitions/weather.json",
619+
)
620+
guidance.add_argument(
621+
"-text", "--text_output", action="store_true", default=False, help="Produce a text response in the output"
622+
)
623+
guidance.add_argument(
624+
"-tool", "--tool_output", action="store_true", default=False, help="Produce a tool call in the output"
625+
)
626+
guidance.add_argument(
627+
"-tcs",
628+
"--tool_call_start",
629+
type=str,
630+
default="",
631+
help="String representation of tool call start (ex: <|tool_call|>). Needs to be marked as special in tokenizer.json for guidance to work.",
632+
)
633+
guidance.add_argument(
634+
"-tce",
635+
"--tool_call_end",
636+
type=str,
637+
default="",
638+
help="String representation of tool call end (ex: <|/tool_call|>). Needs to be marked as special in tokenizer.json for guidance to work.",
639+
)

src/config.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,8 @@ struct VisionInputs_Element : JSON::Element {
648648
void OnValue(std::string_view name, JSON::Value value) override {
649649
if (name == "pixel_values") {
650650
v_.pixel_values = JSON::Get<std::string_view>(value);
651+
} else if (name == "pixel_position_ids") {
652+
v_.pixel_position_ids = JSON::Get<std::string_view>(value);
651653
} else if (name == "image_sizes") {
652654
v_.image_sizes = JSON::Get<std::string_view>(value);
653655
} else if (name == "image_grid_thw") {
@@ -1096,6 +1098,10 @@ struct Model_Element : JSON::Element {
10961098
v_.sep_token_id = static_cast<int>(JSON::Get<double>(value));
10971099
} else if (name == "image_token_id") {
10981100
v_.image_token_id = static_cast<int>(JSON::Get<double>(value));
1101+
} else if (name == "audio_token_id") {
1102+
v_.audio_token_id = static_cast<int>(JSON::Get<double>(value));
1103+
} else if (name == "boa_token_id") {
1104+
v_.boa_token_id = static_cast<int>(JSON::Get<double>(value));
10991105
} else if (name == "video_token_id") {
11001106
v_.video_token_id = static_cast<int>(JSON::Get<double>(value));
11011107
} else if (name == "vision_start_token_id") {

0 commit comments

Comments
 (0)