Skip to content

Commit 799bf5a

Browse files
Qwen 3.5 MoE: Add --backend metal export path (#18880)
Adds Metal backend support to export.py via --backend metal flag: - _prepare_and_quantize_metal: applies source transforms, quantizes experts to MLX affine INT4, quantizes non-expert layers with fpa4w (skips shared_expert_gate with N<4 for prefill compatibility) - _export_metal: exports decode + prefill methods via MetalBackend/ MetalPartitioner CUDA and MLX paths are unchanged.
1 parent 9600f63 commit 799bf5a

2 files changed

Lines changed: 144 additions & 7 deletions

File tree

examples/models/qwen3_5_moe/export.py

Lines changed: 141 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,34 @@ def _prepare_and_quantize_mlx(model, config, args):
6868
pack_all_switch_linears(model)
6969

7070

71+
def _prepare_and_quantize_metal(model, config, args):
72+
"""Metal: apply source transforms, quantize experts + non-expert layers."""
73+
import executorch.backends.apple.metal.ops.gated_delta_rule # noqa: F401
74+
import executorch.backends.apple.metal.ops.gather_qmv # noqa: F401
75+
from executorch.examples.models.qwen3_5_moe.metal_source_transformations import (
76+
metal_source_transformations,
77+
quantize_experts_metal,
78+
)
79+
80+
# Quantize expert weights to Metal-compatible INT4 format
81+
if args.qlinear:
82+
quantize_experts_metal(model, config, args.qlinear_group_size)
83+
84+
if args.qlinear:
85+
from executorch.extension.llm.export.quantize import quantize_model_
86+
87+
# skip_incompatible_shapes skips shared_expert_gate (N=1, N%4!=0)
88+
quantize_model_(
89+
model,
90+
qlinear_config=args.qlinear,
91+
qlinear_group_size=args.qlinear_group_size,
92+
skip_incompatible_shapes=True,
93+
)
94+
95+
_materialize_buffers(model, config)
96+
metal_source_transformations(model, config=config)
97+
98+
7199
def load_and_quantize(args): # noqa: C901
72100
"""Load model from checkpoint, optionally quantize.
73101
@@ -152,6 +180,11 @@ def load_and_quantize(args): # noqa: C901
152180
)
153181
_prepare_and_quantize_mlx(model, config, args)
154182

183+
elif backend == "metal":
184+
if args.prequantized:
185+
raise ValueError("Metal backend does not support --prequantized.")
186+
_prepare_and_quantize_metal(model, config, args)
187+
155188
elif backend == "cuda":
156189
if args.prequantized:
157190
return load_prequantized_model(
@@ -516,6 +549,8 @@ def export_and_lower(model, config, args):
516549

517550
if backend == "mlx":
518551
_export_mlx(model, config, args)
552+
elif backend == "metal":
553+
_export_metal(model, config, args)
519554
else:
520555
_export_cuda(model, config, args)
521556

@@ -600,6 +635,100 @@ def _export_mlx(model, config, args):
600635
print("Done!")
601636

602637

638+
def _export_metal(model, config, args):
639+
"""Export model to .pte via torch.export + Metal backend."""
640+
import torch._inductor.config as inductor_config
641+
642+
from executorch.backends.apple.metal.metal_backend import MetalBackend
643+
from executorch.backends.apple.metal.metal_partitioner import MetalPartitioner
644+
from executorch.exir import (
645+
EdgeCompileConfig,
646+
ExecutorchBackendConfig,
647+
to_edge_transform_and_lower,
648+
)
649+
from executorch.exir.passes import MemoryPlanningPass
650+
from torch.export import Dim, export
651+
652+
inductor_config.coordinate_descent_tuning = False
653+
inductor_config.aot_inductor.compile_wrapper_opt_level = "O0"
654+
655+
# --- Decode method (T=1, static shape) ---
656+
print("Exporting decode method...")
657+
decode_tokens = torch.tensor([[0]], dtype=torch.long)
658+
decode_pos = torch.tensor([0], dtype=torch.long)
659+
with torch.no_grad():
660+
decode_ep = export(model, (decode_tokens, decode_pos), strict=True)
661+
print("Decode export successful!")
662+
663+
# --- Prefill method (T>=2, dynamic shape) ---
664+
print("Exporting prefill method...")
665+
prefill_tokens = torch.tensor([[0, 1]], dtype=torch.long)
666+
prefill_pos = torch.tensor([0, 1], dtype=torch.long)
667+
seq_dim = Dim("seq_len", min=2, max=config.max_seq_len - 1)
668+
prefill_dynamic_shapes = ({1: seq_dim}, {0: seq_dim})
669+
with torch.no_grad():
670+
prefill_ep = export(
671+
model,
672+
(prefill_tokens, prefill_pos),
673+
dynamic_shapes=prefill_dynamic_shapes,
674+
strict=True,
675+
)
676+
print("Prefill export successful!")
677+
678+
# Lower with Metal backend
679+
print("Lowering to ExecuTorch with Metal...")
680+
metadata = {
681+
"get_max_seq_len": config.max_seq_len,
682+
"get_vocab_size": config.vocab_size,
683+
"get_n_layers": config.num_hidden_layers,
684+
"use_kv_cache": True,
685+
"use_sdpa_with_kv_cache": False,
686+
"enable_dynamic_shape": True,
687+
}
688+
et_prog = to_edge_transform_and_lower(
689+
{"decode": decode_ep, "prefill": prefill_ep},
690+
partitioner={
691+
"decode": [
692+
MetalPartitioner(
693+
[MetalBackend.generate_method_name_compile_spec("decode")]
694+
)
695+
],
696+
"prefill": [
697+
MetalPartitioner(
698+
[MetalBackend.generate_method_name_compile_spec("prefill")]
699+
)
700+
],
701+
},
702+
compile_config=EdgeCompileConfig(
703+
_check_ir_validity=False,
704+
_skip_dim_order=True,
705+
),
706+
constant_methods=metadata,
707+
)
708+
et_program = et_prog.to_executorch(
709+
config=ExecutorchBackendConfig(
710+
extract_delegate_segments=True,
711+
do_quant_fusion_and_const_prop=True,
712+
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
713+
),
714+
)
715+
716+
# Save .pte
717+
os.makedirs(args.output_dir, exist_ok=True)
718+
pte_path = os.path.join(args.output_dir, "model.pte")
719+
print(f"Saving to {pte_path}...")
720+
with open(pte_path, "wb") as f:
721+
et_program.write_to_file(f)
722+
size_mb = os.path.getsize(pte_path) / (1024 * 1024)
723+
print(f"Saved {size_mb:.1f} MB")
724+
725+
if et_program._tensor_data:
726+
et_program.write_tensor_data_to_file(args.output_dir)
727+
print(f"Saved tensor data to {args.output_dir}/")
728+
729+
print("Done!")
730+
731+
603732
def _export_cuda(model, config, args):
604733
"""Export model to .pte via torch.export + CUDA backend.
605734
@@ -739,10 +868,8 @@ def _export_cuda(model, config, args):
739868
# ---------------------------------------------------------------------------
740869

741870

742-
def main():
743-
parser = argparse.ArgumentParser(
744-
description="Export Qwen3.5 MoE to ExecuTorch (CUDA or MLX)"
745-
)
871+
def main(): # noqa: C901
872+
parser = argparse.ArgumentParser(description="Export Qwen3.5 MoE to ExecuTorch")
746873
parser.add_argument(
747874
"--model-dir",
748875
default=None,
@@ -760,13 +887,13 @@ def main():
760887
parser.add_argument(
761888
"--backend",
762889
default="cuda",
763-
choices=["cuda", "mlx"],
764-
help="Backend for export: cuda (default) or mlx.",
890+
choices=["cuda", "mlx", "metal"],
891+
help="Backend for export: cuda (default), mlx, or metal.",
765892
)
766893
parser.add_argument(
767894
"--qlinear",
768895
default=None,
769-
choices=["4w", "8w", "8da4w", "8da8w"],
896+
choices=["4w", "8w", "8da4w", "8da8w", "fpa4w"],
770897
help="Quantize linear layers.",
771898
)
772899
parser.add_argument(
@@ -841,6 +968,13 @@ def main():
841968
if args.turboquant:
842969
parser.error("--turboquant is not supported with --backend mlx")
843970

971+
if args.backend == "metal":
972+
if args.turboquant:
973+
parser.error("--turboquant is not supported with --backend metal")
974+
975+
if args.qlinear == "fpa4w" and args.backend != "metal":
976+
parser.error("--qlinear=fpa4w can only be used with --backend=metal")
977+
844978
model, config = load_and_quantize(args)
845979

846980
if args.backend == "cuda":

extension/llm/export/quantize.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,9 @@ def _check_shape_compatible(m, fqn, config_name, group_size, skip_incompatible_s
111111
shape = m.weight.shape
112112
if config_name == "nvfp4":
113113
compatible = shape[-2] % group_size == 0 and shape[-1] % group_size == 0
114+
elif config_name == "fpa4w":
115+
# MPS UIntx kernel requires N % 4 == 0 when M > 1 (e.g. prefill)
116+
compatible = shape[-1] % group_size == 0 and shape[-2] % 4 == 0
114117
elif group_size != 0:
115118
compatible = shape[-1] % group_size == 0
116119
else:

0 commit comments

Comments
 (0)