Skip to content

Commit 4e79ee4

Browse files
authored
Voxtral Realtime: unlimited streaming via decoder ring buffer KV cache (#18637)
The decoder previously used a flat KV cache capped at max_seq_len (default 4096 = ~5.5 min). The model's params.json specifies sliding_window: 8192 for the decoder, matching how vLLM handles it, but ExecuTorch ignored it. In streaming mode (--streaming), the decoder now uses RingKVCache / StandardRingKVCache (the same ring buffer classes used by the streaming encoder) with the decoder's sliding_window=8192, on-the-fly RoPE via inv_freq, and sliding window attention masks. This removes all position limits for streaming transcription. Offline mode is unchanged: flat KV cache, precomputed RoPE table, full causal masks, bounded by max_seq_len. Backwards compatible: old .pte files without sliding_window metadata are detected (sliding_window_==0) and the runner preserves the original max_seq_len hard stop. cc @mattjcly
1 parent b24535b commit 4e79ee4

File tree

7 files changed

+258
-163
lines changed

7 files changed

+258
-163
lines changed

.ci/scripts/export_model_artifact.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ if [ "$MODEL_NAME" = "voxtral_realtime" ]; then
358358
STREAMING_ARG=""
359359
PREPROCESSOR_ARGS="--feature_size 128 --output_file ${OUTPUT_DIR}/preprocessor.pte"
360360
if [ "$USE_STREAMING" = "true" ]; then
361-
STREAMING_ARG="--streaming"
361+
STREAMING_ARG="--streaming --sliding-window 2048"
362362
PREPROCESSOR_ARGS="$PREPROCESSOR_ARGS --streaming"
363363
else
364364
PREPROCESSOR_ARGS="$PREPROCESSOR_ARGS --stack_output --max_audio_len 300"

examples/models/voxtral_realtime/README.md

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@ conversion. At inference time, the C++ runner loads both `.pte` files
1515
and the Tekken tokenizer, then transcribes audio to text.
1616

1717
Two modes are supported: **streaming** (process 80ms chunks in real time,
18-
including live microphone input) and **offline** (encode full audio, then
19-
decode). The examples below use streaming mode. Omit `--streaming` from
20-
export and run commands for offline mode.
18+
including live microphone input, with unlimited duration) and **offline**
19+
(encode full audio, then decode, bounded by `--max-seq-len`). The examples
20+
below use streaming mode. Omit `--streaming` from export and run commands
21+
for offline mode.
2122

2223
## Demo: streaming on Metal backend with microphone input
2324

