Skip to content

Commit b41d055

Browse files
committed
Qwen 3.5 MoE: bypass on-device sampler conditional in non-CUDA export
Qwen35MoE.forward currently routes through an Optional[Tensor] temperature parameter and an if/else that picks between the on-device fused Gumbel-max sampler (CUDA) and raw logits (non-CUDA). The sampling branch is dead code for MLX and Metal exports, since those backends sample on the host. Even though torch.export statically eliminates the branch when temperature defaults to None, the parameter, default value, and unused else-branch leak into the exported program: extra placeholder nodes, different graph hashes, and shifted kernel selection in the lowered MLX/Metal graph. On the tiny test model this slows MLX prefill ~9-37% and decode ~5-19%, and shows up as ~10-25% noise on Metal. Bind model.forward to a minimal (tokens, input_pos) -> logits variant inside _export_mlx and _export_metal before torch.export, so the captured program matches what the backend kernels are tuned for. Eager-mode callers and the CUDA export path are unaffected.
1 parent b32eae7 commit b41d055

1 file changed

Lines changed: 46 additions & 0 deletions

File tree

examples/models/qwen3_5_moe/export.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,44 @@ def export_and_lower(model, config, args):
554554
_export_cuda(model, config, args)
555555

556556

557+
def _strip_sampler_from_forward(model):
558+
"""Replace ``Qwen35MoE.forward`` with a minimal variant that only returns
559+
raw logits, for backends that do not use the on-device fused sampler.
560+
561+
The default ``Qwen35MoE.forward`` accepts an optional sampling temperature
562+
and dispatches between two paths: when a temperature is provided it slices
563+
to the last position, casts to float32, and runs an on-device Gumbel-max
564+
sampler that returns a sampled token id; when no temperature is provided
565+
it returns the full ``[B, T, V]`` logits tensor. The on-device sampler is
566+
only used by the CUDA path (the runner passes a temperature input);
567+
non-CUDA backends sample on the host side after reading logits, so the
568+
sampling branch is dead code at trace time.
569+
570+
The presence of the ``Optional[Tensor]`` parameter, the matching default
571+
value in the exported signature, and the unused else-branch all leak into
572+
the program ``torch.export`` produces -- adding placeholder nodes,
573+
perturbing graph hashes used by downstream pattern matching, and shifting
574+
kernel selection in the lowered graph. Empirically this slows MLX prefill
575+
and decode by 10-30% even though the conditional is statically dead.
576+
577+
This helper rebinds ``model.forward`` to a clean two-argument variant
578+
(tokens, input_pos) -> [B, T, V] logits before tracing, so the program
579+
captured by ``torch.export`` for non-CUDA backends matches the shape the
580+
backend kernels are tuned for. The original ``forward`` is unaffected for
581+
other callers (e.g. eager inference, the CUDA export path).
582+
"""
583+
import types
584+
585+
def _clean_forward(self, tokens, input_pos):
586+
x = self.embed_tokens(tokens)
587+
for layer in self.layers:
588+
x = layer(x, input_pos)
589+
x = self.norm(x)
590+
return self.lm_head(x)
591+
592+
model.forward = types.MethodType(_clean_forward, model)
593+
594+
557595
def _export_mlx(model, config, args):
558596
"""Export model to .pte via torch.export + MLX backend."""
559597
import gc
@@ -568,6 +606,10 @@ def _export_mlx(model, config, args):
568606
from executorch.exir.passes import MemoryPlanningPass
569607
from torch.export import Dim, export
570608

609+
# Use the minimal forward variant for non-CUDA export. See
610+
# _strip_sampler_from_forward docstring for why this matters.
611+
_strip_sampler_from_forward(model)
612+
571613
example_tokens = torch.tensor([[0, 1]], dtype=torch.long)
572614
example_input_pos = torch.tensor([0, 1], dtype=torch.long)
573615
seq_dim = Dim("seq_len", min=1, max=config.max_seq_len - 1)
@@ -651,6 +693,10 @@ def _export_metal(model, config, args):
651693
inductor_config.coordinate_descent_tuning = False
652694
inductor_config.aot_inductor.compile_wrapper_opt_level = "O0"
653695

696+
# Use the minimal forward variant for non-CUDA export. See
697+
# _strip_sampler_from_forward docstring for why this matters.
698+
_strip_sampler_from_forward(model)
699+
654700
# --- Decode method (T=1, static shape) ---
655701
print("Exporting decode method...")
656702
decode_tokens = torch.tensor([[0]], dtype=torch.long)

0 commit comments

Comments
 (0)