Skip to content

Commit 8f047dd

Browse files
committed
up
1 parent fd01e12 commit 8f047dd

2 files changed

Lines changed: 16 additions & 7 deletions

File tree

examples/models/parakeet/export_parakeet_tdt.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -640,7 +640,15 @@ def main():
640640
"--backend",
641641
type=str,
642642
default="xnnpack",
643-
choices=["portable", "xnnpack", "metal", "mlx", "cuda", "cuda-windows", "vulkan"],
643+
choices=[
644+
"portable",
645+
"xnnpack",
646+
"metal",
647+
"mlx",
648+
"cuda",
649+
"cuda-windows",
650+
"vulkan",
651+
],
644652
help="Backend for acceleration (default: xnnpack)",
645653
)
646654
parser.add_argument(

examples/models/voxtral_realtime/model.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -700,7 +700,10 @@ def __init__(self, config: VoxtralRealtimeConfig):
700700
if self.backend == "mlx":
701701
cache_dtype = self.wq.weight.dtype
702702
self.kv_cache = MLXKVCache(
703-
config.sliding_window, self.n_kv_heads, self.head_dim, dtype=cache_dtype
703+
config.sliding_window,
704+
self.n_kv_heads,
705+
self.head_dim,
706+
dtype=cache_dtype,
704707
)
705708
self.sdpa = MLXSDPA(self.n_heads, self.head_dim)
706709
elif self.backend == "metal":
@@ -1170,7 +1173,7 @@ def __init__(self, model: VoxtralRealtimeModel, max_enc_len: int = 750):
11701173
elif config.backend == "metal":
11711174
self.kv_caches = nn.ModuleList(
11721175
[
1173-
StandardEncoderRingKVCache(
1176+
StandardRingKVCache(
11741177
max_enc_len, config.enc_n_heads, config.enc_head_dim
11751178
)
11761179
for _ in range(config.enc_n_layers)
@@ -1184,7 +1187,7 @@ def __init__(self, model: VoxtralRealtimeModel, max_enc_len: int = 750):
11841187
elif config.backend == "cuda":
11851188
self.kv_caches = nn.ModuleList(
11861189
[
1187-
StandardEncoderRingKVCache(
1190+
StandardRingKVCache(
11881191
max_enc_len, config.enc_n_heads, config.enc_head_dim
11891192
)
11901193
for _ in range(config.enc_n_layers)
@@ -1198,9 +1201,7 @@ def __init__(self, model: VoxtralRealtimeModel, max_enc_len: int = 750):
11981201
else:
11991202
self.kv_caches = nn.ModuleList(
12001203
[
1201-
EncoderRingKVCache(
1202-
max_enc_len, config.enc_n_heads, config.enc_head_dim
1203-
)
1204+
RingKVCache(max_enc_len, config.enc_n_heads, config.enc_head_dim)
12041205
for _ in range(config.enc_n_layers)
12051206
]
12061207
)

0 commit comments

Comments
 (0)