[Kernel][gfx1250] Add FlyDSL MXScale FP8/A8W4 GEMM#3106
Open
Conversation
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
Contributor
There was a problem hiding this comment.
Pull request overview
Adds FlyDSL-backed MXScale dense GEMM support for gfx1250 (MXFP8 and A8W4) to AITER, including a public Python wrapper, host-side layout helpers, kernel code, and an AOT CSV parsing/compilation hook.
Changes:
- Introduces
flydsl_mxscale_gemmplus format-named wrappersgemm_mxfp8/gemm_mxa8w4, including kernel-name encode/parse utilities. - Adds host-side padding + preshuffle utilities for B and E8M0 scales, and vendors the gfx1250 MXScale GEMM kernel implementation.
- Adds a dedicated pytest suite and extends the FlyDSL GEMM AOT pipeline to recognize
flydsl_mxscale_*kernels.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| aiter/ops/flydsl/test_flydsl_mxscale_gemm.py | New unit + correctness tests for MXScale GEMM (gated to gfx1250 + flydsl). |
| aiter/ops/flydsl/mxscale_layout.py | Host-side padding and preshuffle helpers for MXScale A/B and E8M0 scales. |
| aiter/ops/flydsl/mxscale_gemm.py | Public wrapper API, kernel-name encode/parse, and runtime launch path. |
| aiter/ops/flydsl/kernels/pipeline_utils.py | Shared pipeline helper utilities used by the gfx1250 kernel. |
| aiter/ops/flydsl/kernels/gemm_fp8fp4_gfx1250.py | Vendored unified MXFP4/MXFP8/A8W4 gfx1250 kernel with MXScale support. |
| aiter/ops/flydsl/kernels/gemm_common_gfx1250.py | Shared gfx1250 GEMM helpers (LDS/pipeline/epilogue utilities). |
| aiter/aot/flydsl/gemm.py | Extends AOT CSV parsing/dispatch to recognize MXScale kernels and compile them. |
| aiter/init.py | Exposes the new optional FlyDSL-backed public entry points at top-level. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
4c9c06b to
53d5b74
Compare
Vendor the FlyDSL gfx1250 mxscale GEMM kernel into aiter and expose two
data-format-named public entries:
- aiter.gemm_mxfp8 MXFP8 (E4M3 + E8M0 1x32)
- aiter.gemm_mxa8w4 A8W4 (FP8 act, FP4 weight, E8M0 1x32)
aiter.flydsl_mxscale_gemm remains exported as the low-level entry that
pins to the FlyDSL backend and exposes all codegen knobs.
Files added:
- aiter/ops/flydsl/kernels/{gemm_fp8fp4_gfx1250,gemm_common_gfx1250,
pipeline_utils}.py vendored from FlyDSL main; only the two
`from kernels.X` imports are rewritten to
relative form.
- aiter/ops/flydsl/mxscale_layout.py host helper for pad +
E8M0(127) scale fill + B 16x16 preshuffle +
WMMA-friendly E8M0 scale preshuffle.
- aiter/ops/flydsl/mxscale_gemm.py public wrappers, kernelName
encode/parse, format-named entries, runtime
arch guard, lazy flydsl import.
- aiter/ops/flydsl/test_flydsl_mxscale_gemm.py unit tests, gated
on CUDA + flydsl + gfx1250.
Wires into AOT:
- aiter/aot/flydsl/gemm.py adds a `flydsl_mxscale_*` parser branch
and `_compile_mxscale_to_cache` for CSV-driven AOT precompilation;
mxscale-kind jobs hard-pin gfx1250 regardless of cu_num.
Public surface:
- aiter/__init__.py exports gemm_mxfp8, gemm_mxa8w4, and
flydsl_mxscale_gemm when flydsl is importable.
This path is intentionally independent from gemm_a8w8_blockscale and
gemm_a8w8_bpreshuffle: the OCP MX scale (E8M0 1x32) is not
interchangeable with the existing per-1x128/128x128 FP32 or PTPC FP32
scale layouts. Future Gluon/CK MXFP8 backends can land behind the
format-named entries without changing the call sites.
…n device/FP4 padding
cdfb6aa to
8eb18c5
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Add MXFP8 and A8W4 dense GEMM support for gfx1250 to AITER.
Technical Details
compile_mxscale_gemmkernel intoaiter/ops/flydsl/kernels/.aiter.gemm_mxfp8— MXFP8 (E4M3 + E8M0 1×32)aiter.gemm_mxa8w4— A8W4 (FP8 act, FP4 weight, E8M0 1×32)aiter.flydsl_mxscale_gemm— low-level entry with all codegen knobs.aiter/aot/flydsl/gemm.py) for CSV-driven precompilation.Test Plan
Test Result
51 tests passed on gfx1250.
Submission Checklist