Skip to content

Commit d1dedfc

Browse files
committed
fix: sign-extend narrow loads, zero padding, and harden plan builder
Correctness: - load_element<T>() now dispatches on the raw PTypeTag instead of routing through ptype_to_unsigned(), which was zero-extending signed narrow types (e.g. i8(-1) → 255 instead of → 0xFFFFFFFF). - as_bytes() copies through a zeroed buffer so struct padding holes are deterministically zero, avoiding UB from uninitialised padding. - Reject F16 early in is_dyn_dispatch_compatible() — prevents unreachable panic in ptype_to_tag(). Plan builder hardening: - Replace unwrap_or / or_else fallbacks with map_err + ? in walk_bitpacked, walk_for, walk_zigzag, walk_dict, walk_runend so dtype errors propagate instead of silently using a default. - Extract walk_mixed_width_child() to short-circuit Primitive children (grab buffer directly) vs. deferring encoded children to subtrees. - MaterializedPlan::execute reads output_ptype from the plan header instead of taking it as a redundant parameter. FFI / DRY: - Move BLOCK_SIZE and KERNEL_STATIC_SHARED_BYTES to the C header; Rust consumes them via bindgen instead of duplicating magic numbers. - Remove dead ptype_byte_width() C function; Rust uses tag_to_ptype().byte_width() instead. - Add bidirectional tag_to_ptype() alongside ptype_to_tag(). Tests: - Add sign-extension tests (i8→u32, i16→u32) for load_element. - Fix test_plan_structure to use byte offsets (256*4=1024). Signed-off-by: Alexander Droste <alexander.droste@protonmail.com>
1 parent ad54f73 commit d1dedfc

File tree

5 files changed

+273
-165
lines changed

5 files changed

+273
-165
lines changed

vortex-cuda/kernels/src/dynamic_dispatch.cu

Lines changed: 67 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,6 @@
5454
#include "dynamic_dispatch.h"
5555
#include "types.cuh"
5656

57-
/// Number of threads per CUDA block (must match the Rust launch config).
58-
constexpr uint32_t BLOCK_SIZE = 64;
5957

6058
// ═══════════════════════════════════════════════════════════════════════════
6159
// Primitives
@@ -70,26 +68,41 @@ __device__ inline uint64_t upper_bound(const T *data, uint64_t len, uint64_t val
7068
}
7169

7270
/// Read one element from global memory at `ptype` width, widen to T.
71+
/// Signed types are sign-extended; unsigned types are zero-extended.
7372
template <typename T>
7473
__device__ inline T load_element(const void *__restrict ptr, PTypeTag ptype, uint64_t idx) {
75-
switch (ptype_to_unsigned(ptype)) {
74+
switch (ptype) {
7675
case PTYPE_U8:
7776
return static_cast<T>(static_cast<const uint8_t *>(ptr)[idx]);
77+
case PTYPE_I8:
78+
return static_cast<T>(static_cast<const int8_t *>(ptr)[idx]);
7879
case PTYPE_U16:
7980
return static_cast<T>(static_cast<const uint16_t *>(ptr)[idx]);
81+
case PTYPE_I16:
82+
return static_cast<T>(static_cast<const int16_t *>(ptr)[idx]);
8083
case PTYPE_U32:
84+
case PTYPE_F32:
8185
return static_cast<T>(static_cast<const uint32_t *>(ptr)[idx]);
86+
case PTYPE_I32:
87+
return static_cast<T>(static_cast<const int32_t *>(ptr)[idx]);
8288
case PTYPE_U64:
89+
case PTYPE_F64:
8390
return static_cast<T>(static_cast<const uint64_t *>(ptr)[idx]);
91+
case PTYPE_I64:
92+
return static_cast<T>(static_cast<const int64_t *>(ptr)[idx]);
8493
default:
8594
__builtin_unreachable();
8695
}
8796
}
8897

89-
/// RUNEND forward-scan cursor for each thread, stored in shared memory so
90-
/// source_op can advance it across calls. Pre-seeded with upper_bound before
91-
/// the tile loop; the RUNEND arm in source_op advances it monotonically.
92-
/// Sized to BLOCK_SIZE to match the kernel block size, one entry per thread.
98+
/// Per-thread run cursor for RUNEND forward-scan, one entry per thread.
99+
///
100+
/// Stored in shared memory so the cursor persists across successive
101+
/// source_op calls in the tile loop. Each thread's positions are
102+
/// monotonically increasing across tiles, so the cursor only advances
103+
/// forward — the next tile picks up exactly where the previous one
104+
/// stopped, avoiding a binary search per tile. The only binary search
105+
/// is the initial upper_bound seed before the tile loop begins.
93106
__shared__ uint64_t runend_cursors[BLOCK_SIZE];
94107

95108
// ═══════════════════════════════════════════════════════════════════════════
@@ -98,37 +111,37 @@ __shared__ uint64_t runend_cursors[BLOCK_SIZE];
98111

99112
/// Apply one scalar operation to N values in registers.
100113
template <typename T, uint32_t N>
101-
__device__ inline void scalar_op(T *v, const struct ScalarOp &op, char *__restrict smem) {
114+
__device__ inline void scalar_op(T *values, const struct ScalarOp &op, char *__restrict smem) {
102115
switch (op.op_code) {
103116
case ScalarOp::FOR: {
104117
const T ref = static_cast<T>(op.params.frame_of_ref.reference);
105118
#pragma unroll
106119
for (uint32_t i = 0; i < N; ++i) {
107-
v[i] += ref;
120+
values[i] += ref;
108121
}
109122
break;
110123
}
111124
case ScalarOp::ZIGZAG: {
112125
#pragma unroll
113126
for (uint32_t i = 0; i < N; ++i) {
114-
v[i] = (v[i] >> 1) ^ static_cast<T>(-(v[i] & 1));
127+
values[i] = (values[i] >> 1) ^ static_cast<T>(-(values[i] & 1));
115128
}
116129
break;
117130
}
118131
case ScalarOp::ALP: {
119132
const float f = op.params.alp.f, e = op.params.alp.e;
120133
#pragma unroll
121134
for (uint32_t i = 0; i < N; ++i) {
122-
float r = static_cast<float>(static_cast<int32_t>(v[i])) * f * e;
123-
v[i] = static_cast<T>(__float_as_uint(r));
135+
float r = static_cast<float>(static_cast<int32_t>(values[i])) * f * e;
136+
values[i] = static_cast<T>(__float_as_uint(r));
124137
}
125138
break;
126139
}
127140
case ScalarOp::DICT: {
128141
const T *dict = reinterpret_cast<const T *>(smem + op.params.dict.values_smem_byte_offset);
129142
#pragma unroll
130143
for (uint32_t i = 0; i < N; ++i) {
131-
v[i] = dict[static_cast<uint32_t>(v[i])];
144+
values[i] = dict[static_cast<uint32_t>(values[i])];
132145
}
133146
break;
134147
}
@@ -182,7 +195,7 @@ __device__ inline void bitunpack(const T *__restrict packed,
182195
/// Position calculation (via THREAD_POS macro):
183196
/// N > 1 (batched): pos = base + j·blockDim.x + threadIdx.x.
184197
/// Caller passes the tile base WITHOUT threadIdx.x.
185-
/// N = 1 (single): base IS the exact position. No stride added.
198+
/// N = 1 (single): base is the exact position. No stride added.
186199
template <typename T, uint32_t N>
187200
__device__ inline void source_op(T *out,
188201
const struct SourceOp &src,
@@ -247,7 +260,8 @@ __device__ inline void source_op(T *out,
247260
// ═══════════════════════════════════════════════════════════════════════════
248261
//
249262
// BITUNPACK tiles at SMEM_TILE_SIZE: cooperative unpack → smem → sync →
250-
// batched read. All other source ops: single pass, no smem scratch.
263+
// batched read. LOAD, SEQUENCE, and RUNEND need no smem scratch and
264+
// process the full block in a single outer iteration, tiled by tile_idx.
251265

252266
/// How many elements to process in this BITUNPACK tile iteration.
253267
/// The first tile may be shorter due to `element_offset` alignment;
@@ -257,7 +271,7 @@ __device__ inline uint32_t bitunpack_tile_len(const Stage &stage, uint32_t block
257271
return min(SMEM_TILE_SIZE - off, block_len - tile_off);
258272
}
259273

260-
/// Process the final (output) stage: decode source → apply scalar ops →
274+
/// Process the final / output stage: decode source → apply scalar ops →
261275
/// streaming-store to global memory. Handles the full block, tiling through
262276
/// smem scratch for BITUNPACK.
263277
template <typename T>
@@ -273,6 +287,10 @@ __device__ void execute_output_stage(T *__restrict output,
273287
const PTypeTag ptype = stage.source_ptype;
274288

275289
if (src.op_code == SourceOp::RUNEND) {
290+
// Seed each thread's cursor with the run containing its first
291+
// strided position. The RUNEND arm in source_op advances the
292+
// cursor monotonically, so this avoids a full binary search on
293+
// every element.
276294
const T *ends = reinterpret_cast<const T *>(smem + src.params.runend.ends_smem_byte_offset);
277295
runend_cursors[threadIdx.x] = upper_bound(ends,
278296
src.params.runend.num_runs,
@@ -283,9 +301,10 @@ __device__ void execute_output_stage(T *__restrict output,
283301
uint32_t chunk_len;
284302
const T *smem_src = nullptr;
285303

286-
// BITUNPACK chunks the block into shared-memory-sized tiles, so this
287-
// advances by one tile per iteration. All other source ops process the
288-
// entire block in a single iteration (chunk_len = block_len).
304+
// BITUNPACK uses smem scratch, so the outer loop advances one
305+
// chunk at a time. LOAD, SEQUENCE, and RUNEND need no smem
306+
// scratch, so chunk_len = block_len (single outer iteration);
307+
// tiling happens in the inner tile_idx loop.
289308
if (src.op_code == SourceOp::BITUNPACK) {
290309
chunk_len = bitunpack_tile_len(stage, block_len, elem_idx);
291310
T *scratch = reinterpret_cast<T *>(smem + stage.smem_byte_offset);
@@ -323,8 +342,12 @@ __device__ void execute_output_stage(T *__restrict output,
323342

324343
#pragma unroll
325344
for (uint32_t j = 0; j < VALUES_PER_TILE; ++j) {
326-
// __stcs bypasses L1 and invalidates the L2 line after write,
327-
// avoiding eviction of input-side data from the cache.
345+
// st.cs (cache streaming): marks this line for earliest
346+
// eviction in L1 and L2. Output data is written once and
347+
// never read again by this kernel, so keeping it cached
348+
// would only compete with the packed input buffers and
349+
// smem-resident dict/runend data that the next tiles still
350+
// need to read. Evict-first lets those stay resident.
328351
__stcs(&output[tile_start + j * blockDim.x + threadIdx.x], values[j]);
329352
}
330353
}
@@ -338,8 +361,6 @@ __device__ void execute_output_stage(T *__restrict output,
338361
for (uint8_t op = 0; op < stage.num_scalar_ops; ++op) {
339362
scalar_op<T, 1>(&val, stage.scalar_ops[op], smem);
340363
}
341-
// __stcs bypasses L1 and invalidates the L2 line after write,
342-
// avoiding eviction of input-side data from the cache.
343364
__stcs(&output[gpos], val);
344365
}
345366

@@ -359,6 +380,12 @@ __device__ void execute_output_stage(T *__restrict output,
359380
/// Decode one input stage (dict values, run-end endpoints, etc.) into its
360381
/// shared memory region so the output stage can reference it later.
361382
/// Applies any scalar ops in-place before returning.
383+
///
384+
/// Unlike execute_output_stage, this does not tile — the entire stage is
385+
/// decoded in one pass. The output stage needs random access into these
386+
/// smem regions (e.g. DICT gathers by arbitrary code value), so the data
387+
/// must be fully resident. The smem limit check in the Rust plan builder
388+
/// ensures the stage fits; if it doesn't, the plan falls back to Unfused.
362389
template <typename T>
363390
__device__ void execute_input_stage(const Stage &stage, char *__restrict smem) {
364391
T *smem_out = reinterpret_cast<T *>(smem + stage.smem_byte_offset);
@@ -367,6 +394,8 @@ __device__ void execute_input_stage(const Stage &stage, char *__restrict smem) {
367394
if (src.op_code == SourceOp::BITUNPACK) {
368395
bitunpack<T>(reinterpret_cast<const T *>(stage.input_ptr), smem_out, 0, stage.len, src);
369396
smem_out += src.params.bitunpack.element_offset % SMEM_TILE_SIZE;
397+
// Write barrier: cooperative bitunpack finished, safe to read
398+
// decoded elements in the scalar-op loop below.
370399
__syncthreads();
371400

372401
if (stage.num_scalar_ops > 0) {
@@ -377,10 +406,16 @@ __device__ void execute_input_stage(const Stage &stage, char *__restrict smem) {
377406
}
378407
smem_out[i] = val;
379408
}
409+
// Write barrier: scalar ops applied in-place, smem region is
410+
// now fully populated for subsequent stages to read.
380411
__syncthreads();
381412
}
382413
} else {
383414
if (src.op_code == SourceOp::RUNEND) {
415+
// Seed each thread's cursor with the run containing its first
416+
// strided position. The RUNEND arm in source_op advances the
417+
// cursor monotonically, so this avoids a full binary search on
418+
// every element.
384419
const T *ends = reinterpret_cast<const T *>(smem + src.params.runend.ends_smem_byte_offset);
385420
runend_cursors[threadIdx.x] =
386421
upper_bound(ends, src.params.runend.num_runs, threadIdx.x + src.params.runend.offset);
@@ -394,6 +429,8 @@ __device__ void execute_input_stage(const Stage &stage, char *__restrict smem) {
394429
}
395430
smem_out[i] = val;
396431
}
432+
// Write barrier: smem region is fully populated for subsequent
433+
// stages to read.
397434
__syncthreads();
398435
}
399436
}
@@ -428,6 +465,13 @@ dynamic_dispatch(T *__restrict output, uint64_t array_len, const uint8_t *__rest
428465
static_cast<uint32_t>(block_end - block_start));
429466
}
430467

468+
// Kernels are instantiated only for unsigned integer types. Signed and
469+
// floating-point arrays reuse the unsigned kernel of the same width —
470+
// the data is bit-identical under reinterpretation, and all arithmetic
471+
// in the pipeline (FoR add, ZigZag decode, ALP decode, DICT gather) is
472+
// correct on the unsigned representation. The one place where signedness
473+
// matters is load_element(), which dispatches on the per-op PTypeTag to
474+
// sign-extend or zero-extend when widening a narrow source to T.
431475
#define GENERATE_KERNEL(suffix, Type) \
432476
extern "C" __global__ void dynamic_dispatch_##suffix(Type *__restrict output, \
433477
uint64_t array_len, \

vortex-cuda/kernels/src/dynamic_dispatch.h

Lines changed: 20 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
/// ALP transforms encoded integers (i32) into floats (f32).
2424
///
2525
/// `PTypeTag` is a compact enum that identifies the primitive type at each
26-
/// point in the pipeline. The kernel uses it to dispatch typed memory
26+
/// point in the pipeline. The kernel uses it to dispatch typed memory
2727
/// operations (LOAD, BITUNPACK) and cross-stage references (DICT gather,
2828
/// RUNEND lookup) at the correct element width and signedness.
2929

@@ -34,7 +34,7 @@
3434
/// Compact tag identifying a Vortex PType for GPU dispatch.
3535
///
3636
/// NOTE: These values intentionally skip F16 (which Rust PType includes),
37-
/// so numeric values do NOT match Rust PType directly. The Rust
37+
/// so numeric values do NOT match Rust PType directly. The Rust
3838
/// `ptype_to_tag()` function handles the mapping at plan-build time.
3939
///
4040
/// The kernel uses this to:
@@ -55,35 +55,13 @@ enum PTypeTag : uint8_t {
5555
PTYPE_F64 = 9,
5656
};
5757

58-
/// Return the byte width of a PTypeTag.
58+
/// Return the unsigned equivalent of a PTypeTag (same width).
5959
#ifdef __cplusplus
6060
#ifdef __CUDACC__
6161
#define PTYPE_HOST_DEVICE __host__ __device__
6262
#else
6363
#define PTYPE_HOST_DEVICE
6464
#endif
65-
PTYPE_HOST_DEVICE constexpr uint8_t ptype_byte_width(PTypeTag tag) {
66-
switch (tag) {
67-
case PTYPE_U8:
68-
case PTYPE_I8:
69-
return 1;
70-
case PTYPE_U16:
71-
case PTYPE_I16:
72-
return 2;
73-
case PTYPE_U32:
74-
case PTYPE_I32:
75-
case PTYPE_F32:
76-
return 4;
77-
case PTYPE_U64:
78-
case PTYPE_I64:
79-
case PTYPE_F64:
80-
return 8;
81-
default:
82-
return 0;
83-
}
84-
}
85-
86-
/// Return the unsigned equivalent of a PTypeTag (same width).
8765
PTYPE_HOST_DEVICE constexpr PTypeTag ptype_to_unsigned(PTypeTag tag) {
8866
switch (tag) {
8967
case PTYPE_I8:
@@ -102,12 +80,20 @@ PTYPE_HOST_DEVICE constexpr PTypeTag ptype_to_unsigned(PTypeTag tag) {
10280
}
10381
#endif
10482

83+
/// Number of threads per CUDA block.
84+
#define BLOCK_SIZE 64
85+
10586
/// Elements processed per CUDA block.
10687
#define ELEMENTS_PER_BLOCK 2048
10788

10889
/// Each tile is flushed to global before the next is decoded.
10990
#define SMEM_TILE_SIZE 1024
11091

92+
/// Fixed shared memory declared in the kernel (bytes), excluded from
93+
/// the dynamic shared memory budget. Accounts for
94+
/// `runend_cursors[BLOCK_SIZE]` — one uint64_t cursor per thread.
95+
#define KERNEL_FIXED_SHARED_BYTES (BLOCK_SIZE * sizeof(uint64_t))
96+
11197
#ifdef __cplusplus
11298
extern "C" {
11399
#endif
@@ -127,7 +113,7 @@ union SourceParams {
127113

128114
/// Decode run-end encoding using ends and values already in shared memory.
129115
///
130-
/// The smem offsets are **byte offsets** so that ends and values can have
116+
/// The smem offsets are byte offsets so that ends and values can have
131117
/// different element widths.
132118
struct RunEndParams {
133119
uint32_t ends_smem_byte_offset; // byte offset to decoded ends in smem
@@ -151,7 +137,7 @@ struct SourceOp {
151137
/// Scalar ops: element-wise transforms in registers.
152138
///
153139
/// Each scalar op declares its `output_ptype` — the PType of the values it
154-
/// produces. Most ops preserve the input type (FOR, ZIGZAG), but some
140+
/// produces. Most ops preserve the input type (FOR, ZIGZAG), but some
155141
/// change it:
156142
/// - ALP: encoded int → float (e.g. i32 → f32)
157143
/// - DICT: codes type → values type (e.g. u8 → u32)
@@ -171,8 +157,8 @@ union ScalarParams {
171157

172158
/// Dictionary gather: use current value as index into decoded values in smem.
173159
///
174-
/// `values_smem_byte_offset` is a **byte offset** so that values can have
175-
/// a different element width than the codes. The plan builder uses
160+
/// `values_smem_byte_offset` is a byte offset so that values can have
161+
/// a different element width than the codes. The plan builder uses
176162
/// `output_ptype` (on the enclosing ScalarOp) to determine the values'
177163
/// element type.
178164
struct DictParams {
@@ -182,8 +168,8 @@ union ScalarParams {
182168

183169
struct ScalarOp {
184170
enum ScalarOpCode { FOR, ZIGZAG, ALP, DICT } op_code;
185-
/// The PType this op produces. For type-preserving ops (FOR, ZIGZAG)
186-
/// this equals the input PType. For type-changing ops (ALP, DICT) this
171+
/// The PType this op produces. For type-preserving ops (FOR, ZIGZAG)
172+
/// this equals the input PType. For type-changing ops (ALP, DICT) this
187173
/// is the new output PType.
188174
enum PTypeTag output_ptype;
189175
union ScalarParams params;
@@ -192,10 +178,10 @@ struct ScalarOp {
192178
/// Packed stage header, followed by `num_scalar_ops` inline ScalarOps.
193179
///
194180
/// `source_ptype` identifies the PType that the source op (BITUNPACK, LOAD,
195-
/// etc.) produces. This may differ from the output PType when scalar ops
181+
/// etc.) produces. This may differ from the output PType when scalar ops
196182
/// change the type (e.g. DICT transforms u8 codes into u32 values).
197183
///
198-
/// `smem_byte_offset` is a **byte offset** into the dynamic shared memory
184+
/// `smem_byte_offset` is a byte offset into the dynamic shared memory
199185
/// pool so that stages with different element widths can coexist.
200186
struct PackedStage {
201187
uint64_t input_ptr; // global memory pointer to this stage's encoded input
@@ -226,7 +212,7 @@ struct __attribute__((aligned(8))) PlanHeader {
226212
/// shared memory region for the output stage to reference. The output stage
227213
/// decodes the root encoding and writes to global memory.
228214
///
229-
/// `source_ptype` is the PType produced by the source op. Scalar ops may
215+
/// `source_ptype` is the PType produced by the source op. Scalar ops may
230216
/// change the type; the final output PType is given by the last scalar op's
231217
/// `output_ptype` (or `source_ptype` if there are no scalar ops).
232218
struct Stage {

0 commit comments

Comments
 (0)