Skip to content

Commit 984de05

Browse files
committed
up
1 parent 3cbfa30 commit 984de05

4 files changed

Lines changed: 45 additions & 27 deletions

File tree

.github/workflows/mlx.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ jobs:
181181
182182
echo "::group::Install Voxtral requirements"
183183
${CONDA_RUN} pip install -U "huggingface_hub[cli]<1.0"
184-
${CONDA_RUN} huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN
184+
${CONDA_RUN} python -c "from huggingface_hub import login; login(token='$SECRET_EXECUTORCH_HF_TOKEN')"
185185
${CONDA_RUN} pip install mistral_common librosa soundfile datasets
186186
OPTIMUM_ET_VERSION=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt)
187187
${CONDA_RUN} pip install "optimum-executorch @ git+https://github.com/huggingface/optimum-executorch.git@${OPTIMUM_ET_VERSION}"
@@ -240,13 +240,13 @@ jobs:
240240
241241
echo "::group::Install Voxtral Realtime requirements"
242242
${CONDA_RUN} pip install -U "huggingface_hub[cli]" safetensors
243-
${CONDA_RUN} huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN
243+
${CONDA_RUN} python -c "from huggingface_hub import login; login(token='$SECRET_EXECUTORCH_HF_TOKEN')"
244244
echo "::endgroup::"
245245
246246
${CONDA_RUN} pip list
247247
248248
echo "::group::Download model"
249-
${CONDA_RUN} huggingface-cli download mistralai/Voxtral-Mini-4B-Realtime-2602
249+
${CONDA_RUN} python -c "from huggingface_hub import snapshot_download; snapshot_download('mistralai/Voxtral-Mini-4B-Realtime-2602')"
250250
MODEL_PATH=$(${CONDA_RUN} python -c "from huggingface_hub import snapshot_download; print(snapshot_download('mistralai/Voxtral-Mini-4B-Realtime-2602'))")
251251
echo "Model path: ${MODEL_PATH}"
252252
echo "::endgroup::"
@@ -313,7 +313,7 @@ jobs:
313313
314314
echo "::group::Install Whisper requirements"
315315
${CONDA_RUN} pip install -U "huggingface_hub[cli]<1.0"
316-
${CONDA_RUN} huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN
316+
${CONDA_RUN} python -c "from huggingface_hub import login; login(token='$SECRET_EXECUTORCH_HF_TOKEN')"
317317
${CONDA_RUN} pip install transformers soundfile datasets librosa
318318
echo "::endgroup::"
319319
@@ -447,7 +447,7 @@ jobs:
447447
448448
echo "::group::Install LLM requirements"
449449
${CONDA_RUN} pip install -U "huggingface_hub[cli]<1.0"
450-
${CONDA_RUN} huggingface-cli login --token $SECRET_EXECUTORCH_HF_TOKEN
450+
${CONDA_RUN} python -c "from huggingface_hub import login; login(token='$SECRET_EXECUTORCH_HF_TOKEN')"
451451
OPTIMUM_ET_VERSION=$(cat .ci/docker/ci_commit_pins/optimum-executorch.txt)
452452
${CONDA_RUN} pip install transformers "optimum-executorch @ git+https://github.com/huggingface/optimum-executorch.git@${OPTIMUM_ET_VERSION}"
453453
echo "::endgroup::"

backends/mlx/examples/llm/README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,14 @@ python -m executorch.backends.mlx.examples.llm.export_llm_hf \
4444
--use-custom-sdpa \
4545
--use-custom-kv-cache
4646

