@@ -538,11 +538,12 @@ def forward(
538538 return y .view (bsz , seqlen , self .dim )
539539
540540
541- class MLXKVCache (nn .Module ):
542- """Wrapper that adapts MLX BHSD KV cache for model's BSHD convention.
541+ class MLXStaticKVCache (nn .Module ):
542+ """Wrapper that adapts MLX static KV cache for model's BSHD convention.
543543
544- The model's QKV projections produce [B, S, H, D] tensors, but MLX's
545- KVCache expects [B, H, S, D]. This wrapper transposes on the way in.
544+ For offline (non-streaming) mode. The model's QKV projections produce
545+ [B, S, H, D] tensors, but MLX's KVCache expects [B, H, S, D].
546+ This wrapper transposes on the way in.
546547 """
547548
548549 def __init__ (
@@ -569,12 +570,13 @@ def update(
569570 return self .cache .update (input_pos , k_val , v_val )
570571
571572
572- class MLXEncoderRingKVCache (nn .Module ):
573- """Wrapper that adapts MLX RingBufferKVCache for the encoder 's BSHD convention.
573+ class MLXRingKVCache (nn .Module ):
574+ """Wrapper that adapts MLX RingBufferKVCache for model 's BSHD convention.
574575
575- The encoder's QKV projections produce [B, S, H, D] tensors, but MLX's
576- RingBufferKVCache expects [B, H, S, D]. This wrapper transposes on the
577- way in and delegates ring buffer semantics to the MLX implementation.
576+ For streaming mode (both encoder and decoder). The model's QKV projections
577+ produce [B, S, H, D] tensors, but MLX's RingBufferKVCache expects
578+ [B, H, S, D]. This wrapper transposes on the way in and delegates
579+ ring buffer semantics to the MLX implementation.
578580 """
579581
580582 def __init__ (
@@ -603,7 +605,9 @@ def update(
603605 v_val = v_val .transpose (1 , 2 )
604606 return self .ring_cache .update (input_pos , k_val , v_val )
605607
606- def create_causal_mask (self , start_pos , seq_len , bool_mask = False ) -> torch .Tensor :
608+ def create_causal_mask (
609+ self , start_pos , seq_len , bool_mask = False , ** kwargs
610+ ) -> torch .Tensor :
607611 return self .ring_cache .create_sliding_window_mask (start_pos , seq_len )
608612
609613
@@ -637,9 +641,10 @@ def forward(
637641 return y .transpose (1 , 2 ).contiguous ().view (bsz , seqlen , self .dim )
638642
639643
640- class MLXEncoderSDPA (nn .Module ):
641- """SDPA for streaming encoder with MLX ring buffer KV cache.
644+ class MLXMaskedSDPA (nn .Module ):
645+ """SDPA with explicit mask for MLX ring buffer KV cache.
642646
647+ Used with MLXRingKVCache for streaming mode (both encoder and decoder).
643648 Uses F.scaled_dot_product_attention with explicit attn_mask from the
644649 ring buffer. KV cache is in BHSD layout, queries are in BSHD.
645650 """
@@ -662,7 +667,7 @@ def forward(
662667 Args:
663668 input_pos: (seq_len,) position indices (unused, kept for interface).
664669 q: (B, seq_len, n_heads, head_dim) in BSHD layout.
665- k, v: (B, n_heads, buf_size, head_dim) in BHSD from MLXEncoderRingKVCache .
670+ k, v: (B, n_heads, buf_size, head_dim) in BHSD from MLXRingKVCache .
666671 bsz, seqlen: batch size and query length.
667672 mask: (1, 1, seq_len, buf_size) additive attention mask from ring buffer.
668673 """
@@ -699,7 +704,7 @@ def __init__(self, config: VoxtralRealtimeConfig):
699704 # Ring buffer KV cache for unlimited streaming.
700705 if self .backend == "mlx" :
701706 cache_dtype = self .wq .weight .dtype
702- self .kv_cache = MLXKVCache (
707+ self .kv_cache = MLXRingKVCache (
703708 config .sliding_window ,
704709 self .n_kv_heads ,
705710 self .head_dim ,
@@ -723,7 +728,16 @@ def __init__(self, config: VoxtralRealtimeConfig):
723728 self .sdpa = SDPA (self .n_heads , self .head_dim )
724729 else :
725730 # Flat KV cache for offline mode (capped at max_seq_len).
726- if self .backend == "metal" :
731+ if self .backend == "mlx" :
732+ cache_dtype = self .wq .weight .dtype
733+ self .kv_cache = MLXStaticKVCache (
734+ config .max_seq_len ,
735+ self .n_kv_heads ,
736+ self .head_dim ,
737+ dtype = cache_dtype ,
738+ )
739+ self .sdpa = MLXSDPA (self .n_heads , self .head_dim )
740+ elif self .backend == "metal" :
727741 self .kv_cache = StaticKVCache (
728742 config .max_seq_len , self .n_kv_heads , self .head_dim
729743 )
@@ -1160,7 +1174,7 @@ def __init__(self, model: VoxtralRealtimeModel, max_enc_len: int = 750):
11601174 cache_dtype = self .layers [0 ].attention .wq .weight .dtype
11611175 self .kv_caches = nn .ModuleList (
11621176 [
1163- MLXEncoderRingKVCache (
1177+ MLXRingKVCache (
11641178 max_enc_len ,
11651179 config .enc_n_heads ,
11661180 config .enc_head_dim ,
@@ -1169,7 +1183,7 @@ def __init__(self, model: VoxtralRealtimeModel, max_enc_len: int = 750):
11691183 for _ in range (config .enc_n_layers )
11701184 ]
11711185 )
1172- self .sdpa = MLXEncoderSDPA (config .enc_n_heads , config .enc_head_dim )
1186+ self .sdpa = MLXMaskedSDPA (config .enc_n_heads , config .enc_head_dim )
11731187 elif config .backend == "metal" :
11741188 self .kv_caches = nn .ModuleList (
11751189 [
0 commit comments