@@ -74,10 +74,18 @@ or masked-scatter like the original non-realtime Voxtral).
7474
7575## Memory Footprint
7676
77- Decoder KV cache: 26 layers × 2 (K, V) × 4096 × 8 × 128 × bytes_per_elem.
78- fp32: ≈ 832 MB, bf16: ≈ 416 MB. Encoder KV caches (streaming):
79- 32 layers × 2 × 1500 × 32 × 64 × bytes_per_elem. fp32: ≈ 786 MB,
80- bf16: ≈ 393 MB.
77+ Decoder KV cache depends on mode:
78+ - ** Offline:** flat buffer sized by ` max_seq_len ` (default 4096).
79+ 26 layers × 2 × 4096 × 8 × 128 × bytes_per_elem.
80+ fp32: ≈ 832 MB, bf16: ≈ 416 MB.
81+ - ** Streaming:** ring buffer sized to 2× ` sliding_window ` (default 8192
82+ → 16384 slots; ` --sliding-window 2048 ` → 4096 slots).
83+ 26 layers × 2 × 2×sliding_window × 8 × 128 × bytes_per_elem.
84+ sw=8192 fp32: ≈ 3.3 GB, bf16: ≈ 1.7 GB.
85+ sw=2048 fp32: ≈ 832 MB, bf16: ≈ 416 MB.
86+
87+ Encoder KV caches (streaming only): 32 layers × 2 × 1500 × 32 × 64 ×
88+ bytes_per_elem. fp32: ≈ 786 MB, bf16: ≈ 393 MB.
8189
8290Runtime memory = model weights (from ` .pte ` ) + KV caches + working
8391memory. Weight sizes depend on quantization: ~ 16 GB (fp32), ~ 8 GB
@@ -103,7 +111,7 @@ VoxtralRealtimeModel
103111 attention_norm: RMSNorm
104112 attention: LMAttention
105113 wq/wk/wv/wo: Linear (no bias)
106- kv_cache: KVCache (XNNPACK) or StaticKVCache (Metal/CUDA)
114+ kv_cache: streaming: RingKVCache/StandardRingKVCache; offline: KVCache/StaticKVCache
107115 sdpa: SDPA (XNNPACK) or MetalSDPA (Metal) or StandardSDPA (CUDA)
108116 ffn_norm: RMSNorm
109117 ada_rms_norm_t_cond: Sequential(Linear, GELU, Linear)
@@ -117,8 +125,8 @@ StreamingAudioEncoderExport
117125 layers: 32x CausalEncoderLayer (shared from encoder.layers)
118126 enc_norm: RMSNorm (shared from encoder.norm)
119127 adapter: AudioLanguageAdapter (shared from model.adapter)
120- kv_caches: 32x EncoderRingKVCache (XNNPACK) or StandardEncoderRingKVCache (Metal/CUDA)
121- sdpa: SDPA (XNNPACK) or MetalSDPA (Metal, transpose_kv=True ) or StandardSDPA (CUDA, transpose_kv=True )
128+ kv_caches: 32x RingKVCache (XNNPACK) or StandardRingKVCache (Metal/CUDA)
129+ sdpa: SDPA (XNNPACK) or MetalSDPA (Metal) or StandardSDPA (CUDA)
122130 inv_freq: RoPE inverse frequencies (owned, on-the-fly computation)
123131```
124132
@@ -145,51 +153,71 @@ flag (e.g., `"xnnpack"`, `"metal"`, `"cuda"`, `"portable"`).
145153
146154### KV cache
147155
148- ** XNNPACK/Portable:** ` KVCache ` with ` [B, S, H, D] ` layout. Uses
149- ` torch.ops.llama.update_cache(value, cache, start_pos) ` which mutates
150- the cache in-place. This avoids the ` index_put_ ` + ` copy_ ` pattern that
151- triggers a ` requires_grad ` bug in ` SpecPropPass ` during ` to_executorch() ` .
152- The ` [B, S, H, D] ` layout matches what ` update_cache ` and ` custom_sdpa `
153- expect, so there are no transposes between cache update and attention.
156+ The decoder KV cache depends on the export mode:
154157
155- ** Metal/CUDA:** ` StaticKVCache ` with ` [B, H, S, D] ` layout. Uses ` index_copy_ `
156- for cache updates, which is compatible with ` torch.export ` and AOTI.
158+ ** Streaming (` --streaming ` ):** Ring buffer KV cache for unlimited
159+ duration. The model's ` params.json ` specifies ` sliding_window: 8192 `
160+ for the decoder (overridable via ` --sliding-window ` ). Each query attends
161+ to only the last ` sliding_window ` positions; old entries are overwritten
162+ when the buffer wraps. Position tracking is analytic (no mutable state).
163+ Sliding window masks are computed each step via ` create_causal_mask ` .
164+
165+ - XNNPACK/Portable: ` RingKVCache ` with ` [B, S, H, D] ` layout, using
166+ ` torch.ops.llama.update_cache_with_indices ` for scatter writes.
167+ - Metal/CUDA: ` StandardRingKVCache ` with ` [B, H, S, D] ` layout, using
168+ ` index_copy_ ` on dim=2 with wrapped indices.
169+
170+ ** Offline (default):** Flat KV cache bounded by ` max_seq_len ` (default
171+ 4096). Full causal attention — each query attends to all prior positions.
172+
173+ - XNNPACK/Portable: ` KVCache ` with ` [B, S, H, D] ` layout, using
174+ ` torch.ops.llama.update_cache ` .
175+ - Metal/CUDA: ` StaticKVCache ` with ` [B, H, S, D] ` layout, using
176+ ` index_copy_ ` .
157177
158178### SDPA
159179
160180` SDPA ` is its own module (not inline code), making it swappable for
161181backend-specific implementations.
162182
163- ** XNNPACK/Portable:** ` SDPA ` uses ` torch.ops.llama.custom_sdpa ` — a
164- fused kernel with causal masking via ` start_pos ` + ` is_causal=True ` .
183+ ** XNNPACK/Portable:** ` SDPA ` uses ` torch.ops.llama.custom_sdpa ` .
184+ In streaming mode, receives a sliding window mask from the ring cache.
185+ In offline mode, uses ` is_causal=True ` with no explicit mask.
165186Handles GQA expansion internally and upcasts to float32.
166187
167188** Metal:** ` MetalSDPA ` uses ` torch.ops.aten._scaled_dot_product_attention_math_for_mps `
168189which handles GQA natively (the kernel infers the group ratio from differing
169190Q vs K/V head counts), avoiding the memory bandwidth overhead of
170191` repeat_interleave ` . Uses explicit additive attention masks
171192that must match the Q/K/V dtype (the kernel reads masks as ` device T* ` ).
172- Used for both decoder (GQA, ` transpose_kv=False ` ) and streaming encoder
173- (no GQA, ` transpose_kv=True ` ).
193+ Both streaming and offline use ` [B, H, S, D] ` KV layout
194+ (` StandardRingKVCache ` and ` StaticKVCache ` share this layout), so
195+ ` transpose_kv=False ` in all cases.
174196
175197** CUDA:** ` StandardSDPA ` uses ` F.scaled_dot_product_attention ` with
176- ` repeat_interleave ` for GQA expansion (32 query heads / 8 KV heads = 4x).
177- Uses boolean attention masks (` True ` =attend, ` False ` =masked) as required
178- by the Triton SDPA kernel. The CUDA backend's Triton SDPA replacement
179- pass optimizes the attention kernel at compile time.
198+ ` enable_gqa=True ` . Uses boolean attention masks (` True ` =attend,
199+ ` False ` =masked) as required by the Triton SDPA kernel. Same
200+ ` [B, H, S, D] ` KV layout as Metal.
180201
181202### Attention layout
182203
183- ** XNNPACK/Portable:** Q/K/V projections produce ` [B, T, H, D] ` via
184- ` .view() ` . RoPE operates on ` [B, T, H, D] ` . ` KVCache ` stores
185- ` [B, S, H, D] ` . ` SDPA ` (custom_sdpa) receives both in this layout — no
186- ` transpose(1, 2) ` in the attention hot path. This eliminates the need for
187- ` RemoveRedundantTransposes ` post-export pass that Llama/optimum-executorch
188- require when using ` [B, H, S, D] ` attention with ` [B, S, H, D] ` cache.
204+ Q/K/V projections produce ` [B, T, H, D] ` via ` .view() ` . RoPE operates
205+ on ` [B, T, H, D] ` .
206+
207+ ** XNNPACK/Portable:** Both ` KVCache ` (offline) and ` RingKVCache `
208+ (streaming) use ` [B, S, H, D] ` . ` SDPA ` (custom_sdpa) receives Q and
209+ KV cache in this layout — no ` transpose(1, 2) ` in the attention hot path.
210+
211+ ** Metal/CUDA:** Both ` StandardRingKVCache ` (streaming) and ` StaticKVCache `
212+ (offline) use ` [B, H, S, D] ` layout. ` MetalSDPA ` /` StandardSDPA ` only
213+ transpose Q from ` [B, T, H, D] ` to ` [B, H, T, D] ` — KV is already in
214+ the expected layout.
215+
216+ ### RoPE
189217
190- ** Metal/CUDA: ** Q/K/V projections still produce ` [B, T, H, D] ` , but
191- ` StaticKVCache ` stores ` [B, H, S, D] ` and ` MetalSDPA ` / ` StandardSDPA ` transpose q to
192- ` [B, H, T, D] ` for the SDPA kernel, then transpose back .
218+ RoPE frequencies are computed on-the-fly using stored ` inv_freq `
219+ (same pattern as the streaming encoder), enabling unlimited position
220+ indices without a precomputed table bound .
193221
194222### Adaptive RMSNorm
195223
@@ -227,15 +255,15 @@ mel_chunk (1, 128, 8) + enc_input_pos (4,)
227255-> audio_embeds (1, 1, 3072)
228256```
229257
230- ** XNNPACK/Portable:** Uses ` EncoderRingKVCache ` (` update_cache_with_indices `
258+ ** XNNPACK/Portable:** Uses ` RingKVCache ` (` update_cache_with_indices `
231259custom op) and ` SDPA ` (` custom_sdpa ` ).
232260
233- ** Metal:** Uses ` StandardEncoderRingKVCache ` (` index_copy_ ` -based ring
234- buffer) and ` MetalSDPA ` (native MPS SDPA kernel with ` transpose_kv=True ` ).
261+ ** Metal:** Uses ` StandardRingKVCache ` (` index_copy_ ` -based ring
262+ buffer) and ` MetalSDPA ` (native MPS SDPA kernel).
235263Masks are created in the model dtype to match the kernel's ` device T* ` expectation.
236264
237- ** CUDA:** Uses ` StandardEncoderRingKVCache ` and ` StandardSDPA `
238- (` F.scaled_dot_product_attention ` with ` transpose_kv=True ` and explicit
265+ ** CUDA:** Uses ` StandardRingKVCache ` and ` StandardSDPA `
266+ (` F.scaled_dot_product_attention ` with explicit
239267sliding window masks).
240268
241269### Streaming decode loop
@@ -270,7 +298,7 @@ encoder — verified to within fp32 precision (max diff < 2e-5).
270298### Encoder KV cache
271299
272300Each of the 32 encoder transformer layers gets its own ring buffer KV
273- cache (` EncoderRingKVCache ` for XNNPACK/Portable, ` StandardEncoderRingKVCache `
301+ cache (` RingKVCache ` for XNNPACK/Portable, ` StandardRingKVCache `
274302for Metal/CUDA) that overwrites old entries when the window is exceeded,
275303enabling streaming of arbitrary length audio.
276304
0 commit comments