@@ -68,6 +68,77 @@ def _prepare_and_quantize_mlx(model, config, args):
6868 pack_all_switch_linears (model )
6969
7070
71+ def _prepare_and_quantize_metal (model , config , args ):
72+ """Metal: apply source transforms, quantize experts + non-expert layers."""
73+ import executorch .backends .apple .metal .ops .gather_qmv # noqa: F401
74+ import executorch .backends .apple .metal .ops .gated_delta_rule # noqa: F401
75+ from executorch .examples .models .qwen3_5_moe .metal_source_transformations import (
76+ metal_source_transformations ,
77+ quantize_experts_metal ,
78+ )
79+
80+ # Quantize expert weights to Metal-compatible INT4 format
81+ if args .qlinear :
82+ quantize_experts_metal (model , config , args .qlinear_group_size )
83+
84+ # Untie lm_head/embedding for independent quantization
85+ if model .lm_head .weight .data_ptr () == model .embed_tokens .weight .data_ptr ():
86+ model .lm_head .weight = nn .Parameter (model .embed_tokens .weight .clone ())
87+
88+ # Quantize non-expert layers with fpa4w (Metal-compatible, no CUDA needed).
89+ # Custom filter skips shared_expert_gate (N=1) which violates fpa4w's
90+ # N%4==0 constraint during prefill (M>1).
91+ if args .qlinear :
92+ from torchao .quantization .quant_api import quantize_
93+
94+ import torchao .experimental .ops .mps # noqa: F401
95+ from torchao .experimental .quant_api import UIntxWeightOnlyConfig
96+
97+ fpa4w_config = UIntxWeightOnlyConfig (
98+ group_size = args .qlinear_group_size ,
99+ bitwidth = 4 ,
100+ uintx_choose_qparams_algorithm = "hqq" ,
101+ )
102+
103+ def _fpa4w_filter (mod , fqn ):
104+ if not isinstance (mod , nn .Linear ):
105+ return False
106+ n , k = mod .weight .shape
107+ if k % args .qlinear_group_size != 0 :
108+ return False
109+ if n < 4 :
110+ return False
111+ return True
112+
113+ for i , layer in enumerate (model .layers ):
114+ layer .to (dtype = torch .bfloat16 )
115+ quantize_ (layer , fpa4w_config , filter_fn = _fpa4w_filter )
116+ print (f" Quantized layer { i + 1 } /{ config .num_hidden_layers } (fpa4w)" , end = "\r " )
117+ print ()
118+
119+ # Quantize lm_head
120+ print ("Quantizing lm_head (fpa4w)..." )
121+ from executorch .extension .llm .export .quantize import quantize_model_
122+
123+ model .lm_head .to (dtype = torch .bfloat16 )
124+ wrapper = nn .ModuleDict ({"lm_head" : model .lm_head })
125+ quantize_model_ (wrapper , qlinear_config = "fpa4w" , qlinear_group_size = args .qlinear_group_size )
126+ model .lm_head = wrapper .lm_head
127+
128+ # Quantize embedding
129+ if args .qembedding :
130+ from executorch .extension .llm .export .quantize import quantize_model_
131+
132+ print (f"Quantizing embeddings ({ args .qembedding } )..." )
133+ model .embed_tokens .to (dtype = torch .bfloat16 )
134+ quantize_model_ (model , qembedding_config = args .qembedding )
135+
136+ model .norm .to (dtype = torch .bfloat16 )
137+
138+ _materialize_buffers (model , config )
139+ metal_source_transformations (model , config = config )
140+
141+
71142def load_and_quantize (args ):
72143 """Load model from checkpoint, optionally quantize.
73144
@@ -146,6 +217,11 @@ def load_and_quantize(args):
146217 )
147218 _prepare_and_quantize_mlx (model , config , args )
148219
220+ elif backend == "metal" :
221+ if args .prequantized :
222+ return load_prequantized_model (args .prequantized , args .max_seq_len )
223+ _prepare_and_quantize_metal (model , config , args )
224+
149225 elif backend == "cuda" :
150226 if args .prequantized :
151227 return load_prequantized_model (args .prequantized , args .max_seq_len )
@@ -497,6 +573,8 @@ def export_and_lower(model, config, args):
497573
498574 if backend == "mlx" :
499575 _export_mlx (model , config , args )
576+ elif backend == "metal" :
577+ _export_metal (model , config , args )
500578 else :
501579 _export_cuda (model , config , args )
502580
@@ -581,6 +659,98 @@ def _export_mlx(model, config, args):
581659 print ("Done!" )
582660
583661
662+ def _export_metal (model , config , args ):
663+ """Export model to .pte via torch.export + Metal backend."""
664+ import torch ._inductor .config as inductor_config
665+
666+ from executorch .backends .apple .metal .metal_backend import MetalBackend
667+ from executorch .backends .apple .metal .metal_partitioner import MetalPartitioner
668+ from executorch .exir import (
669+ EdgeCompileConfig ,
670+ ExecutorchBackendConfig ,
671+ to_edge_transform_and_lower ,
672+ )
673+ from executorch .exir .passes import MemoryPlanningPass
674+ from torch .export import Dim , export
675+
676+ inductor_config .coordinate_descent_tuning = False
677+ inductor_config .aot_inductor .compile_wrapper_opt_level = "O0"
678+
679+ # --- Decode method (T=1, static shape) ---
680+ print ("Exporting decode method..." )
681+ decode_tokens = torch .tensor ([[0 ]], dtype = torch .long )
682+ decode_pos = torch .tensor ([0 ], dtype = torch .long )
683+ with torch .no_grad ():
684+ decode_ep = export (model , (decode_tokens , decode_pos ), strict = True )
685+ print ("Decode export successful!" )
686+
687+ # --- Prefill method (T>=2, dynamic shape) ---
688+ print ("Exporting prefill method..." )
689+ prefill_tokens = torch .tensor ([[0 , 1 ]], dtype = torch .long )
690+ prefill_pos = torch .tensor ([0 , 1 ], dtype = torch .long )
691+ seq_dim = Dim ("seq_len" , min = 2 , max = config .max_seq_len - 1 )
692+ prefill_dynamic_shapes = ({1 : seq_dim }, {0 : seq_dim })
693+ with torch .no_grad ():
694+ prefill_ep = export (
695+ model , (prefill_tokens , prefill_pos ),
696+ dynamic_shapes = prefill_dynamic_shapes , strict = True ,
697+ )
698+ print ("Prefill export successful!" )
699+
700+ # Lower with Metal backend
701+ print ("Lowering to ExecuTorch with Metal..." )
702+ metadata = {
703+ "get_max_seq_len" : config .max_seq_len ,
704+ "get_vocab_size" : config .vocab_size ,
705+ "get_n_layers" : config .num_hidden_layers ,
706+ "use_kv_cache" : True ,
707+ "use_sdpa_with_kv_cache" : False ,
708+ "enable_dynamic_shape" : True ,
709+ }
710+ et_prog = to_edge_transform_and_lower (
711+ {"decode" : decode_ep , "prefill" : prefill_ep },
712+ partitioner = {
713+ "decode" : [
714+ MetalPartitioner (
715+ [MetalBackend .generate_method_name_compile_spec ("decode" )]
716+ )
717+ ],
718+ "prefill" : [
719+ MetalPartitioner (
720+ [MetalBackend .generate_method_name_compile_spec ("prefill" )]
721+ )
722+ ],
723+ },
724+ compile_config = EdgeCompileConfig (
725+ _check_ir_validity = False ,
726+ _skip_dim_order = True ,
727+ ),
728+ constant_methods = metadata ,
729+ )
730+ et_program = et_prog .to_executorch (
731+ config = ExecutorchBackendConfig (
732+ extract_delegate_segments = True ,
733+ do_quant_fusion_and_const_prop = True ,
734+ memory_planning_pass = MemoryPlanningPass (alloc_graph_input = False ),
735+ ),
736+ )
737+
738+ # Save .pte
739+ os .makedirs (args .output_dir , exist_ok = True )
740+ pte_path = os .path .join (args .output_dir , "model.pte" )
741+ print (f"Saving to { pte_path } ..." )
742+ with open (pte_path , "wb" ) as f :
743+ et_program .write_to_file (f )
744+ size_mb = os .path .getsize (pte_path ) / (1024 * 1024 )
745+ print (f"Saved { size_mb :.1f} MB" )
746+
747+ if et_program ._tensor_data :
748+ et_program .write_tensor_data_to_file (args .output_dir )
749+ print (f"Saved tensor data to { args .output_dir } /" )
750+
751+ print ("Done!" )
752+
753+
584754def _export_cuda (model , config , args ):
585755 """Export model to .pte via torch.export + CUDA backend.
586756
@@ -710,7 +880,7 @@ def _export_cuda(model, config, args):
710880
711881def main ():
712882 parser = argparse .ArgumentParser (
713- description = "Export Qwen3.5 MoE to ExecuTorch (CUDA or MLX) "
883+ description = "Export Qwen3.5 MoE to ExecuTorch"
714884 )
715885 parser .add_argument (
716886 "--model-dir" ,
@@ -729,8 +899,8 @@ def main():
729899 parser .add_argument (
730900 "--backend" ,
731901 default = "cuda" ,
732- choices = ["cuda" , "mlx" ],
733- help = "Backend for export: cuda (default) or mlx ." ,
902+ choices = ["cuda" , "mlx" , "metal" ],
903+ help = "Backend for export: cuda (default), mlx, or metal ." ,
734904 )
735905 parser .add_argument (
736906 "--qlinear" ,
@@ -805,6 +975,10 @@ def main():
805975 if args .turboquant :
806976 parser .error ("--turboquant is not supported with --backend mlx" )
807977
978+ if args .backend == "metal" :
979+ if args .turboquant :
980+ parser .error ("--turboquant is not supported with --backend metal" )
981+
808982 model , config = load_and_quantize (args )
809983
810984 if args .backend == "cuda" :
0 commit comments