[MoE] adapt to triton_kernels matmul_ogs -> matmul rename#763
Open
Liang-jianhao97 wants to merge 3 commits into
Open
[MoE] adapt to triton_kernels matmul_ogs -> matmul rename#763Liang-jianhao97 wants to merge 3 commits into
Liang-jianhao97 wants to merge 3 commits into
Conversation
Upstream triton_kernels merged the `matmul_ogs` module into `matmul` and the `matmul_ogs_details` package into `matmul_details`. The `PrecisionConfig` dataclass was also reshaped: `weight_scale` is now `b_mx_scale`, and setting it requires `b_microblock_size` to be provided explicitly (enforced by an assert in the new `matmul()`). - fused_moe_triton: try importing `FnSpecs / FusedActivation / PrecisionConfig / matmul` from `triton_kernels.matmul` first, fall back to the old `triton_kernels.matmul_ogs` path. Alias `matmul as matmul_ogs` so existing call sites stay unchanged. - moe (Mxfp4MoEMethod.process_weights_after_loading): same dual-path import for `FlexCtx / PrecisionConfig`; detect the kwarg name via `dataclasses.fields` so the old `weight_scale=` path keeps working while the new API takes `b_mx_scale=` + `b_microblock_size=`. - Drop the `_amd_smem_safe_tile` workaround that pinned block_m / block_n on gfx950: the underlying LDS-spill is no longer reproducible against current triton / triton_kernels. Co-authored-by: Cursor <cursoragent@cursor.com>
Pre Checkin's `psf/black@stable` flagged the `import triton_kernels.swiglu` + nested `try:` block inside the top-level guard. Add the required blank line so Black treats the inner try as a new statement group. Co-authored-by: Cursor <cursoragent@cursor.com>
Temporarily swap the DeepSeek-V4-Pro accuracy entry from the AITER MoE env (AITER_BF16_FP8_MOE_BOUND=0 + ATOM_MOE_GU_ITLV=1) to the triton MoE env (ATOM_USE_TRITON_MOE=1) so this PR's CI exercises the triton path adapted in fused_moe_triton.py / moe.py. Revert before merge. Co-authored-by: Cursor <cursoragent@cursor.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Fix DeepseekV4 accuracy error. https://github.com/ROCm/triton-internal/issues/1823?reload=1?reload=1
Technical Details
Upstream triton_kernels merged the
matmul_ogsmodule intomatmuland thematmul_ogs_detailspackage intomatmul_details. ThePrecisionConfigdataclass was also reshaped:weight_scaleis nowb_mx_scale, and setting it requiresb_microblock_sizeto be provided explicitly (enforced by an assert in the newmatmul()).FnSpecs / FusedActivation / PrecisionConfig / matmulfromtriton_kernels.matmulfirst, fall back to the oldtriton_kernels.matmul_ogspath. Aliasmatmul as matmul_ogsso existing call sites stay unchanged.FlexCtx / PrecisionConfig; detect the kwarg name viadataclasses.fieldsso the oldweight_scale=...path keeps working while the new API takesb_mx_scale=...andb_microblock_size=...._amd_smem_safe_tileworkaround that pinned block_m / block_n on gfx950: the underlying LDS-spill is no longer reproducible against current triton / triton_kernels.Test Plan
ATOM Test
Test Result
PASS
Submission Checklist