Skip to content

Commit 91a6b57

Browse files
authored
feat(cuda): fuse narrower-than-output Dict codes and RunEnd ends (#7617)
Dict codes and RunEnd ends that are narrower than the output type (e.g. u8 BitPacked codes in a u32 Dict) previously required a separate kernel launch. They are now fused by decoding at the source's native width and widening to T in shared memory. --- Fixes the race from #7603 by applying the type widening with one warp within a block. Further this PR adds benchmarks exercising the widening logic: ``` Benchmarking dict_widen_u8_to_u32/dynamic_dispatch_u32/100M: Warming up for 1.0000 ns Warning: Unable to complete 10 samples in 1.0ns. You may wish to increase target time to 48.4ms. dict_widen_u8_to_u32/dynamic_dispatch_u32/100M time: [203.97 µs 204.37 µs 204.78 µs] thrpt: [1819.2 GiB/s 1822.8 GiB/s 1826.4 GiB/s] Benchmarking dict_widen_u16_to_u32/dynamic_dispatch_u32/100M: Warming up for 1.0000 ns Warning: Unable to complete 10 samples in 1.0ns. You may wish to increase target time to 50.4ms. dict_widen_u16_to_u32/dynamic_dispatch_u32/100M time: [203.74 µs 204.92 µs 205.15 µs] thrpt: [1815.9 GiB/s 1817.9 GiB/s 1828.5 GiB/s] Benchmarking dict_nowiden_u32_to_u32/dynamic_dispatch_u32/100M: Warming up for 1.0000 ns Warning: Unable to complete 10 samples in 1.0ns. You may wish to increase target time to 49.6ms. dict_nowiden_u32_to_u32/dynamic_dispatch_u32/100M time: [170.86 µs 171.18 µs 171.59 µs] thrpt: [2171.0 GiB/s 2176.2 GiB/s 2180.3 GiB/s] ``` Signed-off-by: Alexander Droste <alexander.droste@protonmail.com>
1 parent 543dbe7 commit 91a6b57

5 files changed

Lines changed: 529 additions & 133 deletions

File tree

vortex-cuda/benches/dynamic_dispatch_cuda.rs

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,12 +425,148 @@ fn bench_alp_for_bitpacked(c: &mut Criterion) {
425425
group.finish();
426426
}
427427

428+
// ---------------------------------------------------------------------------
429+
// Benchmark: Dict with narrower BitPacked codes (exercises widen_inplace)
430+
// ---------------------------------------------------------------------------
431+
432+
/// Dict(codes=BitPacked<u8>, values=Prim<u32>) — widens u8 → u32 in smem.
433+
fn bench_dict_bp_u8_codes_u32_values(c: &mut Criterion) {
434+
let mut group = c.benchmark_group("dict_widen_u8_to_u32");
435+
436+
let dict_size: usize = 4; // 2-bit codes
437+
let bit_width: u8 = 2;
438+
let dict_values: Vec<u32> = (0..dict_size as u32).map(|i| i * 1000 + 42).collect();
439+
440+
for (len, len_str) in BENCH_ARGS {
441+
group.throughput(Throughput::Bytes((len * size_of::<u32>()) as u64));
442+
443+
let codes: Vec<u8> = (0..*len).map(|i| (i % dict_size) as u8).collect();
444+
let codes_prim = PrimitiveArray::new(Buffer::from(codes), NonNullable);
445+
let mut ctx = LEGACY_SESSION.create_execution_ctx();
446+
let codes_bp = BitPackedData::encode(&codes_prim.into_array(), bit_width, &mut ctx)
447+
.vortex_expect("bitpack u8 codes");
448+
let values_prim = PrimitiveArray::new(Buffer::from(dict_values.clone()), NonNullable);
449+
let dict = DictArray::new(codes_bp.into_array(), values_prim.into_array());
450+
let array = dict.into_array();
451+
452+
group.bench_with_input(
453+
BenchmarkId::new("dynamic_dispatch_u32", len_str),
454+
len,
455+
|b, &n| {
456+
let mut cuda_ctx =
457+
CudaSession::create_execution_ctx(&VortexSession::empty()).vortex_expect("ctx");
458+
459+
let bench_runner = BenchRunner::new(&array, n, &mut cuda_ctx);
460+
461+
b.iter_custom(|iters| {
462+
let mut total_time = Duration::ZERO;
463+
for _ in 0..iters {
464+
total_time += bench_runner.run(&mut cuda_ctx);
465+
}
466+
total_time
467+
});
468+
},
469+
);
470+
}
471+
472+
group.finish();
473+
}
474+
475+
/// Dict(codes=BitPacked<u16>, values=Prim<u32>) — widens u16 → u32 in smem.
476+
fn bench_dict_bp_u16_codes_u32_values(c: &mut Criterion) {
477+
let mut group = c.benchmark_group("dict_widen_u16_to_u32");
478+
479+
let dict_size: usize = 8; // 3-bit codes
480+
let bit_width: u8 = 3;
481+
let dict_values: Vec<u32> = (0..dict_size as u32).map(|i| i * 5000 + 100).collect();
482+
483+
for (len, len_str) in BENCH_ARGS {
484+
group.throughput(Throughput::Bytes((len * size_of::<u32>()) as u64));
485+
486+
let codes: Vec<u16> = (0..*len).map(|i| (i % dict_size) as u16).collect();
487+
let codes_prim = PrimitiveArray::new(Buffer::from(codes), NonNullable);
488+
let mut ctx = LEGACY_SESSION.create_execution_ctx();
489+
let codes_bp = BitPackedData::encode(&codes_prim.into_array(), bit_width, &mut ctx)
490+
.vortex_expect("bitpack u16 codes");
491+
let values_prim = PrimitiveArray::new(Buffer::from(dict_values.clone()), NonNullable);
492+
let dict = DictArray::new(codes_bp.into_array(), values_prim.into_array());
493+
let array = dict.into_array();
494+
495+
group.bench_with_input(
496+
BenchmarkId::new("dynamic_dispatch_u32", len_str),
497+
len,
498+
|b, &n| {
499+
let mut cuda_ctx =
500+
CudaSession::create_execution_ctx(&VortexSession::empty()).vortex_expect("ctx");
501+
502+
let bench_runner = BenchRunner::new(&array, n, &mut cuda_ctx);
503+
504+
b.iter_custom(|iters| {
505+
let mut total_time = Duration::ZERO;
506+
for _ in 0..iters {
507+
total_time += bench_runner.run(&mut cuda_ctx);
508+
}
509+
total_time
510+
});
511+
},
512+
);
513+
}
514+
515+
group.finish();
516+
}
517+
518+
/// Dict(codes=BitPacked<u32>, values=Prim<u32>) — same-width baseline, no widen.
519+
fn bench_dict_bp_u32_codes_u32_values(c: &mut Criterion) {
520+
let mut group = c.benchmark_group("dict_nowiden_u32_to_u32");
521+
522+
let dict_size: usize = 8; // 3-bit codes
523+
let bit_width: u8 = 3;
524+
let dict_values: Vec<u32> = (0..dict_size as u32).map(|i| i * 5000 + 100).collect();
525+
526+
for (len, len_str) in BENCH_ARGS {
527+
group.throughput(Throughput::Bytes((len * size_of::<u32>()) as u64));
528+
529+
let codes: Vec<u32> = (0..*len).map(|i| (i % dict_size) as u32).collect();
530+
let codes_prim = PrimitiveArray::new(Buffer::from(codes), NonNullable);
531+
let mut ctx = LEGACY_SESSION.create_execution_ctx();
532+
let codes_bp = BitPackedData::encode(&codes_prim.into_array(), bit_width, &mut ctx)
533+
.vortex_expect("bitpack u32 codes");
534+
let values_prim = PrimitiveArray::new(Buffer::from(dict_values.clone()), NonNullable);
535+
let dict = DictArray::new(codes_bp.into_array(), values_prim.into_array());
536+
let array = dict.into_array();
537+
538+
group.bench_with_input(
539+
BenchmarkId::new("dynamic_dispatch_u32", len_str),
540+
len,
541+
|b, &n| {
542+
let mut cuda_ctx =
543+
CudaSession::create_execution_ctx(&VortexSession::empty()).vortex_expect("ctx");
544+
545+
let bench_runner = BenchRunner::new(&array, n, &mut cuda_ctx);
546+
547+
b.iter_custom(|iters| {
548+
let mut total_time = Duration::ZERO;
549+
for _ in 0..iters {
550+
total_time += bench_runner.run(&mut cuda_ctx);
551+
}
552+
total_time
553+
});
554+
},
555+
);
556+
}
557+
558+
group.finish();
559+
}
560+
428561
fn benchmark_dynamic_dispatch(c: &mut Criterion) {
429562
bench_for_bitpacked(c);
430563
bench_dict_bp_codes(c);
431564
bench_runend(c);
432565
bench_dict_bp_codes_bp_for_values(c);
433566
bench_alp_for_bitpacked(c);
567+
bench_dict_bp_u8_codes_u32_values(c);
568+
bench_dict_bp_u16_codes_u32_values(c);
569+
bench_dict_bp_u32_codes_u32_values(c);
434570
}
435571

436572
criterion::criterion_group! {

vortex-cuda/kernels/src/dynamic_dispatch.cu

Lines changed: 106 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,16 @@
3939
//
4040
// ## Mixed-width support
4141
//
42-
// LOAD sources from pending subtrees may have a narrower type than the
43-
// output (e.g. u8 dict codes in a u32 plan). load_element() widens
44-
// to T via static_cast — no separate widen kernel or smem intermediate.
42+
// Dict codes, RunEnd ends, and other child arrays may have a narrower
43+
// element type than the output T. Two mechanisms handle this:
44+
//
45+
// LOAD load_element() dispatches on the per-stage PTypeTag to
46+
// read at the source's native width and static_cast to T.
47+
// BITUNPACK bitunpack_typed() unpacks at the source's native width,
48+
// then widens to T in-place via a backward scan
49+
// (widen_inplace). The smem region is pre-allocated at
50+
// max(source_width, T) bytes per element by the Rust plan
51+
// builder, so the widen never overflows.
4552

4653
#include <assert.h>
4754
#include <cuda.h>
@@ -203,6 +210,28 @@ scatter_patches_chunk(const GPUPatches &patches, T *__restrict out, uint32_t chu
203210
// Source ops
204211
// ═══════════════════════════════════════════════════════════════════════════
205212

213+
/// Widen SOURCE-sized elements in shared memory to DESTINATION-sized in-place.
214+
///
215+
/// A single warp performs the backward scan so that lockstep execution
216+
/// guarantees every load at index i retires before the store at i, and
217+
/// higher indices are already consumed. Using multiple warps would introduce
218+
/// a cross-warp race: a fast warp writing dst[low] can clobber source
219+
/// bytes that a slow warp has not yet read.
220+
template <typename DESTINATION, typename SOURCE>
221+
__device__ inline void widen_inplace(DESTINATION *dst, uint32_t len) {
222+
if constexpr (sizeof(DESTINATION) <= sizeof(SOURCE)) {
223+
return;
224+
}
225+
const SOURCE *src = reinterpret_cast<const SOURCE *>(dst);
226+
if (threadIdx.x < warpSize) {
227+
for (int32_t i = static_cast<int32_t>(len) - 1 - static_cast<int32_t>(threadIdx.x); i >= 0;
228+
i -= warpSize) {
229+
dst[i] = static_cast<DESTINATION>(src[i]);
230+
}
231+
}
232+
__syncthreads();
233+
}
234+
206235
/// FastLanes cooperative unpack — all threads in the block scatter-write
207236
/// decoded elements into `dst`. Caller must issue __syncthreads() before
208237
/// any thread reads from `dst`.
@@ -236,6 +265,68 @@ __device__ inline void bitunpack(const T *__restrict packed,
236265
}
237266
}
238267

268+
/// Dispatch bitunpack at the source's native element width, then widen
269+
/// to T in-place so all downstream scalar ops and smem consumers see
270+
/// T-sized elements. Falls back to the direct `bitunpack<T>` path when
271+
/// the source ptype already matches T. Issues __syncthreads() before
272+
/// returning on all paths.
273+
///
274+
/// Accepts explicit chunk_start / chunk_len so it works for both input
275+
/// stages (full decode with chunk_start=0, chunk_len=stage.len) and
276+
/// the output stage (tiled with varying chunk_start / chunk_len).
277+
template <typename T>
278+
__device__ inline void bitunpack_typed(T *__restrict dst,
279+
const void *__restrict packed,
280+
uint64_t chunk_start,
281+
uint32_t chunk_len,
282+
const struct SourceOp &src,
283+
PTypeTag source_ptype) {
284+
// Fast path: source width matches T — no widening needed.
285+
if (ptype_byte_width(source_ptype) == sizeof(T)) {
286+
bitunpack<T>(reinterpret_cast<const T *>(packed), dst, chunk_start, chunk_len, src);
287+
__syncthreads();
288+
return;
289+
}
290+
291+
// Compute total elements written by bitunpack (including alignment
292+
// padding) so widen_inplace covers the full scratch region.
293+
const uint32_t elem_off = src.params.bitunpack.element_offset;
294+
const uint32_t dst_off = (chunk_start + elem_off) % FL_CHUNK;
295+
const uint32_t n_chunks = (chunk_len + dst_off + FL_CHUNK - 1) / FL_CHUNK;
296+
const uint32_t total_elems = n_chunks * FL_CHUNK;
297+
298+
// Narrow source: unpack at native width, then widen to T.
299+
switch (source_ptype) {
300+
case PTYPE_U8:
301+
case PTYPE_I8: {
302+
auto *narrow = reinterpret_cast<uint8_t *>(dst);
303+
bitunpack<uint8_t>(reinterpret_cast<const uint8_t *>(packed), narrow, chunk_start, chunk_len, src);
304+
__syncthreads();
305+
widen_inplace<T, uint8_t>(dst, total_elems);
306+
break;
307+
}
308+
case PTYPE_U16:
309+
case PTYPE_I16: {
310+
auto *narrow = reinterpret_cast<uint16_t *>(dst);
311+
bitunpack<uint16_t>(reinterpret_cast<const uint16_t *>(packed), narrow, chunk_start, chunk_len, src);
312+
__syncthreads();
313+
widen_inplace<T, uint16_t>(dst, total_elems);
314+
break;
315+
}
316+
case PTYPE_U32:
317+
case PTYPE_I32:
318+
case PTYPE_F32: {
319+
auto *narrow = reinterpret_cast<uint32_t *>(dst);
320+
bitunpack<uint32_t>(reinterpret_cast<const uint32_t *>(packed), narrow, chunk_start, chunk_len, src);
321+
__syncthreads();
322+
widen_inplace<T, uint32_t>(dst, total_elems);
323+
break;
324+
}
325+
default:
326+
__builtin_unreachable();
327+
}
328+
}
329+
239330
/// Read N values from a source op into `out`.
240331
///
241332
/// Dispatches on `src.op_code` to handle each encoding:
@@ -354,16 +445,14 @@ __device__ void execute_output_stage(T *__restrict output,
354445
if (src.op_code == SourceOp::BITUNPACK) {
355446
chunk_len = bitunpack_tile_len(stage, block_len, elem_idx);
356447
T *scratch = reinterpret_cast<T *>(smem + stage.smem_byte_offset);
357-
bitunpack<T>(reinterpret_cast<const T *>(stage.input_ptr),
358-
scratch,
359-
block_start + elem_idx,
360-
chunk_len,
361-
src);
448+
bitunpack_typed<T>(scratch,
449+
reinterpret_cast<const void *>(stage.input_ptr),
450+
block_start + elem_idx,
451+
chunk_len,
452+
src,
453+
ptype);
362454
const uint32_t align = (block_start + elem_idx + src.params.bitunpack.element_offset) % FL_CHUNK;
363455
smem_src = scratch + align;
364-
// Write barrier: all threads finished bitunpack (and any
365-
// patches), safe to read from scratch.
366-
__syncthreads();
367456
} else {
368457
chunk_len = block_len;
369458
}
@@ -438,11 +527,12 @@ __device__ void execute_input_stage(const Stage &stage, char *__restrict smem) {
438527
const auto &src = stage.source;
439528

440529
if (src.op_code == SourceOp::BITUNPACK) {
441-
T *raw_smem = smem_out;
442-
bitunpack<T>(reinterpret_cast<const T *>(stage.input_ptr), smem_out, 0, stage.len, src);
443-
// Write barrier: cooperative bitunpack finished, safe to read
444-
// decoded elements below.
445-
__syncthreads();
530+
bitunpack_typed<T>(smem_out,
531+
reinterpret_cast<const void *>(stage.input_ptr),
532+
0,
533+
stage.len,
534+
src,
535+
stage.source_ptype);
446536

447537
smem_out += src.params.bitunpack.element_offset % SMEM_TILE_SIZE;
448538

vortex-cuda/kernels/src/dynamic_dispatch.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,27 @@ PTYPE_HOST_DEVICE constexpr PTypeTag ptype_to_unsigned(PTypeTag tag) {
7878
return tag;
7979
}
8080
}
81+
82+
PTYPE_HOST_DEVICE constexpr uint8_t ptype_byte_width(PTypeTag tag) {
83+
switch (tag) {
84+
case PTYPE_U8:
85+
case PTYPE_I8:
86+
return 1;
87+
case PTYPE_U16:
88+
case PTYPE_I16:
89+
return 2;
90+
case PTYPE_U32:
91+
case PTYPE_I32:
92+
case PTYPE_F32:
93+
return 4;
94+
case PTYPE_U64:
95+
case PTYPE_I64:
96+
case PTYPE_F64:
97+
return 8;
98+
default:
99+
return 0;
100+
}
101+
}
81102
#endif
82103

83104
/// Number of threads per CUDA block.

0 commit comments

Comments
 (0)