Commit 2f0cdab
committed
perf(flydsl): MXFP4 fused-MoE stage2 optimization for EP4 prefill
Optimize the FlyDSL stage2 (down-projection) kernel for the production
fp4xfp4 EP4 DeepSeek prefill shape (tile_m=64, tile_n=128, tile_k=256,
M=49152) on MI355X.
## Summary of changes
aiter/ops/flydsl/moe_kernels.py
- Add new wrapper params (waves_per_eu, use_async_copy, cu_num_mul) and
plumb them through compile_flydsl_moe_stage2 -> compile_mixed_moe_gemm2.
- Register a production fp4xfp4 stage2 variant
`flydsl_moe2_afp4_wfp4_bf16_t64x128x256_atomic_persist_async_w4_cumul3`
with waves_per_eu=4, use_async_copy=True, cu_num_mul=3.
- Drop the defensive `out.fill_(0)` added in PR #2863: when accumulate=True
the standard fused_moe path already zeros moe_buf via
`moe_buf_set_zero_kernel_2d` inside `moe_sorting_*_fwd`. The extra
fill is a ~token_num*model_dim HBM write (~130 us per call at MI355X
HBM bw on this shape). Callers using flydsl_moe_stage2 directly with
accumulate=True remain responsible for zeroing `out` (documented).
aiter/ops/flydsl/kernels/mixed_moe_gemm_2stage.py
- compile_mixed_moe_gemm2 signature: add waves_per_eu, use_async_copy,
cu_num_mul. Drop unused legacy b_nt, xcd_swizzle on this path.
- cu_num_mul: multiplies persistent-mode grid_y (CU count) by an integer
factor; cu_num_mul>1 launches more persistent CTAs that each cover
fewer M tiles, increasing in-flight parallelism. cu_num_mul=3 is the
EP4-prefill sweet spot; cu_num_mul=4 regresses ~2.4% on the same shape.
- use_async_copy: enable raw_ptr_buffer_load_lds for the X (activation)
tile in the prologue so the X DMA overlaps with B/scale VMEM. Gated
to the production tile/dtype combo so other configs are unaffected.
- Asymmetric b_lo/b_hi split (b_lo loads 3/4 of K, b_hi loads 1/4): the
smaller b_hi burst frees VMEM issue slots right before s_setprio(1)
and the K-loop; the extra ku of b_lo is issued in the prologue where
it overlaps with the X DMA.
- Stage sorted topk weights into LDS once per M-tile and prefetch via
vec4 LDS reads in the cshuffle epilogue, eliminating per-mi VMEM
dwordx4 loads from the MFMA-heavy hot loop.
- Emit AMDGPU `disable_xdl_arb_stall` + `s_setprio(1)` around the
K-loop on gfx950 so the high-prio K-loop drains MFMAs back-to-back.
- Issue scales BEFORE B-VMEM in the steady-state K-loop so the
extract-time stalls in the first MFMA chunk are reduced.
- Defer the lds_tid prologue (sorted_idx/sorted_w VMEM dwords + LDS
stores) to AFTER the long-latency X DMA + B/scale loads, hiding the
lds_tid scalar load latency behind the dominant VMEM phase.
- Merge the lds_tid + lds_tw prologue store guards into a single
scf.IfOp so MLIR can co-schedule the two VMEM buffer_loads.
- For the buffer-atomic path: fast-path the cshuffle row token-index
load (move the AND + multiply + validity-mask into the lds_tid
preload site where it overlaps with B/scale/X VMEM loads).
- Balanced persistent tile distribution (`tD4balPersist`) and balanced
LDS-write hint accounting on the async path (`dswr_async`).
aiter/fused_moe.py
- `_flydsl_stage2_wrapper`: forward waves_per_eu, use_async_copy,
cu_num_mul from the parsed kernel-name config to flydsl_moe_stage2.
## Performance
Benchmark harness:
/home/mingzliu/sgl_opt/ep4_mxfp4_opt/coopt_loop/scripts/bench.sh
Workload:
/home/mingzliu/sgl_opt/test_fused_moe_ep4_mxfp4.py
(M=49152, model_dim=7168, inter_dim=2048, 64 local experts, topk=8,
QuantType.per_1x32, fp4x2 weights & activations, bf16 out)
Metric: 5-run median latency, 3 reps, MI355X (smci355-ccs-aus-n06-21).
baseline (kernel=_atomic_persist): 3438.8 us (3-rep avg)
this PR (kernel=_atomic_persist_async_w4_cumul3): 3173.6 us (3-rep avg)
delta: -265 us (-7.71%)1 parent 446ed17 commit 2f0cdab
3 files changed
Lines changed: 798 additions & 424 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
795 | 795 | | |
796 | 796 | | |
797 | 797 | | |
| 798 | + | |
| 799 | + | |
| 800 | + | |
798 | 801 | | |
799 | 802 | | |
800 | 803 | | |
| |||
0 commit comments