Skip to content

perf(flydsl): MXFP4 fused-MoE stage2 optimization for EP prefill#3117

Open
inkcherry wants to merge 4 commits intoROCm:mainfrom
inkcherry:coopt/fusemoe-cleanup-v3-pr-ready
Open

perf(flydsl): MXFP4 fused-MoE stage2 optimization for EP prefill#3117
inkcherry wants to merge 4 commits intoROCm:mainfrom
inkcherry:coopt/fusemoe-cleanup-v3-pr-ready

Conversation

@inkcherry
Copy link
Copy Markdown
Contributor

@inkcherry inkcherry commented May 11, 2026

for InferenceMax case
cc @Duyi-Wang @GLRocks

Motivation

FlyDSL stage2 (down-projection) is on the hot path for DeepSeek-R1/V3 EP4 prefill on MI355X. The current _atomic_persist kernel doesn't overlap prologue VMEM with the K-loop, and flydsl_moe_stage2() redundantly memsets out even though moe_sorting_*_fwd already zeros moe_buf (~130us HBM bw per call).

Modifications

  • Add a production fp4×fp4 stage2 variant _t64x128x256_atomic_persist_async_w4_cumul3 (waves_per_eu=4, use_async_copy=True, cu_num_mul=3), plumb the new params through flydsl_moe_stage2compile_mixed_moe_gemm2.
  • Drop the defensive out.fill_(0) from flydsl_moe_stage2(); the standard fused_moe path zeros moe_buf via moe_buf_set_zero_kernel_2d. Caller contract documented in the wrapper.
  • Kernel-side passes (gated to the production tile + dtype): async X DMA in prologue, cu_num_mul=3 persistent grid expansion, asymmetric b_lo/b_hi split (3/4 + 1/4), LDS-staged sorted_weights, disable_xdl_arb_stall + carried s_setprio(1), scales-before-B in the K-loop, deferred lds_tid prologue + merged lds_tid/lds_tw guards, buffer-atomic fast path (prologue precomputes per-row byte offset + OOB mask), balanced contiguous M-tile dispatch.

Perf Tests

Hardware: single AMD MI355X.
Workload: test_fused_moe_ep4_mxfp4.pyM=49152, model_dim=7168, inter_dim=2048, E=64, topk=8, per_1x32 fp4x2×fp4x2 → bf16.
Metric: 5-run median × 3 reps, end-to-end fused_moe latency.

Source rep1 rep2 rep3 avg (us) Δ
upstream main 3709.9 3715.1 3710.4 3711.8
this PR 3371.2 3377.0 3406.5 3384.9 −8.81%

Correctness: test_flydsl_moe_a4w4.py 12/12 PASS across tile_m∈{32,64} × token∈{16,64,256} × {stage2, e2e}.

please help on review @lalala-sh,This has improved the inferenceMax E2E scenario. Thanks!

Optimize the FlyDSL stage2 (down-projection) kernel for the production
fp4xfp4 EP4 DeepSeek prefill shape (tile_m=64, tile_n=128, tile_k=256,
M=49152) on MI355X.

## Summary of changes

aiter/ops/flydsl/moe_kernels.py
  - Add new wrapper params (waves_per_eu, use_async_copy, cu_num_mul) and
    plumb them through compile_flydsl_moe_stage2 -> compile_mixed_moe_gemm2.
  - Register a production fp4xfp4 stage2 variant
    `flydsl_moe2_afp4_wfp4_bf16_t64x128x256_atomic_persist_async_w4_cumul3`
    with waves_per_eu=4, use_async_copy=True, cu_num_mul=3.
  - Drop the defensive `out.fill_(0)` added in PR ROCm#2863: when accumulate=True
    the standard fused_moe path already zeros moe_buf via
    `moe_buf_set_zero_kernel_2d` inside `moe_sorting_*_fwd`. The extra
    fill is a ~token_num*model_dim HBM write (~130 us per call at MI355X
    HBM bw on this shape). Callers using flydsl_moe_stage2 directly with
    accumulate=True remain responsible for zeroing `out` (documented).

