Skip to content

Commit d8454cf

Browse files
committed
Qwen 3.5 MoE: bypass on-device sampler conditional in non-CUDA export
Qwen35MoE.forward currently routes through an Optional[Tensor] temperature parameter and an if/else that picks between the on-device fused Gumbel-max sampler (CUDA) and raw logits (non-CUDA). The sampling branch is dead code for MLX and Metal exports, since those backends sample on the host. Even though torch.export statically eliminates the branch when temperature defaults to None, the parameter, default value, and unused else-branch leak into the exported program: extra placeholder nodes, different graph hashes, and shifted kernel selection in the lowered MLX/Metal graph. On the tiny test model this slows MLX prefill ~9-37% and decode ~5-19%, and shows up as ~10-25% noise on Metal. Bind model.forward to a minimal (tokens, input_pos) -> logits variant inside _export_mlx and _export_metal before torch.export, so the captured program matches what the backend kernels are tuned for. Eager-mode callers and the CUDA export path are unaffected.
1 parent b32eae7 commit d8454cf

1 file changed

Lines changed: 32 additions & 0 deletions

File tree

examples/models/qwen3_5_moe/export.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,30 @@ def export_and_lower(model, config, args):
554554
_export_cuda(model, config, args)
555555

556556

557+
def _strip_sampler_from_forward(model):
558+
"""Bind ``model.forward`` to a minimal ``(tokens, input_pos) -> logits``
559+
variant for non-CUDA export.
560+
561+
The default ``Qwen35MoE.forward`` carries an optional temperature input and
562+
a sampling branch used only by the on-device CUDA sampler; non-CUDA
563+
backends sample on the host so that branch is dead code at trace time.
564+
Even when statically eliminated, the extra parameter and branch perturb
565+
the program ``torch.export`` produces enough to shift kernel selection in
566+
the lowered MLX/Metal graph and slow execution by 10-30%. Eager callers
567+
and the CUDA export path are unaffected.
568+
"""
569+
import types
570+
571+
def _clean_forward(self, tokens, input_pos):
572+
x = self.embed_tokens(tokens)
573+
for layer in self.layers:
574+
x = layer(x, input_pos)
575+
x = self.norm(x)
576+
return self.lm_head(x)
577+
578+
model.forward = types.MethodType(_clean_forward, model)
579+
580+
557581
def _export_mlx(model, config, args):
558582
"""Export model to .pte via torch.export + MLX backend."""
559583
import gc
@@ -568,6 +592,10 @@ def _export_mlx(model, config, args):
568592
from executorch.exir.passes import MemoryPlanningPass
569593
from torch.export import Dim, export
570594

595+
# Use the minimal forward variant for non-CUDA export. See
596+
# _strip_sampler_from_forward docstring for why this matters.
597+
_strip_sampler_from_forward(model)
598+
571599
example_tokens = torch.tensor([[0, 1]], dtype=torch.long)
572600
example_input_pos = torch.tensor([0, 1], dtype=torch.long)
573601
seq_dim = Dim("seq_len", min=1, max=config.max_seq_len - 1)
@@ -651,6 +679,10 @@ def _export_metal(model, config, args):
651679
inductor_config.coordinate_descent_tuning = False
652680
inductor_config.aot_inductor.compile_wrapper_opt_level = "O0"
653681

682+
# Use the minimal forward variant for non-CUDA export. See
683+
# _strip_sampler_from_forward docstring for why this matters.
684+
_strip_sampler_from_forward(model)
685+
654686
# --- Decode method (T=1, static shape) ---
655687
print("Exporting decode method...")
656688
decode_tokens = torch.tensor([[0]], dtype=torch.long)

0 commit comments

Comments
 (0)