5454#include " dynamic_dispatch.h"
5555#include " types.cuh"
5656
57- // / Number of threads per CUDA block (must match the Rust launch config).
58- constexpr uint32_t BLOCK_SIZE = 64 ;
5957
6058// ═══════════════════════════════════════════════════════════════════════════
6159// Primitives
@@ -70,26 +68,41 @@ __device__ inline uint64_t upper_bound(const T *data, uint64_t len, uint64_t val
7068}
7169
7270// / Read one element from global memory at `ptype` width, widen to T.
71+ // / Signed types are sign-extended; unsigned types are zero-extended.
7372template <typename T>
7473__device__ inline T load_element (const void *__restrict ptr, PTypeTag ptype, uint64_t idx) {
75- switch (ptype_to_unsigned ( ptype) ) {
74+ switch (ptype) {
7675 case PTYPE_U8:
7776 return static_cast <T>(static_cast <const uint8_t *>(ptr)[idx]);
77+ case PTYPE_I8:
78+ return static_cast <T>(static_cast <const int8_t *>(ptr)[idx]);
7879 case PTYPE_U16:
7980 return static_cast <T>(static_cast <const uint16_t *>(ptr)[idx]);
81+ case PTYPE_I16:
82+ return static_cast <T>(static_cast <const int16_t *>(ptr)[idx]);
8083 case PTYPE_U32:
84+ case PTYPE_F32:
8185 return static_cast <T>(static_cast <const uint32_t *>(ptr)[idx]);
86+ case PTYPE_I32:
87+ return static_cast <T>(static_cast <const int32_t *>(ptr)[idx]);
8288 case PTYPE_U64:
89+ case PTYPE_F64:
8390 return static_cast <T>(static_cast <const uint64_t *>(ptr)[idx]);
91+ case PTYPE_I64:
92+ return static_cast <T>(static_cast <const int64_t *>(ptr)[idx]);
8493 default :
8594 __builtin_unreachable ();
8695 }
8796}
8897
89- // / RUNEND forward-scan cursor for each thread, stored in shared memory so
90- // / source_op can advance it across calls. Pre-seeded with upper_bound before
91- // / the tile loop; the RUNEND arm in source_op advances it monotonically.
92- // / Sized to BLOCK_SIZE to match the kernel block size, one entry per thread.
98+ // / Per-thread run cursor for RUNEND forward-scan, one entry per thread.
99+ // /
100+ // / Stored in shared memory so the cursor persists across successive
101+ // / source_op calls in the tile loop. Each thread's positions are
102+ // / monotonically increasing across tiles, so the cursor only advances
103+ // / forward — the next tile picks up exactly where the previous one
104+ // / stopped, avoiding a binary search per tile. The only binary search
105+ // / is the initial upper_bound seed before the tile loop begins.
93106__shared__ uint64_t runend_cursors[BLOCK_SIZE];
94107
95108// ═══════════════════════════════════════════════════════════════════════════
@@ -98,37 +111,37 @@ __shared__ uint64_t runend_cursors[BLOCK_SIZE];
98111
99112// / Apply one scalar operation to N values in registers.
100113template <typename T, uint32_t N>
101- __device__ inline void scalar_op (T *v , const struct ScalarOp &op, char *__restrict smem) {
114+ __device__ inline void scalar_op (T *values , const struct ScalarOp &op, char *__restrict smem) {
102115 switch (op.op_code ) {
103116 case ScalarOp::FOR: {
104117 const T ref = static_cast <T>(op.params .frame_of_ref .reference );
105118#pragma unroll
106119 for (uint32_t i = 0 ; i < N; ++i) {
107- v [i] += ref;
120+ values [i] += ref;
108121 }
109122 break ;
110123 }
111124 case ScalarOp::ZIGZAG: {
112125#pragma unroll
113126 for (uint32_t i = 0 ; i < N; ++i) {
114- v [i] = (v [i] >> 1 ) ^ static_cast <T>(-(v [i] & 1 ));
127+ values [i] = (values [i] >> 1 ) ^ static_cast <T>(-(values [i] & 1 ));
115128 }
116129 break ;
117130 }
118131 case ScalarOp::ALP: {
119132 const float f = op.params .alp .f , e = op.params .alp .e ;
120133#pragma unroll
121134 for (uint32_t i = 0 ; i < N; ++i) {
122- float r = static_cast <float >(static_cast <int32_t >(v [i])) * f * e;
123- v [i] = static_cast <T>(__float_as_uint (r));
135+ float r = static_cast <float >(static_cast <int32_t >(values [i])) * f * e;
136+ values [i] = static_cast <T>(__float_as_uint (r));
124137 }
125138 break ;
126139 }
127140 case ScalarOp::DICT: {
128141 const T *dict = reinterpret_cast <const T *>(smem + op.params .dict .values_smem_byte_offset );
129142#pragma unroll
130143 for (uint32_t i = 0 ; i < N; ++i) {
131- v [i] = dict[static_cast <uint32_t >(v [i])];
144+ values [i] = dict[static_cast <uint32_t >(values [i])];
132145 }
133146 break ;
134147 }
@@ -182,7 +195,7 @@ __device__ inline void bitunpack(const T *__restrict packed,
182195// / Position calculation (via THREAD_POS macro):
183196// / N > 1 (batched): pos = base + j·blockDim.x + threadIdx.x.
184197// / Caller passes the tile base WITHOUT threadIdx.x.
185- // / N = 1 (single): base IS the exact position. No stride added.
198+ // / N = 1 (single): base is the exact position. No stride added.
186199template <typename T, uint32_t N>
187200__device__ inline void source_op (T *out,
188201 const struct SourceOp &src,
@@ -247,7 +260,8 @@ __device__ inline void source_op(T *out,
247260// ═══════════════════════════════════════════════════════════════════════════
248261//
249262// BITUNPACK tiles at SMEM_TILE_SIZE: cooperative unpack → smem → sync →
250- // batched read. All other source ops: single pass, no smem scratch.
263+ // batched read. LOAD, SEQUENCE, and RUNEND need no smem scratch and
264+ // process the full block in a single outer iteration, tiled by tile_idx.
251265
252266// / How many elements to process in this BITUNPACK tile iteration.
253267// / The first tile may be shorter due to `element_offset` alignment;
@@ -257,7 +271,7 @@ __device__ inline uint32_t bitunpack_tile_len(const Stage &stage, uint32_t block
257271 return min (SMEM_TILE_SIZE - off, block_len - tile_off);
258272}
259273
260- // / Process the final ( output) stage: decode source → apply scalar ops →
274+ // / Process the final / output stage: decode source → apply scalar ops →
261275// / streaming-store to global memory. Handles the full block, tiling through
262276// / smem scratch for BITUNPACK.
263277template <typename T>
@@ -273,6 +287,10 @@ __device__ void execute_output_stage(T *__restrict output,
273287 const PTypeTag ptype = stage.source_ptype ;
274288
275289 if (src.op_code == SourceOp::RUNEND) {
290+ // Seed each thread's cursor with the run containing its first
291+ // strided position. The RUNEND arm in source_op advances the
292+ // cursor monotonically, so this avoids a full binary search on
293+ // every element.
276294 const T *ends = reinterpret_cast <const T *>(smem + src.params .runend .ends_smem_byte_offset );
277295 runend_cursors[threadIdx .x ] = upper_bound (ends,
278296 src.params .runend .num_runs ,
@@ -283,9 +301,10 @@ __device__ void execute_output_stage(T *__restrict output,
283301 uint32_t chunk_len;
284302 const T *smem_src = nullptr ;
285303
286- // BITUNPACK chunks the block into shared-memory-sized tiles, so this
287- // advances by one tile per iteration. All other source ops process the
288- // entire block in a single iteration (chunk_len = block_len).
304+ // BITUNPACK uses smem scratch, so the outer loop advances one
305+ // chunk at a time. LOAD, SEQUENCE, and RUNEND need no smem
306+ // scratch, so chunk_len = block_len (single outer iteration);
307+ // tiling happens in the inner tile_idx loop.
289308 if (src.op_code == SourceOp::BITUNPACK) {
290309 chunk_len = bitunpack_tile_len (stage, block_len, elem_idx);
291310 T *scratch = reinterpret_cast <T *>(smem + stage.smem_byte_offset );
@@ -323,8 +342,12 @@ __device__ void execute_output_stage(T *__restrict output,
323342
324343#pragma unroll
325344 for (uint32_t j = 0 ; j < VALUES_PER_TILE; ++j) {
326- // __stcs bypasses L1 and invalidates the L2 line after write,
327- // avoiding eviction of input-side data from the cache.
345+ // st.cs (cache streaming): marks this line for earliest
346+ // eviction in L1 and L2. Output data is written once and
347+ // never read again by this kernel, so keeping it cached
348+ // would only compete with the packed input buffers and
349+ // smem-resident dict/runend data that the next tiles still
350+ // need to read. Evict-first lets those stay resident.
328351 __stcs (&output[tile_start + j * blockDim .x + threadIdx .x ], values[j]);
329352 }
330353 }
@@ -338,8 +361,6 @@ __device__ void execute_output_stage(T *__restrict output,
338361 for (uint8_t op = 0 ; op < stage.num_scalar_ops ; ++op) {
339362 scalar_op<T, 1 >(&val, stage.scalar_ops [op], smem);
340363 }
341- // __stcs bypasses L1 and invalidates the L2 line after write,
342- // avoiding eviction of input-side data from the cache.
343364 __stcs (&output[gpos], val);
344365 }
345366
@@ -359,6 +380,12 @@ __device__ void execute_output_stage(T *__restrict output,
359380// / Decode one input stage (dict values, run-end endpoints, etc.) into its
360381// / shared memory region so the output stage can reference it later.
361382// / Applies any scalar ops in-place before returning.
383+ // /
384+ // / Unlike execute_output_stage, this does not tile — the entire stage is
385+ // / decoded in one pass. The output stage needs random access into these
386+ // / smem regions (e.g. DICT gathers by arbitrary code value), so the data
387+ // / must be fully resident. The smem limit check in the Rust plan builder
388+ // / ensures the stage fits; if it doesn't, the plan falls back to Unfused.
362389template <typename T>
363390__device__ void execute_input_stage (const Stage &stage, char *__restrict smem) {
364391 T *smem_out = reinterpret_cast <T *>(smem + stage.smem_byte_offset );
@@ -367,6 +394,8 @@ __device__ void execute_input_stage(const Stage &stage, char *__restrict smem) {
367394 if (src.op_code == SourceOp::BITUNPACK) {
368395 bitunpack<T>(reinterpret_cast <const T *>(stage.input_ptr ), smem_out, 0 , stage.len , src);
369396 smem_out += src.params .bitunpack .element_offset % SMEM_TILE_SIZE;
397+ // Write barrier: cooperative bitunpack finished, safe to read
398+ // decoded elements in the scalar-op loop below.
370399 __syncthreads ();
371400
372401 if (stage.num_scalar_ops > 0 ) {
@@ -377,10 +406,16 @@ __device__ void execute_input_stage(const Stage &stage, char *__restrict smem) {
377406 }
378407 smem_out[i] = val;
379408 }
409+ // Write barrier: scalar ops applied in-place, smem region is
410+ // now fully populated for subsequent stages to read.
380411 __syncthreads ();
381412 }
382413 } else {
383414 if (src.op_code == SourceOp::RUNEND) {
415+ // Seed each thread's cursor with the run containing its first
416+ // strided position. The RUNEND arm in source_op advances the
417+ // cursor monotonically, so this avoids a full binary search on
418+ // every element.
384419 const T *ends = reinterpret_cast <const T *>(smem + src.params .runend .ends_smem_byte_offset );
385420 runend_cursors[threadIdx .x ] =
386421 upper_bound (ends, src.params .runend .num_runs , threadIdx .x + src.params .runend .offset );
@@ -394,6 +429,8 @@ __device__ void execute_input_stage(const Stage &stage, char *__restrict smem) {
394429 }
395430 smem_out[i] = val;
396431 }
432+ // Write barrier: smem region is fully populated for subsequent
433+ // stages to read.
397434 __syncthreads ();
398435 }
399436}
@@ -428,6 +465,13 @@ dynamic_dispatch(T *__restrict output, uint64_t array_len, const uint8_t *__rest
428465 static_cast <uint32_t >(block_end - block_start));
429466}
430467
468+ // Kernels are instantiated only for unsigned integer types. Signed and
469+ // floating-point arrays reuse the unsigned kernel of the same width —
470+ // the data is bit-identical under reinterpretation, and all arithmetic
471+ // in the pipeline (FoR add, ZigZag decode, ALP decode, DICT gather) is
472+ // correct on the unsigned representation. The one place where signedness
473+ // matters is load_element(), which dispatches on the per-op PTypeTag to
474+ // sign-extend or zero-extend when widening a narrow source to T.
431475#define GENERATE_KERNEL (suffix, Type ) \
432476 extern " C" __global__ void dynamic_dispatch_##suffix(Type *__restrict output, \
433477 uint64_t array_len, \
0 commit comments