perf(flydsl): MXFP4 fused-MoE stage2 optimization for EP prefill#3117
perf(flydsl): MXFP4 fused-MoE stage2 optimization for EP prefill#3117
Conversation
Optimize the FlyDSL stage2 (down-projection) kernel for the production
fp4xfp4 EP4 DeepSeek prefill shape (tile_m=64, tile_n=128, tile_k=256,
M=49152) on MI355X.
## Summary of changes
aiter/ops/flydsl/moe_kernels.py
- Add new wrapper params (waves_per_eu, use_async_copy, cu_num_mul) and
plumb them through compile_flydsl_moe_stage2 -> compile_mixed_moe_gemm2.
- Register a production fp4xfp4 stage2 variant
`flydsl_moe2_afp4_wfp4_bf16_t64x128x256_atomic_persist_async_w4_cumul3`
with waves_per_eu=4, use_async_copy=True, cu_num_mul=3.
- Drop the defensive `out.fill_(0)` added in PR ROCm#2863: when accumulate=True
the standard fused_moe path already zeros moe_buf via
`moe_buf_set_zero_kernel_2d` inside `moe_sorting_*_fwd`. The extra
fill is a ~token_num*model_dim HBM write (~130 us per call at MI355X
HBM bw on this shape). Callers using flydsl_moe_stage2 directly with
accumulate=True remain responsible for zeroing `out` (documented).
aiter/ops/flydsl/kernels/mixed_moe_gemm_2stage.py
- compile_mixed_moe_gemm2 signature: add waves_per_eu, use_async_copy,
cu_num_mul. Drop unused legacy b_nt, xcd_swizzle on this path.
- cu_num_mul: multiplies persistent-mode grid_y (CU count) by an integer
factor; cu_num_mul>1 launches more persistent CTAs that each cover
fewer M tiles, increasing in-flight parallelism. cu_num_mul=3 is the
EP4-prefill sweet spot; cu_num_mul=4 regresses ~2.4% on the same shape.
- use_async_copy: enable raw_ptr_buffer_load_lds for the X (activation)
tile in the prologue so the X DMA overlaps with B/scale VMEM. Gated
to the production tile/dtype combo so other configs are unaffected.
- Asymmetric b_lo/b_hi split (b_lo loads 3/4 of K, b_hi loads 1/4): the
smaller b_hi burst frees VMEM issue slots right before s_setprio(1)
and the K-loop; the extra ku of b_lo is issued in the prologue where
it overlaps with the X DMA.
- Stage sorted topk weights into LDS once per M-tile and prefetch via
vec4 LDS reads in the cshuffle epilogue, eliminating per-mi VMEM
dwordx4 loads from the MFMA-heavy hot loop.
- Emit AMDGPU `disable_xdl_arb_stall` + `s_setprio(1)` around the
K-loop on gfx950 so the high-prio K-loop drains MFMAs back-to-back.
- Issue scales BEFORE B-VMEM in the steady-state K-loop so the
extract-time stalls in the first MFMA chunk are reduced.
- Defer the lds_tid prologue (sorted_idx/sorted_w VMEM dwords + LDS
stores) to AFTER the long-latency X DMA + B/scale loads, hiding the
lds_tid scalar load latency behind the dominant VMEM phase.
- Merge the lds_tid + lds_tw prologue store guards into a single
scf.IfOp so MLIR can co-schedule the two VMEM buffer_loads.
- For the buffer-atomic path: fast-path the cshuffle row token-index
load (move the AND + multiply + validity-mask into the lds_tid
preload site where it overlaps with B/scale/X VMEM loads).
- Balanced persistent tile distribution (`tD4balPersist`) and balanced
LDS-write hint accounting on the async path (`dswr_async`).
aiter/fused_moe.py
- `_flydsl_stage2_wrapper`: forward waves_per_eu, use_async_copy,
cu_num_mul from the parsed kernel-name config to flydsl_moe_stage2.
## Performance
Benchmark harness:
/home/mingzliu/sgl_opt/ep4_mxfp4_opt/coopt_loop/scripts/bench.sh
Workload:
/home/mingzliu/sgl_opt/test_fused_moe_ep4_mxfp4.py
(M=49152, model_dim=7168, inter_dim=2048, 64 local experts, topk=8,
QuantType.per_1x32, fp4x2 weights & activations, bf16 out)
Metric: 5-run median latency, 3 reps, MI355X (smci355-ccs-aus-n06-21).
baseline (kernel=_atomic_persist): 3438.8 us (3-rep avg)
this PR (kernel=_atomic_persist_async_w4_cumul3): 3173.6 us (3-rep avg)
delta: -265 us (-7.71%)
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
There was a problem hiding this comment.
Pull request overview
Optimizes the FlyDSL MoE stage2 (down-projection) path for MXFP4×MXFP4 EP prefill by introducing a production-tuned persistent+async kernel variant and reducing redundant memory traffic, with new tunable compilation parameters plumbed through the fused MoE wrapper.
Changes:
- Add a new “production” fp4×fp4 stage2 kernel config (
*_persist_async_w4_cumul3) and plumbwaves_per_eu,use_async_copy, andcu_num_multhrough stage2 compilation/dispatch. - Remove the defensive
out.fill_(0)inflydsl_moe_stage2()for atomic mode (caller must provide a zeroed buffer when passingout). - Substantial stage2 kernel scheduling/memory pipeline changes (async X DMA, persistent grid expansion, LDS staging, buffer-atomic fast path, etc.) to improve overlap and latency.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
aiter/ops/flydsl/moe_kernels.py |
Registers the new stage2 kernel config and plumbs new stage2 compilation knobs; updates stage2 output buffer initialization behavior. |
aiter/ops/flydsl/kernels/mixed_moe_gemm_2stage.py |
Implements the stage2 kernel-side optimizations and adds new compilation parameters (waves_per_eu, async copy, CU multiplier). |
aiter/fused_moe.py |
Passes newly added stage2 tuning parameters from parsed kernel configs into flydsl_moe_stage2(). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # NOTE: when ``accumulate=True`` (atomic mode), the caller is responsible | ||
| # for ensuring ``out`` is zero-initialized. In the standard ``fused_moe`` | ||
| # dispatch path this is handled by ``moe_sorting_*_fwd`` which already | ||
| # zeros ``moe_buf`` via ``moe_buf_set_zero_kernel_2d``, so an extra | ||
| # ``out.fill_(0)`` here would be a redundant ~``token_num * model_dim`` | ||
| # HBM write (~130us per call at MI355X HBM bw on EP4 prefill shape). |
| def hot_loop_scheduler(): | ||
| # - MFMA group size per "slot": num_acc_n | ||
| # - Total MFMA per tile: (2*K32 per K64) * k_unroll * m_repeat * num_acc_n | ||
| # - We emit (mfma_group + dsrd + mfma_group) per scheduler iteration. | ||
| mfma_group = num_acc_n | ||
| mfma_total = (k_unroll * 2) * m_repeat * mfma_group | ||
| mfma_per_iter = 2 * mfma_group | ||
| sche_iters = ( | ||
| 0 if mfma_per_iter == 0 else (mfma_total // mfma_per_iter) | ||
| ) | ||
|
|
||
| rocdl.sched_dsrd(2) | ||
| rocdl.sched_mfma(1) | ||
| if tile_m == 16: | ||
| rocdl.sched_vmem(1) | ||
| rocdl.sched_mfma(1) | ||
| if tile_m == 16: | ||
| rocdl.sched_vmem(1) | ||
| if num_acc_n < 4: | ||
| rocdl.sched_dsrd(1) | ||
| rocdl.sched_mfma(1) | ||
| if tile_m == 16: | ||
| rocdl.sched_vmem(1) | ||
| rocdl.sched_dsrd(1) | ||
| rocdl.sched_mfma(1) | ||
| if tile_m == 16: | ||
| rocdl.sched_vmem(1) | ||
| rocdl.sched_mfma(1) | ||
|
|
||
| # DS-write hints (match total A LDS-store micro-ops/thread): | ||
| # async path's raw_ptr_buffer_load_lds writes LDS inside the | ||
| # buffer-load, so there is no register-mediated ds_write. | ||
| # Legacy dswr_tail=num_x_loads overcounts; zero it for async. | ||
| if use_async_copy: | ||
| dswr_tail = 0 | ||
| else: | ||
| dswr_tail = num_x_loads | ||
| if dswr_tail > sche_iters: | ||
| dswr_tail = sche_iters | ||
| dswr_start = sche_iters - dswr_tail | ||
|
|
||
| for sche_i in range_constexpr(sche_iters): | ||
| rocdl.sched_vmem(1) | ||
| rocdl.sched_mfma(mfma_group) | ||
| rocdl.sched_dsrd(1) | ||
| rocdl.sched_mfma(mfma_group) | ||
| if dswr_tail > 0 and sche_i >= dswr_start - 1: | ||
| rocdl.sched_dswr(1) | ||
|
|
||
| rocdl.sched_barrier(0) | ||
|
|
|
|
||
| waves_per_eu: | ||
| - None (default): let the backend pick. | ||
| - 1..4: hint the ROCm backend with ``rocdl.waves_per_eu`` to cap or |
| if base_name == "flydsl_moe2_afp4_wfp4_bf16_t64x128x256_atomic": | ||
| # Production fp4xfp4 stage2 variant for the EP4 | ||
| # DeepSeek prefill shape on MI355X. Adds: | ||
| # - use_async_copy=True (async X DMA in prologue | ||
| # overlaps with B/scale VMEM) | ||
| # - cu_num_mul=3 (persistent grid 3x CU count | ||
| # to fill in-flight slack from small per-WG | ||
| # M tile counts; cu_num_mul=4 regresses ~2.4% | ||
| # on the same shape) | ||
| # - waves_per_eu=4 (best on EP4 prefill at | ||
| # cu_num_mul=3; wpe=5/6 underperform here) | ||
| kernels[f"{base_name}_persist_async_w4_cumul3"] = { | ||
| **base_params, | ||
| "persist": True, | ||
| "use_async_copy": True, | ||
| "waves_per_eu": 4, | ||
| "cu_num_mul": 3, | ||
| } |
| # IMPORTANT: for odd number of K tiles, leave **1** tail tile; for even, leave **2**. | ||
| # Otherwise the 2-tile tail below would double-count the last tile when num_tiles is odd | ||
| # (e.g. inter_dim=192, tile_k=64 -> 3 tiles). | ||
| num_k_tiles_py = int(inter_dim) // int(tile_k) | ||
| num_k_tiles_py = _num_k_tiles_per_batch | ||
| odd_k_tiles = (num_k_tiles_py % 2) == 1 | ||
| tail_tiles = 1 if odd_k_tiles else 2 | ||
| k_main2_py = (num_k_tiles_py - tail_tiles) * int(tile_k) | ||
| if const_expr(k_main2_py < 0): | ||
| if k_main2_py < 0: |
2f0cdab to
70fd554
Compare
The strict flydsl in current aiter-main rejects bare `if X:` / `elif X:`
over Python-level booleans inside a kernel emit body -- it lowers them
into `scf.if` so any Python variable assigned in either branch is
trapped inside the scf.if scope and the outer reference raises
`NameError: name '<X>' is not defined`.
`compile_mixed_moe_gemm1` already wraps every such condition with
`const_expr(...)`. The stage2 emit body in the parent commit had ~98
bare-if conditions over compile-time identifiers (`accumulate`,
`_persistent`, `doweight_stage2`, `_b_split_enabled`, `is_f4_a`,
`use_async_copy`, `_r139_xdma_first`, `_r216_defer_tid`,
`_use_buf_atomic_pre`, ...). Wrap them with `const_expr(...)` so the
condition is evaluated at Python trace time and the variable binding
escapes to the enclosing scope, matching the stage1 convention.
No semantic change: only adds `const_expr(...)` around if/elif heads
inside compile_mixed_moe_gemm2; DSL/runtime conditions (those whose
sub-expressions reference `arith.`, `buffer_ops.`, `rocdl.`, ...) are
left as bare `if` so they remain runtime `scf.if`s.
Repro (without this commit) on aiter-main runtime:
File ".../mixed_moe_gemm_2stage.py", line 3096, in moe_gemm2
shape_lds = fx.make_shape(tile_m, _eff_lds_stride)
NameError: name '_eff_lds_stride' is not defined
After this commit the kernel compiles, the end-to-end fused_moe perf
is unchanged from the parent commit on the legacy flydsl, and is
restored on aiter-main flydsl.
70fd554 to
83ae8ad
Compare
Re-flow a handful of long if-conditions, error message strings, and multi-arg call sites in `aiter/ops/flydsl/moe_kernels.py` and `aiter/ops/flydsl/kernels/mixed_moe_gemm_2stage.py` so they pass the `psf/black@stable` check in `.github/workflows/pre-checks.yaml`. No semantic change -- pure whitespace/line-wrap. `ruff check` on the three changed files is also clean.
| **base_params, | ||
| "persist": True, | ||
| } | ||
| if ( |
| f"sort_block_m ({_sort_block_m}) must be a multiple of tile_m ({tile_m})" | ||
| ) | ||
|
|
||
| # issue async X DMA early in prologue for max overlap with B/scale VMEM. |
There was a problem hiding this comment.
is it work only for the tile config?
| inter_dim_pad: int = 0, | ||
| persist_m: int = 4, | ||
| sort_block_m: int = 0, | ||
| b_nt: int = 2, |
There was a problem hiding this comment.
why you removed b_nt and xcd_swizzle flags?
| a2_scale=a2_scale, | ||
| sorted_weights=sorted_weights, | ||
| sort_block_m=parsed.get("sort_block_m", 0), | ||
| waves_per_eu=parsed.get("waves_per_eu", None), |
There was a problem hiding this comment.
where have you used these three params? I don't see the pr pased these params, seems always use default.
for InferenceMax case
cc @Duyi-Wang @GLRocks
Motivation
FlyDSL stage2 (down-projection) is on the hot path for DeepSeek-R1/V3 EP4 prefill on MI355X. The current
_atomic_persistkernel doesn't overlap prologue VMEM with the K-loop, andflydsl_moe_stage2()redundantly memsetsouteven thoughmoe_sorting_*_fwdalready zerosmoe_buf(~130us HBM bw per call).Modifications
_t64x128x256_atomic_persist_async_w4_cumul3(waves_per_eu=4, use_async_copy=True, cu_num_mul=3), plumb the new params throughflydsl_moe_stage2→compile_mixed_moe_gemm2.out.fill_(0)fromflydsl_moe_stage2(); the standardfused_moepath zerosmoe_bufviamoe_buf_set_zero_kernel_2d. Caller contract documented in the wrapper.cu_num_mul=3persistent grid expansion, asymmetricb_lo/b_hisplit (3/4 + 1/4), LDS-stagedsorted_weights,disable_xdl_arb_stall+ carrieds_setprio(1), scales-before-B in the K-loop, deferredlds_tidprologue + mergedlds_tid/lds_twguards, buffer-atomic fast path (prologue precomputes per-row byte offset + OOB mask), balanced contiguous M-tile dispatch.Perf Tests
Hardware: single AMD MI355X.
Workload:
test_fused_moe_ep4_mxfp4.py—M=49152, model_dim=7168, inter_dim=2048, E=64, topk=8, per_1x32 fp4x2×fp4x2 → bf16.Metric: 5-run median × 3 reps, end-to-end
fused_moelatency.mainCorrectness:
test_flydsl_moe_a4w4.py12/12 PASS acrosstile_m∈{32,64}×token∈{16,64,256}×{stage2, e2e}.please help on review @lalala-sh,This has improved the inferenceMax E2E scenario. Thanks!