Skip to content

Commit f75ca61

Browse files
committed
fix: sign-extend narrow loads, zero padding, and harden plan builder
Correctness: - load_element<T>() now dispatches on the raw PTypeTag instead of routing through ptype_to_unsigned(), which was zero-extending signed narrow types (e.g. i8(-1) → 255 instead of → 0xFFFFFFFF). - as_bytes() copies through a zeroed buffer so struct padding holes are deterministically zero, avoiding UB from uninitialised padding. - Reject F16 early in is_dyn_dispatch_compatible() — prevents unreachable panic in ptype_to_tag(). Plan builder hardening: - Replace unwrap_or / or_else fallbacks with map_err + ? in walk_bitpacked, walk_for, walk_zigzag, walk_dict, walk_runend so dtype errors propagate instead of silently using a default. - Extract walk_mixed_width_child() to short-circuit Primitive children (grab buffer directly) vs. deferring encoded children to subtrees. - MaterializedPlan::execute reads output_ptype from the plan header instead of taking it as a redundant parameter. FFI / DRY: - Move BLOCK_SIZE and KERNEL_STATIC_SHARED_BYTES to the C header; Rust consumes them via bindgen instead of duplicating magic numbers. - Remove dead ptype_byte_width() C function; Rust uses tag_to_ptype().byte_width() instead. - Add bidirectional tag_to_ptype() alongside ptype_to_tag(). Tests: - Add sign-extension tests (i8→u32, i16→u32) for load_element. - Fix test_plan_structure to use byte offsets (256*4=1024). Signed-off-by: Alexander Droste <alexander.droste@protonmail.com>
1 parent ad54f73 commit f75ca61

5 files changed

Lines changed: 273 additions & 166 deletions

File tree

vortex-cuda/kernels/src/dynamic_dispatch.cu

Lines changed: 67 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,6 @@
5454
#include "dynamic_dispatch.h"
5555
#include "types.cuh"
5656

57-
/// Number of threads per CUDA block (must match the Rust launch config).
58-
constexpr uint32_t BLOCK_SIZE = 64;
59-
6057
// ═══════════════════════════════════════════════════════════════════════════
6158
// Primitives
6259
// ═══════════════════════════════════════════════════════════════════════════
@@ -70,26 +67,41 @@ __device__ inline uint64_t upper_bound(const T *data, uint64_t len, uint64_t val
7067
}
7168

