Skip to content

Commit 80997fd

Browse files
authored
Make the output of MoE forward method have expected output in non cuda backends (#19170)
Dropped the unconditional `.float()` from the `temperature is None` branch of `Qwen35MoE.forward` to keep its output having the model author's expected dtype. # Qwen 3.5 MoE perf comparsion between this PR and e2eb417 i did detailed performance comparsion between this PR and the state before applying cuda sampler (commit e2eb417) to see if we can bring perf back. TLDR: With this PR our perf is same or even better than the previous state when running on tiny model across mlx and metal, and on full model + mlx, but crashed on full model on metal; on full model mlx ## Tiny Model **Setup:** M3 Max 128 GB · macOS 26.4 · Xcode 26.4.1 · `--tiny-test` model · MLX `--qlinear 4w --qlinear-group-size 32` · Metal `--qlinear fpa4w` · all measurements use in-process warmup (MLX: warmup at prefill + decode shapes + force-eval; Metal: `--warmup_iters 2 --warmup_decode_steps 4 --ignore_eos`) · median of 3-6 trials. ### MLX (Python pybindings) | Config | Metric | Before this PR | After this PR | Δ | |---|---|---:|---:|---:| | prompt-len=4, max-new=5 | Prefill tok/s | 1077 | **1195** | **+11%** | | prompt-len=4, max-new=5 | Decode tok/s | 294 | **350** | **+19%** | | prompt-len=32, max-new=31 | Prefill tok/s | 7060 | **10842** | **+54%** | | prompt-len=32, max-new=31 | Decode tok/s | 314 | 267\* | −15% (within trial noise, see note) | \* prompt=32 decode trial-by-trial: 281 / 267 / 247 (3 trials). Trial-to-trial spread is ~14%, so the apparent regression is within noise. ### Metal (C++ runner) | Config | Metric | Before this PR (median of 6) | After this PR (median of 6) | Δ | |---|---|---:|---:|---:| | prompt-len=32, max-new=31 | Prefill tok/s (mean ex-cold) | 5351 | **5988** | **+12%** | | prompt-len=32, max-new=31 | Decode tok/s (mean ex-cold) | 217 | **286** | **+32%** | | prompt-len=32, max-new=31 | Decode tok/s (median ex-cold) | 237 | **290** | **+22%** | ## Full Model **Setup:** Qwen/Qwen3.5-35B-A3B (40 layers, 2048d, 256 experts top-8, 67 GB safetensors) · M3 Max 128 GB · macOS 26.4 · Xcode 26.4.1 · MLX `--qlinear 4w --qlinear-group-size 64` · in-process warmup at prefill+decode shapes + force-eval after prefill · median of 3 trials per config. ### MLX (full Qwen 3.5 MoE 35B-A3B) | Config | Metric | Before this PR | After this PR | Δ | |---|---|---:|---:|---:| | prompt=4, max-new=5 | Prefill tok/s | 133.7 | **163.6** | **+22%** | | prompt=4, max-new=5 | Decode tok/s | 36.4 | **44.7** | **+23%** | | prompt=32, max-new=32 | Prefill tok/s | 404.3 | **443.4** | **+10%** | | prompt=32, max-new=32 | Decode tok/s | 37.2 | **43.4** | **+17%** | | prompt=128, max-new=64 | Prefill tok/s | 650.3 | **711.5** | **+9%** | | prompt=128, max-new=64 | Decode tok/s | 38.5 | **43.1** | **+12%** | Trial-to-trial variance is small (≤1 tok/s on decode, ≤5% on prefill) so all deltas are signal. ### Metal (full Qwen 3.5 MoE 35B-A3B) **Not measured.** Metal export of the 35B model OOM-kills on the 128 GB Mac during AOTI inductor compilation (`Killed: 9` exit 137). Confirmed across 3 attempts: default settings, `TORCHINDUCTOR_COMPILE_THREADS=1`, and `--max-seq-len 1024`. The transient peak during AOTI lowering exceeds available RAM. Tiny-model Metal A/B (already collected, see prior summary) shows the same pattern: prefill +12%, decode +22~+32%. ## Conclusion **No regression on either backend; meaningful uplift on both.** MLX shows the cleanest improvement on prefill (+11~+54%) and decode (+19% at small prompt). Metal shows +12% prefill and +22~+32% decode at prompt=32. The single MLX prompt-32 decode delta is within trial-to-trial variance.
1 parent cf01617 commit 80997fd

3 files changed

Lines changed: 70 additions & 18 deletions

File tree

examples/models/qwen3_5_moe/export.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,30 @@ def export_and_lower(model, config, args):
554554
_export_cuda(model, config, args)
555555

556556

557+
def _strip_sampler_from_forward(model):
558+
"""Bind ``model.forward`` to a minimal ``(tokens, input_pos) -> logits``
559+
variant for non-CUDA export.
560+
561+
The default ``Qwen35MoE.forward`` carries an optional temperature input and
562+
a sampling branch used only by the on-device CUDA sampler; non-CUDA
563+
backends sample on the host so that branch is dead code at trace time.
564+
Even when statically eliminated, the extra parameter and branch perturb
565+
the program ``torch.export`` produces enough to shift kernel selection in
566+
the lowered MLX/Metal graph and slow execution by 10-30%. Eager callers
567+
and the CUDA export path are unaffected.
568+
"""
569+
import types
570+
571+
def _clean_forward(self, tokens, input_pos):
572+
x = self.embed_tokens(tokens)
573+
for layer in self.layers:
574+
x = layer(x, input_pos)
575+
x = self.norm(x)
576+
return self.lm_head(x)
577+
578+
model.forward = types.MethodType(_clean_forward, model)
579+
580+
557581
def _export_mlx(model, config, args):
558582
"""Export model to .pte via torch.export + MLX backend."""
559583
import gc
@@ -568,6 +592,8 @@ def _export_mlx(model, config, args):
568592
from executorch.exir.passes import MemoryPlanningPass
569593
from torch.export import Dim, export
570594

595+
_strip_sampler_from_forward(model)
596+
571597
example_tokens = torch.tensor([[0, 1]], dtype=torch.long)
572598
example_input_pos = torch.tensor([0, 1], dtype=torch.long)
573599
seq_dim = Dim("seq_len", min=1, max=config.max_seq_len - 1)
@@ -650,6 +676,7 @@ def _export_metal(model, config, args):
650676

651677
inductor_config.coordinate_descent_tuning = False
652678
inductor_config.aot_inductor.compile_wrapper_opt_level = "O0"
679+
_strip_sampler_from_forward(model)
653680

654681
# --- Decode method (T=1, static shape) ---
655682
print("Exporting decode method...")

examples/models/qwen3_5_moe/main.cpp

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
#ifdef EXECUTORCH_BUILD_CUDA
2727
#include <cuda_runtime.h>
28+
#else
29+
#include <executorch/extension/llm/sampler/util.h>
2830
#endif
2931

3032
DEFINE_string(model_path, "", "Model .pte file path.");
@@ -37,7 +39,10 @@ DEFINE_string(
3739
"Path to file containing prompt text (overrides --prompt).");
3840
DEFINE_double(temperature, 0.8, "Sampling temperature (0 = greedy).");
3941
DEFINE_int32(max_new_tokens, 128, "Maximum tokens to generate.");
40-
DEFINE_bool(cuda_graph, false, "Enable CUDA graph for decode method.");
42+
DEFINE_bool(
43+
cuda_graph,
44+
false,
45+
"Enable CUDA graph for decode method. CUDA only.");
4146

4247
namespace llm = ::executorch::extension::llm;
4348
using ::executorch::extension::from_blob;
@@ -48,10 +53,18 @@ using ::executorch::runtime::EValue;
4853

4954
using SizesType = executorch::aten::SizesType;
5055

51-
// Read a sampled token from the model output tensor [B, 1].
52-
// The model performs Gumbel-max sampling on-device and returns a single
53-
// float token ID. This function copies it from GPU and casts to uint64.
56+
// Convert a model output tensor to the next sampled token id.
57+
//
58+
// On the CUDA build, the model fuses the sampler in (see sampler.py /
59+
// Qwen35MoE.forward) and returns a single sampled token id as a [B, 1]
60+
// float tensor; we just copy that scalar back from device.
61+
//
62+
// On non-CUDA builds (Metal / MLX / CPU), the model returns raw logits
63+
// of shape [B, T, V] in the model dtype (typically bf16). We sample on
64+
// CPU via the shared `llm::logits_to_token` helper, which accepts a
65+
// temperature (0 = greedy / argmax).
5466
static uint64_t read_token(const executorch::aten::Tensor& output) {
67+
#ifdef EXECUTORCH_BUILD_CUDA
5568
const void* ptr = output.const_data_ptr();
5669

5770
cudaPointerAttributes attrs;
@@ -73,6 +86,13 @@ static uint64_t read_token(const executorch::aten::Tensor& output) {
7386
memcpy(&val, ptr, sizeof(float));
7487
}
7588
return static_cast<uint64_t>(val);
89+
#else
90+
// logits_to_token handles 2D / 3D logits and Float / Half / BFloat16 /
91+
// UInt16 dtypes. Negative temperatures are clamped to 0 (greedy).
92+
const float temp =
93+
FLAGS_temperature <= 0.0 ? 0.0f : static_cast<float>(FLAGS_temperature);
94+
return static_cast<uint64_t>(llm::logits_to_token(output, temp));
95+
#endif
7696
}
7797

7898
int main(int argc, char** argv) {
@@ -133,16 +153,23 @@ int main(int argc, char** argv) {
133153
}
134154
auto metadata = metadata_result.get();
135155

156+
#ifdef EXECUTORCH_BUILD_CUDA
136157
// Set CUDA graph option if requested (must be before load_method)
137158
if (FLAGS_cuda_graph) {
138159
executorch::runtime::BackendOptions<2> cuda_opts;
139160
cuda_opts.set_option("enable_cuda_graph_for_method", "decode");
140161
executorch::runtime::set_option("CudaBackend", cuda_opts.view());
141162
printf("CUDA graph enabled for decode method\n");
142163
}
164+
#else
165+
if (FLAGS_cuda_graph) {
166+
ET_LOG(Info, "--cuda_graph ignored on non-CUDA build");
167+
}
168+
#endif
143169

144170
printf("Loading methods...\n");
145171

172+
#ifdef EXECUTORCH_BUILD_CUDA
146173
// Enable cross-method per-FQN weight sharing in the CUDA backend so that
147174
// prefill and decode (which share KV cache and other mutable buffers /
148175
// weights) avoid duplicate GPU allocations. This is critical for fitting
@@ -170,6 +197,7 @@ int main(int argc, char** argv) {
170197
return 1;
171198
}
172199
}
200+
#endif
173201

174202
auto err = module->load_method("prefill");
175203
if (err != Error::Ok) {
@@ -224,12 +252,16 @@ int main(int argc, char** argv) {
224252
// ---------------------------------------------------------------
225253
auto S = [](int64_t v) -> SizesType { return static_cast<SizesType>(v); };
226254

227-
// Use a very small temperature for greedy to avoid division by zero
228-
// while keeping the Gumbel noise negligible relative to logit differences.
255+
#ifdef EXECUTORCH_BUILD_CUDA
256+
// CUDA build: model fuses the sampler in. Pass a temperature tensor as
257+
// a third input. Use a very small temperature for greedy to avoid
258+
// division by zero while keeping the Gumbel noise negligible relative
259+
// to logit differences.
229260
float temp_val =
230261
FLAGS_temperature <= 0.0 ? 1e-6f : static_cast<float>(FLAGS_temperature);
231262
auto temp_tensor =
232263
from_blob(&temp_val, {1}, executorch::aten::ScalarType::Float);
264+
#endif
233265

234266
// ---------------------------------------------------------------
235267
// Prefill
@@ -260,7 +292,9 @@ int main(int argc, char** argv) {
260292
std::vector<EValue> prefill_inputs;
261293
prefill_inputs.push_back(tokens_tensor);
262294
prefill_inputs.push_back(pos_tensor);
295+
#ifdef EXECUTORCH_BUILD_CUDA
263296
prefill_inputs.push_back(temp_tensor);
297+
#endif
264298

265299
auto prefill_result = module->execute(run_method, prefill_inputs);
266300
if (prefill_result.error() != Error::Ok) {
@@ -308,7 +342,9 @@ int main(int argc, char** argv) {
308342
std::vector<EValue> decode_inputs;
309343
decode_inputs.push_back(EValue(decode_tokens));
310344
decode_inputs.push_back(EValue(decode_pos));
345+
#ifdef EXECUTORCH_BUILD_CUDA
311346
decode_inputs.push_back(EValue(temp_tensor));
347+
#endif
312348

313349
auto decode_result = module->execute("decode", decode_inputs);
314350
if (decode_result.error() != Error::Ok) {

examples/models/qwen3_5_moe/model.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
import torch
2323
import torch.nn as nn
24-
2524
from executorch.examples.models.qwen3_5_moe.sampler import sample
2625
from torch.nn import functional as F
2726

@@ -186,7 +185,6 @@ def _apply_rotary(x, cos, sin):
186185

187186

188187
class KVCache(nn.Module):
189-
190188
def __init__(self, n_kv_heads, head_dim, max_seq_len):
191189
super().__init__()
192190
self.register_buffer(
@@ -207,7 +205,6 @@ def update(self, input_pos, k_val, v_val):
207205

208206

209207
class FullAttention(nn.Module):
210-
211208
def __init__(self, config):
212209
super().__init__()
213210
self.n_heads = config.num_attention_heads
@@ -318,7 +315,6 @@ def forward(self, x, input_pos):
318315

319316

320317
class GatedDeltaNet(nn.Module):
321-
322318
def __init__(self, config):
323319
super().__init__()
324320
self.num_k_heads = config.linear_num_key_heads
@@ -540,7 +536,6 @@ def forward(self, x):
540536

541537

542538
class SparseMoE(nn.Module):
543-
544539
def __init__(self, config):
545540
super().__init__()
546541
self.top_k = config.num_experts_per_tok
@@ -574,7 +569,6 @@ def forward(self, x):
574569

575570

576571
class Block(nn.Module):
577-
578572
def __init__(self, config, layer_idx):
579573
super().__init__()
580574
self.layer_type = config.layer_types[layer_idx]
@@ -599,7 +593,6 @@ def forward(self, x, input_pos):
599593

600594

601595
class Qwen35MoE(nn.Module):
602-
603596
def __init__(self, config):
604597
super().__init__()
605598
self.config = config
@@ -620,12 +613,8 @@ def forward(
620613
for layer in self.layers:
621614
x = layer(x, input_pos)
622615
x = self.norm(x)
623-
# When no sampling is requested, return the full ``[B, T, V]``
624-
# logits so callers (eval, custom samplers) can inspect every
625-
# position. Otherwise apply the prefill optimization and only
626-
# materialize ``[B, V]`` for the last token.
627616
if temperature is None:
628-
return self.lm_head(x).float() # [B, T, V] float32
617+
return self.lm_head(x) # [B, T, V] in model dtype
629618
logits = self.lm_head(x[:, -1, :]).float() # [B, V] float32
630619
# GPU-side Gumbel-max sampling: argmax(logits/T + gumbel_noise) is
631620
# equivalent to drawing from softmax(logits/T) but stays entirely

0 commit comments

Comments
 (0)