Skip to content

Commit ffeeafc

Browse files
committed
up
1 parent ab15e4f commit ffeeafc

23 files changed

Lines changed: 41 additions & 779 deletions

backends/mlx/README.md

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ logging from the partitioner and preprocessor — including ops-to-not-decompose
8585
lists, graph dumps, per-node support decisions, and serialization details:
8686

8787
```bash
88-
ET_MLX_DEBUG=1 python -m executorch.backends.mlx.examples.llm.export_llama ...
88+
ET_MLX_DEBUG=1 python -m executorch.backends.mlx.examples.llm.export_llm_hf ...
8989
```
9090

9191
---
@@ -107,19 +107,29 @@ backends/mlx/
107107
│ ├── MLXExecutor.h # ExecutionState, constant loading, helpers
108108
│ ├── MLXInterpreter.h # Op dispatch loop + per-op exec_* functions
109109
│ └── schema_generated.h # [GENERATED] FlatBuffer C++ bindings (flatc)
110+
├── llm/ # LLM infrastructure (KV cache, attention, etc.)
111+
│ ├── cache.py # KV cache implementations (ET + HF static cache)
112+
│ ├── et_attention.py # ExecuTorch custom SDPA attention
113+
│ ├── hf_attention.py # HuggingFace custom SDPA attention
114+
│ ├── quantization.py # TorchAO quantization helpers
115+
│ └── source_transformation.py # Source transforms for MLX export
110116
├── ops.py # Op handlers (ATen target → MLX IR node)
111117
├── patterns.py # Pattern handlers (multi-node fusions)
118+
├── passes.py # Graph passes (RMSNorm fusion, CSE, etc.)
119+
├── pattern_utils.py # Pattern matching utilities for passes
112120
├── program_builder.py # MLXProgramBuilder + REGISTRY
113121
├── partitioner.py # Decides which ops to delegate to MLX
114122
├── preprocess.py # BackendDetails.preprocess() entry point
115-
├── custom_ops.py # Custom torch ops (rope, etc.)
123+
├── custom_ops.py # Custom torch ops (kv_cache_update, custom_sdpa, rope)
124+
├── _logging.py # Debug logging utilities (ET_MLX_DEBUG)
125+
├── pte_inspector.py # .pte file inspection/debugging tool
116126
├── test/
117127
│ ├── test_ops.py # Op test definitions (models + configs)
118128
│ ├── test_utils.py # OpTestCase base class + helpers
119129
│ ├── op_test_runner.cpp # C++ test runner (loads .pte, runs, compares)
120130
│ └── run_all_tests.py # End-to-end: export → C++ run → compare
121131
└── examples/
122-
├── llm/ # LLM export + run (Llama, etc.)
132+
├── llm/ # LLM export + run via HuggingFace
123133
└── whisper/ # Whisper export + run
124134
```
125135

backends/mlx/custom_ops.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,6 @@
2020
from torch import Tensor
2121

2222

23-
# =============================================================================
24-
# kv_cache_update: Functional KV cache update for BHSD layout
25-
# =============================================================================
26-
27-
2823
@torch.library.custom_op("mlx::kv_cache_update", mutates_args=("cache",))
2924
def kv_cache_update(
3025
cache: Tensor, # [B, H, S_max, D] - mutated in place
@@ -73,7 +68,6 @@ def kv_cache_update(
7368
)
7469
cache[:, :, start_pos:end_pos, :] = new_values
7570

76-
# Return dummy tensor like llama.update_cache does
7771
return torch.empty((1,), dtype=new_values.dtype, device=new_values.device)
7872

7973

@@ -88,11 +82,6 @@ def kv_cache_update_fake(
8882
return torch.empty((1,), dtype=new_values.dtype, device="meta")
8983

9084

91-
# =============================================================================
92-
# custom_sdpa: Scaled Dot-Product Attention with KV cache slicing
93-
# =============================================================================
94-
95-
9685
@torch.library.custom_op("mlx::custom_sdpa", mutates_args=())
9786
def mlx_custom_sdpa(
9887
query: Tensor, # [B, num_heads, seq_len, head_dim] - BHSD
@@ -194,11 +183,6 @@ def mlx_custom_sdpa_fake(
194183
return query.new_empty(query.shape)
195184

196185

197-
# =============================================================================
198-
# rope: Rotary Position Embedding (single tensor)
199-
# =============================================================================
200-
201-
202186
@torch.library.custom_op("mlx::rope", mutates_args=())
203187
def rope(
204188
x: Tensor, # (B, H, T, D)

backends/mlx/examples/whisper/export_whisper.py

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,6 @@
5252
logger = logging.getLogger(__name__)
5353

5454

55-
# =============================================================================
56-
# Whisper Encoder Wrapper
57-
# =============================================================================
58-
59-
6055
class WhisperEncoderExportable(nn.Module):
6156
"""
6257
Wrapper around Whisper's encoder for export.
@@ -72,11 +67,6 @@ def forward(self, input_features: torch.Tensor) -> torch.Tensor:
7267
return self.encoder(input_features=input_features).last_hidden_state
7368

7469

75-
# =============================================================================
76-
# Whisper Decoder Self-Attention with KV Cache
77-
# =============================================================================
78-
79-
8070
class WhisperSelfAttentionWithCache(nn.Module):
8171
"""
8272
Whisper self-attention layer with static KV cache.
@@ -147,11 +137,6 @@ def forward(
147137
return self.out_proj(attn_out)
148138

149139

150-
# =============================================================================
151-
# Whisper Cross-Attention (no cache update - K/V pre-computed)
152-
# =============================================================================
153-
154-
155140
class WhisperCrossAttention(nn.Module):
156141
"""
157142
Whisper cross-attention layer.
@@ -192,11 +177,6 @@ def forward(
192177
return self.out_proj(attn_out)
193178

194179

195-
# =============================================================================
196-
# Whisper Decoder Layer Wrapper
197-
# =============================================================================
198-
199-
200180
class WhisperDecoderLayerWithCache(nn.Module):
201181
"""
202182
Wrapper for a single Whisper decoder layer with KV cache.
@@ -254,11 +234,6 @@ def forward(
254234
return hidden_states
255235

256236

257-
# =============================================================================
258-
# Whisper Decoder Wrapper
259-
# =============================================================================
260-
261-
262237
class WhisperDecoderWithCache(nn.Module):
263238
"""
264239
Whisper decoder wrapper with static KV cache.
@@ -335,11 +310,6 @@ def forward(
335310
return logits
336311

337312

338-
# =============================================================================
339-
# Cross-KV Projection Module
340-
# =============================================================================
341-
342-
343313
class WhisperCrossKVProjection(nn.Module):
344314
"""
345315
Compute cross-attention K/V projections from encoder hidden states.
@@ -393,11 +363,6 @@ def forward(
393363
return tuple(k_list), tuple(v_list)
394364

395365

396-
# =============================================================================
397-
# Export Functions
398-
# =============================================================================
399-
400-
401366
def export_whisper_to_mlx(
402367
model_id: str,
403368
output_dir: str,
@@ -528,9 +493,6 @@ def export_whisper_to_mlx(
528493
logger.error("TorchAO not installed. Run: pip install torchao")
529494
raise
530495

531-
# =========================================================================
532-
# Export Encoder
533-
# =========================================================================
534496
logger.info("Exporting encoder...")
535497

536498
with torch.no_grad():
@@ -541,9 +503,6 @@ def export_whisper_to_mlx(
541503

542504
_save_to_pte(encoder_ep, os.path.join(output_dir, "encoder.pte"), "encoder")
543505

544-
# =========================================================================
545-
# Export Cross-KV Projection
546-
# =========================================================================
547506
logger.info("Exporting cross-KV projection...")
548507

549508
with torch.no_grad():
@@ -561,9 +520,6 @@ def export_whisper_to_mlx(
561520

562521
_save_to_pte(cross_kv_ep, os.path.join(output_dir, "cross_kv.pte"), "cross_kv")
563522

564-
# =========================================================================
565-
# Export Decoder
566-
# =========================================================================
567523
logger.info("Exporting decoder...")
568524

569525
# Example inputs for decoder

backends/mlx/examples/whisper/run_whisper.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,6 @@ def run_whisper_inference( # noqa: C901
115115
)
116116
decoder_forward = decoder_program.load_method("forward")
117117

118-
# =========================================================================
119-
# Step 1: Run encoder
120-
# =========================================================================
121118
logger.info("Running encoder...")
122119
overall_start = time.time()
123120
start_time = time.time()
@@ -129,9 +126,6 @@ def run_whisper_inference( # noqa: C901
129126
logger.info(f"Encoder time: {encoder_time:.3f}s")
130127
logger.info(f"Encoder output shape: {encoder_hidden_states.shape}")
131128

132-
# =========================================================================
133-
# Step 2: Compute cross-attention K/V
134-
# =========================================================================
135129
logger.info("Computing cross-attention K/V...")
136130
start_time = time.time()
137131

@@ -145,9 +139,6 @@ def run_whisper_inference( # noqa: C901
145139
logger.info(f"Cross-KV time: {cross_kv_time:.3f}s")
146140
logger.info(f"Cross K/V: {num_layers} layers, each shape {cross_k_tuple[0].shape}")
147141

148-
# =========================================================================
149-
# Step 3: Setup decoder generation
150-
# =========================================================================
151142
# Get forced decoder IDs for language/task
152143
forced_decoder_ids = processor.get_decoder_prompt_ids(
153144
language=language,
@@ -177,9 +168,6 @@ def run_whisper_inference( # noqa: C901
177168

178169
generated_tokens: List[int] = [sot_id]
179170

180-
# =========================================================================
181-
# Step 4: Token-by-token decoder generation
182-
# =========================================================================
183171
logger.info(f"Generating up to {max_new_tokens} tokens...")
184172
decode_start = time.time()
185173

backends/mlx/llm/cache.py

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def update(
126126
Returns:
127127
Tuple of (k_cache, v_cache) - slices of the FULL cache buffers
128128
"""
129-
# Extract start position as int (SymInt during tracing)
129+
130130
if isinstance(input_pos, torch.Tensor):
131131
start_pos = input_pos[0].item()
132132
seq_len = k_val.size(2)
@@ -136,7 +136,6 @@ def update(
136136
else:
137137
start_pos = input_pos
138138

139-
# Use MLX custom op for cache update (mutates in place)
140139
torch.ops.mlx.kv_cache_update(self.k_cache, k_val, start_pos)
141140
torch.ops.mlx.kv_cache_update(self.v_cache, v_val, start_pos)
142141

@@ -145,11 +144,6 @@ def update(
145144
return self.k_cache[:, :, :, :], self.v_cache[:, :, :, :]
146145

147146

148-
# =============================================================================
149-
# RingBufferKVCache - Sliding Window KV Cache
150-
# =============================================================================
151-
152-
153147
class RingBufferKVCache(nn.Module):
154148
"""
155149
Ring buffer KV cache for sliding window attention.
@@ -235,8 +229,6 @@ def update(
235229
start_pos = input_pos
236230
seq_len = k_val.size(2)
237231

238-
# Use MLX custom op for ring buffer cache update (mutates in place)
239-
# ring_size enables wrapping: write_pos = start_pos % ring_size
240232
torch.ops.mlx.kv_cache_update(
241233
self.k_cache, k_val, start_pos, ring_size=self.buffer_size
242234
)
@@ -287,10 +279,6 @@ def create_sliding_window_mask(self, start_pos: int, seq_len: int) -> torch.Tens
287279
return torch.where(attn_mask, 0.0, float("-inf")).unsqueeze(0).unsqueeze(0)
288280

289281

290-
# =============================================================================
291-
# HFStaticCache - Standalone HuggingFace-compatible Static Cache
292-
# =============================================================================
293-
294282
from transformers.cache_utils import StaticCache
295283

296284

@@ -424,28 +412,15 @@ def update(
424412
return self.kv_cache[layer_idx].update(cache_position, key_states, value_states)
425413

426414
def get_seq_length(self, layer_idx: int = 0) -> int:
427-
"""
428-
Get the current sequence length in the cache.
429-
430-
Note: This is approximate - returns the number of non-zero positions.
431-
432-
Args:
433-
layer_idx: Layer index to check (default: 0)
434-
435-
Returns:
436-
Approximate sequence length
437-
"""
438-
# Check how many positions have been filled by looking for non-zero values
415+
"""Approximate sequence length (counts non-zero cache positions)."""
439416
k_cache = self.kv_cache[layer_idx].k_cache
440417
# Check if any value in the head_dim is non-zero for each position
441418
return (k_cache[0, 0, :, 0] != 0).sum().item()
442419

443420
def get_max_cache_shape(self, layer_idx: int = 0) -> int:
444-
"""Get the maximum cache length."""
445421
return self.max_cache_len
446422

447423
def reset(self):
448-
"""Reset all cache buffers to zero."""
449424
for layer_cache in self.kv_cache:
450425
layer_cache.k_cache.zero_()
451426
layer_cache.v_cache.zero_()

backends/mlx/llm/hf_attention.py

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,7 @@ def sdpa_mask_passthrough(
9999
allow_torch_fix: bool = True,
100100
**kwargs,
101101
) -> Optional[torch.Tensor]:
102-
"""
103-
Returns None — the custom SDPA op handles causal masking internally.
104-
105-
Returning None avoids materializing a mask tensor during export, which
106-
would create a bounded tensor that fails at runtime with longer sequences.
107-
"""
102+
"""Returns None — custom SDPA handles causal masking, avoiding bounded mask tensors."""
108103
return None
109104

110105

@@ -128,11 +123,6 @@ def register_mlx_attention(name: str = "mlx") -> None:
128123
)
129124

130125

131-
# =============================================================================
132-
# Sliding Window Attention (Ring Buffer)
133-
# =============================================================================
134-
135-
136126
def get_mlx_sliding_window_sdpa(exportable_module) -> Callable:
137127
"""
138128
Create a closure-based SDPA function for sliding window attention.
@@ -219,16 +209,7 @@ def _sliding_window_sdpa_forward(
219209
def register_mlx_sliding_window_attention(
220210
exportable_module, name: str = "mlx_sliding_window"
221211
) -> None:
222-
"""
223-
Register MLX sliding window attention with HuggingFace's attention interfaces.
224-
225-
Creates a closure that captures the model reference for lazy mask creation,
226-
following optimum-executorch's get_custom_sdpa_for_ring_kv_cache pattern.
227-
228-
Args:
229-
exportable_module: The model module containing ring buffer caches.
230-
name: Name to register the attention implementation under.
231-
"""
212+
"""Register MLX sliding window attention with HuggingFace's attention interfaces."""
232213
try:
233214
from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS
234215
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS

backends/mlx/llm/quantization.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020

2121
def add_quantization_args(parser: argparse.ArgumentParser) -> None:
22-
"""Add common quantization arguments to an argparse parser."""
2322
parser.add_argument(
2423
"--quantize-linear",
2524
type=str,
@@ -58,7 +57,6 @@ def add_quantization_args(parser: argparse.ArgumentParser) -> None:
5857

5958

6059
def _default_group_size(dtype_str: str) -> int:
61-
"""Return the default group size for a given quantization dtype."""
6260
return 32 if dtype_str == "int4" else 128
6361

6462

backends/mlx/llm/source_transformation.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,18 +30,7 @@ def _replace_modules(
3030
factory: Callable[[nn.Module], nn.Module],
3131
label: str,
3232
) -> nn.Module:
33-
"""
34-
Recursively replace all instances of target_type using factory.
35-
36-
Args:
37-
module: Root module to modify (in place)
38-
target_type: Type to match against children
39-
factory: Callable that takes the original child and returns a replacement
40-
label: Human-readable label for logging (e.g. "RMSNorm → FunctionalRMSNorm")
41-
42-
Returns:
43-
The modified module (same reference, mutated in place)
44-
"""
33+
"""Recursively replace all instances of target_type using factory."""
4534

4635
def _recurse(parent: nn.Module) -> int:
4736
count = 0

0 commit comments

Comments
 (0)