Skip to content

Commit 35cfc33

Browse files
committed
GPUs + Flash-KMeans
Signed-off-by: Will Manning <will@willmanning.io>
1 parent fac7e16 commit 35cfc33

1 file changed

Lines changed: 125 additions & 0 deletions

File tree

proposed/0033-block-turboquant.md

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,128 @@ New tests needed:
591591
- Per-block norm consistency with input vectors
592592
- Encode/decode benchmarks (see Performance analysis)
593593

594+
## Future work: GPU decode and fused distance computation
595+
596+
### Motivation
597+
598+
For ANN search workloads, the dominant operation is computing distances between
599+
a query vector and millions of database vectors. On GPU, the goal is to perform
600+
this computation directly on the compressed representation, avoiding the cost
601+
of materializing full decompressed vectors in HBM. BlockTurboQuant's 64-dim
602+
block structure maps naturally to GPU tile sizes and tensor core operations.
603+
604+
### Decode as GEMM
605+
606+
The decode path for a single 64-dim block is:
607+
608+
```
609+
decoded_block = norm × R⁻¹ × codebook_lookup(codes)
610+
```
611+
612+
For a batch of N vectors sharing the same block's rotation matrix R⁻¹:
613+
614+
```
615+
decoded_batch = diag(norms) × R⁻¹ × codebook_lookup_batch(codes)
616+
↑ 64×N matrix
617+
↑ 64×64 × 64×N = GEMM
618+
```
619+
620+
The codebook lookup produces a 64×N matrix (one column per vector, each entry
621+
is `centroids[code]`), and the inverse rotation is a 64×64 matrix multiply —
622+
a GEMM that maps directly to tensor cores.
623+
624+
**Partial decompression pipeline on GPU:**
625+
626+
1. **Decompress rotation matrix** (once per block, shared across all vectors):
627+
- If stored as bitpacked SRHT signs: FastLanes SIMD unpack on CUDA cores,
628+
then expand to 64×64 matrix in shared memory
629+
- If stored as dense 64×64 matrix: direct load to shared memory (16 KB)
630+
631+
2. **Decompress norms** (per vector, per block):
632+
- If cascade-compressed with ALP or Pco: decompress via FastLanes on CUDA
633+
cores into a register tile
634+
- If uncompressed: direct load
635+
636+
3. **Codebook gather** (per vector, per block):
637+
- Stage the codebook in shared memory (16 entries × 4 bytes = 64 bytes at
638+
b=4 — trivially small)
639+
- Gather: stream code bytes from HBM, look up centroid values in shared
640+
memory, assemble 64×N tile
641+
642+
4. **Fused GEMM + scale**:
643+
- R⁻¹ × gathered tile (64×64 × 64×N) on tensor cores
644+
- Column-wise multiply by norms (element-wise scale)
645+
646+
Steps 3-4 can be fused into a single kernel, following the double-buffered
647+
streaming pattern from Flash-KMeans [5]: prefetch the next batch of code bytes
648+
from HBM while computing the current batch's GEMM on tensor cores. This avoids
649+
materializing the full decompressed vectors in HBM — the decoded output is
650+
either consumed immediately by a distance computation or written once to the
651+
output buffer.
652+
653+
### Fused distance computation (no decode)
654+
655+
For distance computation without full decompression, the operation per block is:
656+
657+
```
658+
dot_contribution_k = ‖query_k‖ × ‖data_k‖ × Σ_j dist_table[q_code[j]][d_code[j]]
659+
```
660+
661+
On GPU, this becomes:
662+
663+
1. **Stage distance table in shared memory**: `dist_table[i][j] = centroids[i] ×
664+
centroids[j]`, 16×16 = 1 KB at b=4. Fits trivially in shared memory.
665+
666+
2. **Stream code bytes from HBM**: For each 64-vector × 64-dim tile (matching
667+
the PDX layout), gather from the distance table and accumulate in registers.
668+
This is a gather-reduce pattern — no GEMM, just table lookups and FP adds.
669+
670+
3. **Norm weighting**: After accumulating the unit-norm dot product for all
671+
64 dimensions in a block, multiply by the query and data block norms.
672+
Norms for 64 vectors fit in a single register tile.
673+
674+
4. **Cross-block accumulation**: Sum the weighted dot products across all k
675+
blocks to get the final distance estimate.
676+
677+
The memory access pattern follows Flash-KMeans [5]: stream data tiles from HBM
678+
with double-buffered prefetch, accumulate on-chip, write only the final result.
679+
The key difference is that Flash-KMeans streams full float vectors while we
680+
stream quantized code bytes — 4-8× less HBM bandwidth per vector.
681+
682+
### Int8 tensor core path (b=9)
683+
684+
At b=9, the MSE component uses 8-bit codes. These are indices into a 256-entry
685+
codebook, not raw int8 values — so direct int8 tensor core GEMM does not apply
686+
without transformation. However, if the codebook is approximately linear
687+
(centroids roughly evenly spaced), the codes could be treated as approximate
688+
int8 values with a linear rescaling, enabling direct int8 GEMM for the inner
689+
product computation. This sacrifices some quantization optimality (linear
690+
quantization vs. Max-Lloyd optimal) but enables tensor core throughput.
691+
692+
Whether this tradeoff is worthwhile depends on the application: for ANN ranking
693+
(where relative ordering matters more than absolute accuracy), linear int8 may
694+
be sufficient. For reconstruction (where MSE matters), Max-Lloyd centroids are
695+
preferred and the gather-from-codebook path should be used.
696+
697+
### Interaction with Vortex file format
698+
699+
The GPU decode pipeline reads compressed data from Vortex files:
700+
701+
1. **File reader** loads compressed segments from storage (S3, local SSD)
702+
2. **Host-side cascade decompression** (BitPacked → codes, ALP → norms) or
703+
direct GPU transfer of already-decompressed segments
704+
3. **GPU kernel** performs fused decode or fused distance computation
705+
706+
The BlockTurboQuant encoding's child arrays (codes, norms, rotation signs) are
707+
individually compressed by the cascading compressor. For GPU decode, we need
708+
either:
709+
- Host-side decompression of the cascade, then GPU transfer of the raw children
710+
- Direct GPU decompression of FastLanes/ALP (if GPU decompression kernels exist)
711+
712+
The 64-dim block structure ensures that rotation matrices (64×64 dense or 192
713+
bits SRHT signs) fit comfortably in GPU shared memory, enabling the fused
714+
decode kernel without spilling to HBM.
715+
594716
## References
595717

596718
[1] Zandieh, A., Daliri, M., Hadian, M. and Mirrokni, V. "TurboQuant: Online
@@ -605,3 +727,6 @@ Transform." Advances in Adaptive Data Analysis, 3(1-2):115-126, 2011.
605727

606728
[4] Kuffo, L., Krippner, E. and Boncz, P. "PDX: A Data Layout for Vector
607729
Similarity Search." Proceedings of SIGMOD '25. arXiv:2503.04422, March 2025.
730+
731+
[5] Yang, S. et al. "Flash-KMeans: Fast and Memory-Efficient Exact K-Means."
732+
arXiv:2603.09229, March 2026.

0 commit comments

Comments
 (0)