Skip to content

[MXFP4 KV Cache] MXFP4 KV Cache: support blockfp4-hadamard-quant#25555

Open
DehuaTang wants to merge 2 commits into
sgl-project:mainfrom
DehuaTang:feature/blockfp4-hadamard-quant
Open

[MXFP4 KV Cache] MXFP4 KV Cache: support blockfp4-hadamard-quant#25555
DehuaTang wants to merge 2 commits into
sgl-project:mainfrom
DehuaTang:feature/blockfp4-hadamard-quant

Conversation

@DehuaTang
Copy link
Copy Markdown

@DehuaTang DehuaTang commented May 18, 2026

Motivation

Current MXFP4 KV cache support is still limited compared with NVFP4. However, MXFP4 can be useful across a broader range of hardware, so this PR adds an initial accuracy-improvement path for BlockFP4/
MXFP4-style KV cache quantization.

The implementation is inspired by the Quest/MXFP4 quantization flow in https://github.com/IST-DASLab/qutlass. This PR focuses on the first functional version in Python/Torch. Follow-up PRs can add Triton acceleration kernels and E2E accuracy.

Modifications

  • Add a Hadamard-rotated BlockFP4 KV cache quantization path with block size 32.
  • Align scale handling and FP4 round-to-nearest-even behavior with the MXFP4 reference flow.
  • Update BlockFP4 unit tests, including CUDA roundtrip coverage.

Test Plan

  • pytest test/registered/unit/layers/quantization/test_fp4_kv_cache_quant_method.py -v

CI States

Latest PR Test (Base): ❌ Missing run-ci label — add it to run CI tests.
Latest PR Test (Extra): ❌ Blockedrun-ci is required first.

@github-actions github-actions Bot added the quant LLM Quantization label May 18, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the FP4 KV cache quantization implementation to support variable block sizes, increasing the default size to 32, and introduces an optional Hadamard rotation feature. The quantization logic was also refined to handle midpoint rounding and ensure device compatibility for constants. Feedback suggests updating the class docstring to reflect the new block size, restoring or clarifying the removal of the @torch.compile decorator on the quantization function, and replacing a magic number used in the Hadamard scaling logic with a named constant for better maintainability.


name = "blockfp4"
SCALE_BLOCK_SIZE = 16
SCALE_BLOCK_SIZE = 32
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The default SCALE_BLOCK_SIZE has been updated to 32, but the class docstring (line 309) still mentions block_size=16. Please update the docstring to reflect this change and maintain consistency with the MXFP4 specification.

Comment on lines +71 to +75
def batched_quantize(
tensor: torch.Tensor,
block_size: Optional[int] = None,
hadamard_rotate: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The @torch.compile decorator was removed from batched_quantize but remains on batched_dequantize (line 135). If torch.compile is compatible with the new logic and the Hadamard JIT kernel, it should be restored to maintain performance during the quantization step. If there are compatibility issues, consider removing it from batched_dequantize as well for consistency.

    @torch.compile
    def batched_quantize(
        tensor: torch.Tensor,
        block_size: Optional[int] = None,
        hadamard_rotate: bool = False,
    ) -> tuple[torch.Tensor, torch.Tensor]:

reshaped = hadamard_transform(reshaped, scale=block_size**-0.5)
block_std = reshaped.std(dim=-1, correction=0, keepdim=True)
scale_exp = torch.floor(
torch.log2(block_std * (2.92247856 / E2M1_MAX) + 1e-8)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The magic number 2.92247856 should be defined as a named constant (e.g., GAUSSIAN_SCALE_FACTOR) at the module level. This value likely represents the scaling factor for a Gaussian distribution (e.g., norm.ppf(0.99825)) to capture the range for FP4 quantization after Hadamard rotation. Adding a comment explaining its origin would improve maintainability.

@DehuaTang DehuaTang changed the title [MXFP4] MXFP4 KV Cache: support blockfp4-hadamard-quant [MXFP4 KV Cache] MXFP4 KV Cache: support blockfp4-hadamard-quant May 18, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

quant LLM Quantization

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant