3333if TYPE_CHECKING :
3434 from transformers import AutoModelForCausalLM
3535
36- # Import custom MLX ops for rms_norm and apply_rope
36+ # Import MLX ops to register handlers
3737import executorch .backends .apple .mlx .ops # noqa: F401
3838import torch
3939import torch .nn as nn
5151
5252
5353# =============================================================================
54- # Custom RMSNorm using MLX op
54+ # Custom RMSNorm using aten op
5555# =============================================================================
5656
5757
5858class CustomRMSNorm (nn .Module ):
59- """RMSNorm using the custom mlx::rms_norm op for efficient MLX execution."""
59+ """RMSNorm using torch.nn.functional.rms_norm for efficient MLX execution.
60+
61+ This replaces the HuggingFace LlamaRMSNorm (which uses manual variance + rsqrt
62+ computation) with the aten rms_norm op. The MLX backend has a handler that maps
63+ aten.rms_norm directly to MLX's fast::rms_norm, so this gives us fused execution
64+ without needing a custom op.
65+ """
6066
6167 def __init__ (self , base_rms : nn .Module ):
6268 super ().__init__ ()
@@ -66,7 +72,7 @@ def __init__(self, base_rms: nn.Module):
6672 )
6773
6874 def forward (self , x : torch .Tensor ) -> torch .Tensor :
69- return torch . ops . mlx . rms_norm (x , self .weight , self .eps )
75+ return F . rms_norm (x , ( self . weight . shape [ 0 ],) , self .weight , self .eps )
7076
7177
7278# =============================================================================
@@ -290,7 +296,7 @@ def forward(self, hidden_states: torch.Tensor, pos_int: int) -> torch.Tensor:
290296class LlamaWithFunctionalKV (nn .Module ):
291297 """
292298 Wrapper around HuggingFace Llama that:
293- 1. Replaces RMSNorm with custom mlx ::rms_norm op
299+ 1. Replaces RMSNorm with aten rms_norm (mapped to MLX's fast ::rms_norm)
294300 2. Replaces attention with KVCacheAttention (using mlx::apply_rope)
295301 3. Provides a trace-friendly forward that takes (token_ids, input_pos)
296302 """
@@ -299,7 +305,7 @@ def __init__(
299305 self ,
300306 base : "AutoModelForCausalLM" ,
301307 time_axis : int = 1 ,
302- max_seq_len : int = 4096 ,
308+ max_seq_len : int = 1024 ,
303309 rope_base : float = 500000.0 ,
304310 ):
305311 super ().__init__ ()
@@ -383,19 +389,21 @@ def forward(self, token_ids: torch.Tensor, input_pos: torch.Tensor) -> torch.Ten
383389def export_llama_to_mlx (
384390 model_id : str ,
385391 output_path : str ,
386- quantize : Optional [str ] = None ,
387- max_seq_len : int = 4096 ,
392+ max_seq_len : int = 1024 ,
388393 dtype : str = "fp32" ,
394+ quantize_linear : Optional [str ] = None ,
395+ quantize_embeddings : Optional [str ] = None ,
389396) -> None :
390397 """
391398 Export a Llama model to MLX delegate.
392399
393400 Args:
394401 model_id: HuggingFace model ID
395402 output_path: Path to save the .pte file
396- quantize: Quantization method ("int4", "int8", or None)
397403 max_seq_len: Maximum sequence length for KV cache
398404 dtype: Model dtype ("fp32", "fp16", "bf16")
405+ quantize_linear: Quantization method for linear layers ("int4", "int8", or None)
406+ quantize_embeddings: Quantization method for embedding layers ("int4", "int8", or None)
399407 """
400408 from transformers import AutoModelForCausalLM , AutoTokenizer
401409
@@ -416,47 +424,51 @@ def export_llama_to_mlx(
416424 model .eval ()
417425
418426 # Apply quantization if requested
419- if quantize :
420- logger .info (f "Applying { quantize } quantization..." )
427+ if quantize_linear or quantize_embeddings :
428+ logger .info ("Applying quantization with TorchAO ..." )
421429 try :
422430 from torchao .quantization .granularity import PerGroup
423431 from torchao .quantization .quant_api import IntxWeightOnlyConfig , quantize_
424432
425- if quantize == "int4" :
426- # Quantize embeddings with group size 32
427- quantize_ (
428- model ,
429- IntxWeightOnlyConfig (
430- weight_dtype = torch .int4 , granularity = PerGroup (32 )
431- ),
432- lambda m , fqn : isinstance (m , torch .nn .Embedding ),
433+ # Quantize embedding layers
434+ if quantize_embeddings :
435+ embed_dtype = (
436+ torch .int4 if quantize_embeddings == "int4" else torch .int8
433437 )
434- # Quantize linear layers with group size 64
435- quantize_ (
436- model ,
437- IntxWeightOnlyConfig (
438- weight_dtype = torch .int4 , granularity = PerGroup (64 )
439- ),
438+ embed_group_size = 32 if quantize_embeddings == "int4" else 128
439+ logger .info (
440+ f"Quantizing embedding layers with { quantize_embeddings } (group size { embed_group_size } )..."
440441 )
441- elif quantize == "int8" :
442442 quantize_ (
443443 model ,
444444 IntxWeightOnlyConfig (
445- weight_dtype = torch .int8 , granularity = PerGroup (32 )
445+ weight_dtype = embed_dtype ,
446+ granularity = PerGroup (embed_group_size ),
446447 ),
447- lambda m , fqn : isinstance (m , torch .nn .Embedding ),
448+ filter_fn = lambda m , fqn : isinstance (m , torch .nn .Embedding ),
449+ )
450+
451+ # Quantize linear layers
452+ if quantize_linear :
453+ linear_dtype = torch .int4 if quantize_linear == "int4" else torch .int8
454+ linear_group_size = 32 if quantize_linear == "int4" else 128
455+ logger .info (
456+ f"Quantizing linear layers with { quantize_linear } (group size { linear_group_size } )..."
448457 )
449458 quantize_ (
450459 model ,
451460 IntxWeightOnlyConfig (
452- weight_dtype = torch .int8 , granularity = PerGroup (64 )
461+ weight_dtype = linear_dtype ,
462+ granularity = PerGroup (linear_group_size ),
453463 ),
464+ filter_fn = lambda m , fqn : isinstance (m , torch .nn .Linear ),
454465 )
455- else :
456- logger .warning (f"Unknown quantization method: { quantize } " )
457466
458467 # Tie lm_head weights to embedding after quantization
459- model .model .lm_head .weight = model .model .model .embed_tokens .weight
468+ if quantize_embeddings :
469+ model .model .lm_head .weight = model .model .model .embed_tokens .weight
470+
471+ logger .info ("Applied quantization successfully" )
460472 except ImportError :
461473 logger .error ("TorchAO not installed. Run: pip install torchao" )
462474 raise
@@ -484,11 +496,8 @@ def export_llama_to_mlx(
484496 import executorch .exir as exir
485497 from executorch .backends .apple .mlx import MLXPartitioner
486498 from executorch .exir import EdgeCompileConfig
487- from executorch .exir .backend .backend_details import CompileSpec
488499 from executorch .exir .capture ._config import ExecutorchBackendConfig
489500
490- compile_specs = [CompileSpec ("use_fp16" , bytes ([False ]))]
491-
492501 # Allow repeat_interleave and sdpa ops - they will be handled by MLX backend
493502 edge_config = EdgeCompileConfig (
494503 _core_aten_ops_exception_list = [
@@ -499,7 +508,7 @@ def export_llama_to_mlx(
499508
500509 edge_program = exir .to_edge_transform_and_lower (
501510 ep ,
502- partitioner = [MLXPartitioner (compile_specs = compile_specs )],
511+ partitioner = [MLXPartitioner ()],
503512 compile_config = edge_config ,
504513 )
505514
@@ -538,35 +547,43 @@ def main():
538547 default = "llama_mlx.pte" ,
539548 help = "Output .pte file path" ,
540549 )
541- parser .add_argument (
542- "--quantize" ,
543- type = str ,
544- choices = ["int4" , "int8" ],
545- default = None ,
546- help = "Quantization method" ,
547- )
548550 parser .add_argument (
549551 "--max-seq-len" ,
550552 type = int ,
551- default = 4096 ,
553+ default = 1024 ,
552554 help = "Maximum sequence length for KV cache" ,
553555 )
554556 parser .add_argument (
555557 "--dtype" ,
556558 type = str ,
557559 choices = ["fp32" , "fp16" , "bf16" ],
558- default = "fp32 " ,
560+ default = "bf16 " ,
559561 help = "Model dtype (fp32, fp16, bf16)" ,
560562 )
563+ parser .add_argument (
564+ "--quantize-linear" ,
565+ type = str ,
566+ choices = ["int4" , "int8" ],
567+ default = None ,
568+ help = "Quantization method for linear layers" ,
569+ )
570+ parser .add_argument (
571+ "--quantize-embeddings" ,
572+ type = str ,
573+ choices = ["int4" , "int8" ],
574+ default = None ,
575+ help = "Quantization method for embedding layers" ,
576+ )
561577
562578 args = parser .parse_args ()
563579
564580 export_llama_to_mlx (
565581 model_id = args .model_id ,
566582 output_path = args .output ,
567- quantize = args .quantize ,
568583 max_seq_len = args .max_seq_len ,
569584 dtype = args .dtype ,
585+ quantize_linear = args .quantize_linear ,
586+ quantize_embeddings = args .quantize_embeddings ,
570587 )
571588
572589
0 commit comments