Skip to content

Commit 25f2a3f

Browse files
authored
Voxtral Realtime: (1) convert use_standard_sdpa to backend flag (2) consistent md files (#17749)
1 parent dae7a02 commit 25f2a3f

3 files changed

Lines changed: 61 additions & 56 deletions

File tree

examples/models/voxtral_realtime/export_voxtral_rt.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,16 @@
1212
- token_embedding: token_ids (1, seq_len) -> embeds (1, seq_len, 3072)
1313
1414
With --streaming, produces a streaming .pte instead:
15-
- encode_audio_chunk: mel_chunk (1,128,8) + conv states + enc_pos -> audio_embeds + new states
15+
- encode_audio_chunk: mel_chunk (1,128,8) + enc_pos (4,) -> audio_embeds (1,1,3072)
1616
- text_decoder: same as above
1717
- token_embedding: same as above
1818
1919
Backend support:
2020
- XNNPACK (default): Uses custom SDPA op (torch.ops.llama.custom_sdpa) for optimal performance
21-
- Metal/AOTI: Automatically switches to standard PyTorch SDPA (F.scaled_dot_product_attention)
22-
for text_decoder to avoid AOTI compilation issues. Uses Dim.AUTO for audio encoder
23-
dynamic shapes (explicit bounds cause issues with AOTI). All components run on Metal GPU.
21+
- Metal/AOTI: Uses MetalSDPA (_scaled_dot_product_attention_math_for_mps) for text_decoder
22+
and StandardEncoderSDPA (F.scaled_dot_product_attention) for streaming encoder,
23+
avoiding custom_sdpa which is incompatible with AOTI. Uses Dim.AUTO for audio
24+
encoder dynamic shapes (explicit bounds cause issues with AOTI).
2425
- Portable: Uses custom SDPA like XNNPACK
2526
2627
Usage:
@@ -475,12 +476,11 @@ def main():
475476

476477
# Load model
477478
print("Loading model...")
478-
use_standard_attention = args.backend == "metal"
479479
model = load_model(
480480
args.model_path,
481481
max_seq_len=args.max_seq_len,
482482
n_delay_tokens=args.delay_tokens,
483-
use_standard_attention=use_standard_attention,
483+
backend=args.backend,
484484
)
485485

486486
# Untie output/embedding weights before quantization so each layer gets

examples/models/voxtral_realtime/model.md

Lines changed: 38 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,9 @@ The model exports three methods (offline mode):
4141
| `token_embedding` | token IDs `(1, seq_len)` | embeddings `(1, seq_len, 3072)` |
4242

4343
With `--streaming`, `audio_encoder` is replaced by `encode_audio_chunk`
44-
which takes a mel chunk `(1, 128, 8)` + conv states + encoder positions
45-
and returns audio embeddings `(1, 1, 3072)` + updated conv states.
44+
which takes a mel chunk `(1, 128, 8)` + encoder positions `(4,)` and
45+
returns audio embeddings `(1, 1, 3072)`. Conv states are maintained as
46+
internal buffers.
4647

4748
Audio and text embeddings are **summed** at each position (not concatenated
4849
or masked-scatter like the original non-realtime Voxtral).
@@ -101,21 +102,21 @@ VoxtralRealtimeModel
101102
attention: LMAttention
102103
wq/wk/wv/wo: Linear (no bias)
103104
kv_cache: KVCache (XNNPACK) or StaticKVCache (Metal)
104-
sdpa: SDPA (XNNPACK) or StandardSDPA (Metal)
105+
sdpa: SDPA (XNNPACK) or MetalSDPA (Metal)
105106
ffn_norm: RMSNorm
106107
ada_rms_norm_t_cond: Sequential(Linear, GELU, Linear)
107108
feed_forward: LMMLP (w1/w2/w3)
108109
norm: RMSNorm
109110
output: Linear (tied to tok_embeddings)
110111
111-
StreamingAudioEncoderExport (XNNPACK/Portable only)
112+
StreamingAudioEncoderExport
112113
conv1: nn.Conv1d (shared from encoder.conv_layers[0].conv)
113114
conv2: nn.Conv1d (shared from encoder.conv_layers[1].conv)
114115
layers: 32x CausalEncoderLayer (shared from encoder.layers)
115116
enc_norm: RMSNorm (shared from encoder.norm)
116117
adapter: AudioLanguageAdapter (shared from model.adapter)
117-
kv_caches: 32x EncoderRingKVCache (ring buffer for sliding window attention)
118-
sdpa: SDPA (for streaming attention with custom_sdpa op)
118+
kv_caches: 32x EncoderRingKVCache (XNNPACK) or StandardEncoderRingKVCache (Metal)
119+
sdpa: SDPA (XNNPACK) or StandardEncoderSDPA (Metal)
119120
inv_freq: RoPE inverse frequencies (owned, on-the-fly computation)
120121
```
121122

@@ -137,11 +138,8 @@ than 750 encoder frames (~15s), full causal is equivalent.
137138

138139
The text decoder (`MistralDecoder`) is a 26-layer Mistral decoder with
139140
GQA (32 query heads, 8 KV heads). Backend selection is controlled by the
140-
`use_standard_attention` config flag, set by the export script:
141-
142-
```python
143-
use_standard_attention = (args.backend == "metal")
144-
```
141+
`backend` config field, passed through from the export script's `--backend`
142+
flag (e.g., `"xnnpack"`, `"metal"`, `"portable"`).
145143

146144
### KV cache
147145

@@ -164,9 +162,10 @@ backend-specific implementations.
164162
fused kernel with causal masking via `start_pos` + `is_causal=True`.
165163
Handles GQA expansion internally and upcasts to float32.
166164

167-
**Metal:** `StandardSDPA` uses `F.scaled_dot_product_attention` with
168-
explicit attention masks. AOTInductor has compatibility issues with the
169-
`custom_sdpa` custom op.
165+
**Metal:** `MetalSDPA` uses `torch.ops.aten._scaled_dot_product_attention_math_for_mps`
166+
which handles GQA natively via `gqa_factor`, avoiding the memory bandwidth
167+
overhead of `repeat_interleave`. Uses explicit additive attention masks.
168+
AOTInductor has compatibility issues with the `custom_sdpa` custom op.
170169

171170
### Attention layout
172171

@@ -178,8 +177,9 @@ explicit attention masks. AOTInductor has compatibility issues with the
178177
require when using `[B, H, S, D]` attention with `[B, S, H, D]` cache.
179178

180179
**Metal:** Q/K/V projections still produce `[B, T, H, D]`, but
181-
`StaticKVCache` stores `[B, H, S, D]` and `StandardSDPA` transposes q to
182-
`[B, H, T, D]` for `F.scaled_dot_product_attention`, then transposes back.
180+
`StaticKVCache` stores `[B, H, S, D]` and `MetalSDPA` transposes q to
181+
`[B, H, T, D]` for `_scaled_dot_product_attention_math_for_mps`, then
182+
transposes back.
183183

184184
### Adaptive RMSNorm
185185

@@ -205,28 +205,25 @@ mel at once. It shares all weights with the offline encoder but uses a
205205
different forward path:
206206

207207
```
208-
mel_chunk (1, 128, 8)
209-
+ conv1_state (1, 128, 2) + conv2_state (1, 1280, 2)
208+
mel_chunk (1, 128, 8) + enc_input_pos (4,)
209+
conv1_state (1, 128, 2) and conv2_state (1, 1280, 2) are internal buffers
210210
-> cat(state, chunk) -> raw Conv1d (no CausalConv1d padding) -> GELU
211211
-> cat(state, conv1_out) -> raw Conv1d -> GELU
212212
(1, 1280, 4) -> transpose -> (1, 4, 1280)
213-
-> 32x streaming encoder layer (EncoderRingKVCache + custom_sdpa)
213+
-> 32x streaming encoder layer (ring KV cache + SDPA)
214214
-> RMSNorm
215215
(1, 4, 1280)
216216
-> Reshape downsample (1, 1, 5120) -> Adapter (1, 1, 3072)
217-
-> audio_embeds, new_conv1_state, new_conv2_state
217+
-> audio_embeds (1, 1, 3072)
218218
```
219219

220-
**XNNPACK/Portable only.** Metal does not yet support streaming mode.
221-
The custom ops used by `StreamingAudioEncoderExport`
222-
(`update_cache_with_indices`, `custom_sdpa`) are incompatible with AOTI.
223-
Adding Metal streaming support would require:
220+
**XNNPACK/Portable:** Uses `EncoderRingKVCache` (`update_cache_with_indices`
221+
custom op) and `SDPA` (`custom_sdpa`).
224222

225-
- Replace `EncoderRingKVCache` with an `index_copy_`-based ring buffer
226-
(similar to `StaticKVCache` but with modular index arithmetic)
227-
- Replace `SDPA` (`custom_sdpa`) with `StandardSDPA` using explicit
228-
sliding window masks
229-
- These are the same patterns already used in the Metal text decoder
223+
**Metal:** Uses `StandardEncoderRingKVCache` (`index_copy_`-based ring
224+
buffer) and `StandardEncoderSDPA` (`F.scaled_dot_product_attention` with
225+
explicit sliding window masks) — the same patterns used in the Metal
226+
text decoder.
230227

231228
### Streaming decode loop
232229

@@ -258,9 +255,10 @@ encoder — verified to within fp32 precision (max diff < 2e-5).
258255

259256
### Encoder KV cache
260257

261-
Each of the 32 encoder transformer layers gets its own `EncoderRingKVCache`
262-
instance — a ring buffer that overwrites old entries when the window is
263-
exceeded, enabling streaming of arbitrary length audio.
258+
Each of the 32 encoder transformer layers gets its own ring buffer KV
259+
cache (`EncoderRingKVCache` for XNNPACK/Portable, `StandardEncoderRingKVCache`
260+
for Metal) that overwrites old entries when the window is exceeded,
261+
enabling streaming of arbitrary length audio.
264262

265263
- Cache shape: `(1, 2*max_enc_len, 32, 64)` per layer. The buffer is 2x the
266264
window size because writes happen *before* attention. With a 1x buffer
@@ -281,10 +279,14 @@ ring buffer. This is unrelated to `max_enc_len=16384` in
281279
`CausalWhisperEncoder.__init__`, which is the RoPE frequency table size
282280
for the offline encoder.
283281

284-
Cache writes use `torch.ops.llama.update_cache_with_indices` (a custom op
285-
that scatter-writes via an indices tensor). Write indices are computed
286-
analytically: `(arange(seq_len) + start_pos) % buf_size`. No mutable
287-
position state is needed.
282+
**XNNPACK/Portable:** Cache writes use `torch.ops.llama.update_cache_with_indices`
283+
(a custom op that scatter-writes via an indices tensor). Write indices are
284+
computed analytically: `(arange(seq_len) + start_pos) % buf_size`.
285+
286+
**Metal:** Cache writes use `index_copy_` with wrapped indices
287+
(`input_pos % buf_size`).
288+
289+
No mutable position state is needed in either variant.
288290

289291
Position tracking is analytic — no mutable state buffer. For buffer
290292
slot `j` after `total_written` frames have been stored:

examples/models/voxtral_realtime/model.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,7 @@ class VoxtralRealtimeConfig:
5050
downsample_factor: int = 4
5151
# Runtime
5252
max_seq_len: int = 4096
53-
use_standard_attention: bool = (
54-
False # Use standard PyTorch attention instead of custom ops
55-
)
53+
backend: str = "xnnpack" # "xnnpack", "metal", or "portable"
5654

5755
@staticmethod
5856
def from_params_json(path: str) -> "VoxtralRealtimeConfig":
@@ -563,15 +561,15 @@ def __init__(self, config: VoxtralRealtimeConfig, max_seq_len: int):
563561
self.n_kv_heads = config.n_kv_heads
564562
self.head_dim = config.head_dim
565563
self.dim = config.dim
566-
self.use_standard_attention = config.use_standard_attention
564+
self.backend = config.backend
567565

568566
self.wq = nn.Linear(config.dim, self.n_heads * self.head_dim, bias=False)
569567
self.wk = nn.Linear(config.dim, self.n_kv_heads * self.head_dim, bias=False)
570568
self.wv = nn.Linear(config.dim, self.n_kv_heads * self.head_dim, bias=False)
571569
self.wo = nn.Linear(self.n_heads * self.head_dim, config.dim, bias=False)
572570

573571
# Choose KV cache and SDPA based on backend
574-
if self.use_standard_attention:
572+
if self.backend == "metal":
575573
self.kv_cache = StaticKVCache(max_seq_len, self.n_kv_heads, self.head_dim)
576574
self.sdpa = MetalSDPA(self.n_heads, self.n_kv_heads, self.head_dim)
577575
else:
@@ -595,7 +593,7 @@ def forward(
595593

596594
k, v = self.kv_cache.update(input_pos, k, v)
597595

598-
if self.use_standard_attention:
596+
if self.backend == "metal":
599597
y = self.sdpa(input_pos, q, k, v, B, T, attn_mask)
600598
else:
601599
y = self.sdpa(input_pos, q, k, v, B, T)
@@ -685,7 +683,7 @@ def forward(
685683

686684
# Compute attention mask once for all 26 layers (P3 optimization).
687685
attn_mask: torch.Tensor | None = None
688-
if self.config.use_standard_attention:
686+
if self.config.backend == "metal":
689687
max_seq_len = self.freqs_cos.shape[0]
690688
attn_mask = _build_attn_mask(input_pos, max_seq_len, input_embeds.device)
691689

@@ -909,7 +907,7 @@ def __init__(self, model: VoxtralRealtimeModel, max_enc_len: int = 750):
909907
# Choose cache implementation based on backend
910908
cache_class = (
911909
StandardEncoderRingKVCache
912-
if config.use_standard_attention
910+
if config.backend == "metal"
913911
else EncoderRingKVCache
914912
)
915913
self.kv_caches = nn.ModuleList(
@@ -920,7 +918,7 @@ def __init__(self, model: VoxtralRealtimeModel, max_enc_len: int = 750):
920918
)
921919

922920
# Choose SDPA based on backend
923-
if config.use_standard_attention:
921+
if config.backend == "metal":
924922
self.sdpa = StandardEncoderSDPA(config.enc_n_heads, config.enc_head_dim)
925923
else:
926924
self.sdpa = SDPA(config.enc_n_heads, config.enc_head_dim)
@@ -1067,7 +1065,7 @@ def load_model(
10671065
max_seq_len: int = 4096,
10681066
n_delay_tokens: int = 6,
10691067
dtype: torch.dtype = torch.float32,
1070-
use_standard_attention: bool = False,
1068+
backend: str = "xnnpack",
10711069
) -> VoxtralRealtimeModel:
10721070
"""Load VoxtralRealtimeModel from a Mistral consolidated checkpoint.
10731071
@@ -1080,20 +1078,25 @@ def load_model(
10801078
max_seq_len: Maximum sequence length for KV cache.
10811079
n_delay_tokens: Transcription delay in tokens (default 6 = 480ms).
10821080
dtype: Weight dtype (default: float32).
1083-
use_standard_attention: Use standard PyTorch attention instead of custom ops
1084-
(required for Metal/AOTI backends).
1081+
backend: Backend for acceleration ("xnnpack", "metal", or "portable").
10851082
"""
1083+
_VALID_BACKENDS = ("xnnpack", "metal", "portable")
1084+
if backend not in _VALID_BACKENDS:
1085+
raise ValueError(
1086+
f"Unknown backend '{backend}'. Must be one of {_VALID_BACKENDS}."
1087+
)
1088+
10861089
from safetensors import safe_open
10871090

10881091
model_dir = Path(model_path)
10891092
config = VoxtralRealtimeConfig.from_params_json(str(model_dir / "params.json"))
10901093
config.max_seq_len = max_seq_len
1091-
config.use_standard_attention = use_standard_attention
1094+
config.backend = backend
10921095

10931096
print(
10941097
f"Building model on meta device (dim={config.dim}, enc_dim={config.enc_dim}, "
10951098
f"layers={config.n_layers}, enc_layers={config.enc_n_layers}, "
1096-
f"attention={'standard' if use_standard_attention else 'custom'})..."
1099+
f"backend={backend})..."
10971100
)
10981101
with torch.device("meta"):
10991102
model = VoxtralRealtimeModel(config, max_seq_len)

0 commit comments

Comments
 (0)