[MXFP4 KV Cache] MXFP4 KV Cache: support blockfp4-hadamard-quant#25555
[MXFP4 KV Cache] MXFP4 KV Cache: support blockfp4-hadamard-quant#25555DehuaTang wants to merge 2 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request updates the FP4 KV cache quantization implementation to support variable block sizes, increasing the default size to 32, and introduces an optional Hadamard rotation feature. The quantization logic was also refined to handle midpoint rounding and ensure device compatibility for constants. Feedback suggests updating the class docstring to reflect the new block size, restoring or clarifying the removal of the @torch.compile decorator on the quantization function, and replacing a magic number used in the Hadamard scaling logic with a named constant for better maintainability.
|
|
||
| name = "blockfp4" | ||
| SCALE_BLOCK_SIZE = 16 | ||
| SCALE_BLOCK_SIZE = 32 |
| def batched_quantize( | ||
| tensor: torch.Tensor, | ||
| block_size: Optional[int] = None, | ||
| hadamard_rotate: bool = False, | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: |
There was a problem hiding this comment.
The @torch.compile decorator was removed from batched_quantize but remains on batched_dequantize (line 135). If torch.compile is compatible with the new logic and the Hadamard JIT kernel, it should be restored to maintain performance during the quantization step. If there are compatibility issues, consider removing it from batched_dequantize as well for consistency.
@torch.compile
def batched_quantize(
tensor: torch.Tensor,
block_size: Optional[int] = None,
hadamard_rotate: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:| reshaped = hadamard_transform(reshaped, scale=block_size**-0.5) | ||
| block_std = reshaped.std(dim=-1, correction=0, keepdim=True) | ||
| scale_exp = torch.floor( | ||
| torch.log2(block_std * (2.92247856 / E2M1_MAX) + 1e-8) |
There was a problem hiding this comment.
The magic number 2.92247856 should be defined as a named constant (e.g., GAUSSIAN_SCALE_FACTOR) at the module level. This value likely represents the scaling factor for a Gaussian distribution (e.g., norm.ppf(0.99825)) to capture the range for FP4 quantization after Hadamard rotation. Adding a comment explaining its origin would improve maintainability.
Motivation
Current MXFP4 KV cache support is still limited compared with NVFP4. However, MXFP4 can be useful across a broader range of hardware, so this PR adds an initial accuracy-improvement path for BlockFP4/
MXFP4-style KV cache quantization.
The implementation is inspired by the Quest/MXFP4 quantization flow in https://github.com/IST-DASLab/qutlass. This PR focuses on the first functional version in Python/Torch. Follow-up PRs can add Triton acceleration kernels and E2E accuracy.
Modifications
Test Plan
pytest test/registered/unit/layers/quantization/test_fp4_kv_cache_quant_method.py -vCI States
Latest PR Test (Base): ❌ Missing
run-cilabel — add it to run CI tests.Latest PR Test (Extra): ❌ Blocked —
run-ciis required first.