@@ -71,6 +72,7 @@ python export_voxtral_rt.py \
7172
--model-path ~/models/Voxtral-Mini-4B-Realtime-2602 \
7273
--backend xnnpack \
7374
--streaming \
75+
--sliding-window 2048 \
7476
--output-dir ./voxtral_rt_exports \
7577
--qlinear-encoder 8da4w \
7678
--qlinear 8da4w \
@@ -86,6 +88,7 @@ python export_voxtral_rt.py \
8688
--backend metal \
8789
--dtype bf16 \
8890
--streaming \
91+
--sliding-window 2048 \
8992
--output-dir ./voxtral_rt_exports \
9093
--qlinear-encoder fpa4w \
9194
--qlinear fpa4w
@@ -112,6 +115,7 @@ python export_voxtral_rt.py \
112115
--backend cuda \
113116
--dtype bf16 \
114117
--streaming \
118+
--sliding-window 2048 \
115119
--output-dir ./voxtral_rt_exports \
116120
--qlinear-encoder 4w \
117121
--qlinear-encoder-packing-format tile_packed_to_4d \
@@ -137,6 +141,7 @@ python export_voxtral_rt.py \
137141
--backend cuda-windows \
138142
--dtype bf16 \
139143
--streaming \
144+
--sliding-window 2048 \
140145
--output-dir ./voxtral_rt_exports \
141146
--qlinear-encoder 4w \
142147
--qlinear-encoder-packing-format tile_packed_to_4d \
@@ -159,7 +164,7 @@ python export_voxtral_rt.py \
159164
| `--backend` | `xnnpack` | `xnnpack`, `metal`, `cuda`, `cuda-windows`, or `portable` |
160165
| `--dtype` | `fp32` | Model dtype: `fp32` or `bf16` |
161166
| `--output-dir` | `./voxtral_rt_exports` | Output directory |
162-
| `--max-seq-len` | `4096` | KV cache length |
167+
| `--max-seq-len` | `4096` | KV cache length (offline mode only; ignored with `--streaming`) |
163168
| `--delay-tokens` | `6` | Transcription delay in tokens (6 = 480ms) |
164169
| `--qlinear` | (none) | Decoder linear layer quantization (`4w`, `8w`, `8da4w`, `8da8w`, `fpa4w`) |
165170
| `--qlinear-group-size` | `32` | Group size for decoder linear quantization |
@@ -168,12 +173,14 @@ python export_voxtral_rt.py \
168173
| `--qlinear-encoder-group-size` | `32` | Group size for encoder linear quantization |
169174
| `--qlinear-encoder-packing-format` | (none) | Packing format for encoder 4w quantization (`tile_packed_to_4d` for CUDA) |
170175
| `--qembedding` | (none) | Embedding layer quantization (`8w`) |
171-
| `--streaming` | off | Export streaming encoder with KV cache |
176+
| `--streaming` | off | Export streaming model with ring buffer KV caches (unlimited duration) |
172177
| `--max-enc-len` | `750` | Encoder sliding window size (streaming only) |
178+
| `--sliding-window` | from `params.json` | Decoder sliding window size (streaming only; ignored in offline mode). Smaller values reduce memory and improve decode speed but limit context |
173179

174180
**Notes:**
175181
- `fpa4w` quantization requires `--backend metal`.
176182
- The model was trained with `--delay-tokens 6`. Other values may degrade accuracy.
183+
- The decoder sliding window controls how far back the decoder can attend. At 80ms/step: 2048 = ~2.7 min, 4096 = ~5.5 min, 8192 = ~10.9 min.
177184

178185
## Build
179186

@@ -278,8 +285,8 @@ Ctrl+C stops recording and flushes remaining text.
278285

279286
- **Audio format**: Input must be 16kHz mono WAV. Convert with
280287
`ffmpeg -i input.mp3 -ar 16000 -ac 1 output.wav`.
281-
- **OOM during export**: Reduce `--max-seq-len` or skip encoder
282-
quantization (`--qlinear-encoder`).
288+
- **OOM during export**: Reduce `--max-seq-len` (offline mode) or skip
289+
encoder quantization (`--qlinear-encoder`).
283290
- **"Model was not exported with --streaming"**: Re-export with the
284291
`--streaming` flag. Both `--streaming` and `--mic` runner modes
285292
require a streaming-exported model.

examples/models/voxtral_realtime/export_voxtral_rt.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ def export_all(
248248
"dim": model.config.dim,
249249
"vocab_size": model.config.vocab_size,
250250
"max_seq_len": max_seq_len,
251+
"sliding_window": model.config.sliding_window,
251252
}
252253

253254
return programs, metadata
@@ -346,6 +347,7 @@ def export_streaming(
346347
"enc_dim": model.config.enc_dim,
347348
"vocab_size": model.config.vocab_size,
348349
"max_seq_len": max_seq_len,
350+
"sliding_window": model.config.sliding_window,
349351
"streaming": 1,
350352
"step_samples": step_samples,
351353
"chunk_mel_len": chunk_mel_len,
@@ -541,6 +543,14 @@ def main():
541543
default=750,
542544
help="Encoder sliding window size for streaming (default: 750).",
543545
)
546+
parser.add_argument(
547+
"--sliding-window",
548+
type=int,
549+
default=None,
550+
help="Decoder sliding window size for streaming (default: from params.json, "
551+
"typically 8192). Smaller values reduce memory and improve decode speed "
552+
"but limit how far back the decoder can attend. Only used with --streaming.",
553+
)
544554
parser.add_argument(
545555
"--dtype",
546556
default="fp32",
@@ -555,6 +565,8 @@ def main():
555565
parser.error("--qlinear=fpa4w can only be used with --backend=metal")
556566
if args.qlinear_encoder == "fpa4w" and backend_for_export != "metal":
557567
parser.error("--qlinear-encoder=fpa4w can only be used with --backend=metal")
568+
if args.sliding_window is not None and not args.streaming:
569+
parser.error("--sliding-window only applies to --streaming mode")
558570

559571
os.makedirs(args.output_dir, exist_ok=True)
560572

@@ -567,6 +579,8 @@ def main():
567579
n_delay_tokens=args.delay_tokens,
568580
dtype=model_dtype,
569581
backend=backend_for_export,
582+
streaming=args.streaming,
583+
sliding_window=args.sliding_window,
570584
)
571585

572586
# Move to CUDA for CUDA backend export (AOTInductor needs CUDA tensors)

examples/models/voxtral_realtime/model.md

Lines changed: 66 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,18 @@ or masked-scatter like the original non-realtime Voxtral).
7474

7575
## Memory Footprint
7676

77-
Decoder KV cache: 26 layers × 2 (K, V) × 4096 × 8 × 128 × bytes_per_elem.
78-
fp32: ≈ 832 MB, bf16: ≈ 416 MB. Encoder KV caches (streaming):
79-
32 layers × 2 × 1500 × 32 × 64 × bytes_per_elem. fp32: ≈ 786 MB,
80-
bf16: ≈ 393 MB.
77+
Decoder KV cache depends on mode:
78+
- **Offline:** flat buffer sized by `max_seq_len` (default 4096).
79+
26 layers × 2 × 4096 × 8 × 128 × bytes_per_elem.
80+
fp32: ≈ 832 MB, bf16: ≈ 416 MB.
81+
- **Streaming:** ring buffer sized to 2× `sliding_window` (default 8192
82+
→ 16384 slots; `--sliding-window 2048` → 4096 slots).
83+
26 layers × 2 × 2×sliding_window × 8 × 128 × bytes_per_elem.
84+
sw=8192 fp32: ≈ 3.3 GB, bf16: ≈ 1.7 GB.
85+
sw=2048 fp32: ≈ 832 MB, bf16: ≈ 416 MB.
86+
87+
Encoder KV caches (streaming only): 32 layers × 2 × 1500 × 32 × 64 ×
88+
bytes_per_elem. fp32: ≈ 786 MB, bf16: ≈ 393 MB.
8189