47-
# With INT4 quantization
47+
# With 4-bit quantization
4848
python -m executorch.backends.mlx.examples.llm.export_llm_hf \
4949
--model-id "unsloth/Llama-3.2-1B-Instruct" \
5050
--output llama_hf_int4.pte \
5151
--use-custom-sdpa \
5252
--use-custom-kv-cache \
53-
--quantize-linear int4 \
54-
--quantize-embeddings int4
53+
--qlinear 4w \
54+
--qembedding 4w
5555
```
5656

5757
### Options
@@ -62,8 +62,8 @@ python -m executorch.backends.mlx.examples.llm.export_llm_hf \
6262
| `--output` | *(required)* | Output .pte file path |
6363
| `--max-seq-len` | `1024` | Maximum sequence length for KV cache |
6464
| `--dtype` | `bf16` | Model dtype (`fp32`, `fp16`, `bf16`) |
65-
| `--quantize-linear` | None | Quantization for linear layers (`int4`, `int8`) |
66-
| `--quantize-embeddings` | None | Quantization for embedding layers (`int4`, `int8`) |
65+
| `--qlinear` | None | Quantization for linear layers (`4w`, `8w`, `nvfp4`) |
66+
| `--qembedding` | None | Quantization for embedding layers (`4w`, `8w`, `nvfp4`) |
6767
| `--no-tie-word-embeddings` | `False` | Disable re-tying lm_head to embedding after quantization |
6868
| `--use-custom-sdpa` | `False` | Use MLX custom SDPA (`mlx::custom_sdpa`) |
6969
| `--use-custom-kv-cache` | `False` | Use MLX custom KV cache (`mlx::kv_cache_update`) |

backends/mlx/examples/whisper/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ python -m executorch.backends.mlx.examples.whisper.export_whisper \
3434
| `--output-dir` | `whisper_mlx` | Output directory for `.pte` files |
3535
| `--max-decoder-seq-len` | `256` | Maximum decoder sequence length |
3636
| `--dtype` | `bf16` | Model dtype (`fp32`, `fp16`, `bf16`) |
37+
| `--qlinear` | None | Quantization for linear layers (`4w`, `8w`, `nvfp4`) |
38+
| `--qembedding` | None | Quantization for embedding layers (`4w`, `8w`, `nvfp4`) |
39+
| `--qlinear-group-size` | auto | Group size for linear quantization |
40+
| `--qembedding-group-size` | auto | Group size for embedding quantization |
3741

3842

3943
## Run

examples/models/voxtral_realtime/model.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -538,11 +538,12 @@ def forward(
538538
return y.view(bsz, seqlen, self.dim)
539539

540540

541-
class MLXKVCache(nn.Module):
542-
"""Wrapper that adapts MLX BHSD KV cache for model's BSHD convention.
541+
class MLXStaticKVCache(nn.Module):
542+
"""Wrapper that adapts MLX static KV cache for model's BSHD convention.
543543
544-
The model's QKV projections produce [B, S, H, D] tensors, but MLX's
545-
KVCache expects [B, H, S, D]. This wrapper transposes on the way in.
544+
For offline (non-streaming) mode. The model's QKV projections produce
545+
[B, S, H, D] tensors, but MLX's KVCache expects [B, H, S, D].
546+
This wrapper transposes on the way in.
546547
"""
547548

548549
def __init__(
@@ -569,12 +570,13 @@ def update(
569570
return self.cache.update(input_pos, k_val, v_val)
570571

571572

572-
class MLXEncoderRingKVCache(nn.Module):
573-
"""Wrapper that adapts MLX RingBufferKVCache for the encoder's BSHD convention.
573+
class MLXRingKVCache(nn.Module):
574+
"""Wrapper that adapts MLX RingBufferKVCache for model's BSHD convention.
574575
575-
The encoder's QKV projections produce [B, S, H, D] tensors, but MLX's
576-
RingBufferKVCache expects [B, H, S, D]. This wrapper transposes on the
577-
way in and delegates ring buffer semantics to the MLX implementation.
576+
For streaming mode (both encoder and decoder). The model's QKV projections
577+
produce [B, S, H, D] tensors, but MLX's RingBufferKVCache expects
578+
[B, H, S, D]. This wrapper transposes on the way in and delegates
579+
ring buffer semantics to the MLX implementation.
578580
"""
579581

580582
def __init__(
@@ -603,7 +605,9 @@ def update(
603605
v_val = v_val.transpose(1, 2)
604606
return self.ring_cache.update(input_pos, k_val, v_val)
605607

606-
def create_causal_mask(self, start_pos, seq_len, bool_mask=False) -> torch.Tensor:
608+
def create_causal_mask(
609+
self, start_pos, seq_len, bool_mask=False, **kwargs
610+
) -> torch.Tensor:
607611
return self.ring_cache.create_sliding_window_mask(start_pos, seq_len)
608612

609613

@@ -637,9 +641,10 @@ def forward(
637641
return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
638642

639643

640-
class MLXEncoderSDPA(nn.Module):
641-
"""SDPA for streaming encoder with MLX ring buffer KV cache.
644+
class MLXMaskedSDPA(nn.Module):
645+
"""SDPA with explicit mask for MLX ring buffer KV cache.
642646
647+
Used with MLXRingKVCache for streaming mode (both encoder and decoder).
643648
Uses F.scaled_dot_product_attention with explicit attn_mask from the
644649
ring buffer. KV cache is in BHSD layout, queries are in BSHD.
645650
"""
@@ -662,7 +667,7 @@ def forward(
662667
Args:
663668
input_pos: (seq_len,) position indices (unused, kept for interface).
664669
q: (B, seq_len, n_heads, head_dim) in BSHD layout.
665-
k, v: (B, n_heads, buf_size, head_dim) in BHSD from MLXEncoderRingKVCache.
670+
k, v: (B, n_heads, buf_size, head_dim) in BHSD from MLXRingKVCache.
666671
bsz, seqlen: batch size and query length.
667672
mask: (1, 1, seq_len, buf_size) additive attention mask from ring buffer.
668673
"""
@@ -699,7 +704,7 @@ def __init__(self, config: VoxtralRealtimeConfig):
699704
# Ring buffer KV cache for unlimited streaming.
700705
if self.backend == "mlx":
701706
cache_dtype = self.wq.weight.dtype
702-
self.kv_cache = MLXKVCache(
707+
self.kv_cache = MLXRingKVCache(
703708
config.sliding_window,
704709
self.n_kv_heads,
705710
self.head_dim,
@@ -723,7 +728,16 @@ def __init__(self, config: VoxtralRealtimeConfig):
723728
self.sdpa = SDPA(self.n_heads, self.head_dim)
724729
else:
725730
# Flat KV cache for offline mode (capped at max_seq_len).
726-
if self.backend == "metal":
731+
if self.backend == "mlx":
732+
cache_dtype = self.wq.weight.dtype
733+
self.kv_cache = MLXStaticKVCache(
734+
config.max_seq_len,
735+
self.n_kv_heads,
736+
self.head_dim,
737+
dtype=cache_dtype,
738+
)
739+
self.sdpa = MLXSDPA(self.n_heads, self.head_dim)
740+
elif self.backend == "metal":
727741
self.kv_cache = StaticKVCache(
728742
config.max_seq_len, self.n_kv_heads, self.head_dim
729743
)
@@ -1160,7 +1174,7 @@ def __init__(self, model: VoxtralRealtimeModel, max_enc_len: int = 750):
11601174
cache_dtype = self.layers[0].attention.wq.weight.dtype
11611175
self.kv_caches = nn.ModuleList(
11621176
[
1163-
MLXEncoderRingKVCache(
1177+
MLXRingKVCache(
11641178
max_enc_len,
11651179
config.enc_n_heads,
11661180
config.enc_head_dim,
@@ -1169,7 +1183,7 @@ def __init__(self, model: VoxtralRealtimeModel, max_enc_len: int = 750):
11691183
for _ in range(config.enc_n_layers)
11701184
]
11711185
)
1172-
self.sdpa = MLXEncoderSDPA(config.enc_n_heads, config.enc_head_dim)
1186+
self.sdpa = MLXMaskedSDPA(config.enc_n_heads, config.enc_head_dim)
11731187
elif config.backend == "metal":
11741188
self.kv_caches = nn.ModuleList(
11751189
[

0 commit comments

Comments
 (0)