Skip to content

Commit fdd9ea0

Browse files
committed
up
1 parent 3bb3be2 commit fdd9ea0

32 files changed

Lines changed: 3909 additions & 5151 deletions

backends/apple/mlx/README.md

Lines changed: 403 additions & 0 deletions
Large diffs are not rendered by default.

backends/apple/mlx/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
"""MLX backend for ExecuTorch - executes models on Apple Silicon using MLX."""
1010

11-
# Import custom_ops module to register custom ATen ops before anything else
11+
# Import custom_ops module to register custom ATen ops (rope, etc.)
1212
from executorch.backends.apple.mlx import custom_ops as _custom_ops # noqa: F401
1313
from executorch.backends.apple.mlx.partitioner import MLXPartitioner
1414

backends/apple/mlx/custom_ops.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
can execute efficiently but may not have direct PyTorch equivalents.
1515
1616
The ops are registered using torch.library and include:
17-
- rms_norm: RMSNorm normalization
1817
- rope: Rotary Position Embedding (single tensor)
1918
"""
2019

@@ -24,37 +23,6 @@
2423
from torch import Tensor
2524

2625

27-
# =============================================================================
28-
# rms_norm: RMSNorm normalization
29-
# =============================================================================
30-
31-
32-
@torch.library.custom_op("mlx::rms_norm", mutates_args=())
33-
def rms_norm(x: Tensor, weight: Tensor, eps: float = 1e-5) -> Tensor:
34-
"""
35-
RMSNorm normalization.
36-
37-
Args:
38-
x: Input tensor of shape (..., hidden_dim)
39-
weight: Weight tensor of shape (hidden_dim,)
40-
eps: Small constant for numerical stability
41-
42-
Returns:
43-
Normalized tensor of the same shape as x
44-
"""
45-
x_f = x.to(torch.float32)
46-
var = x_f.pow(2).mean(dim=-1, keepdim=True)
47-
y = x_f * torch.rsqrt(var + eps)
48-
y = y.to(x.dtype)
49-
return y * weight.to(x.dtype)
50-
51-
52-
@torch.library.register_fake("mlx::rms_norm")
53-
def rms_norm_fake(x: Tensor, weight: Tensor, eps: float = 1e-5) -> Tensor:
54-
"""Fake implementation for tracing."""
55-
return x.new_empty(x.shape)
56-
57-
5826
# =============================================================================
5927
# rope: Rotary Position Embedding (single tensor)
6028
# =============================================================================

backends/apple/mlx/examples/llama/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ This example demonstrates how to export and run Llama models using the MLX deleg
77
- **Export**: Convert HuggingFace Llama models to ExecutorCh format with MLX delegate
88
- **Quantization**: Optional INT4/INT8 weight quantization via TorchAO
99
- **KV Cache**: Efficient KV cache implementation for autoregressive generation
10-
- **Custom Ops**: Uses `mlx::rms_norm` and `mlx::apply_rope` for optimal MLX execution
10+
- **Custom Ops**: Uses `mlx::apply_rope` for optimal MLX execution
1111
- **Pybindings**: Run inference using ExecutorCh Python bindings
1212

1313
## Requirements
@@ -89,7 +89,7 @@ python -m executorch.backends.apple.mlx.examples.llama.run_llama \
8989

9090
The example uses a custom model wrapper (`LlamaWithFunctionalKV`) that:
9191

92-
1. **Replaces RMSNorm** with `torch.ops.mlx.rms_norm` - a custom op that maps directly to MLX's efficient RMSNorm implementation
92+
1. **Replaces RMSNorm** with `torch.nn.functional.rms_norm` - which maps to MLX's efficient RMSNorm implementation via the aten.rms_norm handler
9393

9494
2. **Replaces Attention** with `KVCacheAttention` which:
9595
- Uses `torch.ops.mlx.apply_rope` for rotary position embeddings

backends/apple/mlx/examples/llama/export_llama.py

Lines changed: 63 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
if TYPE_CHECKING:
3434
from transformers import AutoModelForCausalLM
3535

36-
# Import custom MLX ops for rms_norm and apply_rope
36+
# Import MLX ops to register handlers
3737
import executorch.backends.apple.mlx.ops # noqa: F401
3838
import torch
3939
import torch.nn as nn
@@ -51,12 +51,18 @@
5151

5252

5353
# =============================================================================
54-
# Custom RMSNorm using MLX op
54+
# Custom RMSNorm using aten op
5555
# =============================================================================
5656

5757

5858
class CustomRMSNorm(nn.Module):
59-
"""RMSNorm using the custom mlx::rms_norm op for efficient MLX execution."""
59+
"""RMSNorm using torch.nn.functional.rms_norm for efficient MLX execution.
60+
61+
This replaces the HuggingFace LlamaRMSNorm (which uses manual variance + rsqrt
62+
computation) with the aten rms_norm op. The MLX backend has a handler that maps
63+
aten.rms_norm directly to MLX's fast::rms_norm, so this gives us fused execution
64+
without needing a custom op.
65+
"""
6066

6167
def __init__(self, base_rms: nn.Module):
6268
super().__init__()
@@ -66,7 +72,7 @@ def __init__(self, base_rms: nn.Module):
6672
)
6773

6874
def forward(self, x: torch.Tensor) -> torch.Tensor:
69-
return torch.ops.mlx.rms_norm(x, self.weight, self.eps)
75+
return F.rms_norm(x, (self.weight.shape[0],), self.weight, self.eps)
7076

7177

7278
# =============================================================================
@@ -290,7 +296,7 @@ def forward(self, hidden_states: torch.Tensor, pos_int: int) -> torch.Tensor:
290296
class LlamaWithFunctionalKV(nn.Module):
291297
"""
292298
Wrapper around HuggingFace Llama that:
293-
1. Replaces RMSNorm with custom mlx::rms_norm op
299+
1. Replaces RMSNorm with aten rms_norm (mapped to MLX's fast::rms_norm)
294300
2. Replaces attention with KVCacheAttention (using mlx::apply_rope)
295301
3. Provides a trace-friendly forward that takes (token_ids, input_pos)
296302
"""
@@ -299,7 +305,7 @@ def __init__(
299305
self,
300306
base: "AutoModelForCausalLM",
301307
time_axis: int = 1,
302-
max_seq_len: int = 4096,
308+
max_seq_len: int = 1024,
303309
rope_base: float = 500000.0,
304310
):
305311
super().__init__()
@@ -383,19 +389,21 @@ def forward(self, token_ids: torch.Tensor, input_pos: torch.Tensor) -> torch.Ten
383389
def export_llama_to_mlx(
384390
model_id: str,
385391
output_path: str,
386-
quantize: Optional[str] = None,
387-
max_seq_len: int = 4096,
392+
max_seq_len: int = 1024,
388393
dtype: str = "fp32",
394+
quantize_linear: Optional[str] = None,
395+
quantize_embeddings: Optional[str] = None,
389396
) -> None:
390397
"""
391398
Export a Llama model to MLX delegate.
392399
393400
Args:
394401
model_id: HuggingFace model ID
395402
output_path: Path to save the .pte file
396-
quantize: Quantization method ("int4", "int8", or None)
397403
max_seq_len: Maximum sequence length for KV cache
398404
dtype: Model dtype ("fp32", "fp16", "bf16")
405+
quantize_linear: Quantization method for linear layers ("int4", "int8", or None)
406+
quantize_embeddings: Quantization method for embedding layers ("int4", "int8", or None)
399407
"""
400408
from transformers import AutoModelForCausalLM, AutoTokenizer
401409

@@ -416,47 +424,51 @@ def export_llama_to_mlx(
416424
model.eval()
417425

418426
# Apply quantization if requested
419-
if quantize:
420-
logger.info(f"Applying {quantize} quantization...")
427+
if quantize_linear or quantize_embeddings:
428+
logger.info("Applying quantization with TorchAO...")
421429
try:
422430
from torchao.quantization.granularity import PerGroup
423431
from torchao.quantization.quant_api import IntxWeightOnlyConfig, quantize_
424432

425-
if quantize == "int4":
426-
# Quantize embeddings with group size 32
427-
quantize_(
428-
model,
429-
IntxWeightOnlyConfig(
430-
weight_dtype=torch.int4, granularity=PerGroup(32)
431-
),
432-
lambda m, fqn: isinstance(m, torch.nn.Embedding),
433+
# Quantize embedding layers
434+
if quantize_embeddings:
435+
embed_dtype = (
436+
torch.int4 if quantize_embeddings == "int4" else torch.int8
433437
)
434-
# Quantize linear layers with group size 64
435-
quantize_(
436-
model,
437-
IntxWeightOnlyConfig(
438-
weight_dtype=torch.int4, granularity=PerGroup(64)
439-
),
438+
embed_group_size = 32 if quantize_embeddings == "int4" else 128
439+
logger.info(
440+
f"Quantizing embedding layers with {quantize_embeddings} (group size {embed_group_size})..."
440441
)
441-
elif quantize == "int8":
442442
quantize_(
443443
model,
444444
IntxWeightOnlyConfig(
445-
weight_dtype=torch.int8, granularity=PerGroup(32)
445+
weight_dtype=embed_dtype,
446+
granularity=PerGroup(embed_group_size),
446447
),
447-
lambda m, fqn: isinstance(m, torch.nn.Embedding),
448+
filter_fn=lambda m, fqn: isinstance(m, torch.nn.Embedding),
449+
)
450+
451+
# Quantize linear layers
452+
if quantize_linear:
453+
linear_dtype = torch.int4 if quantize_linear == "int4" else torch.int8
454+
linear_group_size = 32 if quantize_linear == "int4" else 128
455+
logger.info(
456+
f"Quantizing linear layers with {quantize_linear} (group size {linear_group_size})..."
448457
)
449458
quantize_(
450459
model,
451460
IntxWeightOnlyConfig(
452-
weight_dtype=torch.int8, granularity=PerGroup(64)
461+
weight_dtype=linear_dtype,
462+
granularity=PerGroup(linear_group_size),
453463
),
464+
filter_fn=lambda m, fqn: isinstance(m, torch.nn.Linear),
454465
)
455-
else:
456-
logger.warning(f"Unknown quantization method: {quantize}")
457466

458467
# Tie lm_head weights to embedding after quantization
459-
model.model.lm_head.weight = model.model.model.embed_tokens.weight
468+
if quantize_embeddings:
469+
model.model.lm_head.weight = model.model.model.embed_tokens.weight
470+
471+
logger.info("Applied quantization successfully")
460472
except ImportError:
461473
logger.error("TorchAO not installed. Run: pip install torchao")
462474
raise
@@ -484,11 +496,8 @@ def export_llama_to_mlx(
484496
import executorch.exir as exir
485497
from executorch.backends.apple.mlx import MLXPartitioner
486498
from executorch.exir import EdgeCompileConfig
487-
from executorch.exir.backend.backend_details import CompileSpec
488499
from executorch.exir.capture._config import ExecutorchBackendConfig
489500

490-
compile_specs = [CompileSpec("use_fp16", bytes([False]))]
491-
492501
# Allow repeat_interleave and sdpa ops - they will be handled by MLX backend
493502
edge_config = EdgeCompileConfig(
494503
_core_aten_ops_exception_list=[
@@ -499,7 +508,7 @@ def export_llama_to_mlx(
499508

500509
edge_program = exir.to_edge_transform_and_lower(
501510
ep,
502-
partitioner=[MLXPartitioner(compile_specs=compile_specs)],
511+
partitioner=[MLXPartitioner()],
503512
compile_config=edge_config,
504513
)
505514

@@ -538,35 +547,43 @@ def main():
538547
default="llama_mlx.pte",
539548
help="Output .pte file path",
540549
)
541-
parser.add_argument(
542-
"--quantize",
543-
type=str,
544-
choices=["int4", "int8"],
545-
default=None,
546-
help="Quantization method",
547-
)
548550
parser.add_argument(
549551
"--max-seq-len",
550552
type=int,
551-
default=4096,
553+
default=1024,
552554
help="Maximum sequence length for KV cache",
553555
)
554556
parser.add_argument(
555557
"--dtype",
556558
type=str,
557559
choices=["fp32", "fp16", "bf16"],
558-
default="fp32",
560+
default="bf16",
559561
help="Model dtype (fp32, fp16, bf16)",
560562
)
563+
parser.add_argument(
564+
"--quantize-linear",
565+
type=str,
566+
choices=["int4", "int8"],
567+
default=None,
568+
help="Quantization method for linear layers",
569+
)
570+
parser.add_argument(
571+
"--quantize-embeddings",
572+
type=str,
573+
choices=["int4", "int8"],
574+
default=None,
575+
help="Quantization method for embedding layers",
576+
)
561577

562578
args = parser.parse_args()
563579

564580
export_llama_to_mlx(
565581
model_id=args.model_id,
566582
output_path=args.output,
567-
quantize=args.quantize,
568583
max_seq_len=args.max_seq_len,
569584
dtype=args.dtype,
585+
quantize_linear=args.quantize_linear,
586+
quantize_embeddings=args.quantize_embeddings,
570587
)
571588

572589

0 commit comments

Comments
 (0)