Skip to content

Commit bebed3a

Browse files
committed
feat[vortex-cuda]: GPU FSST decompression kernel
This commit implements on-GPU decompression of the existing FSST encoding. This kernel achieves ~42% max throughput utilization as compared to the `throughput_cuda` benchmark on a DGX spark. CPU work is required to compute the output offsets. The core performance win is buffering up to 24 bytes of decompressed data in three u64 registers and emitting the widest aligned stores possible up to u128 (st.global.v2.u64). The 256-entry symbol table (≤ 2 KB) is read directly from global memory. Staging it into shared memory measured ~3% slower at 10M rows and ~15% slower at 1M rows. The hypothesis is that L1 already holds the table after a few iterations and the explicit shared copy adds bank-conflict latency on the warp-divergent symbols[code] reads; the gap is wider at 1M because the kernel is less bandwidth-bound there. Further optimizations would require an encoding change. Splits-style intra-string parallelism (one GPU thread per ~32-compressed-byte chunk instead of per-string) was prototyped on top of this kernel and measured an additional +30% kernel throughput at 1M clickbench URLs, +26% at 5M, +12% at 10M. Four kernel variants are generated for the unsigned widths of codes_offsets (u8/u16/u32/u64); signed integer ptypes are reinterpreted as their unsigned equivalent on the Rust side, so the bit pattern is preserved without copying. Signed-off-by: Alfonso Subiotto Marques <alfonso.subiotto@polarsignals.com>
1 parent 44a6367 commit bebed3a

7 files changed

Lines changed: 614 additions & 0 deletions

File tree

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vortex-cuda/Cargo.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ rstest = { workspace = true }
5050
tokio = { workspace = true, features = ["rt", "macros"] }
5151
vortex-array = { workspace = true, features = ["_test-harness"] }
5252
vortex-cuda = { path = ".", features = ["_test-harness"] }
53+
vortex-fsst = { workspace = true, features = ["_test-harness"] }
5354

5455
[build-dependencies]
5556
bindgen = { workspace = true }
@@ -94,3 +95,7 @@ harness = false
9495
[[bench]]
9596
name = "throughput_cuda"
9697
harness = false
98+
99+
[[bench]]
100+
name = "fsst_cuda"
101+
harness = false

vortex-cuda/benches/fsst_cuda.rs

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
//! CUDA benchmarks for FSST decompression.
5+
6+
#![expect(clippy::unwrap_used)]
7+
#![expect(clippy::cast_possible_truncation)]
8+
9+
#[allow(dead_code)]
10+
mod bench_config;
11+
mod timed_launch_strategy;
12+
13+
use std::sync::Arc;
14+
use std::sync::atomic::Ordering;
15+
use std::time::Duration;
16+
17+
use criterion::BenchmarkId;
18+
use criterion::Criterion;
19+
use criterion::Throughput;
20+
use futures::executor::block_on;
21+
use vortex::array::IntoArray;
22+
use vortex::array::arrays::PrimitiveArray;
23+
use vortex::array::match_each_integer_ptype;
24+
use vortex::encodings::fsst::FSSTArrayExt;
25+
use vortex::error::VortexExpect;
26+
use vortex::session::VortexSession;
27+
use vortex_cuda::CudaSession;
28+
use vortex_cuda::executor::CudaArrayExt;
29+
use vortex_cuda_macros::cuda_available;
30+
use vortex_cuda_macros::cuda_not_available;
31+
use vortex_fsst::test_utils::make_fsst_clickbench_urls;
32+
33+
use crate::timed_launch_strategy::TimedLaunchStrategy;
34+
35+
const BENCH_SIZES: &[(usize, &str)] = &[(1_000_000, "1M"), (5_000_000, "5M"), (10_000_000, "10M")];
36+
37+
fn benchmark_fsst_cuda_decompress(c: &mut Criterion) {
38+
let mut group = c.benchmark_group("cuda");
39+
40+
for &(n, len_str) in BENCH_SIZES {
41+
let mut setup_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())
42+
.vortex_expect("failed to create execution context");
43+
let fsst = make_fsst_clickbench_urls(n, setup_ctx.execution_ctx());
44+
45+
let lens = fsst
46+
.uncompressed_lengths()
47+
.clone()
48+
.execute::<PrimitiveArray>(setup_ctx.execution_ctx())
49+
.vortex_expect("canonicalize uncompressed_lengths");
50+
let total_size: usize = match_each_integer_ptype!(lens.ptype(), |P| {
51+
lens.as_slice::<P>().iter().map(|x| *x as usize).sum()
52+
});
53+
let uncompressed_size = total_size as u64;
54+
55+
let fsst_array = fsst.into_array();
56+
57+
group.throughput(Throughput::Bytes(uncompressed_size));
58+
group.bench_with_input(
59+
BenchmarkId::new("cuda/fsst/decompress", len_str),
60+
&fsst_array,
61+
|b, fsst_array| {
62+
b.iter_custom(|iters| {
63+
let timed = TimedLaunchStrategy::default();
64+
let timer = timed.timer();
65+
66+
let mut cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())
67+
.vortex_expect("failed to create execution context")
68+
.with_launch_strategy(Arc::new(timed));
69+
70+
for _ in 0..iters {
71+
block_on(fsst_array.clone().execute_cuda(&mut cuda_ctx)).unwrap();
72+
}
73+
Duration::from_nanos(timer.load(Ordering::Relaxed))
74+
});
75+
},
76+
);
77+
}
78+
79+
group.finish();
80+
}
81+
82+
criterion::criterion_group! {
83+
name = benches;
84+
config = bench_config::cuda_bench_config();
85+
targets = benchmark_fsst_cuda_decompress
86+
}
87+
88+
#[cuda_available]
89+
criterion::criterion_main!(benches);
90+
91+
#[cuda_not_available]
92+
fn main() {}

vortex-cuda/kernels/src/fsst.cu

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
#include "config.cuh"
5+
#include <cuda.h>
6+
#include <cuda_runtime.h>
7+
#include <stdint.h>
8+
9+
// FSST decompression. A thread decodes one string at a time.
10+
//
11+
// Per-thread `Scratch` holds 24 bytes across three u64 lanes (`low`, `mid`,
12+
// `high`) plus a `cursor` byte counter. Byte i lives at bit (8 * (i mod 8))
13+
// of:
14+
// low for i in 0..8
15+
// mid for i in 8..16
16+
// high for i in 16..24
17+
//
18+
// lsb msb
19+
// low: [ b0 | b1 | b2 | b3 | b4 | b5 | b6 | b7 ]
20+
// mid: [ b8 | b9 |b10 |b11 |b12 |b13 |b14 |b15 ]
21+
// high: [b16 |b17 |b18 |b19 |b20 |b21 |b22 |b23 ]
22+
//
23+
// `Scratch::drain` picks the largest aligned store the gates allow
24+
// (alignment of out_pos, cursor, remaining out_end room). Bytes leave from
25+
// the low end (`low` byte 0); the kept bytes slide N positions toward that
26+
// low end across all three lanes i.e. each u64 right-shifts by N*8 and
27+
// pulls the next lane's low bits up to fill the vacated high bits.
28+
// `Scratch::push` inserts a length-`len` masked symbol at byte offset
29+
// `cursor`, spanning at most two of the three lanes.
30+
//
31+
// width gate ptx
32+
// ------ ------------------------------------------ ----------------
33+
// 16 B out_pos % 16 == 0, cursor ≥ 16, room ≥ 16 st.global.v2.u64
34+
// 8 B out_pos % 8 == 0, cursor ≥ 8, room ≥ 8 st.global.u64
35+
// 4 B out_pos % 4 == 0, cursor ≥ 4, room ≥ 4 st.global.u32
36+
// 2 B out_pos % 2 == 0, cursor ≥ 2, room ≥ 2 st.global.u16
37+
// 1 B (always) st.global.u8
38+
//
39+
// The narrow widths cover the prologue alignment-up (out_pos not yet
40+
// 16-aligned) and the epilogue tail (< 16 bytes left, no room for u128).
41+
// In steady state out_pos stays 16-aligned and u128 fires repeatedly.
42+
//
43+
// The 256-entry symbol table (≤ 2 KB) is read directly from global memory.
44+
// Staging it into shared memory measured ~3% slower at 10M rows and ~15%
45+
// slower at 1M rows (benchmarked on clickbench URLs). The hypothesis is that L1
46+
// already holds the table after a few iterations and the explicit shared copy
47+
// adds bank-conflict latency on the warp-divergent `symbols[code]` reads; the
48+
// gap is wider at 1M because the kernel is less bandwidth-bound there, so
49+
// per-load latency shows up more.
50+
//
51+
// Decoded symbols are masked to their valid byte length so the table's high
52+
// bits never leak. The main loop drains to `scratch.cursor ≤ 16`, keeping
53+
// the next add (≤ 8 bytes) within the 24-byte capacity.
54+
//
55+
// `codes_offsets` is templated over the four unsigned integer widths
56+
// (u8/u16/u32/u64). `output_offsets` is uint64_t.
57+
58+
// 24-byte scratch buffer split across three u64 lanes. `cursor` is the
59+
// number of bytes currently buffered and the next-push offset.
60+
struct Scratch {
61+
uint64_t low = 0;
62+
uint64_t mid = 0;
63+
uint64_t high = 0;
64+
uint32_t cursor = 0;
65+
66+
// Insert a length-`len` masked symbol at byte offset `cursor`. The
67+
// symbol spans at most two of the three lanes. Caller must ensure
68+
// cursor + len ≤ 24.
69+
__device__ inline void push(uint64_t sym, uint32_t len) {
70+
if (cursor < 8) {
71+
low |= sym << (8u * cursor);
72+
if (cursor + len > 8) {
73+
mid |= sym >> (8u * (8u - cursor));
74+
}
75+
} else {
76+
mid |= sym << (8u * (cursor - 8u));
77+
if (cursor + len > 16) {
78+
high |= sym >> (8u * (16u - cursor));
79+
}
80+
}
81+
cursor += len;
82+
}
83+
84+
// Emit one variable-width aligned store from the low end and slide the
85+
// kept bytes toward the low end across all three lanes.
86+
__device__ inline void drain(uint8_t *__restrict out, uint64_t &out_pos, uint64_t out_end) {
87+
if (cursor >= 16 && (out_pos & 15u) == 0 && out_pos + 16 <= out_end) {
88+
*reinterpret_cast<ulonglong2 *>(out + out_pos) = make_ulonglong2(low, mid);
89+
low = high;
90+
mid = 0;
91+
high = 0;
92+
out_pos += 16;
93+
cursor -= 16;
94+
} else if (cursor >= 8 && (out_pos & 7u) == 0 && out_pos + 8 <= out_end) {
95+
*reinterpret_cast<uint64_t *>(out + out_pos) = low;
96+
low = mid;
97+
mid = high;
98+
high = 0;
99+
out_pos += 8;
100+
cursor -= 8;
101+
} else if (cursor >= 4 && (out_pos & 3u) == 0 && out_pos + 4 <= out_end) {
102+
*reinterpret_cast<uint32_t *>(out + out_pos) = (uint32_t)low;
103+
low = (low >> 32) | (mid << 32);
104+
mid = (mid >> 32) | (high << 32);
105+
high >>= 32;
106+
out_pos += 4;
107+
cursor -= 4;
108+
} else if (cursor >= 2 && (out_pos & 1u) == 0 && out_pos + 2 <= out_end) {
109+
*reinterpret_cast<uint16_t *>(out + out_pos) = (uint16_t)low;
110+
low = (low >> 16) | (mid << 48);
111+
mid = (mid >> 16) | (high << 48);
112+
high >>= 16;
113+
out_pos += 2;
114+
cursor -= 2;
115+
} else {
116+
out[out_pos] = (uint8_t)low;
117+
low = (low >> 8) | (mid << 56);
118+
mid = (mid >> 8) | (high << 56);
119+
high >>= 8;
120+
out_pos += 1;
121+
cursor -= 1;
122+
}
123+
}
124+
};
125+
126+
template <typename OffT>
127+
struct FSSTArgs {
128+
// Compressed FSST code stream, contiguous across all strings. String
129+
// `sid`'s codes live in `[codes_offsets[sid], codes_offsets[sid + 1])`.
130+
const uint8_t *__restrict codes_bytes;
131+
// Per-string offsets into `codes_bytes`, length `num_strings + 1`.
132+
const OffT *__restrict codes_offsets;
133+
// FSST symbol table.
134+
const uint64_t *__restrict symbols;
135+
// Length in bytes (1..=8) of each entry in `symbols`. The remaining bits
136+
// are unspecified.
137+
const uint8_t *__restrict symbol_lengths;
138+
// Buffer to write decoded data into.
139+
uint8_t *__restrict output_bytes;
140+
// Per-string offsets into `output_bytes`, length `num_strings + 1`.
141+
const uint64_t *__restrict output_offsets;
142+
// Validity of each string.
143+
const uint8_t *__restrict validity_bits;
144+
};
145+
146+
template <typename OffT>
147+
__device__ inline void fsst_decode_string(const FSSTArgs<OffT> &args, uint64_t sid) {
148+
if (((args.validity_bits[sid >> 3] >> (sid & 7u)) & 1u) == 0u) {
149+
return;
150+
}
151+
152+
OffT in_pos = args.codes_offsets[sid];
153+
const OffT in_end = args.codes_offsets[sid + 1];
154+
uint64_t out_pos = args.output_offsets[sid];
155+
const uint64_t out_end = args.output_offsets[sid + 1];
156+
157+
Scratch scratch;
158+
159+
while (in_pos < in_end) {
160+
// Drain to scratch.cursor ≤ 16 so the next ≤8-byte symbol fits in 24.
161+
while (scratch.cursor > 16) {
162+
scratch.drain(args.output_bytes, out_pos, out_end);
163+
}
164+
165+
// Decode next code. 255 is the escape for raw literal bytes.
166+
const uint8_t code = args.codes_bytes[in_pos];
167+
uint64_t sym;
168+
uint32_t len, consumed;
169+
if (code == 255) {
170+
sym = (uint64_t)args.codes_bytes[in_pos + 1];
171+
len = 1;
172+
consumed = 2;
173+
} else {
174+
sym = args.symbols[code];
175+
len = args.symbol_lengths[code];
176+
consumed = 1;
177+
}
178+
179+
// Zero out the symbol's high bytes beyond its valid length.
180+
const uint64_t mask = (len == 8) ? ~0ULL : ((1ULL << (8u * len)) - 1ULL);
181+
sym &= mask;
182+
183+
scratch.push(sym, len);
184+
in_pos += (OffT)consumed;
185+
}
186+
187+
// Epilogue: drain everything that's left.
188+
while (scratch.cursor > 0) {
189+
scratch.drain(args.output_bytes, out_pos, out_end);
190+
}
191+
}
192+
193+
#define GENERATE_FSST_KERNEL(suffix, OffT) \
194+
extern "C" __global__ void fsst_##suffix(const uint8_t *__restrict codes_bytes, \
195+
const OffT *__restrict codes_offsets, \
196+
const uint64_t *__restrict symbols, \
197+
const uint8_t *__restrict symbol_lengths, \
198+
const uint64_t *__restrict output_offsets, \
199+
const uint8_t *__restrict validity_bits, \
200+
uint8_t *__restrict output_bytes, \
201+
uint64_t num_strings) { \
202+
const FSSTArgs<OffT> args = { \
203+
codes_bytes, \
204+
codes_offsets, \
205+
symbols, \
206+
symbol_lengths, \
207+
output_bytes, \
208+
output_offsets, \
209+
validity_bits, \
210+
}; \
211+
\
212+
const uint64_t elements_per_block = (uint64_t)blockDim.x * ELEMENTS_PER_THREAD; \
213+
const uint64_t block_start = (uint64_t)blockIdx.x * elements_per_block; \
214+
const uint64_t block_end = (block_start + elements_per_block < num_strings) \
215+
? (block_start + elements_per_block) \
216+
: num_strings; \
217+
\
218+
for (uint64_t sid = block_start + threadIdx.x; sid < block_end; sid += blockDim.x) { \
219+
fsst_decode_string<OffT>(args, sid); \
220+
} \
221+
}
222+
223+
GENERATE_FSST_KERNEL(u8, uint8_t)
224+
GENERATE_FSST_KERNEL(u16, uint16_t)
225+
GENERATE_FSST_KERNEL(u32, uint32_t)
226+
GENERATE_FSST_KERNEL(u64, uint64_t)

0 commit comments

Comments
 (0)