diff --git a/vortex-cuda/kernels/src/dynamic_dispatch.cu b/vortex-cuda/kernels/src/dynamic_dispatch.cu index 5ba884d3de5..29bf26efefb 100644 --- a/vortex-cuda/kernels/src/dynamic_dispatch.cu +++ b/vortex-cuda/kernels/src/dynamic_dispatch.cu @@ -1,15 +1,47 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -/// GPU kernel that decompresses a Vortex encoding tree in a single launch via dynamic dispatch. -/// -/// Stages communicate through shared memory: early input stages populate -/// persistent smem regions (e.g., dictionary values, run-end endpoints) that -/// later stages reference via smem offsets. -/// -/// The final output stage writes directly to global memory instead of back -/// to shared memory. Shared memory is dynamically sized at launch time to -/// fit all intermediate buffers that must coexist simultaneously. +// ═══════════════════════════════════════════════════════════════════════════ +// Dynamic dispatch kernel +// ═══════════════════════════════════════════════════════════════════════════ +// +// Vortex arrays are stored as nested encodings — e.g. ALP(FoR(BitPacked)) +// or Dict(codes=BitPacked, values=FoR(BitPacked)). This kernel walks +// such a tree in a single launch by decomposing it into a linear sequence +// of stages described by a packed plan buffer on the device. +// +// Each block produces ELEMENTS_PER_BLOCK output elements. Input stages +// are fully decoded per block (every block independently decodes the +// complete dict values, run-end endpoints, etc. into its own shared +// memory). +// +// ## Pipeline +// +// Input stages run first: each decodes a dependency (dict values, run-end +// endpoints) into shared memory that the output stage later references via +// byte offsets for DICT gathers and RUNEND binary searches. +// +// The output stage then processes the full block through: +// +// source_op → scalar_op (FoR/ZigZag/ALP/DICT) → streaming store +// +// in register batches of VALUES_PER_TILE (8 for u32) per thread. +// +// ## Source ops +// +// BITUNPACK Cooperative FastLanes unpack into smem scratch, sync, +// then batch-read from smem. Tiles at 1024 elements. +// LOAD Read from global memory, widening to T if narrower. +// SEQUENCE Compute base + i * multiplier in registers. +// RUNEND Forward-scan through ends/values arrays that input stages +// decoded into shared memory. Per-thread cursor in +// runend_cursors[] avoids re-searching across tile iterations. +// +// ## Mixed-width support +// +// LOAD sources from pending subtrees may have a narrower type than the +// output (e.g. u8 dict codes in a u32 plan). load_element() widens +// to T via static_cast — no separate widen kernel or smem intermediate. #include #include @@ -22,165 +54,93 @@ #include "dynamic_dispatch.h" #include "types.cuh" -/// Binary search for first element strictly greater than value. +// ═══════════════════════════════════════════════════════════════════════════ +// Primitives +// ═══════════════════════════════════════════════════════════════════════════ + +/// Binary search for the first element in `data[0..len)` strictly greater +/// than `value`. Returns `len` if all elements are ≤ value. template __device__ inline uint64_t upper_bound(const T *data, uint64_t len, uint64_t value) { auto it = thrust::upper_bound(thrust::seq, data, data + len, value); return it - data; } -/// Executes a source operation to fill a shared memory region with decoded data. -/// -/// This function handles the first phase of each stage's pipeline. It reads -/// compressed or raw data from global memory and writes decoded elements into -/// the stage's shared memory region. -/// -/// @param input Global memory pointer to the stage's encoded input data -/// @param smem_output Shared memory pointer where decoded elements are written -/// @param chunk_start Starting index of the chunk to process (block-relative for output stage) -/// @param chunk_len Number of elements to produce (may be < ELEMENTS_PER_BLOCK for tail blocks) -/// @param source_op Source operation descriptor (BITUNPACK, LOAD, or RUNEND) -/// @param smem_base Base of the entire dynamic shared memory pool, used by RUNEND -/// to resolve offsets to ends/values decoded by earlier stages +/// Read one element from global memory at `ptype` width, widen to T. +/// Signed types are sign-extended; unsigned types are zero-extended. template -__device__ inline void dynamic_source_op(const T *__restrict input, - T *__restrict &smem_output, - uint64_t chunk_start, - uint32_t chunk_len, - const struct SourceOp &source_op, - T *__restrict smem_base) { - constexpr uint32_t T_BITS = sizeof(T) * 8; - - switch (source_op.op_code) { - case SourceOp::BITUNPACK: { - constexpr uint32_t FL_CHUNK_SIZE = 1024; - constexpr uint32_t LANES_PER_FL_BLOCK = FL_CHUNK_SIZE / T_BITS; - const uint32_t bit_width = source_op.params.bitunpack.bit_width; - const uint32_t packed_words_per_fl_block = LANES_PER_FL_BLOCK * bit_width; - - const uint32_t element_offset = source_op.params.bitunpack.element_offset; - const uint32_t smem_within_offset = (chunk_start + element_offset) % FL_CHUNK_SIZE; - const uint64_t first_fl_block = (chunk_start + element_offset) / FL_CHUNK_SIZE; - - // FL blocks must divide evenly. Otherwise, the last unpack would overflow smem. - static_assert((ELEMENTS_PER_BLOCK % FL_CHUNK_SIZE) == 0); - - const auto div_ceil = [](auto a, auto b) { - return (a + b - 1) / b; - }; - const uint32_t num_fl_chunks = div_ceil(chunk_len + smem_within_offset, FL_CHUNK_SIZE); - - for (uint32_t chunk_idx = 0; chunk_idx < num_fl_chunks; ++chunk_idx) { - const T *packed_chunk = input + (first_fl_block + chunk_idx) * packed_words_per_fl_block; - T *smem_lane = smem_output + chunk_idx * FL_CHUNK_SIZE; - // Distribute unpacking across threads via lane-wise decomposition. - for (uint32_t lane = threadIdx.x; lane < LANES_PER_FL_BLOCK; lane += blockDim.x) { - bit_unpack_lane(packed_chunk, smem_lane, 0, lane, bit_width); - } - } - smem_output += smem_within_offset; - return; - } - - case SourceOp::LOAD: { - // Copy elements verbatim from global memory into shared memory. - for (uint32_t i = threadIdx.x; i < chunk_len; i += blockDim.x) { - smem_output[i] = input[chunk_start + i]; - } - return; - } - - case SourceOp::RUNEND: { - // Ends and values were decoded into shared memory by earlier stages. - const T *ends = &smem_base[source_op.params.runend.ends_smem_offset]; - const T *values = &smem_base[source_op.params.runend.values_smem_offset]; - const uint64_t num_runs = source_op.params.runend.num_runs; - const uint64_t offset = source_op.params.runend.offset; - - // Each thread binary-searches for its first position's run, then - // forward-scans for subsequent positions. Strided positions are - // monotonically increasing per thread, so current_run only advances. - uint64_t current_run = upper_bound(ends, num_runs, chunk_start + threadIdx.x + offset); - - for (uint32_t i = threadIdx.x; i < chunk_len; i += blockDim.x) { - uint64_t pos = chunk_start + i + offset; - - while (current_run < num_runs && static_cast(ends[current_run]) <= pos) { - current_run++; - } - - smem_output[i] = values[min(current_run, num_runs - 1)]; - } - return; - } - - case SourceOp::SEQUENCE: { - // Generate a linear sequence: value[i] = base + i * multiplier. - // Used for SequenceArray (e.g. monotonic run-end endpoints). - const T base = static_cast(source_op.params.sequence.base); - const T mul = static_cast(source_op.params.sequence.multiplier); - for (uint32_t i = threadIdx.x; i < chunk_len; i += blockDim.x) { - smem_output[i] = base + static_cast(chunk_start + i) * mul; - } - break; - } - +__device__ inline T load_element(const void *__restrict ptr, PTypeTag ptype, uint64_t idx) { + switch (ptype) { + case PTYPE_U8: + return static_cast(static_cast(ptr)[idx]); + case PTYPE_I8: + return static_cast(static_cast(ptr)[idx]); + case PTYPE_U16: + return static_cast(static_cast(ptr)[idx]); + case PTYPE_I16: + return static_cast(static_cast(ptr)[idx]); + case PTYPE_U32: + case PTYPE_F32: + return static_cast(static_cast(ptr)[idx]); + case PTYPE_I32: + return static_cast(static_cast(ptr)[idx]); + case PTYPE_U64: + case PTYPE_F64: + return static_cast(static_cast(ptr)[idx]); + case PTYPE_I64: + return static_cast(static_cast(ptr)[idx]); default: __builtin_unreachable(); } } -/// Applies a single scalar operation to N values in registers. -/// -/// Scalar operations are applied element-wise after the source op fills shared -/// memory. All ops compose fluently in any order: FoR adds a constant, ZigZag -/// decodes signed integers, ALP decodes floats, and DICT gathers from a -/// dictionary in shared memory. +/// Per-thread run cursor for RUNEND forward-scan, one entry per thread. /// -/// @param values Array of N values to transform in-place -/// @param op The scalar operation descriptor -/// @param smem_base Base of dynamic shared memory pool (used by DICT to resolve offsets) +/// Stored in shared memory so the cursor persists across successive +/// source_op calls in the tile loop. Each thread's positions are +/// monotonically increasing across tiles, so the cursor only advances +/// forward — the next tile picks up exactly where the previous one +/// stopped, avoiding a binary search per tile. The only binary search +/// is the initial upper_bound seed before the tile loop begins. +__shared__ uint64_t runend_cursors[BLOCK_SIZE]; + +// ═══════════════════════════════════════════════════════════════════════════ +// Scalar ops +// ═══════════════════════════════════════════════════════════════════════════ + +/// Apply one scalar operation to N values in registers. template -__device__ inline void apply_scalar_op(T *values, const struct ScalarOp &op, T *__restrict smem_base) { +__device__ inline void scalar_op(T *values, const struct ScalarOp &op, char *__restrict smem) { switch (op.op_code) { case ScalarOp::FOR: { const T ref = static_cast(op.params.frame_of_ref.reference); - // clang-format off - #pragma unroll - // clang-format on +#pragma unroll for (uint32_t i = 0; i < N; ++i) { values[i] += ref; } break; } case ScalarOp::ZIGZAG: { - // clang-format off - #pragma unroll - // clang-format on +#pragma unroll for (uint32_t i = 0; i < N; ++i) { values[i] = (values[i] >> 1) ^ static_cast(-(values[i] & 1)); } break; } case ScalarOp::ALP: { - const float f = op.params.alp.f; - const float e = op.params.alp.e; - // clang-format off - #pragma unroll - // clang-format on + const float f = op.params.alp.f, e = op.params.alp.e; +#pragma unroll for (uint32_t i = 0; i < N; ++i) { - float result = static_cast(static_cast(values[i])) * f * e; - values[i] = static_cast(__float_as_uint(result)); + float r = static_cast(static_cast(values[i])) * f * e; + values[i] = static_cast(__float_as_uint(r)); } break; } case ScalarOp::DICT: { - const T *dict_values = &smem_base[op.params.dict.values_smem_offset]; - // clang-format off - #pragma unroll - // clang-format on + const T *dict = reinterpret_cast(smem + op.params.dict.values_smem_byte_offset); +#pragma unroll for (uint32_t i = 0; i < N; ++i) { - values[i] = dict_values[static_cast(values[i])]; + values[i] = dict[static_cast(values[i])]; } break; } @@ -189,168 +149,333 @@ __device__ inline void apply_scalar_op(T *values, const struct ScalarOp &op, T * } } -/// Store policy for global memory writes. -enum class StorePolicy { - /// Default write-back stores — data stays in L2 cache. - WRITEBACK, - /// Streaming stores (`__stcs` / `st.cs`) — hint L2 to evict early. - /// Use for write-only output data that this kernel will not read again. - /// `__stcs` is a regular synchronous store (not async like `cp.async`), - /// so the existing `__syncthreads()` barrier after each tile is - /// sufficient for ordering. - STREAMING, -}; - -/// Reads values from `smem_input`, applies scalar ops in registers, and -/// writes results to `write_dest` at `write_offset`. -template -__device__ void apply_scalar_ops(const T *__restrict smem_input, - T *__restrict write_dest, - uint64_t write_offset, +// ═══════════════════════════════════════════════════════════════════════════ +// Source ops +// ═══════════════════════════════════════════════════════════════════════════ + +/// FastLanes cooperative unpack — all threads in the block scatter-write +/// decoded elements into `dst`. Caller must issue __syncthreads() before +/// any thread reads from `dst`. +template +__device__ inline void bitunpack(const T *__restrict packed, + T *__restrict dst, + uint64_t chunk_start, uint32_t chunk_len, - uint8_t num_scalar_ops, - const struct ScalarOp *scalar_ops, - T *__restrict smem_base) { - constexpr uint32_t VALUES_PER_LOOP = 64 / sizeof(T); - const uint32_t tile_size = blockDim.x * VALUES_PER_LOOP; - const uint32_t num_full_tiles = chunk_len / tile_size; - - // Each thread holds multiple values in registers for instruction-level - // parallelism, hiding pipeline latency between independent operations. - for (uint32_t tile = 0; tile < num_full_tiles; ++tile) { - const uint32_t tile_base = tile * tile_size; - T values[VALUES_PER_LOOP]; - - // clang-format off - #pragma unroll - // clang-format on - for (uint32_t idx = 0; idx < VALUES_PER_LOOP; ++idx) { - values[idx] = smem_input[tile_base + idx * blockDim.x + threadIdx.x]; + const struct SourceOp &src) { + constexpr uint32_t T_BITS = sizeof(T) * 8; + constexpr uint32_t FL_CHUNK = 1024; + constexpr uint32_t LANES = FL_CHUNK / T_BITS; + const uint32_t bw = src.params.bitunpack.bit_width; + const uint32_t words_per_block = LANES * bw; + const uint32_t elem_off = src.params.bitunpack.element_offset; + const uint32_t dst_off = (chunk_start + elem_off) % FL_CHUNK; + const uint64_t first_block = (chunk_start + elem_off) / FL_CHUNK; + + static_assert((ELEMENTS_PER_BLOCK % FL_CHUNK) == 0); + const uint32_t n_chunks = (chunk_len + dst_off + FL_CHUNK - 1) / FL_CHUNK; + + for (uint32_t c = 0; c < n_chunks; ++c) { + const T *src_chunk = packed + (first_block + c) * words_per_block; + T *chunk_dst = dst + c * FL_CHUNK; + for (uint32_t lane = threadIdx.x; lane < LANES; lane += blockDim.x) { + bit_unpack_lane(src_chunk, chunk_dst, 0, lane, bw); } + } +} - for (uint8_t op_idx = 0; op_idx < num_scalar_ops; ++op_idx) { - apply_scalar_op(values, scalar_ops[op_idx], smem_base); +/// Read N values from a source op into `out`. +/// +/// Dispatches on `src.op_code` to handle each encoding: +/// BITUNPACK — read from `smem_src` at `smem_base` offset. +/// LOAD — read from `raw_input` via load_element (type-widening). +/// SEQUENCE — compute base + pos × multiplier in registers. +/// RUNEND — forward-scan ends/values in smem using runend_cursors. +/// +/// Position calculation (via THREAD_POS macro): +/// N > 1 (batched): pos = base + j·blockDim.x + threadIdx.x. +/// Caller passes the tile base WITHOUT threadIdx.x. +/// N = 1 (single): base is the exact position. No stride added. +template +__device__ inline void source_op(T *out, + const struct SourceOp &src, + const void *raw_input, + PTypeTag ptype, + const T *smem_src, + uint32_t smem_base, + uint64_t global_base, + char *__restrict smem) { + // Wrapped in a macro, rather than a lambda, to avoid allocating additional GPU registers. +#define THREAD_POS(base, j) ((N == 1) ? (base) : ((base) + (j) * blockDim.x + threadIdx.x)) + + switch (src.op_code) { + case SourceOp::BITUNPACK: { +#pragma unroll + for (uint32_t j = 0; j < N; ++j) { + out[j] = smem_src[THREAD_POS(smem_base, j)]; } - - // clang-format off - #pragma unroll - // clang-format on - for (uint32_t idx = 0; idx < VALUES_PER_LOOP; ++idx) { - if constexpr (S == StorePolicy::STREAMING) { - __stcs(&write_dest[write_offset + tile_base + idx * blockDim.x + threadIdx.x], values[idx]); - } else { - write_dest[write_offset + tile_base + idx * blockDim.x + threadIdx.x] = values[idx]; - } + return; + } + case SourceOp::LOAD: { +#pragma unroll + for (uint32_t j = 0; j < N; ++j) { + out[j] = load_element(raw_input, ptype, THREAD_POS(global_base, j)); } + return; } - - const uint32_t rem_start = num_full_tiles * tile_size; - for (uint32_t elem_idx = rem_start + threadIdx.x; elem_idx < chunk_len; elem_idx += blockDim.x) { - T val = smem_input[elem_idx]; - for (uint8_t op_idx = 0; op_idx < num_scalar_ops; ++op_idx) { - apply_scalar_op(&val, scalar_ops[op_idx], smem_base); + case SourceOp::SEQUENCE: { + const T base = static_cast(src.params.sequence.base); + const T mul = static_cast(src.params.sequence.multiplier); +#pragma unroll + for (uint32_t j = 0; j < N; ++j) { + out[j] = base + static_cast(THREAD_POS(global_base, j)) * mul; } - if constexpr (S == StorePolicy::STREAMING) { - __stcs(&write_dest[write_offset + elem_idx], val); - } else { - write_dest[write_offset + elem_idx] = val; + return; + } + case SourceOp::RUNEND: { + const T *ends = reinterpret_cast(smem + src.params.runend.ends_smem_byte_offset); + const T *values = reinterpret_cast(smem + src.params.runend.values_smem_byte_offset); + const uint64_t num_runs = src.params.runend.num_runs; + const uint64_t offset = src.params.runend.offset; + uint64_t &run = runend_cursors[threadIdx.x]; +#pragma unroll + for (uint32_t j = 0; j < N; ++j) { + const uint64_t pos = THREAD_POS(global_base, j) + offset; + while (run < num_runs && static_cast(ends[run]) <= pos) { + run++; + } + out[j] = values[min(run, num_runs - 1)]; } + return; + } + default: + __builtin_unreachable(); } + +#undef THREAD_POS +} + +// ═══════════════════════════════════════════════════════════════════════════ +// Output stage — source_op → scalar_op → streaming store +// ═══════════════════════════════════════════════════════════════════════════ +// +// BITUNPACK tiles at SMEM_TILE_SIZE: cooperative unpack → smem → sync → +// batched read. LOAD, SEQUENCE, and RUNEND need no smem scratch and +// process the full block in a single outer iteration, tiled by tile_idx. + +/// How many elements to process in this BITUNPACK tile iteration. +/// The first tile may be shorter due to `element_offset` alignment; +/// the last tile may be shorter because we've reached `block_len`. +__device__ inline uint32_t bitunpack_tile_len(const Stage &stage, uint32_t block_len, uint32_t tile_off) { + const uint32_t off = (tile_off == 0) ? stage.source.params.bitunpack.element_offset : 0; + return min(SMEM_TILE_SIZE - off, block_len - tile_off); } -/// Decodes and transforms a stage's data through shared memory, writing -/// final results to `write_dest` at `write_offset`. Input stages write -/// back to smem; the output stage writes to global memory. -template -__device__ void execute_stage(const Stage &stage, - T *__restrict smem_base, - uint64_t chunk_start, - uint32_t chunk_len, - T *__restrict write_dest, - uint64_t write_offset) { - T *smem_output = &smem_base[stage.smem_offset]; - - dynamic_source_op(reinterpret_cast(stage.input_ptr), - smem_output, - chunk_start, +/// Process the final / output stage: decode source → apply scalar ops → +/// streaming-store to global memory. Handles the full block, tiling through +/// smem scratch for BITUNPACK. +template +__device__ void execute_output_stage(T *__restrict output, + const Stage &stage, + char *__restrict smem, + uint64_t block_start, + uint32_t block_len) { + constexpr uint32_t VALUES_PER_TILE = 32 / sizeof(T); + const uint32_t tile_size = blockDim.x * VALUES_PER_TILE; + const auto &src = stage.source; + const void *raw_input = reinterpret_cast(stage.input_ptr); + const PTypeTag ptype = stage.source_ptype; + + if (src.op_code == SourceOp::RUNEND) { + // Seed each thread's cursor with the run containing its first + // strided position. The RUNEND arm in source_op advances the + // cursor monotonically, so this avoids a full binary search on + // every element. + const T *ends = reinterpret_cast(smem + src.params.runend.ends_smem_byte_offset); + runend_cursors[threadIdx.x] = upper_bound(ends, + src.params.runend.num_runs, + block_start + threadIdx.x + src.params.runend.offset); + } + + for (uint32_t elem_idx = 0; elem_idx < block_len;) { + uint32_t chunk_len; + const T *smem_src = nullptr; + + // BITUNPACK uses smem scratch, so the outer loop advances one + // chunk at a time. LOAD, SEQUENCE, and RUNEND need no smem + // scratch, so chunk_len = block_len (single outer iteration); + // tiling happens in the inner tile_idx loop. + if (src.op_code == SourceOp::BITUNPACK) { + chunk_len = bitunpack_tile_len(stage, block_len, elem_idx); + T *scratch = reinterpret_cast(smem + stage.smem_byte_offset); + bitunpack(reinterpret_cast(stage.input_ptr), + scratch, + block_start + elem_idx, chunk_len, - stage.source, - smem_base); - __syncthreads(); - - apply_scalar_ops(smem_output, - write_dest, - write_offset, - chunk_len, - stage.num_scalar_ops, - stage.scalar_ops, - smem_base); - __syncthreads(); + src); + constexpr uint32_t FL_CHUNK = 1024; // FastLanes chunk size + const uint32_t align = (block_start + elem_idx + src.params.bitunpack.element_offset) % FL_CHUNK; + smem_src = scratch + align; + // Write barrier: all threads finished bitunpack, safe to read from scratch. + __syncthreads(); + } else { + chunk_len = block_len; + } + + const uint32_t tile_count = chunk_len / tile_size; + for (uint32_t tile_idx = 0; tile_idx < tile_count; ++tile_idx) { + const uint64_t tile_start = block_start + elem_idx + static_cast(tile_idx) * tile_size; + T values[VALUES_PER_TILE]; + + source_op(values, + src, + raw_input, + ptype, + smem_src, + tile_idx * tile_size, + tile_start, + smem); + + for (uint8_t op = 0; op < stage.num_scalar_ops; ++op) { + scalar_op(values, stage.scalar_ops[op], smem); + } + +#pragma unroll + for (uint32_t j = 0; j < VALUES_PER_TILE; ++j) { + // st.cs (cache streaming): marks this line for earliest + // eviction in L1 and L2. Output data is written once and + // never read again by this kernel, so keeping it cached + // would only compete with the packed input buffers and + // smem-resident dict/runend data that the next tiles still + // need to read. Evict-first lets those stay resident. + __stcs(&output[tile_start + j * blockDim.x + threadIdx.x], values[j]); + } + } + + const uint32_t rem = tile_count * tile_size; + for (uint32_t i = rem + threadIdx.x; i < chunk_len; i += blockDim.x) { + const uint64_t gpos = block_start + elem_idx + i; + T val; + source_op(&val, src, raw_input, ptype, smem_src, i, gpos, smem); + + for (uint8_t op = 0; op < stage.num_scalar_ops; ++op) { + scalar_op(&val, stage.scalar_ops[op], smem); + } + __stcs(&output[gpos], val); + } + + if (src.op_code == SourceOp::BITUNPACK) { + // Read barrier: all threads finished reading scratch, safe to + // overwrite it with the next chunk's bitunpack. + __syncthreads(); + } + elem_idx += chunk_len; + } } -/// Computes the number of elements to process in an output tile. +// ═══════════════════════════════════════════════════════════════════════════ +// Input stages — decode dependencies into shared memory for the output stage +// ═══════════════════════════════════════════════════════════════════════════ + +/// Decode one input stage (dict values, run-end endpoints, etc.) into its +/// shared memory region so the output stage can reference it later. +/// Applies any scalar ops in-place before returning. /// -/// Each tile decodes exactly one FL block == SMEM_TILE_SIZE elements into -/// shared memory. In case BITUNPACK is sliced, we need to account for the -/// sub-byte element offset. -__device__ inline uint32_t output_tile_len(const Stage &stage, uint32_t block_len, uint32_t tile_off) { - const uint32_t element_offset = (tile_off == 0 && stage.source.op_code == SourceOp::BITUNPACK) - ? stage.source.params.bitunpack.element_offset - : 0; - return min(SMEM_TILE_SIZE - element_offset, block_len - tile_off); +/// Unlike execute_output_stage, this does not tile — the entire stage is +/// decoded in one pass. The output stage needs random access into these +/// smem regions (e.g. DICT gathers by arbitrary code value), so the data +/// must be fully resident. The smem limit check in the Rust plan builder +/// ensures the stage fits; if it doesn't, the plan falls back to Unfused. +template +__device__ void execute_input_stage(const Stage &stage, char *__restrict smem) { + T *smem_out = reinterpret_cast(smem + stage.smem_byte_offset); + const auto &src = stage.source; + + if (src.op_code == SourceOp::BITUNPACK) { + bitunpack(reinterpret_cast(stage.input_ptr), smem_out, 0, stage.len, src); + smem_out += src.params.bitunpack.element_offset % SMEM_TILE_SIZE; + // Write barrier: cooperative bitunpack finished, safe to read + // decoded elements in the scalar-op loop below. + __syncthreads(); + + if (stage.num_scalar_ops > 0) { + for (uint32_t i = threadIdx.x; i < stage.len; i += blockDim.x) { + T val = smem_out[i]; + for (uint8_t op = 0; op < stage.num_scalar_ops; ++op) { + scalar_op(&val, stage.scalar_ops[op], smem); + } + smem_out[i] = val; + } + // Write barrier: scalar ops applied in-place, smem region is + // now fully populated for subsequent stages to read. + __syncthreads(); + } + } else { + if (src.op_code == SourceOp::RUNEND) { + // Seed each thread's cursor with the run containing its first + // strided position. The RUNEND arm in source_op advances the + // cursor monotonically, so this avoids a full binary search on + // every element. + const T *ends = reinterpret_cast(smem + src.params.runend.ends_smem_byte_offset); + runend_cursors[threadIdx.x] = + upper_bound(ends, src.params.runend.num_runs, threadIdx.x + src.params.runend.offset); + } + const void *raw_input = reinterpret_cast(stage.input_ptr); + for (uint32_t i = threadIdx.x; i < stage.len; i += blockDim.x) { + T val; + source_op(&val, src, raw_input, stage.source_ptype, nullptr, 0, i, smem); + for (uint8_t op = 0; op < stage.num_scalar_ops; ++op) { + scalar_op(&val, stage.scalar_ops[op], smem); + } + smem_out[i] = val; + } + // Write barrier: smem region is fully populated for subsequent + // stages to read. + __syncthreads(); + } } -/// Entry point of the dynamic dispatch kernel. -/// -/// 1. Input stages populate shared memory (e.g. dict values, run-end -/// endpoints) for the output stage to reference. -/// 2. The output stage decodes the root encoding and writes to global -/// memory. -/// -/// @param output Output buffer -/// @param array_len Total number of elements to produce -/// @param packed_plan Pointer to the packed plan byte buffer +// ═══════════════════════════════════════════════════════════════════════════ +// Kernel entry +// ═══════════════════════════════════════════════════════════════════════════ + +/// Kernel entry point. Parses the packed plan, runs all input stages to +/// populate shared memory, then runs the output stage to produce results. template __device__ void dynamic_dispatch(T *__restrict output, uint64_t array_len, const uint8_t *__restrict packed_plan) { + extern __shared__ char smem[]; - extern __shared__ char smem_bytes[]; - T *smem_base = reinterpret_cast(smem_bytes); + const auto *hdr = reinterpret_cast(packed_plan); + const uint8_t *cursor = packed_plan + sizeof(struct PlanHeader); + const uint8_t last = hdr->num_stages - 1; - const auto *header = reinterpret_cast(packed_plan); - const uint8_t *stage_cursor = packed_plan + sizeof(struct PlanHeader); - const uint8_t last = header->num_stages - 1; - - // Input stages: Decode inputs into smem regions. - for (uint8_t idx = 0; idx < last; ++idx) { - Stage stage = parse_stage(stage_cursor); - T *smem_output = &smem_base[stage.smem_offset]; - execute_stage(stage, smem_base, 0, stage.len, smem_output, 0); + for (uint8_t i = 0; i < last; ++i) { + Stage input_stage = parse_stage(cursor); + execute_input_stage(input_stage, smem); } - Stage output_stage = parse_stage(stage_cursor); + Stage output_stage = parse_stage(cursor); const uint64_t block_start = static_cast(blockIdx.x) * ELEMENTS_PER_BLOCK; const uint64_t block_end = min(block_start + ELEMENTS_PER_BLOCK, array_len); - const uint32_t block_len = static_cast(block_end - block_start); - - for (uint32_t tile_off = 0; tile_off < block_len;) { - const uint32_t tile_len = output_tile_len(output_stage, block_len, tile_off); - execute_stage(output_stage, - smem_base, - block_start + tile_off, - tile_len, - output, - block_start + tile_off); - tile_off += tile_len; - } + execute_output_stage(output, + output_stage, + smem, + block_start, + static_cast(block_end - block_start)); } -/// Generates a dynamic dispatch kernel entry point for each unsigned integer type. -#define GENERATE_DYNAMIC_DISPATCH_KERNEL(suffix, Type) \ +// Kernels are instantiated only for unsigned integer types. Signed and +// floating-point arrays reuse the unsigned kernel of the same width — +// the data is bit-identical under reinterpretation, and all arithmetic +// in the pipeline (FoR add, ZigZag decode, ALP decode, DICT gather) is +// correct on the unsigned representation. The one place where signedness +// matters is load_element(), which dispatches on the per-op PTypeTag to +// sign-extend or zero-extend when widening a narrow source to T. +#define GENERATE_KERNEL(suffix, Type) \ extern "C" __global__ void dynamic_dispatch_##suffix(Type *__restrict output, \ uint64_t array_len, \ const uint8_t *__restrict packed_plan) { \ dynamic_dispatch(output, array_len, packed_plan); \ } -FOR_EACH_UNSIGNED_INT(GENERATE_DYNAMIC_DISPATCH_KERNEL) +FOR_EACH_UNSIGNED_INT(GENERATE_KERNEL) diff --git a/vortex-cuda/kernels/src/dynamic_dispatch.h b/vortex-cuda/kernels/src/dynamic_dispatch.h index 03aee530b56..95540c51581 100644 --- a/vortex-cuda/kernels/src/dynamic_dispatch.h +++ b/vortex-cuda/kernels/src/dynamic_dispatch.h @@ -15,17 +15,89 @@ /// [PackedStage 0][ScalarOp × N0] /// [PackedStage 1][ScalarOp × N1] /// ... +/// +/// ## Per-op type tracking +/// +/// Each source op and scalar op may produce a different PType than its input. +/// For example, DICT transforms codes (e.g. u8) into values (e.g. f32), and +/// ALP transforms encoded integers (i32) into floats (f32). +/// +/// `PTypeTag` is a compact enum that identifies the primitive type at each +/// point in the pipeline. The kernel uses it to dispatch typed memory +/// operations (LOAD, BITUNPACK) and cross-stage references (DICT gather, +/// RUNEND lookup) at the correct element width and signedness. #pragma once #include +/// Compact tag identifying a Vortex PType for GPU dispatch. +/// +/// NOTE: These values intentionally skip F16 (which Rust PType includes), +/// so numeric values do NOT match Rust PType directly. The Rust +/// `ptype_to_tag()` function handles the mapping at plan-build time. +/// +/// The kernel uses this to: +/// - Select the correct element width for LOAD / BITUNPACK source ops. +/// - Index shared memory at the correct stride for DICT / RUNEND cross-stage +/// references. +/// - Distinguish int vs float for ALP decode. +enum PTypeTag : uint8_t { + PTYPE_U8 = 0, + PTYPE_U16 = 1, + PTYPE_U32 = 2, + PTYPE_U64 = 3, + PTYPE_I8 = 4, + PTYPE_I16 = 5, + PTYPE_I32 = 6, + PTYPE_I64 = 7, + PTYPE_F32 = 8, + PTYPE_F64 = 9, +}; + +/// Return the unsigned equivalent of a PTypeTag (same width). +#ifdef __cplusplus +#ifdef __CUDACC__ +#define PTYPE_HOST_DEVICE __host__ __device__ +#else +#define PTYPE_HOST_DEVICE +#endif +PTYPE_HOST_DEVICE constexpr PTypeTag ptype_to_unsigned(PTypeTag tag) { + switch (tag) { + case PTYPE_I8: + return PTYPE_U8; + case PTYPE_I16: + return PTYPE_U16; + case PTYPE_I32: + case PTYPE_F32: + return PTYPE_U32; + case PTYPE_I64: + case PTYPE_F64: + return PTYPE_U64; + default: + return tag; + } +} +#endif + +/// Number of threads per CUDA block. +#define BLOCK_SIZE 64 + /// Elements processed per CUDA block. #define ELEMENTS_PER_BLOCK 2048 /// Each tile is flushed to global before the next is decoded. #define SMEM_TILE_SIZE 1024 +/// Fixed shared memory declared in the kernel (bytes), excluded from +/// the dynamic shared memory budget. Accounts for +/// `runend_cursors[BLOCK_SIZE]` — one uint64_t cursor per thread. +/// +/// Uses a literal (64 * 8 = 512) instead of `BLOCK_SIZE * sizeof(uint64_t)` +/// so that bindgen can export it as a Rust constant (bindgen cannot evaluate +/// expressions involving other macros or sizeof). +#define KERNEL_FIXED_SHARED_BYTES 512 + #ifdef __cplusplus extern "C" { #endif @@ -44,9 +116,12 @@ union SourceParams { } load; /// Decode run-end encoding using ends and values already in shared memory. + /// + /// The smem offsets are byte offsets so that ends and values can have + /// different element widths. struct RunEndParams { - uint32_t ends_smem_offset; // element offset to decoded ends in smem - uint32_t values_smem_offset; // element offset to decoded values in smem + uint32_t ends_smem_byte_offset; // byte offset to decoded ends in smem + uint32_t values_smem_byte_offset; // byte offset to decoded values in smem uint64_t num_runs; uint64_t offset; // slice offset into the run-end encoded array } runend; @@ -64,7 +139,16 @@ struct SourceOp { }; /// Scalar ops: element-wise transforms in registers. -/// All ops compose fluently in any order. +/// +/// Each scalar op declares its `output_ptype` — the PType of the values it +/// produces. Most ops preserve the input type (FOR, ZIGZAG), but some +/// change it: +/// - ALP: encoded int → float (e.g. i32 → f32) +/// - DICT: codes type → values type (e.g. u8 → u32) +/// +/// The plan builder uses `output_ptype` to determine the element width +/// for shared memory allocation and to propagate type information +/// through the pipeline. union ScalarParams { struct FoRParams { uint64_t reference; @@ -76,30 +160,48 @@ union ScalarParams { } alp; /// Dictionary gather: use current value as index into decoded values in smem. + /// + /// `values_smem_byte_offset` is a byte offset so that values can have + /// a different element width than the codes. The plan builder uses + /// `output_ptype` (on the enclosing ScalarOp) to determine the values' + /// element type. struct DictParams { - uint32_t values_smem_offset; // element offset to decoded dict values in smem + uint32_t values_smem_byte_offset; // byte offset to decoded dict values in smem } dict; }; struct ScalarOp { enum ScalarOpCode { FOR, ZIGZAG, ALP, DICT } op_code; + /// The PType this op produces. For type-preserving ops (FOR, ZIGZAG) + /// this equals the input PType. For type-changing ops (ALP, DICT) this + /// is the new output PType. + enum PTypeTag output_ptype; union ScalarParams params; }; /// Packed stage header, followed by `num_scalar_ops` inline ScalarOps. +/// +/// `source_ptype` identifies the PType that the source op (BITUNPACK, LOAD, +/// etc.) produces. This may differ from the output PType when scalar ops +/// change the type (e.g. DICT transforms u8 codes into u32 values). +/// +/// `smem_byte_offset` is a byte offset into the dynamic shared memory +/// pool so that stages with different element widths can coexist. struct PackedStage { - uint64_t input_ptr; // global memory pointer to this stage's encoded input - uint32_t smem_offset; // element offset within dynamic shared memory for output - uint32_t len; // number of elements this stage produces + uint64_t input_ptr; // global memory pointer to this stage's encoded input + uint32_t smem_byte_offset; // byte offset within dynamic shared memory for output + uint32_t len; // number of elements this stage produces struct SourceOp source; uint8_t num_scalar_ops; + enum PTypeTag source_ptype; // PType produced by the source op }; /// Header for the packed plan byte buffer. struct __attribute__((aligned(8))) PlanHeader { uint8_t num_stages; - uint16_t plan_size_bytes; // total size of the packed plan including this header + enum PTypeTag output_ptype; // PType of the final output array + uint16_t plan_size_bytes; // total size of the packed plan including this header }; #ifdef __cplusplus @@ -113,13 +215,18 @@ struct __attribute__((aligned(8))) PlanHeader { /// Input stages decode data (e.g. dict values, run-end endpoints) into a /// shared memory region for the output stage to reference. The output stage /// decodes the root encoding and writes to global memory. +/// +/// `source_ptype` is the PType produced by the source op. Scalar ops may +/// change the type; the final output PType is given by the last scalar op's +/// `output_ptype` (or `source_ptype` if there are no scalar ops). struct Stage { uint64_t input_ptr; // encoded input in global memory - uint32_t smem_offset; // output offset in shared memory (elements) + uint32_t smem_byte_offset; // byte offset within dynamic shared memory uint32_t len; // elements produced + enum PTypeTag source_ptype; // PType produced by the source op struct SourceOp source; // source decode op uint8_t num_scalar_ops; // number of scalar ops - const struct ScalarOp *scalar_ops; // scalar deoode ops + const struct ScalarOp *scalar_ops; // scalar decode ops }; /// Parse a single stage from the packed plan byte buffer and advance the cursor. @@ -136,8 +243,9 @@ __device__ inline Stage parse_stage(const uint8_t *&cursor) { return Stage { .input_ptr = packed_stage->input_ptr, - .smem_offset = packed_stage->smem_offset, + .smem_byte_offset = packed_stage->smem_byte_offset, .len = packed_stage->len, + .source_ptype = packed_stage->source_ptype, .source = packed_stage->source, .num_scalar_ops = packed_stage->num_scalar_ops, .scalar_ops = ops, diff --git a/vortex-cuda/src/dynamic_dispatch/mod.rs b/vortex-cuda/src/dynamic_dispatch/mod.rs index 4cf2560a766..e6444327683 100644 --- a/vortex-cuda/src/dynamic_dispatch/mod.rs +++ b/vortex-cuda/src/dynamic_dispatch/mod.rs @@ -18,7 +18,6 @@ use std::borrow::Borrow; use std::mem::size_of; -use std::slice::from_raw_parts; use std::sync::Arc; use cudarc::driver::DevicePtr; @@ -49,17 +48,54 @@ pub use plan_builder::MaterializedPlan; include!(concat!(env!("OUT_DIR"), "/dynamic_dispatch.rs")); -/// Reinterpret a `&T` as a byte slice for serialization into the packed plan. -/// -/// # Safety +/// Convert a Rust `PType` to the C `PTypeTag` constant. +pub fn ptype_to_tag(ptype: PType) -> PTypeTag { + match ptype { + PType::U8 => PTypeTag_PTYPE_U8, + PType::U16 => PTypeTag_PTYPE_U16, + PType::U32 => PTypeTag_PTYPE_U32, + PType::U64 => PTypeTag_PTYPE_U64, + PType::I8 => PTypeTag_PTYPE_I8, + PType::I16 => PTypeTag_PTYPE_I16, + PType::I32 => PTypeTag_PTYPE_I32, + PType::I64 => PTypeTag_PTYPE_I64, + PType::F16 => unreachable!("F16 is not supported by CUDA dynamic dispatch"), + PType::F32 => PTypeTag_PTYPE_F32, + PType::F64 => PTypeTag_PTYPE_F64, + } +} + +/// Convert a C `PTypeTag` back to a Rust `PType`. +pub fn tag_to_ptype(tag: PTypeTag) -> PType { + match tag { + PTypeTag_PTYPE_U8 => PType::U8, + PTypeTag_PTYPE_U16 => PType::U16, + PTypeTag_PTYPE_U32 => PType::U32, + PTypeTag_PTYPE_U64 => PType::U64, + PTypeTag_PTYPE_I8 => PType::I8, + PTypeTag_PTYPE_I16 => PType::I16, + PTypeTag_PTYPE_I32 => PType::I32, + PTypeTag_PTYPE_I64 => PType::I64, + PTypeTag_PTYPE_F32 => PType::F32, + PTypeTag_PTYPE_F64 => PType::F64, + _ => unreachable!("unknown PTypeTag {tag}"), + } +} + +/// Serialize a `#[repr(C)]` struct to a byte vector for the packed plan. /// -/// The caller must ensure `T` is a `#[repr(C)]` type whose layout is -/// compatible with the C ABI. All the types we serialise (`PlanHeader`, -/// `PackedStage`, `ScalarOp`) are bindgen-generated `#[repr(C)]` structs. -/// Padding bytes may be uninitialised on the Rust side, but the C reader -/// never inspects them, so the values are irrelevant. -fn as_bytes(val: &T) -> &[u8] { - unsafe { from_raw_parts(std::ptr::addr_of!(*val).cast(), size_of::()) } +/// Copies field data into a pre-zeroed buffer so padding holes are +/// deterministically zero, avoiding UB from reading uninitialised bytes. +fn as_bytes(val: &T) -> Vec { + let n = size_of::(); + let mut buf = vec![0u8; n]; + // SAFETY: T is a bindgen-generated #[repr(C)] struct with only + // integer/float/enum fields. We overwrite the zeroed buffer with + // the struct's bytes; padding holes keep their zero value. + unsafe { + std::ptr::copy_nonoverlapping(std::ptr::addr_of!(*val).cast::(), buf.as_mut_ptr(), n); + } + buf } /// A stage used to build a [`CudaDispatchPlan`] on the host side. @@ -71,9 +107,11 @@ pub struct MaterializedStage { /// Device pointer to the input buffer for this stage. pub input_ptr: u64, /// Byte offset into shared memory where this stage's data is stored. - pub smem_offset: u32, + pub smem_byte_offset: u32, /// Number of elements in this stage. pub len: u32, + /// PType tag for the source op's output type. + pub source_ptype: PTypeTag, /// The source operation that produces the initial values (e.g. load, bitunpack, sequence). pub source: SourceOp, /// Chain of element-wise scalar operations applied after the source (e.g. frame-of-reference, zigzag, ALP). @@ -83,15 +121,17 @@ pub struct MaterializedStage { impl MaterializedStage { pub fn new( input_ptr: u64, - smem_offset: u32, + smem_byte_offset: u32, len: u32, + source_ptype: PTypeTag, source: SourceOp, scalar_ops: &[ScalarOp], ) -> Self { Self { input_ptr, - smem_offset, + smem_byte_offset, len, + source_ptype, source, scalar_ops: scalar_ops.to_vec(), } @@ -104,8 +144,9 @@ impl MaterializedStage { #[derive(Clone)] pub struct ParsedStage { pub input_ptr: u64, - pub smem_offset: u32, + pub smem_byte_offset: u32, pub len: u32, + pub source_ptype: PTypeTag, pub source: SourceOp, pub num_scalar_ops: u8, pub scalar_ops: Vec, @@ -134,7 +175,7 @@ impl CudaDispatchPlan { /// # Panics /// /// Panics if `stages` is empty or the serialized plan exceeds 65535 bytes. - pub fn new(stages: I) -> Self + pub fn new(stages: I, output_ptype: PTypeTag) -> Self where I: IntoIterator, I::Item: Borrow, @@ -160,25 +201,25 @@ impl CudaDispatchPlan { let mut buffer = ByteBufferMut::with_capacity_aligned(total_size, Alignment::of::()); - // Write header. let header = PlanHeader { num_stages: stages.len() as u8, + output_ptype, plan_size_bytes: total_size as u16, }; - buffer.extend_from_slice(as_bytes(&header)); + buffer.extend_from_slice(&as_bytes(&header)); - // Write each stage header followed by its scalar ops. for stage in &stages { let packed_stage = PackedStage { input_ptr: stage.input_ptr, - smem_offset: stage.smem_offset, + smem_byte_offset: stage.smem_byte_offset, len: stage.len, source: stage.source, num_scalar_ops: stage.scalar_ops.len() as u8, + source_ptype: stage.source_ptype, }; - buffer.extend_from_slice(as_bytes(&packed_stage)); + buffer.extend_from_slice(&as_bytes(&packed_stage)); for op in &stage.scalar_ops { - buffer.extend_from_slice(as_bytes(op)); + buffer.extend_from_slice(&as_bytes(op)); } } @@ -195,8 +236,16 @@ impl CudaDispatchPlan { /// Number of stages in the plan. pub fn num_stages(&self) -> u8 { - let header: PlanHeader = unsafe { *self.buffer.as_ptr().cast() }; - header.num_stages + self.header().num_stages + } + + /// PType of the final output array. + pub fn output_ptype(&self) -> PType { + tag_to_ptype(self.header().output_ptype) + } + + fn header(&self) -> PlanHeader { + unsafe { *self.buffer.as_ptr().cast() } } /// Parse and return a read-only view of the stage at `index`. @@ -232,8 +281,9 @@ impl CudaDispatchPlan { ParsedStage { input_ptr: ps.input_ptr, - smem_offset: ps.smem_offset, + smem_byte_offset: ps.smem_byte_offset, len: ps.len, + source_ptype: ps.source_ptype, source: ps.source, num_scalar_ops: ps.num_scalar_ops, scalar_ops, @@ -272,13 +322,13 @@ impl SourceOp { /// /// # Arguments /// - /// * `ends_smem_offset` - smem region holding run-end endpoints - /// * `values_smem_offset` - smem region holding per-run values + /// * `ends_smem_byte_offset` - byte offset to decoded ends in smem + /// * `values_smem_byte_offset` - byte offset to decoded values in smem /// * `num_runs` - number of runs (length of ends/values) /// * `offset` - logical offset for sliced arrays pub fn runend( - ends_smem_offset: u32, - values_smem_offset: u32, + ends_smem_byte_offset: u32, + values_smem_byte_offset: u32, num_runs: u64, offset: u64, ) -> Self { @@ -286,8 +336,8 @@ impl SourceOp { op_code: SourceOp_SourceOpCode_RUNEND, params: SourceParams { runend: SourceParams_RunEndParams { - ends_smem_offset, - values_smem_offset, + ends_smem_byte_offset, + values_smem_byte_offset, num_runs, offset, }, @@ -309,9 +359,10 @@ impl SourceOp { impl ScalarOp { /// Frame-of-reference: add a constant. - pub fn frame_of_ref(reference: u64) -> Self { + pub fn frame_of_ref(reference: u64, output_ptype: PTypeTag) -> Self { Self { op_code: ScalarOp_ScalarOpCode_FOR, + output_ptype, params: ScalarParams { frame_of_ref: ScalarParams_FoRParams { reference }, }, @@ -319,10 +370,11 @@ impl ScalarOp { } /// Zigzag decode. - pub fn zigzag() -> Self { + pub fn zigzag(output_ptype: PTypeTag) -> Self { // SAFETY: Zigzag has no parameters; zeroed union is valid. Self { op_code: ScalarOp_ScalarOpCode_ZIGZAG, + output_ptype, params: unsafe { std::mem::zeroed() }, } } @@ -331,6 +383,7 @@ impl ScalarOp { pub fn alp(f: f32, e: f32) -> Self { Self { op_code: ScalarOp_ScalarOpCode_ALP, + output_ptype: PTypeTag_PTYPE_F32, params: ScalarParams { alp: ScalarParams_AlpParams { f, e }, }, @@ -339,23 +392,24 @@ impl ScalarOp { /// Dictionary gather: use current value as index into decoded values /// in shared memory (populated by an earlier input stage). - pub fn dict(values_smem_offset: u32) -> Self { + pub fn dict(values_smem_byte_offset: u32, output_ptype: PTypeTag) -> Self { Self { op_code: ScalarOp_ScalarOpCode_DICT, + output_ptype, params: ScalarParams { - dict: ScalarParams_DictParams { values_smem_offset }, + dict: ScalarParams_DictParams { + values_smem_byte_offset, + }, }, } } } impl MaterializedPlan { - pub fn execute( - self, - output_ptype: PType, - len: usize, - ctx: &mut CudaExecutionCtx, - ) -> VortexResult { + pub fn execute(self, len: usize, ctx: &mut CudaExecutionCtx) -> VortexResult { + let output_ptype = self.dispatch_plan.output_ptype(); + // The CUDA kernels are instantiated for unsigned integer types only; + // map signed/float ptypes to their same-width unsigned counterpart. let unsigned_ptype = match output_ptype { PType::U8 | PType::I8 => PType::U8, PType::U16 | PType::I16 => PType::U16, @@ -393,10 +447,10 @@ impl MaterializedPlan { ); let cuda_function = ctx.load_function("dynamic_dispatch", &[T::PTYPE])?; - let num_blocks = u32::try_from(len.div_ceil(2048))?; + let num_blocks = u32::try_from(len.div_ceil(ELEMENTS_PER_BLOCK as usize))?; let config = LaunchConfig { grid_dim: (num_blocks, 1, 1), - block_dim: (64, 1, 1), + block_dim: (BLOCK_SIZE, 1, 1), shared_mem_bytes: self.shared_mem_bytes, }; @@ -449,16 +503,12 @@ mod tests { use vortex::error::VortexResult; use vortex::session::VortexSession; - use super::CudaDispatchPlan; - use super::DispatchPlan; - use super::MaterializedStage; - use super::SMEM_TILE_SIZE; - use super::ScalarOp; - use super::SourceOp; use super::*; + use crate::CanonicalCudaExt; use crate::CudaBufferExt; use crate::CudaDeviceBuffer; use crate::CudaExecutionCtx; + use crate::hybrid_dispatch::try_gpu_dispatch; use crate::session::CudaSession; fn bitpacked_array_u32(bit_width: u8, len: usize) -> BitPackedArray { @@ -501,16 +551,20 @@ mod tests { let scalar_ops: Vec = references .iter() - .map(|&r| ScalarOp::frame_of_ref(r as u64)) + .map(|&r| ScalarOp::frame_of_ref(r as u64, PTypeTag_PTYPE_U32)) .collect(); - let plan = CudaDispatchPlan::new([MaterializedStage::new( - input_ptr, - 0, - len as u32, - SourceOp::bitunpack(bit_width, 0), - &scalar_ops, - )]); + let plan = CudaDispatchPlan::new( + [MaterializedStage::new( + input_ptr, + 0, + len as u32, + PTypeTag_PTYPE_U32, + SourceOp::bitunpack(bit_width, 0), + &scalar_ops, + )], + PTypeTag_PTYPE_U32, + ); assert_eq!(plan.stage(0).num_scalar_ops, 4); let actual = run_dynamic_dispatch_plan(&cuda_ctx, len, &plan, SMEM_TILE_SIZE * 4)?; @@ -521,41 +575,52 @@ mod tests { #[crate::test] fn test_plan_structure() { - // Stage 0: input dict values (BP→FoR) into smem[0..256) - // Stage 1: output codes (BP→FoR→DICT) into smem[256..1280), gather from smem[0] - let plan = CudaDispatchPlan::new([ - MaterializedStage::new( - 0xAAAA, - 0, - 256, - SourceOp::bitunpack(4, 0), - &[ScalarOp::frame_of_ref(10)], - ), - MaterializedStage::new( - 0xBBBB, - 256, - 1024, - SourceOp::bitunpack(6, 0), - &[ScalarOp::frame_of_ref(42), ScalarOp::dict(0)], - ), - ]); + // Stage 0: input dict values (BP→FoR), 256 u32 elements → smem bytes [0..1024) + // Stage 1: output codes (BP→FoR→DICT), 1024 elements, gather from smem byte 0 + let values_smem_bytes: u32 = 256 * 4; // 256 u32 elements × 4 bytes + let plan = CudaDispatchPlan::new( + [ + MaterializedStage::new( + 0xAAAA, + 0, + 256, + PTypeTag_PTYPE_U32, + SourceOp::bitunpack(4, 0), + &[ScalarOp::frame_of_ref(10, PTypeTag_PTYPE_U32)], + ), + MaterializedStage::new( + 0xBBBB, + values_smem_bytes, + 1024, + PTypeTag_PTYPE_U32, + SourceOp::bitunpack(6, 0), + &[ + ScalarOp::frame_of_ref(42, PTypeTag_PTYPE_U32), + ScalarOp::dict(0, PTypeTag_PTYPE_U32), + ], + ), + ], + PTypeTag_PTYPE_U32, + ); assert_eq!(plan.num_stages(), 2); // Input stage let s0 = plan.stage(0); - assert_eq!(s0.smem_offset, 0); + assert_eq!(s0.smem_byte_offset, 0); assert_eq!(s0.len, 256); + assert_eq!(s0.source_ptype, PTypeTag_PTYPE_U32); assert_eq!(s0.input_ptr, 0xAAAA); // Output stage let s1 = plan.stage(1); - assert_eq!(s1.smem_offset, 256); + assert_eq!(s1.smem_byte_offset, values_smem_bytes); assert_eq!(s1.len, SMEM_TILE_SIZE); + assert_eq!(s1.source_ptype, PTypeTag_PTYPE_U32); assert_eq!(s1.input_ptr, 0xBBBB); assert_eq!(s1.num_scalar_ops, 2); assert_eq!( - unsafe { s1.scalar_ops[1].params.dict.values_smem_offset }, + unsafe { s1.scalar_ops[1].params.dict.values_smem_byte_offset }, 0 ); } @@ -593,17 +658,21 @@ mod tests { let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?; let (input_ptr, _di) = copy_raw_to_device(&cuda_ctx, &data)?; - let plan = CudaDispatchPlan::new([MaterializedStage::new( - input_ptr, - 0, - len as u32, - SourceOp::load(), - &[ - ScalarOp::frame_of_ref(reference as u64), - ScalarOp::zigzag(), - ScalarOp::alp(alp_f, alp_e), - ], - )]); + let plan = CudaDispatchPlan::new( + [MaterializedStage::new( + input_ptr, + 0, + len as u32, + PTypeTag_PTYPE_U32, + SourceOp::load(), + &[ + ScalarOp::frame_of_ref(reference as u64, PTypeTag_PTYPE_U32), + ScalarOp::zigzag(PTypeTag_PTYPE_U32), + ScalarOp::alp(alp_f, alp_e), + ], + )], + PTypeTag_PTYPE_U32, + ); let actual = run_dynamic_dispatch_plan(&cuda_ctx, len, &plan, SMEM_TILE_SIZE * 4)?; assert_eq!(actual, expected); @@ -644,10 +713,10 @@ mod tests { launch_builder.arg(&array_len_u64); launch_builder.arg(&plan_ptr); - let num_blocks = u32::try_from(output_len.div_ceil(2048))?; + let num_blocks = u32::try_from(output_len.div_ceil(ELEMENTS_PER_BLOCK as usize))?; let config = LaunchConfig { grid_dim: (num_blocks, 1, 1), - block_dim: (64, 1, 1), + block_dim: (BLOCK_SIZE, 1, 1), shared_mem_bytes, }; unsafe { @@ -949,38 +1018,106 @@ mod tests { } #[crate::test] - fn test_dict_mismatched_ptypes_rejected() -> VortexResult<()> { + async fn test_dict_mixed_width_u8_codes_u32_values() -> VortexResult<()> { let dict_values: Vec = vec![100, 200, 300, 400]; let len = 3000; let codes: Vec = (0..len).map(|i| (i % dict_values.len()) as u8).collect(); - let codes_prim = PrimitiveArray::new(Buffer::from(codes), NonNullable); - let values_prim = PrimitiveArray::new(Buffer::from(dict_values), NonNullable); + let codes_prim = PrimitiveArray::new(Buffer::from(codes.clone()), NonNullable); + let values_prim = PrimitiveArray::new(Buffer::from(dict_values.clone()), NonNullable); + let dict = DictArray::try_new(codes_prim.into_array(), values_prim.into_array())?; + let array = dict.into_array(); + + // Mixed-width Dict (u8 codes, u32 values): both are Primitive, so + // walk_mixed_width_child grabs the codes buffer directly as a LOAD + // source. No pending subtrees → Fused. + let plan = DispatchPlan::new(&array)?; + assert!( + matches!(plan, DispatchPlan::Fused(..)), + "expected Fused for mixed-width Dict with primitive codes" + ); + + // Execute through the hybrid dispatch path (handles widening). + let mut cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?; + let canonical = try_gpu_dispatch(&array, &mut cuda_ctx).await?; + let result = CanonicalCudaExt::into_host(canonical).await?.into_array(); + + let expected: Vec = codes.iter().map(|&c| dict_values[c as usize]).collect(); + let expected_arr = PrimitiveArray::new(Buffer::from(expected), NonNullable).into_array(); + vortex::array::assert_arrays_eq!(expected_arr, result); + + Ok(()) + } + + #[crate::test] + async fn test_dict_mixed_width_u16_codes_u32_values() -> VortexResult<()> { + let dict_values: Vec = vec![1000, 2000, 3000, 4000, 5000]; + let len = 2048; + let codes: Vec = (0..len).map(|i| (i % dict_values.len()) as u16).collect(); + + let codes_prim = PrimitiveArray::new(Buffer::from(codes.clone()), NonNullable); + let values_prim = PrimitiveArray::new(Buffer::from(dict_values.clone()), NonNullable); let dict = DictArray::try_new(codes_prim.into_array(), values_prim.into_array())?; + let array = dict.into_array(); + + // Mixed-width Dict (u16 codes, u32 values): both are Primitive, so + // walk_mixed_width_child grabs the codes buffer directly as a LOAD + // source. No pending subtrees → Fused. + let plan = DispatchPlan::new(&array)?; + assert!( + matches!(plan, DispatchPlan::Fused(..)), + "expected Fused for mixed-width Dict with primitive codes" + ); + + // Execute through the hybrid dispatch path (handles widening). + let mut cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?; + let canonical = try_gpu_dispatch(&array, &mut cuda_ctx).await?; + let result = CanonicalCudaExt::into_host(canonical).await?.into_array(); - // DispatchPlan::new should return Unfused because u8 codes != u32 values in byte width. - assert!(matches!( - DispatchPlan::new(&dict.into_array())?, - DispatchPlan::Unfused - )); + let expected: Vec = codes.iter().map(|&c| dict_values[c as usize]).collect(); + let expected_arr = PrimitiveArray::new(Buffer::from(expected), NonNullable).into_array(); + vortex::array::assert_arrays_eq!(expected_arr, result); Ok(()) } #[crate::test] - fn test_runend_mismatched_ptypes_rejected() -> VortexResult<()> { + async fn test_runend_mixed_width_u64_ends_u32_values() -> VortexResult<()> { let ends: Vec = vec![1000, 2000, 3000]; - let values: Vec = vec![10, 20, 30]; + let values: Vec = vec![10, 20, 30]; + let len = 3000; let ends_arr = PrimitiveArray::new(Buffer::from(ends), NonNullable).into_array(); let values_arr = PrimitiveArray::new(Buffer::from(values), NonNullable).into_array(); let re = RunEnd::new(ends_arr, values_arr); + let array = re.into_array(); + + // Ends (u64) are wider than values (u32), so the kernel would truncate + // ends via load_element. The plan builder rejects this as Unfused. + let plan = DispatchPlan::new(&array)?; + assert!( + matches!(plan, DispatchPlan::Unfused), + "expected Unfused for RunEnd with wider ends" + ); - // DispatchPlan::new should return Unfused because u64 ends != i32 values in byte width. - assert!(matches!( - DispatchPlan::new(&re.into_array())?, - DispatchPlan::Unfused - )); + // Execute through the non-fused dispatch path. + let mut cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?; + let canonical = try_gpu_dispatch(&array, &mut cuda_ctx).await?; + let result = CanonicalCudaExt::into_host(canonical).await?.into_array(); + + let expected: Vec = (0..len as u64) + .map(|i| { + if i < 1000 { + 10 + } else if i < 2000 { + 20 + } else { + 30 + } + }) + .collect(); + let expected_arr = PrimitiveArray::new(Buffer::from(expected), NonNullable).into_array(); + vortex::array::assert_arrays_eq!(expected_arr, result); Ok(()) } @@ -1345,4 +1482,338 @@ mod tests { Ok(()) } + + #[crate::test] + async fn test_for_bitpacked_u8() -> VortexResult<()> { + let bit_width: u8 = 4; + let len = 3000; + let reference = 100u8; + let max_val = (1u64 << bit_width).saturating_sub(1); + let residuals: Vec = (0..len).map(|i| (i as u64 % (max_val + 1)) as u8).collect(); + let expected: Vec = residuals + .iter() + .map(|&r| r.wrapping_add(reference)) + .collect(); + + let primitive = PrimitiveArray::new(Buffer::from(residuals), NonNullable); + let bp = BitPacked::encode(&primitive.into_array(), bit_width).vortex_expect("bitpack u8"); + let for_arr = FoR::try_new( + bp.into_array(), + Scalar::primitive(reference, Nullability::NonNullable), + )?; + let array = for_arr.into_array(); + + let mut cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?; + let canonical = try_gpu_dispatch(&array, &mut cuda_ctx).await?; + let result = CanonicalCudaExt::into_host(canonical).await?.into_array(); + + let expected_arr = PrimitiveArray::new(Buffer::from(expected), NonNullable).into_array(); + vortex::array::assert_arrays_eq!(expected_arr, result); + Ok(()) + } + + #[crate::test] + async fn test_for_bitpacked_u16() -> VortexResult<()> { + let bit_width: u8 = 10; + let len = 3000; + let reference = 1000u16; + let max_val = (1u64 << bit_width).saturating_sub(1); + let residuals: Vec = (0..len) + .map(|i| (i as u64 % (max_val + 1)) as u16) + .collect(); + let expected: Vec = residuals + .iter() + .map(|&r| r.wrapping_add(reference)) + .collect(); + + let primitive = PrimitiveArray::new(Buffer::from(residuals), NonNullable); + let bp = BitPacked::encode(&primitive.into_array(), bit_width).vortex_expect("bitpack u16"); + let for_arr = FoR::try_new( + bp.into_array(), + Scalar::primitive(reference, Nullability::NonNullable), + )?; + let array = for_arr.into_array(); + + let mut cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?; + let canonical = try_gpu_dispatch(&array, &mut cuda_ctx).await?; + let result = CanonicalCudaExt::into_host(canonical).await?.into_array(); + + let expected_arr = PrimitiveArray::new(Buffer::from(expected), NonNullable).into_array(); + vortex::array::assert_arrays_eq!(expected_arr, result); + Ok(()) + } + + #[crate::test] + async fn test_for_bitpacked_u64() -> VortexResult<()> { + let bit_width: u8 = 20; + let len = 3000; + let reference = 100_000u64; + let max_val = (1u64 << bit_width).saturating_sub(1); + let residuals: Vec = (0..len).map(|i| i as u64 % (max_val + 1)).collect(); + let expected: Vec = residuals + .iter() + .map(|&r| r.wrapping_add(reference)) + .collect(); + + let primitive = PrimitiveArray::new(Buffer::from(residuals), NonNullable); + let bp = BitPacked::encode(&primitive.into_array(), bit_width).vortex_expect("bitpack u64"); + let for_arr = FoR::try_new( + bp.into_array(), + Scalar::primitive(reference, Nullability::NonNullable), + )?; + let array = for_arr.into_array(); + + let mut cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?; + let canonical = try_gpu_dispatch(&array, &mut cuda_ctx).await?; + let result = CanonicalCudaExt::into_host(canonical).await?.into_array(); + + let expected_arr = PrimitiveArray::new(Buffer::from(expected), NonNullable).into_array(); + vortex::array::assert_arrays_eq!(expected_arr, result); + Ok(()) + } + + #[crate::test] + async fn test_empty_array() -> VortexResult<()> { + let values: Vec = vec![]; + let primitive = PrimitiveArray::new(Buffer::from(values), NonNullable); + let mut cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?; + let canonical = try_gpu_dispatch(&primitive.into_array(), &mut cuda_ctx).await?; + let result = CanonicalCudaExt::into_host(canonical).await?.into_array(); + assert_eq!(result.len(), 0); + Ok(()) + } + + #[crate::test] + async fn test_single_element() -> VortexResult<()> { + let values: Vec = vec![42]; + let primitive = PrimitiveArray::new(Buffer::from(values.clone()), NonNullable); + let bp = BitPacked::encode(&primitive.into_array(), 6).vortex_expect("bitpack"); + let for_arr = FoR::try_new( + bp.into_array(), + Scalar::primitive(0u32, Nullability::NonNullable), + )?; + let array = for_arr.into_array(); + + let mut cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?; + let canonical = try_gpu_dispatch(&array, &mut cuda_ctx).await?; + let result = CanonicalCudaExt::into_host(canonical).await?.into_array(); + + let expected = PrimitiveArray::new(Buffer::from(values), NonNullable).into_array(); + vortex::array::assert_arrays_eq!(expected, result); + Ok(()) + } + + #[crate::test] + async fn test_exactly_elements_per_block() -> VortexResult<()> { + // Exactly 2048 elements — one full block, no remainder + let bit_width: u8 = 6; + let len = 2048; + let reference = 1000u32; + let max_val = (1u64 << bit_width).saturating_sub(1); + let residuals: Vec = (0..len) + .map(|i| (i as u64 % (max_val + 1)) as u32) + .collect(); + let expected: Vec = residuals.iter().map(|&r| r + reference).collect(); + + let primitive = PrimitiveArray::new(Buffer::from(residuals), NonNullable); + let bp = BitPacked::encode(&primitive.into_array(), bit_width).vortex_expect("bitpack"); + let for_arr = FoR::try_new( + bp.into_array(), + Scalar::primitive(reference, Nullability::NonNullable), + )?; + let array = for_arr.into_array(); + + let mut cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?; + let canonical = try_gpu_dispatch(&array, &mut cuda_ctx).await?; + let result = CanonicalCudaExt::into_host(canonical).await?.into_array(); + + let expected_arr = PrimitiveArray::new(Buffer::from(expected), NonNullable).into_array(); + vortex::array::assert_arrays_eq!(expected_arr, result); + Ok(()) + } + + #[crate::test] + fn test_f64_rejected() { + // F64 arrays should be rejected by the plan builder, not silently accepted. + let values: Vec = vec![1.0, 2.0, 3.0]; + let primitive = PrimitiveArray::new(Buffer::from(values), NonNullable); + let plan = DispatchPlan::new(&primitive.into_array()) + .expect("DispatchPlan::new should not fail for f64"); + assert!( + matches!(plan, DispatchPlan::Unfused), + "expected F64 to be classified as Unfused" + ); + } + + #[crate::test] + async fn test_runend_u32_ends_u16_values() -> VortexResult<()> { + // RunEnd with u32 ends, u16 values. Output type = u16. + // Ends (u32) differ from output (u16) → pending subtree. + let ends: Vec = vec![500, 1000, 1500, 2000]; + let values: Vec = vec![100, 200, 300, 400]; + let len = 2000; + + let ends_arr = PrimitiveArray::new(Buffer::from(ends), NonNullable).into_array(); + let values_arr = PrimitiveArray::new(Buffer::from(values), NonNullable).into_array(); + let re = RunEnd::new(ends_arr, values_arr); + let array = re.into_array(); + + // Ends (u32) are wider than values (u16), so the kernel would truncate + // ends via load_element. The plan builder rejects this as Unfused. + let plan = DispatchPlan::new(&array)?; + assert!( + matches!(plan, DispatchPlan::Unfused), + "expected Unfused for RunEnd with wider ends" + ); + + let mut cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?; + let canonical = try_gpu_dispatch(&array, &mut cuda_ctx).await?; + let result = CanonicalCudaExt::into_host(canonical).await?.into_array(); + + let expected: Vec = (0..len as u64) + .map(|i| { + if i < 500 { + 100u16 + } else if i < 1000 { + 200 + } else if i < 1500 { + 300 + } else { + 400 + } + }) + .collect(); + let expected_arr = PrimitiveArray::new(Buffer::from(expected), NonNullable).into_array(); + vortex::array::assert_arrays_eq!(expected_arr, result); + + Ok(()) + } + + #[crate::test] + async fn test_dict_bitpacked_u8_codes_u32_values() -> VortexResult<()> { + // Dict with BitPacked u8 codes (narrower than u32 output) and u32 values. + // Codes become a pending subtree, values fuse. + let dict_values: Vec = vec![100, 200, 300, 400]; + let len = 2048; + let codes: Vec = (0..len).map(|i| (i % dict_values.len()) as u8).collect(); + + let codes_prim = PrimitiveArray::new(Buffer::from(codes.clone()), NonNullable); + // BitPack the u8 codes at 2 bits (4 values need 2 bits) + let codes_bp = + BitPacked::encode(&codes_prim.into_array(), 2).vortex_expect("bitpack codes"); + let values_prim = PrimitiveArray::new(Buffer::from(dict_values.clone()), NonNullable); + let dict = DictArray::try_new(codes_bp.into_array(), values_prim.into_array())?; + let array = dict.into_array(); + + let plan = DispatchPlan::new(&array)?; + assert!( + matches!(plan, DispatchPlan::PartiallyFused { .. }), + "expected PartiallyFused for mixed-width Dict with BitPacked codes" + ); + + let mut cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?; + let canonical = try_gpu_dispatch(&array, &mut cuda_ctx).await?; + let result = CanonicalCudaExt::into_host(canonical).await?.into_array(); + + let expected: Vec = codes.iter().map(|&c| dict_values[c as usize]).collect(); + let expected_arr = PrimitiveArray::new(Buffer::from(expected), NonNullable).into_array(); + vortex::array::assert_arrays_eq!(expected_arr, result); + + Ok(()) + } + + #[crate::test] + async fn test_sliced_dict_mixed_width() -> VortexResult<()> { + // Sliced Dict with u8 codes and u32 values — combines PartiallyFused + slice handling. + let dict_values: Vec = vec![100, 200, 300, 400]; + let full_len = 4096; + let codes: Vec = (0..full_len) + .map(|i| (i % dict_values.len()) as u8) + .collect(); + + let codes_prim = PrimitiveArray::new(Buffer::from(codes.clone()), NonNullable); + let values_prim = PrimitiveArray::new(Buffer::from(dict_values.clone()), NonNullable); + let dict = DictArray::try_new(codes_prim.into_array(), values_prim.into_array())?; + + // Slice from 1000..3000 + let sliced = dict.into_array().slice(1000..3000)?; + + let mut cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?; + let canonical = try_gpu_dispatch(&sliced, &mut cuda_ctx).await?; + let result = CanonicalCudaExt::into_host(canonical).await?.into_array(); + + let expected: Vec = codes[1000..3000] + .iter() + .map(|&c| dict_values[c as usize]) + .collect(); + let expected_arr = PrimitiveArray::new(Buffer::from(expected), NonNullable).into_array(); + vortex::array::assert_arrays_eq!(expected_arr, result); + + Ok(()) + } + + /// Verify that `load_element` sign-extends signed narrow types when + /// widening to a wider T. E.g. i8(-1) = 0xFF must become u32(0xFFFFFFFF) + /// (the bit-pattern for i32(-1)), not u32(0x000000FF) = 255. + #[crate::test] + fn test_load_element_sign_extends_i8_to_u32() -> VortexResult<()> { + let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?; + + let i8_values: Vec = vec![-1, -2, -3, 127, -128, 0, 1, 42]; + let len = i8_values.len(); + let device_buf = Arc::new(cuda_ctx.stream().clone_htod(&i8_values).expect("htod")); + let (input_ptr, _) = device_buf.device_ptr(cuda_ctx.stream()); + + // Build a single-stage LOAD plan: source ptype = I8, output ptype = U32. + // The kernel (instantiated as u32) must sign-extend each i8 element. + let plan = CudaDispatchPlan::new( + [MaterializedStage::new( + input_ptr, + 0, + len as u32, + PTypeTag_PTYPE_I8, + SourceOp::load(), + &[], + )], + PTypeTag_PTYPE_U32, + ); + + let actual = run_dynamic_dispatch_plan(&cuda_ctx, len, &plan, SMEM_TILE_SIZE * 4)?; + + // Expected: each i8 sign-extended to i32, then viewed as u32. + let expected: Vec = i8_values.iter().map(|&v| (v as i32) as u32).collect(); + assert_eq!(actual, expected); + + Ok(()) + } + + /// Same as above but for i16 → u32 widening. + #[crate::test] + fn test_load_element_sign_extends_i16_to_u32() -> VortexResult<()> { + let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?; + + let i16_values: Vec = vec![-1, -256, -32768, 32767, 0, 1, -100, 12345]; + let len = i16_values.len(); + let device_buf = Arc::new(cuda_ctx.stream().clone_htod(&i16_values).expect("htod")); + let (input_ptr, _) = device_buf.device_ptr(cuda_ctx.stream()); + + let plan = CudaDispatchPlan::new( + [MaterializedStage::new( + input_ptr, + 0, + len as u32, + PTypeTag_PTYPE_I16, + SourceOp::load(), + &[], + )], + PTypeTag_PTYPE_U32, + ); + + let actual = run_dynamic_dispatch_plan(&cuda_ctx, len, &plan, SMEM_TILE_SIZE * 4)?; + + let expected: Vec = i16_values.iter().map(|&v| (v as i32) as u32).collect(); + assert_eq!(actual, expected); + + Ok(()) + } } diff --git a/vortex-cuda/src/dynamic_dispatch/plan_builder.rs b/vortex-cuda/src/dynamic_dispatch/plan_builder.rs index 1ea06e522ba..4bec033d4ff 100644 --- a/vortex-cuda/src/dynamic_dispatch/plan_builder.rs +++ b/vortex-cuda/src/dynamic_dispatch/plan_builder.rs @@ -34,9 +34,12 @@ use vortex::error::vortex_err; use super::CudaDispatchPlan; use super::MaterializedStage; +use super::PTypeTag; use super::SMEM_TILE_SIZE; use super::ScalarOp; use super::SourceOp; +use super::ptype_to_tag; +use super::tag_to_ptype; use crate::CudaBufferExt; use crate::CudaExecutionCtx; @@ -52,6 +55,12 @@ pub struct MaterializedPlan { /// Checks whether the encoding of an array can be fused into a dynamic-dispatch plan. fn is_dyn_dispatch_compatible(array: &ArrayRef) -> bool { + // The dynamic dispatch kernel only supports F32 floats (via ALP). + // F16 and F64 have no reinterpret path in the kernel. + if matches!(PType::try_from(array.dtype()), Ok(PType::F16 | PType::F64)) { + return false; + } + let id = array.encoding_id(); if id == ALP::ID { let arr = array.as_::(); @@ -62,25 +71,27 @@ fn is_dyn_dispatch_compatible(array: &ArrayRef) -> bool { } if id == Dict::ID { let arr = array.as_::(); - // As of now the dict dyn dispatch kernel requires - // codes and values to have the same byte width. - return match ( - PType::try_from(arr.values().dtype()), - PType::try_from(arr.codes().dtype()), - ) { - (Ok(values), Ok(codes)) => values.byte_width() == codes.byte_width(), + // Dict codes and values may have different byte widths. + // The kernel handles mixed widths via widening input stages, + // but only when codes are no wider than values (the output type). + // Wider codes would be truncated by load_element(). + let values_ptype = PType::try_from(arr.values().dtype()); + let codes_ptype = PType::try_from(arr.codes().dtype()); + return match (values_ptype, codes_ptype) { + (Ok(vp), Ok(cp)) => cp.byte_width() <= vp.byte_width(), _ => false, }; } if id == RunEnd::ID { let arr = array.as_::(); - // As of now the run-end dyn dispatch kernel requires - // ends and values to have the same byte width. - return match ( - PType::try_from(arr.ends().dtype()), - PType::try_from(arr.values().dtype()), - ) { - (Ok(e), Ok(v)) => e.byte_width() == v.byte_width(), + // RunEnd ends and values may have different byte widths. + // The kernel handles mixed widths via widening input stages, + // but only when ends are no wider than values (the output type). + // Wider ends would be truncated by load_element(). + let ends_ptype = PType::try_from(arr.ends().dtype()); + let values_ptype = PType::try_from(arr.values().dtype()); + return match (ends_ptype, values_ptype) { + (Ok(ep), Ok(vp)) => ep.byte_width() <= vp.byte_width(), _ => false, }; } @@ -98,19 +109,22 @@ struct Stage { /// Index into `FusedPlan::source_buffers`, or `None` /// for sources that don't read from a device buffer. source_buffer_index: Option, + /// PType tag for the source op's output type. + source_ptype: PTypeTag, } impl Stage { - fn new(source: SourceOp, source_buffer_index: Option) -> Self { + fn new(source: SourceOp, source_buffer_index: Option, source_ptype: PTypeTag) -> Self { Self { source, scalar_ops: vec![], source_buffer_index, + source_ptype, } } } -type SmemOffset = u32; +type SmemByteOffset = u32; type OutputLen = u32; /// A dispatch plan before device materialization. @@ -146,16 +160,35 @@ pub enum DispatchPlan { /// reference data from the earlier stages), and writing the result to /// global memory. /// +/// # Per-stage PType tracking +/// +/// Each stage carries a `source_ptype` (`PTypeTag`) that identifies the +/// primitive type produced by its source op (LOAD, BITUNPACK, etc.). +/// Scalar ops may change the type (e.g. DICT transforms codes → values, +/// ALP transforms encoded ints → floats); each `ScalarOp` declares its +/// `output_ptype`. The kernel uses these tags to dispatch typed memory +/// operations and cross-stage references at the correct element width. +/// /// # Shared memory allocation /// -/// Total shared memory = (`smem_cursor` + `SMEM_TILE_SIZE`) × `elem_bytes`. +/// Total shared memory = `smem_byte_cursor` + `SMEM_TILE_SIZE` × `output_elem_bytes`. +/// +/// `smem_byte_cursor` is tracked in bytes and covers the preceding +/// fully-decoded stages (dict values, run-end endpoints). Each stage's +/// shared memory footprint is `len × final_ptype_byte_width`, where the +/// final ptype is determined by the last scalar op's `output_ptype` (or +/// `source_ptype` if there are no scalar ops). +/// +/// All shared memory offsets are byte offsets — the C ABI uses byte +/// offsets and per-field `PTypeTag` values so that stages with different +/// element widths can coexist in the same shared memory pool. /// /// This is sufficient because: /// /// - Earlier stages only originate from dict (values) and run-end (ends, -/// values). `push_smem_stage` reserves the full auxiliary data length in -/// `smem_cursor`, so each stage's source op has room to decode the complete -/// input. +/// values). `push_smem_stage` reserves the appropriate number of bytes +/// in `smem_byte_cursor`, so each stage's source op has room to decode +/// the complete input. /// /// - The output stage (last) tiles at `SMEM_TILE_SIZE` (1024 elements), /// so its source op never writes more than 1024 elements into the @@ -169,13 +202,16 @@ pub enum DispatchPlan { pub struct FusedPlan { /// Stages in kernel execution order; all but the last decode into /// shared memory, the last decodes into global memory. - stages: Vec<(Stage, SmemOffset, OutputLen)>, - /// Shared memory reserved by the non-output stages. - smem_cursor: SmemOffset, + stages: Vec<(Stage, SmemByteOffset, OutputLen)>, + /// Shared memory reserved by the non-output stages, in bytes. + smem_byte_cursor: SmemByteOffset, /// Source buffers. `None` entries are placeholder slots for pending subtrees, /// filled by [`materialize_with_subtrees`] before device copy. source_buffers: Vec>, - elem_bytes: u32, + /// Bytes per element of the root (output) array. + output_elem_bytes: u32, + /// PType of the root (output) array, as a C ABI tag. + output_ptype: PTypeTag, } impl DispatchPlan { @@ -212,56 +248,64 @@ impl DispatchPlan { } impl FusedPlan { - /// Maximum shared memory per block in bytes (48 KB). + /// Maximum shared memory per block in bytes (48 KB, static + dynamic). /// - /// 48 KB is the default per-block dynamic shared memory limit across - /// all CUDA architectures. Higher limits (up to 227 KB on Hopper) - /// require an explicit opt-in via `cuFuncSetAttribute`. + /// 48 KB is the default per-block shared memory limit across all CUDA + /// architectures. Higher limits (up to 227 KB on Hopper) require an + /// explicit opt-in via `cuFuncSetAttribute`. const MAX_SHARED_MEM_BYTES: u32 = 48 * 1024; + /// Fixed shared memory used by the kernel (bytes). + /// Sourced from the C header via bindgen. + const FIXED_SHARED_MEM_BYTES: u32 = super::KERNEL_FIXED_SHARED_BYTES; + /// Build a plan by walking the encoding tree from root to leaf. /// /// During the walk, incompatible nodes are discovered and recorded in the /// returned `Vec`. fn build(array: &ArrayRef) -> VortexResult<(Self, Vec)> { - let elem_bytes = PType::try_from(array.dtype()) - .map_err(|_| { - vortex_err!( - "dyn dispatch requires primitive dtype, got {:?}", - array.dtype() - ) - })? - .byte_width() as u32; + let output_ptype_rust = PType::try_from(array.dtype()).map_err(|_| { + vortex_err!( + "dyn dispatch requires primitive dtype, got {:?}", + array.dtype() + ) + })?; + if output_ptype_rust == PType::F64 { + vortex_bail!("dynamic dispatch does not support f64 output"); + } + let output_elem_bytes = output_ptype_rust.byte_width() as u32; + let output_ptype = ptype_to_tag(output_ptype_rust); let mut pending_subtrees: Vec = Vec::new(); let mut plan = Self { stages: Vec::new(), - smem_cursor: SmemOffset::from(0u32), + smem_byte_cursor: 0u32, source_buffers: Vec::new(), - elem_bytes, + output_elem_bytes, + output_ptype, }; let len = array.len() as u32; let output = plan.walk(array.clone(), &mut pending_subtrees)?; - plan.stages.push((output, plan.smem_cursor, len)); + plan.stages.push((output, plan.smem_byte_cursor, len)); Ok((plan, pending_subtrees)) } - /// Shared memory bytes needed to launch this plan. - /// - /// `smem_cursor` covers the preceding fully-decoded stages (dict values, - /// run-end ). `SMEM_TILE_SIZE` covers the output stage's scratch region — - /// the output stage processes `ELEMENTS_PER_BLOCK` (2048) elements per - /// block by tiling through this 1024-element window. - fn shared_mem_bytes(&self) -> u32 { - (self.smem_cursor + SMEM_TILE_SIZE) * self.elem_bytes + /// Dynamic shared memory bytes passed to the CUDA launch config. + fn dynamic_shared_mem_bytes(&self) -> u32 { + self.smem_byte_cursor + SMEM_TILE_SIZE * self.output_elem_bytes + } + + /// Total shared memory (fixed + dynamic) for limit checking. + fn total_shared_mem_bytes(&self) -> u32 { + Self::FIXED_SHARED_MEM_BYTES + self.dynamic_shared_mem_bytes() } /// Returns `true` if this plan's shared memory requirement exceeds /// the per-block limit, logging a trace message when it does. fn exceeds_shared_mem_limit(&self) -> bool { - let required = self.shared_mem_bytes(); + let required = self.total_shared_mem_bytes(); if required > Self::MAX_SHARED_MEM_BYTES { trace!( required, @@ -275,7 +319,7 @@ impl FusedPlan { /// Copy source buffers to the device, producing a [`MaterializedPlan`]. pub fn materialize(self, ctx: &CudaExecutionCtx) -> VortexResult { - let shared_mem_bytes = self.shared_mem_bytes(); + let shared_mem_bytes = self.dynamic_shared_mem_bytes(); let mut device_buffers = Vec::new(); let mut device_ptrs: Vec = Vec::new(); @@ -298,14 +342,18 @@ impl FusedPlan { } }; + // Byte offsets are passed directly to the C ABI — the kernel now + // indexes shared memory by byte offset and casts to the correct type + // using source_ptype / output_ptype. let stages: Vec = self .stages .iter() - .map(|(stage, smem_offset, len)| { + .map(|(stage, smem_byte_offset, len)| { MaterializedStage::new( resolve_ptr(stage), - *smem_offset, + *smem_byte_offset, *len, + stage.source_ptype, stage.source, &stage.scalar_ops, ) @@ -313,7 +361,7 @@ impl FusedPlan { .collect(); Ok(MaterializedPlan { - dispatch_plan: CudaDispatchPlan::new(stages), + dispatch_plan: CudaDispatchPlan::new(stages, self.output_ptype), device_buffers, shared_mem_bytes, }) @@ -347,16 +395,20 @@ impl FusedPlan { if !is_dyn_dispatch_compatible(&array) { // Subtree can't be fused — record it as a deferred LOAD source. // Bail if dtype is non-primitive (can't become a LOAD stage). - if PType::try_from(array.dtype()).is_err() { - vortex_bail!( + let ptype = PType::try_from(array.dtype()).map_err(|_| { + vortex_err!( "unfusable subtree has non-primitive dtype {:?}, cannot partially fuse", array.dtype() - ); - } + ) + })?; let buf_idx = self.source_buffers.len(); self.source_buffers.push(None); // placeholder, filled at materialize time pending_subtrees.push(array); - return Ok(Stage::new(SourceOp::load(), Some(buf_idx))); + return Ok(Stage::new( + SourceOp::load(), + Some(buf_idx), + ptype_to_tag(ptype), + )); } let id = array.encoding_id(); @@ -413,7 +465,11 @@ impl FusedPlan { let prim = array.as_::(); let buf_index = self.source_buffers.len(); self.source_buffers.push(Some(prim.buffer_handle().clone())); - Ok(Stage::new(SourceOp::load(), Some(buf_index))) + Ok(Stage::new( + SourceOp::load(), + Some(buf_index), + ptype_to_tag(prim.ptype()), + )) } fn walk_bitpacked(&mut self, array: ArrayRef) -> VortexResult { @@ -423,11 +479,15 @@ impl FusedPlan { vortex_bail!("Dynamic dispatch does not support BitPackedArray with patches"); } + let source_ptype = ptype_to_tag(PType::try_from(bp.dtype()).map_err(|_| { + vortex_err!("BitPacked must have primitive dtype, got {:?}", bp.dtype()) + })?); let buf_index = self.source_buffers.len(); self.source_buffers.push(Some(bp.packed().clone())); Ok(Stage::new( SourceOp::bitunpack(bp.bit_width(), bp.offset()), Some(buf_index), + source_ptype, )) } @@ -443,12 +503,18 @@ impl FusedPlan { .pvalue() .ok_or_else(|| vortex_err!("FoR reference scalar is null"))?; let encoded = for_arr.encoded().clone(); + let output_ptype = + ptype_to_tag(PType::try_from(array.dtype()).map_err(|_| { + vortex_err!("FoR must have primitive dtype, got {:?}", array.dtype()) + })?); let mut pipeline = self.walk(encoded, pending_subtrees)?; let ref_u64 = ref_pvalue .reinterpret_cast(ref_pvalue.ptype().to_unsigned()) .cast::()?; - pipeline.scalar_ops.push(ScalarOp::frame_of_ref(ref_u64)); + pipeline + .scalar_ops + .push(ScalarOp::frame_of_ref(ref_u64, output_ptype)); Ok(pipeline) } @@ -459,9 +525,12 @@ impl FusedPlan { ) -> VortexResult { let zz = array.as_::(); let encoded = zz.encoded().clone(); + let output_ptype = ptype_to_tag(PType::try_from(array.dtype()).map_err(|_| { + vortex_err!("ZigZag must have primitive dtype, got {:?}", array.dtype()) + })?); let mut pipeline = self.walk(encoded, pending_subtrees)?; - pipeline.scalar_ops.push(ScalarOp::zigzag()); + pipeline.scalar_ops.push(ScalarOp::zigzag(output_ptype)); Ok(pipeline) } @@ -494,6 +563,31 @@ impl FusedPlan { Ok(pipeline) } + /// Handle a child array whose element width differs from the output type. + /// + /// If the child is a `Primitive`, its buffer is grabbed directly as a LOAD + /// source — no separate kernel launch needed, since `load_element()` + /// handles the widening in-kernel. Otherwise, the child is recorded as a + /// pending subtree for separate execution. + fn walk_mixed_width_child( + &mut self, + child: ArrayRef, + pending_subtrees: &mut Vec, + ) -> VortexResult { + let ptype = PType::try_from(child.dtype())?; + if child.encoding_id() == Primitive::ID { + return self.walk_primitive(child); + } + let buf_idx = self.source_buffers.len(); + self.source_buffers.push(None); + pending_subtrees.push(child); + Ok(Stage::new( + SourceOp::load(), + Some(buf_idx), + ptype_to_tag(ptype), + )) + } + fn walk_dict( &mut self, array: ArrayRef, @@ -503,12 +597,35 @@ impl FusedPlan { let values = dict.values().clone(); let codes = dict.codes().clone(); + let values_ptype = PType::try_from(values.dtype())?; + let values_elem_bytes = values_ptype.byte_width() as u32; + let codes_ptype = PType::try_from(codes.dtype())?; + let codes_elem_bytes = codes_ptype.byte_width() as u32; + + // If values have a different width than the output type, they + // can't be fused into the same kernel instantiation. Primitives + // are handled directly (just grab the buffer); other encodings + // become pending subtrees executed by a separate kernel. let values_len = values.len() as u32; - let values_spec = self.walk(values, pending_subtrees)?; - let values_smem_offset = self.push_smem_stage(values_spec, values_len); + let values_spec = if values_elem_bytes != self.output_elem_bytes { + self.walk_mixed_width_child(values, pending_subtrees)? + } else { + self.walk(values, pending_subtrees)? + }; + let values_smem_byte_offset = self.push_smem_stage(values_spec, values_len); - let mut pipeline = self.walk(codes, pending_subtrees)?; - pipeline.scalar_ops.push(ScalarOp::dict(values_smem_offset)); + // Same for codes. + let mut pipeline = if codes_elem_bytes != self.output_elem_bytes { + self.walk_mixed_width_child(codes, pending_subtrees)? + } else { + self.walk(codes, pending_subtrees)? + }; + // DICT scalar op: pass byte offset directly (C ABI uses byte offsets). + // output_ptype is the values' ptype — DICT transforms codes → values. + pipeline.scalar_ops.push(ScalarOp::dict( + values_smem_byte_offset, + ptype_to_tag(values_ptype), + )); Ok(pipeline) } @@ -518,6 +635,7 @@ impl FusedPlan { Ok(Stage::new( SourceOp::sequence(seq.base().cast()?, seq.multiplier().cast()?), None, + self.output_ptype, )) } @@ -533,23 +651,66 @@ impl FusedPlan { let num_runs = ends.len() as u32; let num_values = values.len() as u32; - let ends_spec = self.walk(ends, pending_subtrees)?; - let ends_smem = self.push_smem_stage(ends_spec, num_runs); - let values_spec = self.walk(values, pending_subtrees)?; - let values_smem = self.push_smem_stage(values_spec, num_values); + let ends_ptype = PType::try_from(ends.dtype())?; + let ends_elem_bytes = ends_ptype.byte_width() as u32; + let values_ptype = PType::try_from(values.dtype())?; + let values_elem_bytes = values_ptype.byte_width() as u32; + // If ends or values have a different width than the output type, + // they can't be fused into the same kernel instantiation. + // Primitives are handled directly; others become pending subtrees. + let ends_spec = if ends_elem_bytes != self.output_elem_bytes { + self.walk_mixed_width_child(ends, pending_subtrees)? + } else { + self.walk(ends, pending_subtrees)? + }; + let ends_smem_byte_offset = self.push_smem_stage(ends_spec, num_runs); + + let values_spec = if values_elem_bytes != self.output_elem_bytes { + self.walk_mixed_width_child(values, pending_subtrees)? + } else { + self.walk(values, pending_subtrees)? + }; + let values_smem_byte_offset = self.push_smem_stage(values_spec, num_values); + + // Pass byte offsets and PTypeTags directly — the C ABI now uses + // byte offsets and per-field ptype tags for cross-stage references. Ok(Stage::new( - SourceOp::runend(ends_smem, values_smem, num_runs as u64, offset), + SourceOp::runend( + ends_smem_byte_offset, + values_smem_byte_offset, + num_runs as u64, + offset, + ), None, + self.output_ptype, )) } /// Add a stage that decodes fully into shared memory before the output - /// stage runs. Returns the shared memory offset where the data starts. + /// stage runs. Returns the shared memory byte offset where the data starts. + /// + /// The smem region is sized at the stage's output ptype width — i.e. + /// the ptype after all scalar ops have run. For stages that go through + /// type-changing scalar ops (e.g. dict values with FoR→ALP), the final + /// smem footprint is `len × final_ptype_byte_width`. If there are no + /// scalar ops, the source_ptype determines the width. fn push_smem_stage(&mut self, spec: Stage, len: u32) -> u32 { - let smem_offset = self.smem_cursor; - self.stages.push((spec, smem_offset, len)); - self.smem_cursor += len; - smem_offset + let smem_byte_offset = self.smem_byte_cursor; + // The kernel's execute_input_stage always writes T-wide elements + // into smem (reinterpret_cast), so we must allocate at least + // output_elem_bytes per element — even if the stage's final ptype + // is narrower. Otherwise the writes overflow into the next region. + let final_ptype = spec + .scalar_ops + .last() + .map(|op| op.output_ptype) + .unwrap_or(spec.source_ptype); + let final_elem_bytes = tag_to_ptype(final_ptype).byte_width() as u32; + let elem_bytes = final_elem_bytes.max(self.output_elem_bytes); + let stage_bytes = len * elem_bytes; + self.stages.push((spec, smem_byte_offset, len)); + self.smem_byte_cursor += stage_bytes; + smem_byte_offset } } diff --git a/vortex-cuda/src/hybrid_dispatch/mod.rs b/vortex-cuda/src/hybrid_dispatch/mod.rs index 61708c8248a..0f7f45eeeb2 100644 --- a/vortex-cuda/src/hybrid_dispatch/mod.rs +++ b/vortex-cuda/src/hybrid_dispatch/mod.rs @@ -25,6 +25,10 @@ //! variant are executed first (sequentially, same stream), their device buffers //! become `LOAD` ops in a fused plan via `FusedPlan::materialize_with_subtrees`. //! Each subtree re-enters [`try_gpu_dispatch`] and may itself fuse. +//! When a subtree's ptype differs from the output ptype (e.g. `u8` dict +//! codes in a `u32` Dict), widening from the subtree's native width to `T` +//! happens in-kernel via `load_element()` in the LOAD source op — no +//! separate widen pass is needed. //! //! 3. Fallback — root is not fusable. Delegate to its registered //! `CudaExecute` kernel; its children re-enter [`try_gpu_dispatch`]. @@ -44,7 +48,6 @@ use tracing::trace; use vortex::array::ArrayRef; use vortex::array::Canonical; -use vortex::dtype::PType; use vortex::error::VortexResult; use vortex::error::vortex_err; @@ -71,30 +74,29 @@ pub async fn try_gpu_dispatch( match DispatchPlan::new(array)? { DispatchPlan::Fused(plan) => { - let output_ptype = PType::try_from(array.dtype())?; let materialized = plan.materialize(ctx)?; let num_stages = materialized.dispatch_plan.num_stages(); trace!(encoding = %array.encoding_id(), num_stages, "fully-fused dispatch"); - materialized.execute(output_ptype, array.len(), ctx) + materialized.execute(array.len(), ctx) } DispatchPlan::PartiallyFused { plan, pending_subtrees, } => { - let output_ptype = PType::try_from(array.dtype())?; let mut subtree_buffers = Vec::with_capacity(pending_subtrees.len()); // TODO(0ax1): execute subtrees concurrently using separate CUDA streams. for subtree in &pending_subtrees { let canonical = subtree.clone().execute_cuda(ctx).await?; - subtree_buffers.push(canonical.into_primitive().into_data_parts().buffer); + let buffer = canonical.into_primitive().into_data_parts().buffer; + subtree_buffers.push(buffer); } let num_subtrees = subtree_buffers.len(); let materialized = plan.materialize_with_subtrees(subtree_buffers, ctx)?; let num_stages = materialized.dispatch_plan.num_stages(); trace!(encoding = %array.encoding_id(), num_stages, num_subtrees, "partially-fused dispatch"); - materialized.execute(output_ptype, array.len(), ctx) + materialized.execute(array.len(), ctx) } DispatchPlan::Unfused => { // Unfused kernel dispatch fallback.