7269
/// Read one element from global memory at `ptype` width, widen to T.
70+
/// Signed types are sign-extended; unsigned types are zero-extended.
7371
template <typename T>
7472
__device__ inline T load_element(const void *__restrict ptr, PTypeTag ptype, uint64_t idx) {
75-
switch (ptype_to_unsigned(ptype)) {
73+
switch (ptype) {
7674
case PTYPE_U8:
7775
return static_cast<T>(static_cast<const uint8_t *>(ptr)[idx]);
76+
case PTYPE_I8:
77+
return static_cast<T>(static_cast<const int8_t *>(ptr)[idx]);
7878
case PTYPE_U16:
7979
return static_cast<T>(static_cast<const uint16_t *>(ptr)[idx]);
80+
case PTYPE_I16:
81+
return static_cast<T>(static_cast<const int16_t *>(ptr)[idx]);
8082
case PTYPE_U32:
83+
case PTYPE_F32:
8184
return static_cast<T>(static_cast<const uint32_t *>(ptr)[idx]);
85+
case PTYPE_I32:
86+
return static_cast<T>(static_cast<const int32_t *>(ptr)[idx]);
8287
case PTYPE_U64:
88+
case PTYPE_F64:
8389
return static_cast<T>(static_cast<const uint64_t *>(ptr)[idx]);
90+
case PTYPE_I64:
91+
return static_cast<T>(static_cast<const int64_t *>(ptr)[idx]);
8492
default:
8593
__builtin_unreachable();
8694
}
8795
}
8896

89-
/// RUNEND forward-scan cursor for each thread, stored in shared memory so
90-
/// source_op can advance it across calls. Pre-seeded with upper_bound before
91-
/// the tile loop; the RUNEND arm in source_op advances it monotonically.
92-
/// Sized to BLOCK_SIZE to match the kernel block size, one entry per thread.
97+
/// Per-thread run cursor for RUNEND forward-scan, one entry per thread.
98+
///
99+
/// Stored in shared memory so the cursor persists across successive
100+
/// source_op calls in the tile loop. Each thread's positions are
101+
/// monotonically increasing across tiles, so the cursor only advances
102+
/// forward — the next tile picks up exactly where the previous one
103+
/// stopped, avoiding a binary search per tile. The only binary search
104+
/// is the initial upper_bound seed before the tile loop begins.
93105
__shared__ uint64_t runend_cursors[BLOCK_SIZE];
94106

95107
// ═══════════════════════════════════════════════════════════════════════════
@@ -98,37 +110,37 @@ __shared__ uint64_t runend_cursors[BLOCK_SIZE];
98110

99111
/// Apply one scalar operation to N values in registers.
100112
template <typename T, uint32_t N>
101-
__device__ inline void scalar_op(T *v, const struct ScalarOp &op, char *__restrict smem) {
113+
__device__ inline void scalar_op(T *values, const struct ScalarOp &op, char *__restrict smem) {
102114
switch (op.op_code) {
103115
case ScalarOp::FOR: {
104116
const T ref = static_cast<T>(op.params.frame_of_ref.reference);
105117
#pragma unroll
106118
for (uint32_t i = 0; i < N; ++i) {
107-
v[i] += ref;
119+
values[i] += ref;
108120
}
109121
break;
110122
}
111123
case ScalarOp::ZIGZAG: {
112124
#pragma unroll
113125
for (uint32_t i = 0; i < N; ++i) {
114-
v[i] = (v[i] >> 1) ^ static_cast<T>(-(v[i] & 1));
126+
values[i] = (values[i] >> 1) ^ static_cast<T>(-(values[i] & 1));
115127
}
116128
break;
117129
}
118130
case ScalarOp::ALP: {
119131
const float f = op.params.alp.f, e = op.params.alp.e;
120132
#pragma unroll
121133
for (uint32_t i = 0; i < N; ++i) {
122-
float r = static_cast<float>(static_cast<int32_t>(v[i])) * f * e;
123-
v[i] = static_cast<T>(__float_as_uint(r));
134+
float r = static_cast<float>(static_cast<int32_t>(values[i])) * f * e;
135+
values[i] = static_cast<T>(__float_as_uint(r));
124136
}
125137
break;
126138
}
127139
case ScalarOp::DICT: {
128140
const T *dict = reinterpret_cast<const T *>(smem + op.params.dict.values_smem_byte_offset);
129141
#pragma unroll
130142
for (uint32_t i = 0; i < N; ++i) {
131-
v[i] = dict[static_cast<uint32_t>(v[i])];
143+
values[i] = dict[static_cast<uint32_t>(values[i])];
132144
}
133145
break;
134146
}
@@ -182,7 +194,7 @@ __device__ inline void bitunpack(const T *__restrict packed,
182194
/// Position calculation (via THREAD_POS macro):
183195
/// N > 1 (batched): pos = base + j·blockDim.x + threadIdx.x.
184196
/// Caller passes the tile base WITHOUT threadIdx.x.
185-
/// N = 1 (single): base IS the exact position. No stride added.
197+
/// N = 1 (single): base is the exact position. No stride added.
186198
template <typename T, uint32_t N>
187199
__device__ inline void source_op(T *out,
188200
const struct SourceOp &src,
@@ -247,7 +259,8 @@ __device__ inline void source_op(T *out,
247259
// ═══════════════════════════════════════════════════════════════════════════
248260
//
249261
// BITUNPACK tiles at SMEM_TILE_SIZE: cooperative unpack → smem → sync →
250-
// batched read. All other source ops: single pass, no smem scratch.
262+
// batched read. LOAD, SEQUENCE, and RUNEND need no smem scratch and
263+
// process the full block in a single outer iteration, tiled by tile_idx.
251264

252265
/// How many elements to process in this BITUNPACK tile iteration.
253266
/// The first tile may be shorter due to `element_offset` alignment;
@@ -257,7 +270,7 @@ __device__ inline uint32_t bitunpack_tile_len(const Stage &stage, uint32_t block
257270
return min(SMEM_TILE_SIZE - off, block_len - tile_off);
258271
}
259272

260-
/// Process the final (output) stage: decode source → apply scalar ops →
273+
/// Process the final / output stage: decode source → apply scalar ops →
261274
/// streaming-store to global memory. Handles the full block, tiling through
262275
/// smem scratch for BITUNPACK.
263276
template <typename T>
@@ -273,6 +286,10 @@ __device__ void execute_output_stage(T *__restrict output,
273286
const PTypeTag ptype = stage.source_ptype;
274287

275288
if (src.op_code == SourceOp::RUNEND) {
289+
// Seed each thread's cursor with the run containing its first
290+
// strided position. The RUNEND arm in source_op advances the
291+
// cursor monotonically, so this avoids a full binary search on
292+
// every element.
276293
const T *ends = reinterpret_cast<const T *>(smem + src.params.runend.ends_smem_byte_offset);
277294
runend_cursors[threadIdx.x] = upper_bound(ends,
278295
src.params.runend.num_runs,
@@ -283,9 +300,10 @@ __device__ void execute_output_stage(T *__restrict output,
283300
uint32_t chunk_len;
284301
const T *smem_src = nullptr;
285302

286-
// BITUNPACK chunks the block into shared-memory-sized tiles, so this
287-
// advances by one tile per iteration. All other source ops process the
288-
// entire block in a single iteration (chunk_len = block_len).
303+
// BITUNPACK uses smem scratch, so the outer loop advances one
304+
// chunk at a time. LOAD, SEQUENCE, and RUNEND need no smem
305+
// scratch, so chunk_len = block_len (single outer iteration);
306+
// tiling happens in the inner tile_idx loop.
289307
if (src.op_code == SourceOp::BITUNPACK) {
290308
chunk_len = bitunpack_tile_len(stage, block_len, elem_idx);
291309
T *scratch = reinterpret_cast<T *>(smem + stage.smem_byte_offset);
@@ -323,8 +341,12 @@ __device__ void execute_output_stage(T *__restrict output,
323341

324342
#pragma unroll
325343
for (uint32_t j = 0; j < VALUES_PER_TILE; ++j) {
326-
// __stcs bypasses L1 and invalidates the L2 line after write,
327-
// avoiding eviction of input-side data from the cache.
344+
// st.cs (cache streaming): marks this line for earliest
345+
// eviction in L1 and L2. Output data is written once and
346+
// never read again by this kernel, so keeping it cached
347+
// would only compete with the packed input buffers and
348+
// smem-resident dict/runend data that the next tiles still
349+
// need to read. Evict-first lets those stay resident.
328350
__stcs(&output[tile_start + j * blockDim.x + threadIdx.x], values[j]);
329351
}
330352
}
@@ -338,8 +360,6 @@ __device__ void execute_output_stage(T *__restrict output,
338360
for (uint8_t op = 0; op < stage.num_scalar_ops; ++op) {
339361
scalar_op<T, 1>(&val, stage.scalar_ops[op], smem);
340362
}
341-
// __stcs bypasses L1 and invalidates the L2 line after write,
342-
// avoiding eviction of input-side data from the cache.
343363
__stcs(&output[gpos], val);
344364
}
345365

@@ -359,6 +379,12 @@ __device__ void execute_output_stage(T *__restrict output,
359379
/// Decode one input stage (dict values, run-end endpoints, etc.) into its
360380
/// shared memory region so the output stage can reference it later.
361381
/// Applies any scalar ops in-place before returning.
382+
///
383+
/// Unlike execute_output_stage, this does not tile — the entire stage is
384+
/// decoded in one pass. The output stage needs random access into these
385+
/// smem regions (e.g. DICT gathers by arbitrary code value), so the data
386+
/// must be fully resident. The smem limit check in the Rust plan builder
387+
/// ensures the stage fits; if it doesn't, the plan falls back to Unfused.
362388
template <typename T>
363389
__device__ void execute_input_stage(const Stage &stage, char *__restrict smem) {
364390
T *smem_out = reinterpret_cast<T *>(smem + stage.smem_byte_offset);
@@ -367,6 +393,8 @@ __device__ void execute_input_stage(const Stage &stage, char *__restrict smem) {
367393
if (src.op_code == SourceOp::BITUNPACK) {
368394
bitunpack<T>(reinterpret_cast<const T *>(stage.input_ptr), smem_out, 0, stage.len, src);
369395
smem_out += src.params.bitunpack.element_offset % SMEM_TILE_SIZE;
396+
// Write barrier: cooperative bitunpack finished, safe to read
397+
// decoded elements in the scalar-op loop below.
370398
__syncthreads();
371399

372400
if (stage.num_scalar_ops > 0) {
@@ -377,10 +405,16 @@ __device__ void execute_input_stage(const Stage &stage, char *__restrict smem) {
377405
}
378406
smem_out[i] = val;
379407
}
408+
// Write barrier: scalar ops applied in-place, smem region is
409+
// now fully populated for subsequent stages to read.
380410
__syncthreads();
381411
}
382412
} else {
383413
if (src.op_code == SourceOp::RUNEND) {
414+
// Seed each thread's cursor with the run containing its first
415+
// strided position. The RUNEND arm in source_op advances the
416+
// cursor monotonically, so this avoids a full binary search on
417+
// every element.
384418
const T *ends = reinterpret_cast<const T *>(smem + src.params.runend.ends_smem_byte_offset);
385419
runend_cursors[threadIdx.x] =
386420
upper_bound(ends, src.params.runend.num_runs, threadIdx.x + src.params.runend.offset);
@@ -394,6 +428,8 @@ __device__ void execute_input_stage(const Stage &stage, char *__restrict smem) {
394428
}
395429
smem_out[i] = val;
396430
}
431+
// Write barrier: smem region is fully populated for subsequent
432+
// stages to read.
397433
__syncthreads();
398434
}
399435
}
@@ -428,6 +464,13 @@ dynamic_dispatch(T *__restrict output, uint64_t array_len, const uint8_t *__rest
428464
static_cast<uint32_t>(block_end - block_start));
429465
}
430466

467+
// Kernels are instantiated only for unsigned integer types. Signed and
468+
// floating-point arrays reuse the unsigned kernel of the same width —
469+
// the data is bit-identical under reinterpretation, and all arithmetic
470+
// in the pipeline (FoR add, ZigZag decode, ALP decode, DICT gather) is
471+
// correct on the unsigned representation. The one place where signedness
472+
// matters is load_element(), which dispatches on the per-op PTypeTag to
473+
// sign-extend or zero-extend when widening a narrow source to T.
431474
#define GENERATE_KERNEL(suffix, Type) \
432475
extern "C" __global__ void dynamic_dispatch_##suffix(Type *__restrict output, \
433476
uint64_t array_len, \

vortex-cuda/kernels/src/dynamic_dispatch.h

Lines changed: 20 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
/// ALP transforms encoded integers (i32) into floats (f32).
2424
///
2525
/// `PTypeTag` is a compact enum that identifies the primitive type at each
26-
/// point in the pipeline. The kernel uses it to dispatch typed memory
26+
/// point in the pipeline. The kernel uses it to dispatch typed memory
2727
/// operations (LOAD, BITUNPACK) and cross-stage references (DICT gather,
2828
/// RUNEND lookup) at the correct element width and signedness.
2929

@@ -34,7 +34,7 @@
3434
/// Compact tag identifying a Vortex PType for GPU dispatch.
3535
///
3636
/// NOTE: These values intentionally skip F16 (which Rust PType includes),
37-
/// so numeric values do NOT match Rust PType directly. The Rust
37+
/// so numeric values do NOT match Rust PType directly. The Rust
3838
/// `ptype_to_tag()` function handles the mapping at plan-build time.
3939
///
4040
/// The kernel uses this to:
@@ -55,35 +55,13 @@ enum PTypeTag : uint8_t {
5555
PTYPE_F64 = 9,
5656
};
5757

58-
/// Return the byte width of a PTypeTag.
58+
/// Return the unsigned equivalent of a PTypeTag (same width).
5959
#ifdef __cplusplus
6060
#ifdef __CUDACC__
6161
#define PTYPE_HOST_DEVICE __host__ __device__
6262
#else
6363
#define PTYPE_HOST_DEVICE
6464
#endif
65-
PTYPE_HOST_DEVICE constexpr uint8_t ptype_byte_width(PTypeTag tag) {
66-
switch (tag) {
67-
case PTYPE_U8:
68-
case PTYPE_I8:
69-
return 1;
70-
case PTYPE_U16:
71-
case PTYPE_I16:
72-
return 2;
73-
case PTYPE_U32:
74-
case PTYPE_I32:
75-
case PTYPE_F32:
76-
return 4;
77-
case PTYPE_U64:
78-
case PTYPE_I64:
79-
case PTYPE_F64:
80-
return 8;
81-
default:
82-
return 0;
83-
}
84-
}
85-
86-
/// Return the unsigned equivalent of a PTypeTag (same width).
8765
PTYPE_HOST_DEVICE constexpr PTypeTag ptype_to_unsigned(PTypeTag tag) {
8866
switch (tag) {
8967
case PTYPE_I8:
@@ -102,12 +80,20 @@ PTYPE_HOST_DEVICE constexpr PTypeTag ptype_to_unsigned(PTypeTag tag) {
10280
}
10381
#endif
10482

83+
/// Number of threads per CUDA block.
84+
#define BLOCK_SIZE 64
85+
10586
/// Elements processed per CUDA block.
10687
#define ELEMENTS_PER_BLOCK 2048
10788

10889
/// Each tile is flushed to global before the next is decoded.
10990
#define SMEM_TILE_SIZE 1024
11091

92+
/// Fixed shared memory declared in the kernel (bytes), excluded from
93+
/// the dynamic shared memory budget. Accounts for
94+
/// `runend_cursors[BLOCK_SIZE]` — one uint64_t cursor per thread.
95+
#define KERNEL_FIXED_SHARED_BYTES (BLOCK_SIZE * sizeof(uint64_t))
96+
11197
#ifdef __cplusplus
11298
extern "C" {
11399
#endif
@@ -127,7 +113,7 @@ union SourceParams {
127113

128114
/// Decode run-end encoding using ends and values already in shared memory.
129115
///
130-
/// The smem offsets are **byte offsets** so that ends and values can have
116+
/// The smem offsets are byte offsets so that ends and values can have
131117
/// different element widths.
132118
struct RunEndParams {
133119
uint32_t ends_smem_byte_offset; // byte offset to decoded ends in smem
@@ -151,7 +137,7 @@ struct SourceOp {
151137
/// Scalar ops: element-wise transforms in registers.
152138
///
153139
/// Each scalar op declares its `output_ptype` — the PType of the values it
154-
/// produces. Most ops preserve the input type (FOR, ZIGZAG), but some
140+
/// produces. Most ops preserve the input type (FOR, ZIGZAG), but some
155141
/// change it:
156142
/// - ALP: encoded int → float (e.g. i32 → f32)
157143
/// - DICT: codes type → values type (e.g. u8 → u32)
@@ -171,8 +157,8 @@ union ScalarParams {
171157

172158
/// Dictionary gather: use current value as index into decoded values in smem.
173159
///
174-
/// `values_smem_byte_offset` is a **byte offset** so that values can have
175-
/// a different element width than the codes. The plan builder uses
160+
/// `values_smem_byte_offset` is a byte offset so that values can have
161+
/// a different element width than the codes. The plan builder uses
176162
/// `output_ptype` (on the enclosing ScalarOp) to determine the values'
177163
/// element type.
178164
struct DictParams {
@@ -182,8 +168,8 @@ union ScalarParams {
182168

183169
struct ScalarOp {
184170
enum ScalarOpCode { FOR, ZIGZAG, ALP, DICT } op_code;
185-
/// The PType this op produces. For type-preserving ops (FOR, ZIGZAG)
186-
/// this equals the input PType. For type-changing ops (ALP, DICT) this
171+
/// The PType this op produces. For type-preserving ops (FOR, ZIGZAG)
172+
/// this equals the input PType. For type-changing ops (ALP, DICT) this
187173
/// is the new output PType.
188174
enum PTypeTag output_ptype;
189175
union ScalarParams params;
@@ -192,10 +178,10 @@ struct ScalarOp {
192178
/// Packed stage header, followed by `num_scalar_ops` inline ScalarOps.
193179
///
194180
/// `source_ptype` identifies the PType that the source op (BITUNPACK, LOAD,
195-
/// etc.) produces. This may differ from the output PType when scalar ops
181+
/// etc.) produces. This may differ from the output PType when scalar ops
196182
/// change the type (e.g. DICT transforms u8 codes into u32 values).
197183
///
198-
/// `smem_byte_offset` is a **byte offset** into the dynamic shared memory
184+
/// `smem_byte_offset` is a byte offset into the dynamic shared memory
199185
/// pool so that stages with different element widths can coexist.
200186
struct PackedStage {
201187
uint64_t input_ptr; // global memory pointer to this stage's encoded input
@@ -226,7 +212,7 @@ struct __attribute__((aligned(8))) PlanHeader {
226212
/// shared memory region for the output stage to reference. The output stage
227213
/// decodes the root encoding and writes to global memory.
228214
///
229-
/// `source_ptype` is the PType produced by the source op. Scalar ops may
215+
/// `source_ptype` is the PType produced by the source op. Scalar ops may
230216
/// change the type; the final output PType is given by the last scalar op's
231217
/// `output_ptype` (or `source_ptype` if there are no scalar ops).
232218
struct Stage {

0 commit comments

Comments
 (0)