8290
Runtime memory = model weights (from `.pte`) + KV caches + working
8391
memory. Weight sizes depend on quantization: ~16 GB (fp32), ~8 GB
@@ -103,7 +111,7 @@ VoxtralRealtimeModel
103111
attention_norm: RMSNorm
104112
attention: LMAttention
105113
wq/wk/wv/wo: Linear (no bias)
106-
kv_cache: KVCache (XNNPACK) or StaticKVCache (Metal/CUDA)
114+
kv_cache: streaming: RingKVCache/StandardRingKVCache; offline: KVCache/StaticKVCache
107115
sdpa: SDPA (XNNPACK) or MetalSDPA (Metal) or StandardSDPA (CUDA)
108116
ffn_norm: RMSNorm
109117
ada_rms_norm_t_cond: Sequential(Linear, GELU, Linear)
@@ -117,8 +125,8 @@ StreamingAudioEncoderExport
117125
layers: 32x CausalEncoderLayer (shared from encoder.layers)
118126
enc_norm: RMSNorm (shared from encoder.norm)
119127
adapter: AudioLanguageAdapter (shared from model.adapter)
120-
kv_caches: 32x EncoderRingKVCache (XNNPACK) or StandardEncoderRingKVCache (Metal/CUDA)
121-
sdpa: SDPA (XNNPACK) or MetalSDPA (Metal, transpose_kv=True) or StandardSDPA (CUDA, transpose_kv=True)
128+
kv_caches: 32x RingKVCache (XNNPACK) or StandardRingKVCache (Metal/CUDA)
129+
sdpa: SDPA (XNNPACK) or MetalSDPA (Metal) or StandardSDPA (CUDA)
122130
inv_freq: RoPE inverse frequencies (owned, on-the-fly computation)
123131
```
124132

@@ -145,51 +153,71 @@ flag (e.g., `"xnnpack"`, `"metal"`, `"cuda"`, `"portable"`).
145153

146154
### KV cache
147155

148-
**XNNPACK/Portable:** `KVCache` with `[B, S, H, D]` layout. Uses
149-
`torch.ops.llama.update_cache(value, cache, start_pos)` which mutates
150-
the cache in-place. This avoids the `index_put_` + `copy_` pattern that
151-
triggers a `requires_grad` bug in `SpecPropPass` during `to_executorch()`.
152-
The `[B, S, H, D]` layout matches what `update_cache` and `custom_sdpa`
153-
expect, so there are no transposes between cache update and attention.
156+
The decoder KV cache depends on the export mode:
154157

155-
**Metal/CUDA:** `StaticKVCache` with `[B, H, S, D]` layout. Uses `index_copy_`
156-
for cache updates, which is compatible with `torch.export` and AOTI.
158+
**Streaming (`--streaming`):** Ring buffer KV cache for unlimited
159+
duration. The model's `params.json` specifies `sliding_window: 8192`
160+
for the decoder (overridable via `--sliding-window`). Each query attends
161+
to only the last `sliding_window` positions; old entries are overwritten
162+
when the buffer wraps. Position tracking is analytic (no mutable state).
163+
Sliding window masks are computed each step via `create_causal_mask`.
164+
165+
- XNNPACK/Portable: `RingKVCache` with `[B, S, H, D]` layout, using
166+
`torch.ops.llama.update_cache_with_indices` for scatter writes.
167+
- Metal/CUDA: `StandardRingKVCache` with `[B, H, S, D]` layout, using
168+
`index_copy_` on dim=2 with wrapped indices.
169+
170+
**Offline (default):** Flat KV cache bounded by `max_seq_len` (default
171+
4096). Full causal attention — each query attends to all prior positions.
172+
173+
- XNNPACK/Portable: `KVCache` with `[B, S, H, D]` layout, using
174+
`torch.ops.llama.update_cache`.
175+
- Metal/CUDA: `StaticKVCache` with `[B, H, S, D]` layout, using
176+
`index_copy_`.
157177

158178
### SDPA
159179

160180
`SDPA` is its own module (not inline code), making it swappable for
161181
backend-specific implementations.
162182

163-
**XNNPACK/Portable:** `SDPA` uses `torch.ops.llama.custom_sdpa` — a
164-
fused kernel with causal masking via `start_pos` + `is_causal=True`.
183+
**XNNPACK/Portable:** `SDPA` uses `torch.ops.llama.custom_sdpa`.
184+
In streaming mode, receives a sliding window mask from the ring cache.
185+
In offline mode, uses `is_causal=True` with no explicit mask.
165186
Handles GQA expansion internally and upcasts to float32.
166187

167188
**Metal:** `MetalSDPA` uses `torch.ops.aten._scaled_dot_product_attention_math_for_mps`
168189
which handles GQA natively (the kernel infers the group ratio from differing
169190
Q vs K/V head counts), avoiding the memory bandwidth overhead of
170191
`repeat_interleave`. Uses explicit additive attention masks
171192
that must match the Q/K/V dtype (the kernel reads masks as `device T*`).
172-
Used for both decoder (GQA, `transpose_kv=False`) and streaming encoder
173-
(no GQA, `transpose_kv=True`).
193+
Both streaming and offline use `[B, H, S, D]` KV layout
194+
(`StandardRingKVCache` and `StaticKVCache` share this layout), so
195+
`transpose_kv=False` in all cases.
174196

175197
**CUDA:** `StandardSDPA` uses `F.scaled_dot_product_attention` with
176-
`repeat_interleave` for GQA expansion (32 query heads / 8 KV heads = 4x).
177-
Uses boolean attention masks (`True`=attend, `False`=masked) as required
178-
by the Triton SDPA kernel. The CUDA backend's Triton SDPA replacement
179-
pass optimizes the attention kernel at compile time.
198+
`enable_gqa=True`. Uses boolean attention masks (`True`=attend,
199+
`False`=masked) as required by the Triton SDPA kernel. Same
200+
`[B, H, S, D]` KV layout as Metal.
180201

181202
### Attention layout
182203

183-
**XNNPACK/Portable:** Q/K/V projections produce `[B, T, H, D]` via
184-
`.view()`. RoPE operates on `[B, T, H, D]`. `KVCache` stores
185-
`[B, S, H, D]`. `SDPA` (custom_sdpa) receives both in this layout — no
186-
`transpose(1, 2)` in the attention hot path. This eliminates the need for
187-
`RemoveRedundantTransposes` post-export pass that Llama/optimum-executorch
188-
require when using `[B, H, S, D]` attention with `[B, S, H, D]` cache.
204+
Q/K/V projections produce `[B, T, H, D]` via `.view()`. RoPE operates
205+
on `[B, T, H, D]`.
206+
207+
**XNNPACK/Portable:** Both `KVCache` (offline) and `RingKVCache`
208+
(streaming) use `[B, S, H, D]`. `SDPA` (custom_sdpa) receives Q and
209+
KV cache in this layout — no `transpose(1, 2)` in the attention hot path.
210+
211+
**Metal/CUDA:** Both `StandardRingKVCache` (streaming) and `StaticKVCache`
212+
(offline) use `[B, H, S, D]` layout. `MetalSDPA`/`StandardSDPA` only
213+
transpose Q from `[B, T, H, D]` to `[B, H, T, D]` — KV is already in
214+
the expected layout.
215+
216+
### RoPE
189217

190-
**Metal/CUDA:** Q/K/V projections still produce `[B, T, H, D]`, but
191-
`StaticKVCache` stores `[B, H, S, D]` and `MetalSDPA`/`StandardSDPA` transpose q to
192-
`[B, H, T, D]` for the SDPA kernel, then transpose back.
218+
RoPE frequencies are computed on-the-fly using stored `inv_freq`
219+
(same pattern as the streaming encoder), enabling unlimited position
220+
indices without a precomputed table bound.
193221

194222
### Adaptive RMSNorm
195223

@@ -227,15 +255,15 @@ mel_chunk (1, 128, 8) + enc_input_pos (4,)
227255
-> audio_embeds (1, 1, 3072)
228256
```
229257

