@@ -700,7 +700,10 @@ def __init__(self, config: VoxtralRealtimeConfig):
700700 if self .backend == "mlx" :
701701 cache_dtype = self .wq .weight .dtype
702702 self .kv_cache = MLXKVCache (
703- config .sliding_window , self .n_kv_heads , self .head_dim , dtype = cache_dtype
703+ config .sliding_window ,
704+ self .n_kv_heads ,
705+ self .head_dim ,
706+ dtype = cache_dtype ,
704707 )
705708 self .sdpa = MLXSDPA (self .n_heads , self .head_dim )
706709 elif self .backend == "metal" :
@@ -1170,7 +1173,7 @@ def __init__(self, model: VoxtralRealtimeModel, max_enc_len: int = 750):
11701173 elif config .backend == "metal" :
11711174 self .kv_caches = nn .ModuleList (
11721175 [
1173- StandardEncoderRingKVCache (
1176+ StandardRingKVCache (
11741177 max_enc_len , config .enc_n_heads , config .enc_head_dim
11751178 )
11761179 for _ in range (config .enc_n_layers )
@@ -1184,7 +1187,7 @@ def __init__(self, model: VoxtralRealtimeModel, max_enc_len: int = 750):
11841187 elif config .backend == "cuda" :
11851188 self .kv_caches = nn .ModuleList (
11861189 [
1187- StandardEncoderRingKVCache (
1190+ StandardRingKVCache (
11881191 max_enc_len , config .enc_n_heads , config .enc_head_dim
11891192 )
11901193 for _ in range (config .enc_n_layers )
@@ -1198,9 +1201,7 @@ def __init__(self, model: VoxtralRealtimeModel, max_enc_len: int = 750):
11981201 else :
11991202 self .kv_caches = nn .ModuleList (
12001203 [
1201- EncoderRingKVCache (
1202- max_enc_len , config .enc_n_heads , config .enc_head_dim
1203- )
1204+ RingKVCache (max_enc_len , config .enc_n_heads , config .enc_head_dim )
12041205 for _ in range (config .enc_n_layers )
12051206 ]
12061207 )
0 commit comments