Skip to content

Vortex quantization and topk kernel adaption#1

Open
zxr-creator wants to merge 21 commits intoInfini-AI-Lab:v1from
zxr-creator:v1
Open

Vortex quantization and topk kernel adaption#1
zxr-creator wants to merge 21 commits intoInfini-AI-Lab:v1from
zxr-creator:v1

Conversation

@zxr-creator
Copy link
Copy Markdown

Add INT8 Quantization Support for Vortex

This PR adds INT8 quantization support to the Vortex sparse attention framework to reduce memory usage and enable low-precision execution.

Main Changes

Implement INT8 quantization with adjustments to improve memory utilization.

Add preliminary FP8 quantization support.

Update reduce_pp_kernel parameters to support quantized data.

Adapt the Top-K kernel from SGLang for Vortex.

Add a runtime parameter to switch between two Top-K kernels (naive and sglang).

Add RTX PRO 6000 compatibility and fix several Vortex kernel issues.

zxr-creator and others added 21 commits February 22, 2026 23:59
Key changes:
1. Memory Pool (`vtx_graph_memory_pool.py`):
   - Removed hardcoded bf16 assertions in `VTXGraphCachePool` to support `torch.int8` allocations.
   - Added parallel `float32` scale buffers (`k_scale`, `v_scale`) mapped to the paged layout.
   - Preserved `bfloat16` shadow buffers (`k_bf16`) for auxiliary metadata (e.g., centroids) to ensure the Vortex sparse indexer/TopK remains unaffected and mathematically identical.

2. Quantize-on-Write (`set_kv.py`):
   - Implemented a custom Triton kernel (`set_kv_buffer_int8_kernel`) that quantizes incoming `bf16` tokens into `int8` on the fly using per-token absmax scaling (`scale = max(abs(x)) / 127.0`).
   - Wired the new launcher into the cache update flow.

3. Decode Path (`vtx_graph_backend.py` & `paged_decode_int8.py`):
   - Bypassed FlashInfer for INT8 decoding.
   - Wired in the custom Triton decode kernel (`paged_decode_int8`) that reads the `int8` pages and `float32` scales directly into SRAM, performing fused inline dequantization without allocating temporary full-cache VRAM buffers.
   - Seamlessly integrated with existing sparse routing indices (`indptr`, `indices`).

4. Prefill Path (`vtx_graph_backend.py` & `paged_prefill_int8.py`):
   - Implemented an OOM-safe `bf16` fallback for prefill.
   - Added a new Triton kernel (`dequant_paged_int8_to_bf16`) to dynamically extract and dequantize *only the accessed pages* for the current batch into a tiny, compacted `bf16` buffer.
   - Modified the FlashInfer `BatchPrefillWithPagedKVCacheWrapper` planner to map over the compacted subset indices, entirely avoiding full-cache dequantization OOMs.
Merge the warpper modification with the v1
…gs and pages; fix on the previous quantization implementaion, with lanuch_graph dtype set to the quant type
… (naive sparse attention, flash sparse attention, flashmoba)
- Introduced a comprehensive benchmarking suite for TopK kernel variants, measuring kernel-level latency.
- Added scripts for offline calibration of TopK mapping modes, including:# 0: None           — original fp16 bit-pattern bucketing
# 1: LUT CDF        — LUT-based CDF equalization (calibrated)
# 2: Quantile       — piecewise-linear quantile mapping (calibrated)
# 3: Power          — y = sign(x) * |x|^p
# 4: Log            — y = sign(x) * log(|x| + 1)
# 5: Index Cache    — reuse previous layer's indices
# 6: Asinh          — y = asinh(beta * x)
# 7: Log1p          — y = sign(x) * log1p(alpha * |x|)
# 8: Trunc8         — bf16 upper-8-bit bucketing
-  Adding various remap functions for the bucket sort in sglang topk kernel, with evaluation and visualization scripts.
- Implemented analysis tools for TopK distribution profiling.
- Removed outdated GPU architecture flags from setup.py.
- Added new mapping modes (Erf, Tanh, Subtract) to analyze_topk_distribution.py and bench_topk.py.
- Updated functions to handle new modes and added support for noscale parameters in autotune and benchmark scripts.
- Enhanced the TopK kernel with additional profiling metrics and improved handling of kernel arguments.
- Updated example scripts to reflect new modes and parameters for distribution analysis.
…cripts to reflect changes in histogram calibration and TopK mapping parameters.
- Added ExpStretch and TopkWindow modes to analyze_topk_distribution.py and bench_topk.py.
- Introduced topk_output_sglang_ori function for original sglang kernel in vortex_torch_C.
- Updated autotune and benchmark scripts to include new modes and original kernel.
- Modified example scripts to reflect changes in histogram calibration and TopK mapping parameters.
- Added new  file for enhanced profiling capabilities.
- Updated  to include the new profiling source file.
- Modified  to expand the sweep grid for parameter tuning.
- Refactored  to improve handling of hyperparameters and added subprocess profiling for large TopK values.
- Enhanced  and  to support new parameters for profiling.
- Updated example scripts to reflect changes in TopK parameters and profiling options.
… and usability

- Updated autotune_topk_mapping.py to optimize hyperparameter tuning based on kernel latency.
- Simplified the sweep grid and improved documentation for usage.
- Enhanced bench_topk.py to expose public helpers and added CLI modes for benchmarking.
- Introduced new remap functions and improved kernel integration for profiling.
- Added watchdog timeout option in calibrate_topk.py for SGLang scheduler.
- Removed outdated greedy_layer_search.py as part of code cleanup.
… and usability

- Updated autotune_topk_mapping.py to optimize hyperparameter tuning based on kernel latency.
- Simplified the sweep grid and improved documentation for usage.
- Enhanced bench_topk.py to expose public helpers and added CLI modes for benchmarking.
- Introduced new remap functions and improved kernel integration for profiling.
- Added watchdog timeout option in calibrate_topk.py for SGLang scheduler.
- Removed outdated greedy_layer_search.py as part of code cleanup.
…kernel in vortex_torch_C.

- Updated setup.py to include the new source file for the original kernel.
- Enhanced autotune_topk_mapping.py and bench_topk.py to support new mapping modes and original kernel integration.
- Expanded the sweep grid in autotune_topk_mapping.py for improved hyperparameter tuning.
- Added a new command-line argument in calibrate_topk.py for maximum total tokens to manage KV pool size.
- Removed outdated remap_function_bench.sh script as part of code cleanup.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant