@@ -68,6 +68,34 @@ def _prepare_and_quantize_mlx(model, config, args):
6868 pack_all_switch_linears (model )
6969
7070
71+ def _prepare_and_quantize_metal (model , config , args ):
72+ """Metal: apply source transforms, quantize experts + non-expert layers."""
73+ import executorch .backends .apple .metal .ops .gated_delta_rule # noqa: F401
74+ import executorch .backends .apple .metal .ops .gather_qmv # noqa: F401
75+ from executorch .examples .models .qwen3_5_moe .metal_source_transformations import (
76+ metal_source_transformations ,
77+ quantize_experts_metal ,
78+ )
79+
80+ # Quantize expert weights to Metal-compatible INT4 format
81+ if args .qlinear :
82+ quantize_experts_metal (model , config , args .qlinear_group_size )
83+
84+ if args .qlinear :
85+ from executorch .extension .llm .export .quantize import quantize_model_
86+
87+ # skip_incompatible_shapes skips shared_expert_gate (N=1, N%4!=0)
88+ quantize_model_ (
89+ model ,
90+ qlinear_config = args .qlinear ,
91+ qlinear_group_size = args .qlinear_group_size ,
92+ skip_incompatible_shapes = True ,
93+ )
94+
95+ _materialize_buffers (model , config )
96+ metal_source_transformations (model , config = config )
97+
98+
7199def load_and_quantize (args ): # noqa: C901
72100 """Load model from checkpoint, optionally quantize.
73101
@@ -152,6 +180,11 @@ def load_and_quantize(args): # noqa: C901
152180 )
153181 _prepare_and_quantize_mlx (model , config , args )
154182
183+ elif backend == "metal" :
184+ if args .prequantized :
185+ raise ValueError ("Metal backend does not support --prequantized." )
186+ _prepare_and_quantize_metal (model , config , args )
187+
155188 elif backend == "cuda" :
156189 if args .prequantized :
157190 return load_prequantized_model (
@@ -516,6 +549,8 @@ def export_and_lower(model, config, args):
516549
517550 if backend == "mlx" :
518551 _export_mlx (model , config , args )
552+ elif backend == "metal" :
553+ _export_metal (model , config , args )
519554 else :
520555 _export_cuda (model , config , args )
521556
@@ -600,6 +635,100 @@ def _export_mlx(model, config, args):
600635 print ("Done!" )
601636
602637
638+ def _export_metal (model , config , args ):
639+ """Export model to .pte via torch.export + Metal backend."""
640+ import torch ._inductor .config as inductor_config
641+
642+ from executorch .backends .apple .metal .metal_backend import MetalBackend
643+ from executorch .backends .apple .metal .metal_partitioner import MetalPartitioner
644+ from executorch .exir import (
645+ EdgeCompileConfig ,
646+ ExecutorchBackendConfig ,
647+ to_edge_transform_and_lower ,
648+ )
649+ from executorch .exir .passes import MemoryPlanningPass
650+ from torch .export import Dim , export
651+
652+ inductor_config .coordinate_descent_tuning = False
653+ inductor_config .aot_inductor .compile_wrapper_opt_level = "O0"
654+
655+ # --- Decode method (T=1, static shape) ---
656+ print ("Exporting decode method..." )
657+ decode_tokens = torch .tensor ([[0 ]], dtype = torch .long )
658+ decode_pos = torch .tensor ([0 ], dtype = torch .long )
659+ with torch .no_grad ():
660+ decode_ep = export (model , (decode_tokens , decode_pos ), strict = True )
661+ print ("Decode export successful!" )
662+
663+ # --- Prefill method (T>=2, dynamic shape) ---
664+ print ("Exporting prefill method..." )
665+ prefill_tokens = torch .tensor ([[0 , 1 ]], dtype = torch .long )
666+ prefill_pos = torch .tensor ([0 , 1 ], dtype = torch .long )
667+ seq_dim = Dim ("seq_len" , min = 2 , max = config .max_seq_len - 1 )
668+ prefill_dynamic_shapes = ({1 : seq_dim }, {0 : seq_dim })
669+ with torch .no_grad ():
670+ prefill_ep = export (
671+ model ,
672+ (prefill_tokens , prefill_pos ),
673+ dynamic_shapes = prefill_dynamic_shapes ,
674+ strict = True ,
675+ )
676+ print ("Prefill export successful!" )
677+
678+ # Lower with Metal backend
679+ print ("Lowering to ExecuTorch with Metal..." )
680+ metadata = {
681+ "get_max_seq_len" : config .max_seq_len ,
682+ "get_vocab_size" : config .vocab_size ,
683+ "get_n_layers" : config .num_hidden_layers ,
684+ "use_kv_cache" : True ,
685+ "use_sdpa_with_kv_cache" : False ,
686+ "enable_dynamic_shape" : True ,
687+ }
688+ et_prog = to_edge_transform_and_lower (
689+ {"decode" : decode_ep , "prefill" : prefill_ep },
690+ partitioner = {
691+ "decode" : [
692+ MetalPartitioner (
693+ [MetalBackend .generate_method_name_compile_spec ("decode" )]
694+ )
695+ ],
696+ "prefill" : [
697+ MetalPartitioner (
698+ [MetalBackend .generate_method_name_compile_spec ("prefill" )]
699+ )
700+ ],
701+ },
702+ compile_config = EdgeCompileConfig (
703+ _check_ir_validity = False ,
704+ _skip_dim_order = True ,
705+ ),
706+ constant_methods = metadata ,
707+ )
708+ et_program = et_prog .to_executorch (
709+ config = ExecutorchBackendConfig (
710+ extract_delegate_segments = True ,
711+ do_quant_fusion_and_const_prop = True ,
712+ memory_planning_pass = MemoryPlanningPass (alloc_graph_input = False ),
713+ ),
714+ )
715+
716+ # Save .pte
717+ os .makedirs (args .output_dir , exist_ok = True )
718+ pte_path = os .path .join (args .output_dir , "model.pte" )
719+ print (f"Saving to { pte_path } ..." )
720+ with open (pte_path , "wb" ) as f :
721+ et_program .write_to_file (f )
722+ size_mb = os .path .getsize (pte_path ) / (1024 * 1024 )
723+ print (f"Saved { size_mb :.1f} MB" )
724+
725+ if et_program ._tensor_data :
726+ et_program .write_tensor_data_to_file (args .output_dir )
727+ print (f"Saved tensor data to { args .output_dir } /" )
728+
729+ print ("Done!" )
730+
731+
603732def _export_cuda (model , config , args ):
604733 """Export model to .pte via torch.export + CUDA backend.
605734
@@ -739,10 +868,8 @@ def _export_cuda(model, config, args):
739868# ---------------------------------------------------------------------------
740869
741870
742- def main ():
743- parser = argparse .ArgumentParser (
744- description = "Export Qwen3.5 MoE to ExecuTorch (CUDA or MLX)"
745- )
871+ def main (): # noqa: C901
872+ parser = argparse .ArgumentParser (description = "Export Qwen3.5 MoE to ExecuTorch" )
746873 parser .add_argument (
747874 "--model-dir" ,
748875 default = None ,
@@ -760,13 +887,13 @@ def main():
760887 parser .add_argument (
761888 "--backend" ,
762889 default = "cuda" ,
763- choices = ["cuda" , "mlx" ],
764- help = "Backend for export: cuda (default) or mlx ." ,
890+ choices = ["cuda" , "mlx" , "metal" ],
891+ help = "Backend for export: cuda (default), mlx, or metal ." ,
765892 )
766893 parser .add_argument (
767894 "--qlinear" ,
768895 default = None ,
769- choices = ["4w" , "8w" , "8da4w" , "8da8w" ],
896+ choices = ["4w" , "8w" , "8da4w" , "8da8w" , "fpa4w" ],
770897 help = "Quantize linear layers." ,
771898 )
772899 parser .add_argument (
@@ -841,6 +968,13 @@ def main():
841968 if args .turboquant :
842969 parser .error ("--turboquant is not supported with --backend mlx" )
843970
971+ if args .backend == "metal" :
972+ if args .turboquant :
973+ parser .error ("--turboquant is not supported with --backend metal" )
974+
975+ if args .qlinear == "fpa4w" and args .backend != "metal" :
976+ parser .error ("--qlinear=fpa4w can only be used with --backend=metal" )
977+
844978 model , config = load_and_quantize (args )
845979
846980 if args .backend == "cuda" :
0 commit comments