230-
**XNNPACK/Portable:** Uses `EncoderRingKVCache` (`update_cache_with_indices`
258+
**XNNPACK/Portable:** Uses `RingKVCache` (`update_cache_with_indices`
231259
custom op) and `SDPA` (`custom_sdpa`).
232260

233-
**Metal:** Uses `StandardEncoderRingKVCache` (`index_copy_`-based ring
234-
buffer) and `MetalSDPA` (native MPS SDPA kernel with `transpose_kv=True`).
261+
**Metal:** Uses `StandardRingKVCache` (`index_copy_`-based ring
262+
buffer) and `MetalSDPA` (native MPS SDPA kernel).
235263
Masks are created in the model dtype to match the kernel's `device T*` expectation.
236264

237-
**CUDA:** Uses `StandardEncoderRingKVCache` and `StandardSDPA`
238-
(`F.scaled_dot_product_attention` with `transpose_kv=True` and explicit
265+
**CUDA:** Uses `StandardRingKVCache` and `StandardSDPA`
266+
(`F.scaled_dot_product_attention` with explicit
239267
sliding window masks).
240268

241269
### Streaming decode loop
@@ -270,7 +298,7 @@ encoder — verified to within fp32 precision (max diff < 2e-5).
270298
### Encoder KV cache
271299

272300
Each of the 32 encoder transformer layers gets its own ring buffer KV
273-
cache (`EncoderRingKVCache` for XNNPACK/Portable, `StandardEncoderRingKVCache`
301+
cache (`RingKVCache` for XNNPACK/Portable, `StandardRingKVCache`
274302
for Metal/CUDA) that overwrites old entries when the window is exceeded,
275303
enabling streaming of arbitrary length audio.
276304

0 commit comments

Comments
 (0)