[Gluon] fused_mxfp4_quant for gfx1250#3093
Open
amd-jrosas wants to merge 28 commits intomainfrom
Open
Conversation
…1250, replaced aiter kernels with torch in reference functions
This reverts commit 149446f.
* add test gluon kernel * add test gluon kernel * add gluon kernel * gluon ut pass * comment and formatting * comment * move gluon kernel to new file and create tdm pipeline version * update * update * update * non-tdm async copy UT passed, tdm asycn copy UT not passing * update * tmp * update * update * gfx950 compatible * update * always max size_per_thread x threads_per_warp along the fastest dim * redesign tdm desc and offsets * update * TDM random gather UA3D UT passed
Add 2d unified attention kernel in gluon, with async. and TDM support, supports gfx950 and gfx1250
* update * update UT * fix async_copy bug, add kv_cache_shuffle torch * update * add key_cache shuffling, fix gluon async_copy bug * tmp * update * tmp * translate kernel to new style * update baseline * add V_BLOCKED_LAYOUT * instr_shape update * add load_shared_relaxed to async kernel * update * add make dll * update * gluon key shuffle * update * add value cache shuffle support, but only for block_size > 64 * change name * change preshuffle load write logic * shuffle for async kernel * update * update * update * update * shuffle for tdm tmp * gfx1250 async shuffle fix * updates * tdm gather shuffle ut pass * skip ut if LDS requirement exceeds 320 kB * update * update test scripts * clean up * revert chip_info hack * update * update * clean up
Port changes from #2282. Adds default Triton kernel tuning configs for gfx1250 covering all GEMM variants and MHA (fwd + bwd).
…#2316) * Add gfx1250 arch enablement: fp8 support + test refactoring - Add gfx1250 to CDNA_ARCHS and FP8_ARCHS in flash attention utils - Refactor fp8 dtype selection in tests to use aiter.utility.dtypes.fp8 instead of hardcoded per-arch checks, enabling gfx1250 support - Note: gemm_config_utils left unchanged to preserve failure on missing gfx1250 configs (no fallback to gfx950) Port of #2284 with intentional deviation on gemm_config_utils. * Address comments * Fix pytest mark skipif * Address comments
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Create gluon version of fused_mxfp4_quant for gfx1250.
Technical Details
New gluon kernel to include TDM functions and updated logic for gfx1250. Gluon Kernel written separately and contained in _gluon_kernels/quant/ folder and updated made to existing API in triton to default to gluon kernel when gfx1250 is detected. Updates were also made to the test file to verify gluon kernel.
Test Plan
Verified testing through existing test_fused_mxfp4_quant.py and successfully pass across various shapes and feature set.
Test Result
All tests were successful with various features set to true/false. More details below.
Main features
Shape Details
*132 value had a single FP4 element mismatch by 0.25 which appears to be a rounding boundary case. Occurs with shuffle=True, and dtype=float16 but appears to be a rounding artifact.
Submission Checklist