Skip to content

Commit 2f0cdab

Browse files
committed
perf(flydsl): MXFP4 fused-MoE stage2 optimization for EP4 prefill
Optimize the FlyDSL stage2 (down-projection) kernel for the production fp4xfp4 EP4 DeepSeek prefill shape (tile_m=64, tile_n=128, tile_k=256, M=49152) on MI355X. ## Summary of changes aiter/ops/flydsl/moe_kernels.py - Add new wrapper params (waves_per_eu, use_async_copy, cu_num_mul) and plumb them through compile_flydsl_moe_stage2 -> compile_mixed_moe_gemm2. - Register a production fp4xfp4 stage2 variant `flydsl_moe2_afp4_wfp4_bf16_t64x128x256_atomic_persist_async_w4_cumul3` with waves_per_eu=4, use_async_copy=True, cu_num_mul=3. - Drop the defensive `out.fill_(0)` added in PR #2863: when accumulate=True the standard fused_moe path already zeros moe_buf via `moe_buf_set_zero_kernel_2d` inside `moe_sorting_*_fwd`. The extra fill is a ~token_num*model_dim HBM write (~130 us per call at MI355X HBM bw on this shape). Callers using flydsl_moe_stage2 directly with accumulate=True remain responsible for zeroing `out` (documented). aiter/ops/flydsl/kernels/mixed_moe_gemm_2stage.py - compile_mixed_moe_gemm2 signature: add waves_per_eu, use_async_copy, cu_num_mul. Drop unused legacy b_nt, xcd_swizzle on this path. - cu_num_mul: multiplies persistent-mode grid_y (CU count) by an integer factor; cu_num_mul>1 launches more persistent CTAs that each cover fewer M tiles, increasing in-flight parallelism. cu_num_mul=3 is the EP4-prefill sweet spot; cu_num_mul=4 regresses ~2.4% on the same shape. - use_async_copy: enable raw_ptr_buffer_load_lds for the X (activation) tile in the prologue so the X DMA overlaps with B/scale VMEM. Gated to the production tile/dtype combo so other configs are unaffected. - Asymmetric b_lo/b_hi split (b_lo loads 3/4 of K, b_hi loads 1/4): the smaller b_hi burst frees VMEM issue slots right before s_setprio(1) and the K-loop; the extra ku of b_lo is issued in the prologue where it overlaps with the X DMA. - Stage sorted topk weights into LDS once per M-tile and prefetch via vec4 LDS reads in the cshuffle epilogue, eliminating per-mi VMEM dwordx4 loads from the MFMA-heavy hot loop. - Emit AMDGPU `disable_xdl_arb_stall` + `s_setprio(1)` around the K-loop on gfx950 so the high-prio K-loop drains MFMAs back-to-back. - Issue scales BEFORE B-VMEM in the steady-state K-loop so the extract-time stalls in the first MFMA chunk are reduced. - Defer the lds_tid prologue (sorted_idx/sorted_w VMEM dwords + LDS stores) to AFTER the long-latency X DMA + B/scale loads, hiding the lds_tid scalar load latency behind the dominant VMEM phase. - Merge the lds_tid + lds_tw prologue store guards into a single scf.IfOp so MLIR can co-schedule the two VMEM buffer_loads. - For the buffer-atomic path: fast-path the cshuffle row token-index load (move the AND + multiply + validity-mask into the lds_tid preload site where it overlaps with B/scale/X VMEM loads). - Balanced persistent tile distribution (`tD4balPersist`) and balanced LDS-write hint accounting on the async path (`dswr_async`). aiter/fused_moe.py - `_flydsl_stage2_wrapper`: forward waves_per_eu, use_async_copy, cu_num_mul from the parsed kernel-name config to flydsl_moe_stage2. ## Performance Benchmark harness: /home/mingzliu/sgl_opt/ep4_mxfp4_opt/coopt_loop/scripts/bench.sh Workload: /home/mingzliu/sgl_opt/test_fused_moe_ep4_mxfp4.py (M=49152, model_dim=7168, inter_dim=2048, 64 local experts, topk=8, QuantType.per_1x32, fp4x2 weights & activations, bf16 out) Metric: 5-run median latency, 3 reps, MI355X (smci355-ccs-aus-n06-21). baseline (kernel=_atomic_persist): 3438.8 us (3-rep avg) this PR (kernel=_atomic_persist_async_w4_cumul3): 3173.6 us (3-rep avg) delta: -265 us (-7.71%)
1 parent 446ed17 commit 2f0cdab

3 files changed

Lines changed: 798 additions & 424 deletions

File tree

aiter/fused_moe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -795,6 +795,9 @@ def _flydsl_stage2_wrapper(
795795
a2_scale=a2_scale,
796796
sorted_weights=sorted_weights,
797797
sort_block_m=parsed.get("sort_block_m", 0),
798+
waves_per_eu=parsed.get("waves_per_eu", None),
799+
use_async_copy=parsed.get("use_async_copy", False),
800+
cu_num_mul=parsed.get("cu_num_mul", 1),
798801
b_nt=parsed.get("b_nt", 0),
799802
persist=parsed.get("persist", None),
800803
bias=bias2,

0 commit comments

Comments
 (0)