Skip to content

Commit 09ae129

Browse files
authored
feat(cuda): fuse narrower-than-output Dict codes and RunEnd ends (#7603)
Dict codes and RunEnd ends that are narrower than the output type (e.g. u8 BitPacked codes in a u32 Dict) previously required a separate kernel launch. They are now fused by decoding at the source's native width and widening to T in shared memory. --------- Signed-off-by: Alexander Droste <alexander.droste@protonmail.com>
1 parent dfb9992 commit 09ae129

4 files changed

Lines changed: 387 additions & 133 deletions

File tree

vortex-cuda/kernels/src/dynamic_dispatch.cu

Lines changed: 100 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,16 @@
3939
//
4040
// ## Mixed-width support
4141
//
42-
// LOAD sources from pending subtrees may have a narrower type than the
43-
// output (e.g. u8 dict codes in a u32 plan). load_element() widens
44-
// to T via static_cast — no separate widen kernel or smem intermediate.
42+
// Dict codes, RunEnd ends, and other child arrays may have a narrower
43+
// element type than the output T. Two mechanisms handle this:
44+
//
45+
// LOAD load_element() dispatches on the per-stage PTypeTag to
46+
// read at the source's native width and static_cast to T.
47+
// BITUNPACK bitunpack_typed() unpacks at the source's native width,
48+
// then widens to T in-place via a backward scan
49+
// (widen_inplace). The smem region is pre-allocated at
50+
// max(source_width, T) bytes per element by the Rust plan
51+
// builder, so the widen never overflows.
4552

4653
#include <assert.h>
4754
#include <cuda.h>
@@ -203,6 +210,22 @@ scatter_patches_chunk(const GPUPatches &patches, T *__restrict out, uint32_t chu
203210
// Source ops
204211
// ═══════════════════════════════════════════════════════════════════════════
205212

213+
/// Widen U-sized elements in shared memory to T-sized, in-place.
214+
/// Backward scan ensures no unread element is overwritten since
215+
/// sizeof(T) >= sizeof(U) guarantees the write at index i touches
216+
/// only bytes beyond those of src[i].
217+
template <typename T, typename U>
218+
__device__ inline void widen_inplace(T *dst, uint32_t len) {
219+
if constexpr (sizeof(T) <= sizeof(U)) {
220+
return;
221+
}
222+
const U *src = reinterpret_cast<const U *>(dst);
223+
for (int32_t i = static_cast<int32_t>(len) - 1 - threadIdx.x; i >= 0; i -= blockDim.x) {
224+
dst[i] = static_cast<T>(src[i]);
225+
}
226+
__syncthreads();
227+
}
228+
206229
/// FastLanes cooperative unpack — all threads in the block scatter-write
207230
/// decoded elements into `dst`. Caller must issue __syncthreads() before
208231
/// any thread reads from `dst`.
@@ -236,6 +259,68 @@ __device__ inline void bitunpack(const T *__restrict packed,
236259
}
237260
}
238261

262+
/// Dispatch bitunpack at the source's native element width, then widen
263+
/// to T in-place so all downstream scalar ops and smem consumers see
264+
/// T-sized elements. Falls back to the direct `bitunpack<T>` path when
265+
/// the source ptype already matches T. Issues __syncthreads() before
266+
/// returning on all paths.
267+
///
268+
/// Accepts explicit chunk_start / chunk_len so it works for both input
269+
/// stages (full decode with chunk_start=0, chunk_len=stage.len) and
270+
/// the output stage (tiled with varying chunk_start / chunk_len).
271+
template <typename T>
272+
__device__ inline void bitunpack_typed(T *__restrict dst,
273+
const void *__restrict packed,
274+
uint64_t chunk_start,
275+
uint32_t chunk_len,
276+
const struct SourceOp &src,
277+
PTypeTag source_ptype) {
278+
// Fast path: source width matches T — no widening needed.
279+
if (ptype_byte_width(source_ptype) == sizeof(T)) {
280+
bitunpack<T>(reinterpret_cast<const T *>(packed), dst, chunk_start, chunk_len, src);
281+
__syncthreads();
282+
return;
283+
}
284+
285+
// Compute total elements written by bitunpack (including alignment
286+
// padding) so widen_inplace covers the full scratch region.
287+
const uint32_t elem_off = src.params.bitunpack.element_offset;
288+
const uint32_t dst_off = (chunk_start + elem_off) % FL_CHUNK;
289+
const uint32_t n_chunks = (chunk_len + dst_off + FL_CHUNK - 1) / FL_CHUNK;
290+
const uint32_t total_elems = n_chunks * FL_CHUNK;
291+
292+
// Narrow source: unpack at native width, then widen to T.
293+
switch (source_ptype) {
294+
case PTYPE_U8:
295+
case PTYPE_I8: {
296+
auto *narrow = reinterpret_cast<uint8_t *>(dst);
297+
bitunpack<uint8_t>(reinterpret_cast<const uint8_t *>(packed), narrow, chunk_start, chunk_len, src);
298+
__syncthreads();
299+
widen_inplace<T, uint8_t>(dst, total_elems);
300+
break;
301+
}
302+
case PTYPE_U16:
303+
case PTYPE_I16: {
304+
auto *narrow = reinterpret_cast<uint16_t *>(dst);
305+
bitunpack<uint16_t>(reinterpret_cast<const uint16_t *>(packed), narrow, chunk_start, chunk_len, src);
306+
__syncthreads();
307+
widen_inplace<T, uint16_t>(dst, total_elems);
308+
break;
309+
}
310+
case PTYPE_U32:
311+
case PTYPE_I32:
312+
case PTYPE_F32: {
313+
auto *narrow = reinterpret_cast<uint32_t *>(dst);
314+
bitunpack<uint32_t>(reinterpret_cast<const uint32_t *>(packed), narrow, chunk_start, chunk_len, src);
315+
__syncthreads();
316+
widen_inplace<T, uint32_t>(dst, total_elems);
317+
break;
318+
}
319+
default:
320+
__builtin_unreachable();
321+
}
322+
}
323+
239324
/// Read N values from a source op into `out`.
240325
///
241326
/// Dispatches on `src.op_code` to handle each encoding:
@@ -354,16 +439,14 @@ __device__ void execute_output_stage(T *__restrict output,
354439
if (src.op_code == SourceOp::BITUNPACK) {
355440
chunk_len = bitunpack_tile_len(stage, block_len, elem_idx);
356441
T *scratch = reinterpret_cast<T *>(smem + stage.smem_byte_offset);
357-
bitunpack<T>(reinterpret_cast<const T *>(stage.input_ptr),
358-
scratch,
359-
block_start + elem_idx,
360-
chunk_len,
361-
src);
442+
bitunpack_typed<T>(scratch,
443+
reinterpret_cast<const void *>(stage.input_ptr),
444+
block_start + elem_idx,
445+
chunk_len,
446+
src,
447+
ptype);
362448
const uint32_t align = (block_start + elem_idx + src.params.bitunpack.element_offset) % FL_CHUNK;
363449
smem_src = scratch + align;
364-
// Write barrier: all threads finished bitunpack (and any
365-
// patches), safe to read from scratch.
366-
__syncthreads();
367450
} else {
368451
chunk_len = block_len;
369452
}
@@ -438,11 +521,12 @@ __device__ void execute_input_stage(const Stage &stage, char *__restrict smem) {
438521
const auto &src = stage.source;
439522

440523
if (src.op_code == SourceOp::BITUNPACK) {
441-
T *raw_smem = smem_out;
442-
bitunpack<T>(reinterpret_cast<const T *>(stage.input_ptr), smem_out, 0, stage.len, src);
443-
// Write barrier: cooperative bitunpack finished, safe to read
444-
// decoded elements below.
445-
__syncthreads();
524+
bitunpack_typed<T>(smem_out,
525+
reinterpret_cast<const void *>(stage.input_ptr),
526+
0,
527+
stage.len,
528+
src,
529+
stage.source_ptype);
446530

447531
smem_out += src.params.bitunpack.element_offset % SMEM_TILE_SIZE;
448532

vortex-cuda/kernels/src/dynamic_dispatch.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,27 @@ PTYPE_HOST_DEVICE constexpr PTypeTag ptype_to_unsigned(PTypeTag tag) {
7878
return tag;
7979
}
8080
}
81+
82+
PTYPE_HOST_DEVICE constexpr uint8_t ptype_byte_width(PTypeTag tag) {
83+
switch (tag) {
84+
case PTYPE_U8:
85+
case PTYPE_I8:
86+
return 1;
87+
case PTYPE_U16:
88+
case PTYPE_I16:
89+
return 2;
90+
case PTYPE_U32:
91+
case PTYPE_I32:
92+
case PTYPE_F32:
93+
return 4;
94+
case PTYPE_U64:
95+
case PTYPE_I64:
96+
case PTYPE_F64:
97+
return 8;
98+
default:
99+
return 0;
100+
}
101+
}
81102
#endif
82103

83104
/// Number of threads per CUDA block.

0 commit comments

Comments
 (0)