Skip to content

Commit 638edaa

Browse files
Update
[ghstack-poisoned]
1 parent 5306c5a commit 638edaa

1 file changed

Lines changed: 177 additions & 3 deletions

File tree

examples/models/qwen3_5_moe/export.py

Lines changed: 177 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,77 @@ def _prepare_and_quantize_mlx(model, config, args):
6868
pack_all_switch_linears(model)
6969

7070

71+
def _prepare_and_quantize_metal(model, config, args):
72+
"""Metal: apply source transforms, quantize experts + non-expert layers."""
73+
import executorch.backends.apple.metal.ops.gather_qmv # noqa: F401
74+
import executorch.backends.apple.metal.ops.gated_delta_rule # noqa: F401
75+
from executorch.examples.models.qwen3_5_moe.metal_source_transformations import (
76+
metal_source_transformations,
77+
quantize_experts_metal,
78+
)
79+
80+
# Quantize expert weights to Metal-compatible INT4 format
81+
if args.qlinear:
82+
quantize_experts_metal(model, config, args.qlinear_group_size)
83+
84+
# Untie lm_head/embedding for independent quantization
85+
if model.lm_head.weight.data_ptr() == model.embed_tokens.weight.data_ptr():
86+
model.lm_head.weight = nn.Parameter(model.embed_tokens.weight.clone())
87+
88+
# Quantize non-expert layers with fpa4w (Metal-compatible, no CUDA needed).
89+
# Custom filter skips shared_expert_gate (N=1) which violates fpa4w's
90+
# N%4==0 constraint during prefill (M>1).
91+
if args.qlinear:
92+
from torchao.quantization.quant_api import quantize_
93+
94+
import torchao.experimental.ops.mps # noqa: F401
95+
from torchao.experimental.quant_api import UIntxWeightOnlyConfig
96+
97+
fpa4w_config = UIntxWeightOnlyConfig(
98+
group_size=args.qlinear_group_size,
99+
bitwidth=4,
100+
uintx_choose_qparams_algorithm="hqq",
101+
)
102+
103+
def _fpa4w_filter(mod, fqn):
104+
if not isinstance(mod, nn.Linear):
105+
return False
106+
n, k = mod.weight.shape
107+
if k % args.qlinear_group_size != 0:
108+
return False
109+
if n < 4:
110+
return False
111+
return True
112+
113+
for i, layer in enumerate(model.layers):
114+
layer.to(dtype=torch.bfloat16)
115+
quantize_(layer, fpa4w_config, filter_fn=_fpa4w_filter)
116+
print(f" Quantized layer {i + 1}/{config.num_hidden_layers} (fpa4w)", end="\r")
117+
print()
118+
119+
# Quantize lm_head
120+
print("Quantizing lm_head (fpa4w)...")
121+
from executorch.extension.llm.export.quantize import quantize_model_
122+
123+
model.lm_head.to(dtype=torch.bfloat16)
124+
wrapper = nn.ModuleDict({"lm_head": model.lm_head})
125+
quantize_model_(wrapper, qlinear_config="fpa4w", qlinear_group_size=args.qlinear_group_size)
126+
model.lm_head = wrapper.lm_head
127+
128+
# Quantize embedding
129+
if args.qembedding:
130+
from executorch.extension.llm.export.quantize import quantize_model_
131+
132+
print(f"Quantizing embeddings ({args.qembedding})...")
133+
model.embed_tokens.to(dtype=torch.bfloat16)
134+
quantize_model_(model, qembedding_config=args.qembedding)
135+
136+
model.norm.to(dtype=torch.bfloat16)
137+
138+
_materialize_buffers(model, config)
139+
metal_source_transformations(model, config=config)
140+
141+
71142
def load_and_quantize(args):
72143
"""Load model from checkpoint, optionally quantize.
73144
@@ -146,6 +217,11 @@ def load_and_quantize(args):
146217
)
147218
_prepare_and_quantize_mlx(model, config, args)
148219

220+
elif backend == "metal":
221+
if args.prequantized:
222+
return load_prequantized_model(args.prequantized, args.max_seq_len)
223+
_prepare_and_quantize_metal(model, config, args)
224+
149225
elif backend == "cuda":
150226
if args.prequantized:
151227
return load_prequantized_model(args.prequantized, args.max_seq_len)
@@ -497,6 +573,8 @@ def export_and_lower(model, config, args):
497573

498574
if backend == "mlx":
499575
_export_mlx(model, config, args)
576+
elif backend == "metal":
577+
_export_metal(model, config, args)
500578
else:
501579
_export_cuda(model, config, args)
502580

@@ -581,6 +659,98 @@ def _export_mlx(model, config, args):
581659
print("Done!")
582660

583661

662+
def _export_metal(model, config, args):
663+
"""Export model to .pte via torch.export + Metal backend."""
664+
import torch._inductor.config as inductor_config
665+
666+
from executorch.backends.apple.metal.metal_backend import MetalBackend
667+
from executorch.backends.apple.metal.metal_partitioner import MetalPartitioner
668+
from executorch.exir import (
669+
EdgeCompileConfig,
670+
ExecutorchBackendConfig,
671+
to_edge_transform_and_lower,
672+
)
673+
from executorch.exir.passes import MemoryPlanningPass
674+
from torch.export import Dim, export
675+
676+
inductor_config.coordinate_descent_tuning = False
677+
inductor_config.aot_inductor.compile_wrapper_opt_level = "O0"
678+
679+
# --- Decode method (T=1, static shape) ---
680+
print("Exporting decode method...")
681+
decode_tokens = torch.tensor([[0]], dtype=torch.long)
682+
decode_pos = torch.tensor([0], dtype=torch.long)
683+
with torch.no_grad():
684+
decode_ep = export(model, (decode_tokens, decode_pos), strict=True)
685+
print("Decode export successful!")
686+
687+
# --- Prefill method (T>=2, dynamic shape) ---
688+
print("Exporting prefill method...")
689+
prefill_tokens = torch.tensor([[0, 1]], dtype=torch.long)
690+
prefill_pos = torch.tensor([0, 1], dtype=torch.long)
691+
seq_dim = Dim("seq_len", min=2, max=config.max_seq_len - 1)
692+
prefill_dynamic_shapes = ({1: seq_dim}, {0: seq_dim})
693+
with torch.no_grad():
694+
prefill_ep = export(
695+
model, (prefill_tokens, prefill_pos),
696+
dynamic_shapes=prefill_dynamic_shapes, strict=True,
697+
)
698+
print("Prefill export successful!")
699+
700+
# Lower with Metal backend
701+
print("Lowering to ExecuTorch with Metal...")
702+
metadata = {
703+
"get_max_seq_len": config.max_seq_len,
704+
"get_vocab_size": config.vocab_size,
705+
"get_n_layers": config.num_hidden_layers,
706+
"use_kv_cache": True,
707+
"use_sdpa_with_kv_cache": False,
708+
"enable_dynamic_shape": True,
709+
}
710+
et_prog = to_edge_transform_and_lower(
711+
{"decode": decode_ep, "prefill": prefill_ep},
712+
partitioner={
713+
"decode": [
714+
MetalPartitioner(
715+
[MetalBackend.generate_method_name_compile_spec("decode")]
716+
)
717+
],
718+
"prefill": [
719+
MetalPartitioner(
720+
[MetalBackend.generate_method_name_compile_spec("prefill")]
721+
)
722+
],
723+
},
724+
compile_config=EdgeCompileConfig(
725+
_check_ir_validity=False,
726+
_skip_dim_order=True,
727+
),
728+
constant_methods=metadata,
729+
)
730+
et_program = et_prog.to_executorch(
731+
config=ExecutorchBackendConfig(
732+
extract_delegate_segments=True,
733+
do_quant_fusion_and_const_prop=True,
734+
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
735+
),
736+
)
737+
738+
# Save .pte
739+
os.makedirs(args.output_dir, exist_ok=True)
740+
pte_path = os.path.join(args.output_dir, "model.pte")
741+
print(f"Saving to {pte_path}...")
742+
with open(pte_path, "wb") as f:
743+
et_program.write_to_file(f)
744+
size_mb = os.path.getsize(pte_path) / (1024 * 1024)
745+
print(f"Saved {size_mb:.1f} MB")
746+
747+
if et_program._tensor_data:
748+
et_program.write_tensor_data_to_file(args.output_dir)
749+
print(f"Saved tensor data to {args.output_dir}/")
750+
751+
print("Done!")
752+
753+
584754
def _export_cuda(model, config, args):
585755
"""Export model to .pte via torch.export + CUDA backend.
586756
@@ -710,7 +880,7 @@ def _export_cuda(model, config, args):
710880

711881
def main():
712882
parser = argparse.ArgumentParser(
713-
description="Export Qwen3.5 MoE to ExecuTorch (CUDA or MLX)"
883+
description="Export Qwen3.5 MoE to ExecuTorch"
714884
)
715885
parser.add_argument(
716886
"--model-dir",
@@ -729,8 +899,8 @@ def main():
729899
parser.add_argument(
730900
"--backend",
731901
default="cuda",
732-
choices=["cuda", "mlx"],
733-
help="Backend for export: cuda (default) or mlx.",
902+
choices=["cuda", "mlx", "metal"],
903+
help="Backend for export: cuda (default), mlx, or metal.",
734904
)
735905
parser.add_argument(
736906
"--qlinear",
@@ -805,6 +975,10 @@ def main():
805975
if args.turboquant:
806976
parser.error("--turboquant is not supported with --backend mlx")
807977

978+
if args.backend == "metal":
979+
if args.turboquant:
980+
parser.error("--turboquant is not supported with --backend metal")
981+
808982
model, config = load_and_quantize(args)
809983

810984
if args.backend == "cuda":

0 commit comments

Comments
 (0)