Skip to content

[MoE] adapt to triton_kernels matmul_ogs -> matmul rename#763

Open
Liang-jianhao97 wants to merge 3 commits into
mainfrom
jianlian/triton-kernels-matmul-rename
Open

[MoE] adapt to triton_kernels matmul_ogs -> matmul rename#763
Liang-jianhao97 wants to merge 3 commits into
mainfrom
jianlian/triton-kernels-matmul-rename

Conversation

@Liang-jianhao97
Copy link
Copy Markdown

@Liang-jianhao97 Liang-jianhao97 commented May 12, 2026

Motivation

Fix DeepseekV4 accuracy error. https://github.com/ROCm/triton-internal/issues/1823?reload=1?reload=1

Technical Details

Upstream triton_kernels merged the matmul_ogs module into matmul and the matmul_ogs_details package into matmul_details. The PrecisionConfig dataclass was also reshaped: weight_scale is now b_mx_scale, and setting it requires b_microblock_size to be provided explicitly (enforced by an assert in the new matmul()).

  • fused_moe_triton: try importing FnSpecs / FusedActivation / PrecisionConfig / matmul from triton_kernels.matmul first, fall back to the old triton_kernels.matmul_ogs path. Alias matmul as matmul_ogs so existing call sites stay unchanged.
  • moe (Mxfp4MoEMethod.process_weights_after_loading): same dual-path import for FlexCtx / PrecisionConfig; detect the kwarg name via dataclasses.fields so the old weight_scale=... path keeps working while the new API takes b_mx_scale=... and b_microblock_size=....
  • Drop the _amd_smem_safe_tile workaround that pinned block_m / block_n on gfx950: the underlying LDS-spill is no longer reproducible against current triton / triton_kernels.

Test Plan

ATOM Test

Test Result

PASS

Submission Checklist

jianlian and others added 3 commits May 12, 2026 04:41
Upstream triton_kernels merged the `matmul_ogs` module into `matmul`
and the `matmul_ogs_details` package into `matmul_details`. The
`PrecisionConfig` dataclass was also reshaped: `weight_scale` is now
`b_mx_scale`, and setting it requires `b_microblock_size` to be
provided explicitly (enforced by an assert in the new `matmul()`).

- fused_moe_triton: try importing `FnSpecs / FusedActivation /
  PrecisionConfig / matmul` from `triton_kernels.matmul` first, fall
  back to the old `triton_kernels.matmul_ogs` path. Alias `matmul as
  matmul_ogs` so existing call sites stay unchanged.
- moe (Mxfp4MoEMethod.process_weights_after_loading): same dual-path
  import for `FlexCtx / PrecisionConfig`; detect the kwarg name via
  `dataclasses.fields` so the old `weight_scale=` path keeps working
  while the new API takes `b_mx_scale=` + `b_microblock_size=`.
- Drop the `_amd_smem_safe_tile` workaround that pinned
  block_m / block_n on gfx950: the underlying LDS-spill is no longer
  reproducible against current triton / triton_kernels.

Co-authored-by: Cursor <cursoragent@cursor.com>
Pre Checkin's `psf/black@stable` flagged the `import triton_kernels.swiglu`
+ nested `try:` block inside the top-level guard. Add the required blank
line so Black treats the inner try as a new statement group.

Co-authored-by: Cursor <cursoragent@cursor.com>
Temporarily swap the DeepSeek-V4-Pro accuracy entry from the AITER MoE
env (AITER_BF16_FP8_MOE_BOUND=0 + ATOM_MOE_GU_ITLV=1) to the triton MoE
env (ATOM_USE_TRITON_MOE=1) so this PR's CI exercises the triton path
adapted in fused_moe_triton.py / moe.py. Revert before merge.

Co-authored-by: Cursor <cursoragent@cursor.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants