Skip to content

Commit 3b7bbfd

Browse files
authored
perf(qwen35moe): fuse router (argsort_top_k) + gate_up swiglu (#472)
Match llama's qwen3moe graph: argsort_top_k lets ggml-cuda's topk-moe fusion collapse softmax->topk->get_rows->norm into ~1 kernel, and ggml_swiglu on the combined gate_up buffer drops the 2 ggml_cont copies per layer (x30 MoE layers). Same selection -> bit-identical output. Env-gated for A/B: DFLASH_NO_MOE_ROUTER_FUSE / DFLASH_NO_MOE_SWIGLU_FUSE. Perf-neutral at all-hot (113.1 vs 113.3 tok/s, noise) — these router/ swiglu ops are <3% each; the residual ~3% decode gap vs llama is the shared MoE GEMV (mul_mat_q, 58% at 16.7% occ) launch-bound floor, not missed fusions. Lands graph-node parity; removes "missed fusion" as a gap explanation.
1 parent 737cd47 commit 3b7bbfd

1 file changed

Lines changed: 31 additions & 11 deletions

File tree

server/src/qwen35moe/qwen35moe_ffn.cpp

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "qwen35_ops.h"
44

5+
#include <cstdlib>
56
#include <cmath>
67

78
namespace dflash::common {
@@ -27,7 +28,16 @@ Qwen35MoeRouterOutputs build_qwen35moe_router(
2728
break;
2829
}
2930

30-
ggml_tensor * selected = ggml_top_k(ctx, probs, n_used);
31+
// ggml_argsort_top_k emits GGML_OP_ARGSORT (+view), which ggml-cuda's
32+
// topk-moe fusion (ggml_cuda_topk_moe_fusion) recognizes and fuses the whole
33+
// softmax->topk->get_rows->norm router into ~1 kernel. ggml_top_k emits
34+
// GGML_OP_TOP_K, which the fusion does NOT match -> 6-7 separate kernels/layer
35+
// x30 MoE layers (the launch-bound decode gap vs llama, which uses argsort_top_k).
36+
// Same top-k selection -> bit-identical. DFLASH_NO_MOE_ROUTER_FUSE=1 = old path.
37+
static const bool router_fuse = (std::getenv("DFLASH_NO_MOE_ROUTER_FUSE") == nullptr);
38+
ggml_tensor * selected = router_fuse
39+
? ggml_argsort_top_k(ctx, probs, n_used)
40+
: ggml_top_k(ctx, probs, n_used);
3141

3242
ggml_tensor * probs_3d = ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens);
3343
ggml_tensor * weights = ggml_get_rows(ctx, probs_3d, selected);
@@ -72,19 +82,29 @@ ggml_tensor * build_qwen35moe_ffn(
7282

7383
ggml_tensor * cur_3d = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens);
7484
ggml_tensor * gu = nullptr;
85+
// Combined gate_up: split + swiglu in ONE op. ggml_swiglu reads gate from
86+
// [0,nc) and up from [nc,2nc) of the same buffer, so the two ggml_cont copies
87+
// that materialised the strided halves are eliminated — 2 extra copy kernels
88+
// per layer x 30 MoE layers that llama's qwen3moe graph never emits.
89+
// DFLASH_NO_MOE_SWIGLU_FUSE=1 restores the view+cont+split path (bit-id gate).
90+
static const bool moe_swiglu_fuse = (std::getenv("DFLASH_NO_MOE_SWIGLU_FUSE") == nullptr);
7591
if (L.ffn_gate_up_exps) {
7692
ggml_tensor * gate_up_e = apply_scale2(
7793
ctx, ggml_mul_mat_id(ctx, L.ffn_gate_up_exps, cur_3d, selected), L.ffn_gate_up_exps_s);
78-
ggml_tensor * gate_e = ggml_view_3d(ctx, gate_up_e,
79-
n_ff_exp, gate_up_e->ne[1], gate_up_e->ne[2],
80-
gate_up_e->nb[1], gate_up_e->nb[2], 0);
81-
ggml_tensor * up_e = ggml_view_3d(ctx, gate_up_e,
82-
n_ff_exp, gate_up_e->ne[1], gate_up_e->ne[2],
83-
gate_up_e->nb[1], gate_up_e->nb[2],
84-
(size_t)n_ff_exp * ggml_element_size(gate_up_e));
85-
gate_e = ggml_cont(ctx, gate_e);
86-
up_e = ggml_cont(ctx, up_e);
87-
gu = ggml_swiglu_split(ctx, gate_e, up_e);
94+
if (moe_swiglu_fuse) {
95+
gu = ggml_swiglu(ctx, gate_up_e); // silu(gate) * up, no views/conts
96+
} else {
97+
ggml_tensor * gate_e = ggml_view_3d(ctx, gate_up_e,
98+
n_ff_exp, gate_up_e->ne[1], gate_up_e->ne[2],
99+
gate_up_e->nb[1], gate_up_e->nb[2], 0);
100+
ggml_tensor * up_e = ggml_view_3d(ctx, gate_up_e,
101+
n_ff_exp, gate_up_e->ne[1], gate_up_e->ne[2],
102+
gate_up_e->nb[1], gate_up_e->nb[2],
103+
(size_t)n_ff_exp * ggml_element_size(gate_up_e));
104+
gate_e = ggml_cont(ctx, gate_e);
105+
up_e = ggml_cont(ctx, up_e);
106+
gu = ggml_swiglu_split(ctx, gate_e, up_e);
107+
}
88108
} else {
89109
ggml_tensor * gate_e = apply_scale2(
90110
ctx, ggml_mul_mat_id(ctx, L.ffn_gate_exps, cur_3d, selected), L.ffn_gate_exps_s);

0 commit comments

Comments
 (0)