aiter/ops/flydsl/kernels/mixed_moe_gemm_2stage.py
  - compile_mixed_moe_gemm2 signature: add waves_per_eu, use_async_copy,
    cu_num_mul. Drop unused legacy b_nt, xcd_swizzle on this path.
  - cu_num_mul: multiplies persistent-mode grid_y (CU count) by an integer
    factor; cu_num_mul>1 launches more persistent CTAs that each cover
    fewer M tiles, increasing in-flight parallelism. cu_num_mul=3 is the
    EP4-prefill sweet spot; cu_num_mul=4 regresses ~2.4% on the same shape.
  - use_async_copy: enable raw_ptr_buffer_load_lds for the X (activation)
    tile in the prologue so the X DMA overlaps with B/scale VMEM. Gated
    to the production tile/dtype combo so other configs are unaffected.
  - Asymmetric b_lo/b_hi split (b_lo loads 3/4 of K, b_hi loads 1/4): the
    smaller b_hi burst frees VMEM issue slots right before s_setprio(1)
    and the K-loop; the extra ku of b_lo is issued in the prologue where
    it overlaps with the X DMA.
  - Stage sorted topk weights into LDS once per M-tile and prefetch via
    vec4 LDS reads in the cshuffle epilogue, eliminating per-mi VMEM
    dwordx4 loads from the MFMA-heavy hot loop.
  - Emit AMDGPU `disable_xdl_arb_stall` + `s_setprio(1)` around the
    K-loop on gfx950 so the high-prio K-loop drains MFMAs back-to-back.
  - Issue scales BEFORE B-VMEM in the steady-state K-loop so the
    extract-time stalls in the first MFMA chunk are reduced.
  - Defer the lds_tid prologue (sorted_idx/sorted_w VMEM dwords + LDS
    stores) to AFTER the long-latency X DMA + B/scale loads, hiding the
    lds_tid scalar load latency behind the dominant VMEM phase.
  - Merge the lds_tid + lds_tw prologue store guards into a single
    scf.IfOp so MLIR can co-schedule the two VMEM buffer_loads.
  - For the buffer-atomic path: fast-path the cshuffle row token-index
    load (move the AND + multiply + validity-mask into the lds_tid
    preload site where it overlaps with B/scale/X VMEM loads).
  - Balanced persistent tile distribution (`tD4balPersist`) and balanced
    LDS-write hint accounting on the async path (`dswr_async`).

aiter/fused_moe.py
  - `_flydsl_stage2_wrapper`: forward waves_per_eu, use_async_copy,
    cu_num_mul from the parsed kernel-name config to flydsl_moe_stage2.

## Performance

Benchmark harness:
  /home/mingzliu/sgl_opt/ep4_mxfp4_opt/coopt_loop/scripts/bench.sh
Workload:
  /home/mingzliu/sgl_opt/test_fused_moe_ep4_mxfp4.py
  (M=49152, model_dim=7168, inter_dim=2048, 64 local experts, topk=8,
   QuantType.per_1x32, fp4x2 weights & activations, bf16 out)
Metric: 5-run median latency, 3 reps, MI355X (smci355-ccs-aus-n06-21).

  baseline (kernel=_atomic_persist):                3438.8 us (3-rep avg)
  this PR  (kernel=_atomic_persist_async_w4_cumul3): 3173.6 us (3-rep avg)
  delta:                                             -265 us  (-7.71%)
@inkcherry inkcherry requested review from a team and Copilot May 11, 2026 05:43
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3117 --add-label <label>

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Optimizes the FlyDSL MoE stage2 (down-projection) path for MXFP4×MXFP4 EP prefill by introducing a production-tuned persistent+async kernel variant and reducing redundant memory traffic, with new tunable compilation parameters plumbed through the fused MoE wrapper.

Changes:

  • Add a new “production” fp4×fp4 stage2 kernel config (*_persist_async_w4_cumul3) and plumb waves_per_eu, use_async_copy, and cu_num_mul through stage2 compilation/dispatch.
  • Remove the defensive out.fill_(0) in flydsl_moe_stage2() for atomic mode (caller must provide a zeroed buffer when passing out).
  • Substantial stage2 kernel scheduling/memory pipeline changes (async X DMA, persistent grid expansion, LDS staging, buffer-atomic fast path, etc.) to improve overlap and latency.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.

