@@ -41,8 +41,9 @@ The model exports three methods (offline mode):
4141| ` token_embedding ` | token IDs ` (1, seq_len) ` | embeddings ` (1, seq_len, 3072) ` |
4242
4343With ` --streaming ` , ` audio_encoder ` is replaced by ` encode_audio_chunk `
44- which takes a mel chunk ` (1, 128, 8) ` + conv states + encoder positions
45- and returns audio embeddings ` (1, 1, 3072) ` + updated conv states.
44+ which takes a mel chunk ` (1, 128, 8) ` + encoder positions ` (4,) ` and
45+ returns audio embeddings ` (1, 1, 3072) ` . Conv states are maintained as
46+ internal buffers.
4647
4748Audio and text embeddings are ** summed** at each position (not concatenated
4849or masked-scatter like the original non-realtime Voxtral).
@@ -101,21 +102,21 @@ VoxtralRealtimeModel
101102 attention: LMAttention
102103 wq/wk/wv/wo: Linear (no bias)
103104 kv_cache: KVCache (XNNPACK) or StaticKVCache (Metal)
104- sdpa: SDPA (XNNPACK) or StandardSDPA (Metal)
105+ sdpa: SDPA (XNNPACK) or MetalSDPA (Metal)
105106 ffn_norm: RMSNorm
106107 ada_rms_norm_t_cond: Sequential(Linear, GELU, Linear)
107108 feed_forward: LMMLP (w1/w2/w3)
108109 norm: RMSNorm
109110 output: Linear (tied to tok_embeddings)
110111
111- StreamingAudioEncoderExport (XNNPACK/Portable only)
112+ StreamingAudioEncoderExport
112113 conv1: nn.Conv1d (shared from encoder.conv_layers[0].conv)
113114 conv2: nn.Conv1d (shared from encoder.conv_layers[1].conv)
114115 layers: 32x CausalEncoderLayer (shared from encoder.layers)
115116 enc_norm: RMSNorm (shared from encoder.norm)
116117 adapter: AudioLanguageAdapter (shared from model.adapter)
117- kv_caches: 32x EncoderRingKVCache (ring buffer for sliding window attention )
118- sdpa: SDPA (for streaming attention with custom_sdpa op )
118+ kv_caches: 32x EncoderRingKVCache (XNNPACK) or StandardEncoderRingKVCache (Metal )
119+ sdpa: SDPA (XNNPACK) or StandardEncoderSDPA (Metal )
119120 inv_freq: RoPE inverse frequencies (owned, on-the-fly computation)
120121```
121122
@@ -137,11 +138,8 @@ than 750 encoder frames (~15s), full causal is equivalent.
137138
138139The text decoder (` MistralDecoder ` ) is a 26-layer Mistral decoder with
139140GQA (32 query heads, 8 KV heads). Backend selection is controlled by the
140- ` use_standard_attention ` config flag, set by the export script:
141-
142- ``` python
143- use_standard_attention = (args.backend == " metal" )
144- ```
141+ ` backend ` config field, passed through from the export script's ` --backend `
142+ flag (e.g., ` "xnnpack" ` , ` "metal" ` , ` "portable" ` ).
145143
146144### KV cache
147145
@@ -164,9 +162,10 @@ backend-specific implementations.
164162fused kernel with causal masking via ` start_pos ` + ` is_causal=True ` .
165163Handles GQA expansion internally and upcasts to float32.
166164
167- ** Metal:** ` StandardSDPA ` uses ` F.scaled_dot_product_attention ` with
168- explicit attention masks. AOTInductor has compatibility issues with the
169- ` custom_sdpa ` custom op.
165+ ** Metal:** ` MetalSDPA ` uses ` torch.ops.aten._scaled_dot_product_attention_math_for_mps `
166+ which handles GQA natively via ` gqa_factor ` , avoiding the memory bandwidth
167+ overhead of ` repeat_interleave ` . Uses explicit additive attention masks.
168+ AOTInductor has compatibility issues with the ` custom_sdpa ` custom op.
170169
171170### Attention layout
172171
@@ -178,8 +177,9 @@ explicit attention masks. AOTInductor has compatibility issues with the
178177require when using ` [B, H, S, D] ` attention with ` [B, S, H, D] ` cache.
179178
180179** Metal:** Q/K/V projections still produce ` [B, T, H, D] ` , but
181- ` StaticKVCache ` stores ` [B, H, S, D] ` and ` StandardSDPA ` transposes q to
182- ` [B, H, T, D] ` for ` F.scaled_dot_product_attention ` , then transposes back.
180+ ` StaticKVCache ` stores ` [B, H, S, D] ` and ` MetalSDPA ` transposes q to
181+ ` [B, H, T, D] ` for ` _scaled_dot_product_attention_math_for_mps ` , then
182+ transposes back.
183183
184184### Adaptive RMSNorm
185185
@@ -205,28 +205,25 @@ mel at once. It shares all weights with the offline encoder but uses a
205205different forward path:
206206
207207```
208- mel_chunk (1, 128, 8)
209- + conv1_state (1, 128, 2) + conv2_state (1, 1280, 2)
208+ mel_chunk (1, 128, 8) + enc_input_pos (4,)
209+ conv1_state (1, 128, 2) and conv2_state (1, 1280, 2) are internal buffers
210210 -> cat(state, chunk) -> raw Conv1d (no CausalConv1d padding) -> GELU
211211 -> cat(state, conv1_out) -> raw Conv1d -> GELU
212212(1, 1280, 4) -> transpose -> (1, 4, 1280)
213- -> 32x streaming encoder layer (EncoderRingKVCache + custom_sdpa )
213+ -> 32x streaming encoder layer (ring KV cache + SDPA )
214214 -> RMSNorm
215215(1, 4, 1280)
216216 -> Reshape downsample (1, 1, 5120) -> Adapter (1, 1, 3072)
217- -> audio_embeds, new_conv1_state, new_conv2_state
217+ -> audio_embeds (1, 1, 3072)
218218```
219219
220- ** XNNPACK/Portable only.** Metal does not yet support streaming mode.
221- The custom ops used by ` StreamingAudioEncoderExport `
222- (` update_cache_with_indices ` , ` custom_sdpa ` ) are incompatible with AOTI.
223- Adding Metal streaming support would require:
220+ ** XNNPACK/Portable:** Uses ` EncoderRingKVCache ` (` update_cache_with_indices `
221+ custom op) and ` SDPA ` (` custom_sdpa ` ).
224222
225- - Replace ` EncoderRingKVCache ` with an ` index_copy_ ` -based ring buffer
226- (similar to ` StaticKVCache ` but with modular index arithmetic)
227- - Replace ` SDPA ` (` custom_sdpa ` ) with ` StandardSDPA ` using explicit
228- sliding window masks
229- - These are the same patterns already used in the Metal text decoder
223+ ** Metal:** Uses ` StandardEncoderRingKVCache ` (` index_copy_ ` -based ring
224+ buffer) and ` StandardEncoderSDPA ` (` F.scaled_dot_product_attention ` with
225+ explicit sliding window masks) — the same patterns used in the Metal
226+ text decoder.
230227
231228### Streaming decode loop
232229
@@ -258,9 +255,10 @@ encoder — verified to within fp32 precision (max diff < 2e-5).
258255
259256### Encoder KV cache
260257
261- Each of the 32 encoder transformer layers gets its own ` EncoderRingKVCache `
262- instance — a ring buffer that overwrites old entries when the window is
263- exceeded, enabling streaming of arbitrary length audio.
258+ Each of the 32 encoder transformer layers gets its own ring buffer KV
259+ cache (` EncoderRingKVCache ` for XNNPACK/Portable, ` StandardEncoderRingKVCache `
260+ for Metal) that overwrites old entries when the window is exceeded,
261+ enabling streaming of arbitrary length audio.
264262
265263- Cache shape: ` (1, 2*max_enc_len, 32, 64) ` per layer. The buffer is 2x the
266264 window size because writes happen * before* attention. With a 1x buffer
@@ -281,10 +279,14 @@ ring buffer. This is unrelated to `max_enc_len=16384` in
281279` CausalWhisperEncoder.__init__ ` , which is the RoPE frequency table size
282280for the offline encoder.
283281
284- Cache writes use ` torch.ops.llama.update_cache_with_indices ` (a custom op
285- that scatter-writes via an indices tensor). Write indices are computed
286- analytically: ` (arange(seq_len) + start_pos) % buf_size ` . No mutable
287- position state is needed.
282+ ** XNNPACK/Portable:** Cache writes use ` torch.ops.llama.update_cache_with_indices `
283+ (a custom op that scatter-writes via an indices tensor). Write indices are
284+ computed analytically: ` (arange(seq_len) + start_pos) % buf_size ` .
285+
286+ ** Metal:** Cache writes use ` index_copy_ ` with wrapped indices
287+ (` input_pos % buf_size ` ).
288+
289+ No mutable position state is needed in either variant.
288290
289291Position tracking is analytic — no mutable state buffer. For buffer
290292slot ` j ` after ` total_written ` frames have been stored:
0 commit comments