You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Add Gemma4 multimodal support (vision + audio) (microsoft#2103)
## Summary
Adds end-to-end support for Google Gemma 4 multimodal models in ORT
GenAI, covering text-only (`gemma4_text`), vision-language, and
any-to-any (vision + audio + text) variants.
## Changes
### Model registration
- Register `gemma4_text` as LLM and `gemma4` as MMM (multi-modal model)
- MMM auto-detects speech support from `speech.filename` in genai_config
— no separate `gemma4_any_to_any` type needed
- Register `Gemma4MultiModalProcessor` in the processor factory
### Gemma4 multimodal processor (`gemma4_multimodal_processor.cpp/h`)
- **Vision**: Preprocesses images via `Gemma4ImageTransform`
(onnxruntime-extensions), trims padded patches to actual count using
`num_soft_tokens` from preprocessor, produces `pixel_values` +
`pixel_position_ids`
- **Audio**: Extracts mel features via `Gemma4LogMel`, computes
`audio_sizes` for the pipeline, generates `input_features_mask`
(all-True for single-clip inference), and expands `<|audio|>`
placeholder tokens in the prompt
- **Prompt handling**: Expands both `<|image|>` and `<|audio|>` tokens
from the chat template into the correct number of soft tokens before
encoding. Handles template-inserted tokens and auto-insertion when no
template is available
### KV cache — per-layer head_dim (`kv_cache.cpp/h`)
- Auto-detects varying `head_dim` across layers from ONNX session input
shapes (Gemma4 uses 256 for sliding-window layers, 512 for global
attention layers)
- Creates per-layer `empty_pasts_` with correct head dimensions
- Handles `layer_shapes_[i][2] == 0` (unconstrained) in Update to avoid
zero-size allocation
- Updates `layer_shapes_` sequence dimension for
`past_present_share_buffer` mode
### Position inputs — int64 support (`position_inputs.cpp`)
- `WindowedPositionInputs` now supports both `int32_t` and `int64_t` for
`position_ids` and `attention_mask`
- Type-dispatching lambdas for all data access points (first window,
subsequent windows, token generation)
### Multi-modal pipeline (`multi_modal.cpp/h`)
- **DecoderState**: Optional `decoder_input_ids_` for models requiring
`input_ids` alongside `inputs_embeds`
- **EmbeddingState**: Handles empty `audio_features` tensor when
embedding model requires it but no speech session exists
(`AllocateEmptyFeatures`)
- **SpeechState**: Manages 3D→2D reshape of speech output
(`ReshapeFeatures`) before passing to embedding model
- **Pipeline**: Conditional audio feature reshape and empty audio
fallback based on `num_audio_tokens_`
### MultiModalFeatures (`multi_modal_features.cpp/h`)
- `AllocateEmptyFeatures()` — pre-allocates empty tensor for optional
inputs
- `ReshapeFeatures()` — in-place reshape with data copy and state
pointer update
- `batch_size <= 0` support — skip batch dimension for 3D model outputs
### Config (`config.h/cpp`)
- Added `pixel_position_ids` to vision inputs
- Added `audio_token_id` and `boa_token_id` to model config
- Added `PixelPositionIdsName` default constant
### Example script (`common.py`)
- Added `{"type": "audio"}` entries for Gemma-style structured content
in `get_user_content`
## Testing
Tested with Gemma4 E2B model exported via mobius:
- ✅ Text-only generation
- ✅ Image description (detailed landscape analysis)
- ✅ Audio transcription (Windows SAPI TTS → model correctly identifies
speech content)
- ✅ Image-only with any-to-any config (empty audio_features handled)
- ✅ Mixed GQA + standard Attention with
`past_present_share_buffer=false`
- ✅ Per-layer head_dim KV cache (256/512)
- ✅ int64 position_ids with `WindowedPositionInputs`
generator_params.add_argument('-c', '--chunk_size', type=int, default=0, help="Chunk size for prefill chunking during context processing (default: 0 = disabled, >0 = enabled)")
529
-
generator_params.add_argument('-s', '--do_sample', action='store_true', help='Do random sampling. When false, greedy or beam search are used to generate the output. Defaults to false')
530
-
generator_params.add_argument('-i', '--min_length', type=int, help='Min number of tokens to generate including the prompt')
531
-
generator_params.add_argument('-l', '--max_length', type=int, help='Max number of tokens to generate including the prompt')
532
-
generator_params.add_argument('-b', '--num_beams', type=int, default=1, help='Number of beams to create')
533
-
generator_params.add_argument('-rs', '--num_return_sequences', type=int, default=1, help='Number of return sequences to produce')
534
-
generator_params.add_argument('-r', '--repetition_penalty', type=float, help='Repetition penalty to sample with')
535
-
generator_params.add_argument('-t', '--temperature', type=float, help='Temperature to sample with')
536
-
generator_params.add_argument('-k', '--top_k', type=int, help='Top k tokens to sample from')
537
-
generator_params.add_argument('-p', '--top_p', type=float, help='Top p probability to sample with')
566
+
generator_params.add_argument(
567
+
"-c",
568
+
"--chunk_size",
569
+
type=int,
570
+
default=0,
571
+
help="Chunk size for prefill chunking during context processing (default: 0 = disabled, >0 = enabled)",
572
+
)
573
+
generator_params.add_argument(
574
+
"-s",
575
+
"--do_sample",
576
+
action="store_true",
577
+
help="Do random sampling. When false, greedy or beam search are used to generate the output. Defaults to false",
578
+
)
579
+
generator_params.add_argument(
580
+
"-i", "--min_length", type=int, help="Min number of tokens to generate including the prompt"
581
+
)
582
+
generator_params.add_argument(
583
+
"-l", "--max_length", type=int, help="Max number of tokens to generate including the prompt"
584
+
)
585
+
generator_params.add_argument("-b", "--num_beams", type=int, default=1, help="Number of beams to create")
586
+
generator_params.add_argument(
587
+
"-rs", "--num_return_sequences", type=int, default=1, help="Number of return sequences to produce"
588
+
)
589
+
generator_params.add_argument("-r", "--repetition_penalty", type=float, help="Repetition penalty to sample with")
590
+
generator_params.add_argument("-t", "--temperature", type=float, help="Temperature to sample with")
591
+
generator_params.add_argument("-k", "--top_k", type=int, help="Top k tokens to sample from")
592
+
generator_params.add_argument("-p", "--top_p", type=float, help="Top p probability to sample with")
guidance.add_argument('-rf', '--response_format', type=str, default="", choices=["", "text", "json_object", "json_schema", "lark_grammar"], help='Provide response format for the model')
550
-
guidance.add_argument('-tf', '--tools_file', type=str, default="", help='Path to file containing list of OpenAI-compatible tool definitions. Ex: test/test_models/tool-definitions/weather.json')
551
-
guidance.add_argument('-text', '--text_output', action='store_true', default=False, help='Produce a text response in the output')
552
-
guidance.add_argument('-tool', '--tool_output', action='store_true', default=False, help='Produce a tool call in the output')
553
-
guidance.add_argument('-tcs', '--tool_call_start', type=str, default="", help='String representation of tool call start (ex: <|tool_call|>). Needs to be marked as special in tokenizer.json for guidance to work.')
554
-
guidance.add_argument('-tce', '--tool_call_end', type=str, default="", help='String representation of tool call end (ex: <|/tool_call|>). Needs to be marked as special in tokenizer.json for guidance to work.')
0 commit comments