File Description
aiter/ops/flydsl/moe_kernels.py Registers the new stage2 kernel config and plumbs new stage2 compilation knobs; updates stage2 output buffer initialization behavior.
aiter/ops/flydsl/kernels/mixed_moe_gemm_2stage.py Implements the stage2 kernel-side optimizations and adds new compilation parameters (waves_per_eu, async copy, CU multiplier).
aiter/fused_moe.py Passes newly added stage2 tuning parameters from parsed kernel configs into flydsl_moe_stage2().

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +1098 to +1103
# NOTE: when ``accumulate=True`` (atomic mode), the caller is responsible
# for ensuring ``out`` is zero-initialized. In the standard ``fused_moe``
# dispatch path this is handled by ``moe_sorting_*_fwd`` which already
# zeros ``moe_buf`` via ``moe_buf_set_zero_kernel_2d``, so an extra
# ``out.fill_(0)`` here would be a redundant ~``token_num * model_dim``
# HBM write (~130us per call at MI355X HBM bw on EP4 prefill shape).
Comment on lines 4167 to 4217
def hot_loop_scheduler():
# - MFMA group size per "slot": num_acc_n
# - Total MFMA per tile: (2*K32 per K64) * k_unroll * m_repeat * num_acc_n
# - We emit (mfma_group + dsrd + mfma_group) per scheduler iteration.
mfma_group = num_acc_n
mfma_total = (k_unroll * 2) * m_repeat * mfma_group
mfma_per_iter = 2 * mfma_group
sche_iters = (
0 if mfma_per_iter == 0 else (mfma_total // mfma_per_iter)
)

rocdl.sched_dsrd(2)
rocdl.sched_mfma(1)
if tile_m == 16:
rocdl.sched_vmem(1)
rocdl.sched_mfma(1)
if tile_m == 16:
rocdl.sched_vmem(1)
if num_acc_n < 4:
rocdl.sched_dsrd(1)
rocdl.sched_mfma(1)
if tile_m == 16:
rocdl.sched_vmem(1)
rocdl.sched_dsrd(1)
rocdl.sched_mfma(1)
if tile_m == 16:
rocdl.sched_vmem(1)
rocdl.sched_mfma(1)

# DS-write hints (match total A LDS-store micro-ops/thread):
# async path's raw_ptr_buffer_load_lds writes LDS inside the
# buffer-load, so there is no register-mediated ds_write.
# Legacy dswr_tail=num_x_loads overcounts; zero it for async.
if use_async_copy:
dswr_tail = 0
else:
dswr_tail = num_x_loads
if dswr_tail > sche_iters:
dswr_tail = sche_iters
dswr_start = sche_iters - dswr_tail

for sche_i in range_constexpr(sche_iters):
rocdl.sched_vmem(1)
rocdl.sched_mfma(mfma_group)
rocdl.sched_dsrd(1)
rocdl.sched_mfma(mfma_group)
if dswr_tail > 0 and sche_i >= dswr_start - 1:
rocdl.sched_dswr(1)

rocdl.sched_barrier(0)


waves_per_eu:
- None (default): let the backend pick.
- 1..4: hint the ROCm backend with ``rocdl.waves_per_eu`` to cap or
Comment thread aiter/ops/flydsl/moe_kernels.py Outdated
Comment on lines +189 to +206
if base_name == "flydsl_moe2_afp4_wfp4_bf16_t64x128x256_atomic":
# Production fp4xfp4 stage2 variant for the EP4
# DeepSeek prefill shape on MI355X. Adds:
# - use_async_copy=True (async X DMA in prologue
# overlaps with B/scale VMEM)
# - cu_num_mul=3 (persistent grid 3x CU count
# to fill in-flight slack from small per-WG
# M tile counts; cu_num_mul=4 regresses ~2.4%
# on the same shape)
# - waves_per_eu=4 (best on EP4 prefill at
# cu_num_mul=3; wpe=5/6 underperform here)
kernels[f"{base_name}_persist_async_w4_cumul3"] = {
**base_params,
"persist": True,
"use_async_copy": True,
"waves_per_eu": 4,
"cu_num_mul": 3,
}
Comment on lines +4353 to +4360
# IMPORTANT: for odd number of K tiles, leave **1** tail tile; for even, leave **2**.
# Otherwise the 2-tile tail below would double-count the last tile when num_tiles is odd
# (e.g. inter_dim=192, tile_k=64 -> 3 tiles).
num_k_tiles_py = int(inter_dim) // int(tile_k)
num_k_tiles_py = _num_k_tiles_per_batch
odd_k_tiles = (num_k_tiles_py % 2) == 1
tail_tiles = 1 if odd_k_tiles else 2
k_main2_py = (num_k_tiles_py - tail_tiles) * int(tile_k)
if const_expr(k_main2_py < 0):
if k_main2_py < 0:
@inkcherry inkcherry force-pushed the coopt/fusemoe-cleanup-v3-pr-ready branch from 2f0cdab to 70fd554 Compare May 11, 2026 06:21
The strict flydsl in current aiter-main rejects bare `if X:` / `elif X:`
over Python-level booleans inside a kernel emit body -- it lowers them
into `scf.if` so any Python variable assigned in either branch is
trapped inside the scf.if scope and the outer reference raises
`NameError: name '<X>' is not defined`.

`compile_mixed_moe_gemm1` already wraps every such condition with
`const_expr(...)`. The stage2 emit body in the parent commit had ~98
bare-if conditions over compile-time identifiers (`accumulate`,
`_persistent`, `doweight_stage2`, `_b_split_enabled`, `is_f4_a`,
`use_async_copy`, `_r139_xdma_first`, `_r216_defer_tid`,
`_use_buf_atomic_pre`, ...).  Wrap them with `const_expr(...)` so the
condition is evaluated at Python trace time and the variable binding
escapes to the enclosing scope, matching the stage1 convention.

No semantic change: only adds `const_expr(...)` around if/elif heads
inside compile_mixed_moe_gemm2; DSL/runtime conditions (those whose
sub-expressions reference `arith.`, `buffer_ops.`, `rocdl.`, ...) are
left as bare `if` so they remain runtime `scf.if`s.

Repro (without this commit) on aiter-main runtime:

  File ".../mixed_moe_gemm_2stage.py", line 3096, in moe_gemm2
      shape_lds = fx.make_shape(tile_m, _eff_lds_stride)
  NameError: name '_eff_lds_stride' is not defined

After this commit the kernel compiles, the end-to-end fused_moe perf
is unchanged from the parent commit on the legacy flydsl, and is
restored on aiter-main flydsl.
@inkcherry inkcherry force-pushed the coopt/fusemoe-cleanup-v3-pr-ready branch from 70fd554 to 83ae8ad Compare May 11, 2026 06:25
inkcherry added 2 commits May 11, 2026 09:04
Re-flow a handful of long if-conditions, error message strings, and
multi-arg call sites in `aiter/ops/flydsl/moe_kernels.py` and
`aiter/ops/flydsl/kernels/mixed_moe_gemm_2stage.py` so they pass the
`psf/black@stable` check in `.github/workflows/pre-checks.yaml`.

No semantic change -- pure whitespace/line-wrap.  `ruff check` on the
three changed files is also clean.
**base_params,
"persist": True,
}
if (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the hardcode...

f"sort_block_m ({_sort_block_m}) must be a multiple of tile_m ({tile_m})"
)

# issue async X DMA early in prologue for max overlap with B/scale VMEM.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it work only for the tile config?

inter_dim_pad: int = 0,
persist_m: int = 4,
sort_block_m: int = 0,
b_nt: int = 2,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why you removed b_nt and xcd_swizzle flags?

Comment thread aiter/fused_moe.py
a2_scale=a2_scale,
sorted_weights=sorted_weights,
sort_block_m=parsed.get("sort_block_m", 0),
waves_per_eu=parsed.get("waves_per_eu", None),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where have you used these three params? I don't see the pr pased these params, seems always use default.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants