Add dynamic MXFP4 quantization kernel and tests#1299
Draft
suryajasper wants to merge 1 commit into
Draft
Conversation
Implement a Wave kernel for dynamic MXFP4 (E2M1) quantization of f32 activations with E8M0 block scales, following the algorithm from ROCm/aiter's dynamic_mxfp4_quant. The pipeline is split into three stages: - Scale computation (compute_mxfp4_scales): per-32-element block amax via IEEE 754 bit manipulation, producing E8M0 scale bytes and a broadcasted quant_scale tensor. Runs as pure PyTorch on GPU since the 32-element reduction is incompatible with AMD wave64 (which requires >=64 elements for wave-level reduce). - FP4 encoding (Wave kernel): comparison-based E2M1 encoding using seven threshold comparisons and select, entirely elementwise. - Nibble packing (pack_mxfp4_codes): pairs of i8 codes into uint8. Also adds f32_to_mxfp4 PyTorch reference in mxfp_utils.py for validation. Tests: - Lit test: FileCheck verifying key MLIR ops (absf, cmpf, select, fptosi) in the compiled kernel. - Unit pytest: 4 parametrized shapes comparing Wave kernel output against f32_to_mxfp4 reference with dequantize round-trip. - E2E pytest: Wave quant -> MXFP4 GEMM validated against the full PyTorch reference pipeline (f32_to_mxfp4 + torchScaledGemmMXFP4). Signed-off-by: Surya Jasper <45545431+suryajasper@users.noreply.github.com> Made-with: Cursor
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Implement a Wave kernel for dynamic MXFP4 (E2M1) quantization of f32 activations with E8M0 block scales, following the algorithm from ROCm/aiter's dynamic_mxfp4_quant.
The pipeline is split into three stages:
Also adds f32_to_mxfp4 PyTorch reference in mxfp_utils.py for validation.
Tests: