@@ -591,6 +591,128 @@ New tests needed:
591591- Per-block norm consistency with input vectors
592592- Encode/decode benchmarks (see Performance analysis)
593593
594+ ## Future work: GPU decode and fused distance computation
595+
596+ ### Motivation
597+
598+ For ANN search workloads, the dominant operation is computing distances between
599+ a query vector and millions of database vectors. On GPU, the goal is to perform
600+ this computation directly on the compressed representation, avoiding the cost
601+ of materializing full decompressed vectors in HBM. BlockTurboQuant's 64-dim
602+ block structure maps naturally to GPU tile sizes and tensor core operations.
603+
604+ ### Decode as GEMM
605+
606+ The decode path for a single 64-dim block is:
607+
608+ ```
609+ decoded_block = norm × R⁻¹ × codebook_lookup(codes)
610+ ```
611+
612+ For a batch of N vectors sharing the same block's rotation matrix R⁻¹:
613+
614+ ```
615+ decoded_batch = diag(norms) × R⁻¹ × codebook_lookup_batch(codes)
616+ ↑ 64×N matrix
617+ ↑ 64×64 × 64×N = GEMM
618+ ```
619+
620+ The codebook lookup produces a 64×N matrix (one column per vector, each entry
621+ is ` centroids[code] ` ), and the inverse rotation is a 64×64 matrix multiply —
622+ a GEMM that maps directly to tensor cores.
623+
624+ ** Partial decompression pipeline on GPU:**
625+
626+ 1 . ** Decompress rotation matrix** (once per block, shared across all vectors):
627+ - If stored as bitpacked SRHT signs: FastLanes SIMD unpack on CUDA cores,
628+ then expand to 64×64 matrix in shared memory
629+ - If stored as dense 64×64 matrix: direct load to shared memory (16 KB)
630+
631+ 2 . ** Decompress norms** (per vector, per block):
632+ - If cascade-compressed with ALP or Pco: decompress via FastLanes on CUDA
633+ cores into a register tile
634+ - If uncompressed: direct load
635+
636+ 3 . ** Codebook gather** (per vector, per block):
637+ - Stage the codebook in shared memory (16 entries × 4 bytes = 64 bytes at
638+ b=4 — trivially small)
639+ - Gather: stream code bytes from HBM, look up centroid values in shared
640+ memory, assemble 64×N tile
641+
642+ 4 . ** Fused GEMM + scale** :
643+ - R⁻¹ × gathered tile (64×64 × 64×N) on tensor cores
644+ - Column-wise multiply by norms (element-wise scale)
645+
646+ Steps 3-4 can be fused into a single kernel, following the double-buffered
647+ streaming pattern from Flash-KMeans [ 5] : prefetch the next batch of code bytes
648+ from HBM while computing the current batch's GEMM on tensor cores. This avoids
649+ materializing the full decompressed vectors in HBM — the decoded output is
650+ either consumed immediately by a distance computation or written once to the
651+ output buffer.
652+
653+ ### Fused distance computation (no decode)
654+
655+ For distance computation without full decompression, the operation per block is:
656+
657+ ```
658+ dot_contribution_k = ‖query_k‖ × ‖data_k‖ × Σ_j dist_table[q_code[j]][d_code[j]]
659+ ```
660+
661+ On GPU, this becomes:
662+
663+ 1 . ** Stage distance table in shared memory** : `dist_table[ i] [ j ] = centroids[ i] ×
664+ centroids[ j] `, 16×16 = 1 KB at b=4. Fits trivially in shared memory.
665+
666+ 2 . ** Stream code bytes from HBM** : For each 64-vector × 64-dim tile (matching
667+ the PDX layout), gather from the distance table and accumulate in registers.
668+ This is a gather-reduce pattern — no GEMM, just table lookups and FP adds.
669+
670+ 3 . ** Norm weighting** : After accumulating the unit-norm dot product for all
671+ 64 dimensions in a block, multiply by the query and data block norms.
672+ Norms for 64 vectors fit in a single register tile.
673+
674+ 4 . ** Cross-block accumulation** : Sum the weighted dot products across all k
675+ blocks to get the final distance estimate.
676+
677+ The memory access pattern follows Flash-KMeans [ 5] : stream data tiles from HBM
678+ with double-buffered prefetch, accumulate on-chip, write only the final result.
679+ The key difference is that Flash-KMeans streams full float vectors while we
680+ stream quantized code bytes — 4-8× less HBM bandwidth per vector.
681+
682+ ### Int8 tensor core path (b=9)
683+
684+ At b=9, the MSE component uses 8-bit codes. These are indices into a 256-entry
685+ codebook, not raw int8 values — so direct int8 tensor core GEMM does not apply
686+ without transformation. However, if the codebook is approximately linear
687+ (centroids roughly evenly spaced), the codes could be treated as approximate
688+ int8 values with a linear rescaling, enabling direct int8 GEMM for the inner
689+ product computation. This sacrifices some quantization optimality (linear
690+ quantization vs. Max-Lloyd optimal) but enables tensor core throughput.
691+
692+ Whether this tradeoff is worthwhile depends on the application: for ANN ranking
693+ (where relative ordering matters more than absolute accuracy), linear int8 may
694+ be sufficient. For reconstruction (where MSE matters), Max-Lloyd centroids are
695+ preferred and the gather-from-codebook path should be used.
696+
697+ ### Interaction with Vortex file format
698+
699+ The GPU decode pipeline reads compressed data from Vortex files:
700+
701+ 1 . ** File reader** loads compressed segments from storage (S3, local SSD)
702+ 2 . ** Host-side cascade decompression** (BitPacked → codes, ALP → norms) or
703+ direct GPU transfer of already-decompressed segments
704+ 3 . ** GPU kernel** performs fused decode or fused distance computation
705+
706+ The BlockTurboQuant encoding's child arrays (codes, norms, rotation signs) are
707+ individually compressed by the cascading compressor. For GPU decode, we need
708+ either:
709+ - Host-side decompression of the cascade, then GPU transfer of the raw children
710+ - Direct GPU decompression of FastLanes/ALP (if GPU decompression kernels exist)
711+
712+ The 64-dim block structure ensures that rotation matrices (64×64 dense or 192
713+ bits SRHT signs) fit comfortably in GPU shared memory, enabling the fused
714+ decode kernel without spilling to HBM.
715+
594716## References
595717
596718[ 1] Zandieh, A., Daliri, M., Hadian, M. and Mirrokni, V. "TurboQuant: Online
@@ -605,3 +727,6 @@ Transform." Advances in Adaptive Data Analysis, 3(1-2):115-126, 2011.
605727
606728[ 4] Kuffo, L., Krippner, E. and Boncz, P. "PDX: A Data Layout for Vector
607729Similarity Search." Proceedings of SIGMOD '25. arXiv:2503.04422, March 2025.
730+
731+ [ 5] Yang, S. et al. "Flash-KMeans: Fast and Memory-Efficient Exact K-Means."
732+ arXiv:2603.09229, March 2026.
0 commit comments