Skip to content

Add dynamic MXFP4 quantization kernel and tests#1299

Draft
suryajasper wants to merge 1 commit into
iree-org:mainfrom
suryajasper:mxfp4-dynamic-quantization
Draft

Add dynamic MXFP4 quantization kernel and tests#1299
suryajasper wants to merge 1 commit into
iree-org:mainfrom
suryajasper:mxfp4-dynamic-quantization

Conversation

@suryajasper
Copy link
Copy Markdown
Contributor

Implement a Wave kernel for dynamic MXFP4 (E2M1) quantization of f32 activations with E8M0 block scales, following the algorithm from ROCm/aiter's dynamic_mxfp4_quant.

The pipeline is split into three stages:

  • Scale computation (compute_mxfp4_scales): per-32-element block amax via IEEE 754 bit manipulation, producing E8M0 scale bytes and a broadcasted quant_scale tensor. Runs as pure PyTorch on GPU since the 32-element reduction is incompatible with AMD wave64 (which requires >=64 elements for wave-level reduce).
  • FP4 encoding (Wave kernel): comparison-based E2M1 encoding using seven threshold comparisons and select, entirely elementwise.
  • Nibble packing (pack_mxfp4_codes): pairs of i8 codes into uint8.

Also adds f32_to_mxfp4 PyTorch reference in mxfp_utils.py for validation.

Tests:

  • Lit test: FileCheck verifying key MLIR ops (absf, cmpf, select, fptosi) in the compiled kernel.
  • Unit pytest: 4 parametrized shapes comparing Wave kernel output against f32_to_mxfp4 reference with dequantize round-trip.
  • E2E pytest: Wave quant -> MXFP4 GEMM validated against the full PyTorch reference pipeline (f32_to_mxfp4 + torchScaledGemmMXFP4).

Implement a Wave kernel for dynamic MXFP4 (E2M1) quantization of f32
activations with E8M0 block scales, following the algorithm from
ROCm/aiter's dynamic_mxfp4_quant.

The pipeline is split into three stages:
- Scale computation (compute_mxfp4_scales): per-32-element block amax
  via IEEE 754 bit manipulation, producing E8M0 scale bytes and a
  broadcasted quant_scale tensor. Runs as pure PyTorch on GPU since
  the 32-element reduction is incompatible with AMD wave64 (which
  requires >=64 elements for wave-level reduce).
- FP4 encoding (Wave kernel): comparison-based E2M1 encoding using
  seven threshold comparisons and select, entirely elementwise.
- Nibble packing (pack_mxfp4_codes): pairs of i8 codes into uint8.

Also adds f32_to_mxfp4 PyTorch reference in mxfp_utils.py for
validation.

Tests:
- Lit test: FileCheck verifying key MLIR ops (absf, cmpf, select,
  fptosi) in the compiled kernel.
- Unit pytest: 4 parametrized shapes comparing Wave kernel output
  against f32_to_mxfp4 reference with dequantize round-trip.
- E2E pytest: Wave quant -> MXFP4 GEMM validated against the full
  PyTorch reference pipeline (f32_to_mxfp4 + torchScaledGemmMXFP4).

Signed-off-by: Surya Jasper <45545431+suryajasper@users.noreply.github.com>
Made-with: Cursor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant