5454#include " dynamic_dispatch.h"
5555#include " types.cuh"
5656
57- // / Number of threads per CUDA block (must match the Rust launch config).
58- constexpr uint32_t BLOCK_SIZE = 64 ;
59-
6057// ═══════════════════════════════════════════════════════════════════════════
6158// Primitives
6259// ═══════════════════════════════════════════════════════════════════════════
@@ -70,26 +67,41 @@ __device__ inline uint64_t upper_bound(const T *data, uint64_t len, uint64_t val
7067}
7168
7269// / Read one element from global memory at `ptype` width, widen to T.
70+ // / Signed types are sign-extended; unsigned types are zero-extended.
7371template <typename T>
7472__device__ inline T load_element (const void *__restrict ptr, PTypeTag ptype, uint64_t idx) {
75- switch (ptype_to_unsigned ( ptype) ) {
73+ switch (ptype) {
7674 case PTYPE_U8:
7775 return static_cast <T>(static_cast <const uint8_t *>(ptr)[idx]);
76+ case PTYPE_I8:
77+ return static_cast <T>(static_cast <const int8_t *>(ptr)[idx]);
7878 case PTYPE_U16:
7979 return static_cast <T>(static_cast <const uint16_t *>(ptr)[idx]);
80+ case PTYPE_I16:
81+ return static_cast <T>(static_cast <const int16_t *>(ptr)[idx]);
8082 case PTYPE_U32:
83+ case PTYPE_F32:
8184 return static_cast <T>(static_cast <const uint32_t *>(ptr)[idx]);
85+ case PTYPE_I32:
86+ return static_cast <T>(static_cast <const int32_t *>(ptr)[idx]);
8287 case PTYPE_U64:
88+ case PTYPE_F64:
8389 return static_cast <T>(static_cast <const uint64_t *>(ptr)[idx]);
90+ case PTYPE_I64:
91+ return static_cast <T>(static_cast <const int64_t *>(ptr)[idx]);
8492 default :
8593 __builtin_unreachable ();
8694 }
8795}
8896
89- // / RUNEND forward-scan cursor for each thread, stored in shared memory so
90- // / source_op can advance it across calls. Pre-seeded with upper_bound before
91- // / the tile loop; the RUNEND arm in source_op advances it monotonically.
92- // / Sized to BLOCK_SIZE to match the kernel block size, one entry per thread.
97+ // / Per-thread run cursor for RUNEND forward-scan, one entry per thread.
98+ // /
99+ // / Stored in shared memory so the cursor persists across successive
100+ // / source_op calls in the tile loop. Each thread's positions are
101+ // / monotonically increasing across tiles, so the cursor only advances
102+ // / forward — the next tile picks up exactly where the previous one
103+ // / stopped, avoiding a binary search per tile. The only binary search
104+ // / is the initial upper_bound seed before the tile loop begins.
93105__shared__ uint64_t runend_cursors[BLOCK_SIZE];
94106
95107// ═══════════════════════════════════════════════════════════════════════════
@@ -98,37 +110,37 @@ __shared__ uint64_t runend_cursors[BLOCK_SIZE];
98110
99111// / Apply one scalar operation to N values in registers.
100112template <typename T, uint32_t N>
101- __device__ inline void scalar_op (T *v , const struct ScalarOp &op, char *__restrict smem) {
113+ __device__ inline void scalar_op (T *values , const struct ScalarOp &op, char *__restrict smem) {
102114 switch (op.op_code ) {
103115 case ScalarOp::FOR: {
104116 const T ref = static_cast <T>(op.params .frame_of_ref .reference );
105117#pragma unroll
106118 for (uint32_t i = 0 ; i < N; ++i) {
107- v [i] += ref;
119+ values [i] += ref;
108120 }
109121 break ;
110122 }
111123 case ScalarOp::ZIGZAG: {
112124#pragma unroll
113125 for (uint32_t i = 0 ; i < N; ++i) {
114- v [i] = (v [i] >> 1 ) ^ static_cast <T>(-(v [i] & 1 ));
126+ values [i] = (values [i] >> 1 ) ^ static_cast <T>(-(values [i] & 1 ));
115127 }
116128 break ;
117129 }
118130 case ScalarOp::ALP: {
119131 const float f = op.params .alp .f , e = op.params .alp .e ;
120132#pragma unroll
121133 for (uint32_t i = 0 ; i < N; ++i) {
122- float r = static_cast <float >(static_cast <int32_t >(v [i])) * f * e;
123- v [i] = static_cast <T>(__float_as_uint (r));
134+ float r = static_cast <float >(static_cast <int32_t >(values [i])) * f * e;
135+ values [i] = static_cast <T>(__float_as_uint (r));
124136 }
125137 break ;
126138 }
127139 case ScalarOp::DICT: {
128140 const T *dict = reinterpret_cast <const T *>(smem + op.params .dict .values_smem_byte_offset );
129141#pragma unroll
130142 for (uint32_t i = 0 ; i < N; ++i) {
131- v [i] = dict[static_cast <uint32_t >(v [i])];
143+ values [i] = dict[static_cast <uint32_t >(values [i])];
132144 }
133145 break ;
134146 }
@@ -182,7 +194,7 @@ __device__ inline void bitunpack(const T *__restrict packed,
182194// / Position calculation (via THREAD_POS macro):
183195// / N > 1 (batched): pos = base + j·blockDim.x + threadIdx.x.
184196// / Caller passes the tile base WITHOUT threadIdx.x.
185- // / N = 1 (single): base IS the exact position. No stride added.
197+ // / N = 1 (single): base is the exact position. No stride added.
186198template <typename T, uint32_t N>
187199__device__ inline void source_op (T *out,
188200 const struct SourceOp &src,
@@ -247,7 +259,8 @@ __device__ inline void source_op(T *out,
247259// ═══════════════════════════════════════════════════════════════════════════
248260//
249261// BITUNPACK tiles at SMEM_TILE_SIZE: cooperative unpack → smem → sync →
250- // batched read. All other source ops: single pass, no smem scratch.
262+ // batched read. LOAD, SEQUENCE, and RUNEND need no smem scratch and
263+ // process the full block in a single outer iteration, tiled by tile_idx.
251264
252265// / How many elements to process in this BITUNPACK tile iteration.
253266// / The first tile may be shorter due to `element_offset` alignment;
@@ -257,7 +270,7 @@ __device__ inline uint32_t bitunpack_tile_len(const Stage &stage, uint32_t block
257270 return min (SMEM_TILE_SIZE - off, block_len - tile_off);
258271}
259272
260- // / Process the final ( output) stage: decode source → apply scalar ops →
273+ // / Process the final / output stage: decode source → apply scalar ops →
261274// / streaming-store to global memory. Handles the full block, tiling through
262275// / smem scratch for BITUNPACK.
263276template <typename T>
@@ -273,6 +286,10 @@ __device__ void execute_output_stage(T *__restrict output,
273286 const PTypeTag ptype = stage.source_ptype ;
274287
275288 if (src.op_code == SourceOp::RUNEND) {
289+ // Seed each thread's cursor with the run containing its first
290+ // strided position. The RUNEND arm in source_op advances the
291+ // cursor monotonically, so this avoids a full binary search on
292+ // every element.
276293 const T *ends = reinterpret_cast <const T *>(smem + src.params .runend .ends_smem_byte_offset );
277294 runend_cursors[threadIdx .x ] = upper_bound (ends,
278295 src.params .runend .num_runs ,
@@ -283,9 +300,10 @@ __device__ void execute_output_stage(T *__restrict output,
283300 uint32_t chunk_len;
284301 const T *smem_src = nullptr ;
285302
286- // BITUNPACK chunks the block into shared-memory-sized tiles, so this
287- // advances by one tile per iteration. All other source ops process the
288- // entire block in a single iteration (chunk_len = block_len).
303+ // BITUNPACK uses smem scratch, so the outer loop advances one
304+ // chunk at a time. LOAD, SEQUENCE, and RUNEND need no smem
305+ // scratch, so chunk_len = block_len (single outer iteration);
306+ // tiling happens in the inner tile_idx loop.
289307 if (src.op_code == SourceOp::BITUNPACK) {
290308 chunk_len = bitunpack_tile_len (stage, block_len, elem_idx);
291309 T *scratch = reinterpret_cast <T *>(smem + stage.smem_byte_offset );
@@ -323,8 +341,12 @@ __device__ void execute_output_stage(T *__restrict output,
323341
324342#pragma unroll
325343 for (uint32_t j = 0 ; j < VALUES_PER_TILE; ++j) {
326- // __stcs bypasses L1 and invalidates the L2 line after write,
327- // avoiding eviction of input-side data from the cache.
344+ // st.cs (cache streaming): marks this line for earliest
345+ // eviction in L1 and L2. Output data is written once and
346+ // never read again by this kernel, so keeping it cached
347+ // would only compete with the packed input buffers and
348+ // smem-resident dict/runend data that the next tiles still
349+ // need to read. Evict-first lets those stay resident.
328350 __stcs (&output[tile_start + j * blockDim .x + threadIdx .x ], values[j]);
329351 }
330352 }
@@ -338,8 +360,6 @@ __device__ void execute_output_stage(T *__restrict output,
338360 for (uint8_t op = 0 ; op < stage.num_scalar_ops ; ++op) {
339361 scalar_op<T, 1 >(&val, stage.scalar_ops [op], smem);
340362 }
341- // __stcs bypasses L1 and invalidates the L2 line after write,
342- // avoiding eviction of input-side data from the cache.
343363 __stcs (&output[gpos], val);
344364 }
345365
@@ -359,6 +379,12 @@ __device__ void execute_output_stage(T *__restrict output,
359379// / Decode one input stage (dict values, run-end endpoints, etc.) into its
360380// / shared memory region so the output stage can reference it later.
361381// / Applies any scalar ops in-place before returning.
382+ // /
383+ // / Unlike execute_output_stage, this does not tile — the entire stage is
384+ // / decoded in one pass. The output stage needs random access into these
385+ // / smem regions (e.g. DICT gathers by arbitrary code value), so the data
386+ // / must be fully resident. The smem limit check in the Rust plan builder
387+ // / ensures the stage fits; if it doesn't, the plan falls back to Unfused.
362388template <typename T>
363389__device__ void execute_input_stage (const Stage &stage, char *__restrict smem) {
364390 T *smem_out = reinterpret_cast <T *>(smem + stage.smem_byte_offset );
@@ -367,6 +393,8 @@ __device__ void execute_input_stage(const Stage &stage, char *__restrict smem) {
367393 if (src.op_code == SourceOp::BITUNPACK) {
368394 bitunpack<T>(reinterpret_cast <const T *>(stage.input_ptr ), smem_out, 0 , stage.len , src);
369395 smem_out += src.params .bitunpack .element_offset % SMEM_TILE_SIZE;
396+ // Write barrier: cooperative bitunpack finished, safe to read
397+ // decoded elements in the scalar-op loop below.
370398 __syncthreads ();
371399
372400 if (stage.num_scalar_ops > 0 ) {
@@ -377,10 +405,16 @@ __device__ void execute_input_stage(const Stage &stage, char *__restrict smem) {
377405 }
378406 smem_out[i] = val;
379407 }
408+ // Write barrier: scalar ops applied in-place, smem region is
409+ // now fully populated for subsequent stages to read.
380410 __syncthreads ();
381411 }
382412 } else {
383413 if (src.op_code == SourceOp::RUNEND) {
414+ // Seed each thread's cursor with the run containing its first
415+ // strided position. The RUNEND arm in source_op advances the
416+ // cursor monotonically, so this avoids a full binary search on
417+ // every element.
384418 const T *ends = reinterpret_cast <const T *>(smem + src.params .runend .ends_smem_byte_offset );
385419 runend_cursors[threadIdx .x ] =
386420 upper_bound (ends, src.params .runend .num_runs , threadIdx .x + src.params .runend .offset );
@@ -394,6 +428,8 @@ __device__ void execute_input_stage(const Stage &stage, char *__restrict smem) {
394428 }
395429 smem_out[i] = val;
396430 }
431+ // Write barrier: smem region is fully populated for subsequent
432+ // stages to read.
397433 __syncthreads ();
398434 }
399435}
@@ -428,6 +464,13 @@ dynamic_dispatch(T *__restrict output, uint64_t array_len, const uint8_t *__rest
428464 static_cast <uint32_t >(block_end - block_start));
429465}
430466
467+ // Kernels are instantiated only for unsigned integer types. Signed and
468+ // floating-point arrays reuse the unsigned kernel of the same width —
469+ // the data is bit-identical under reinterpretation, and all arithmetic
470+ // in the pipeline (FoR add, ZigZag decode, ALP decode, DICT gather) is
471+ // correct on the unsigned representation. The one place where signedness
472+ // matters is load_element(), which dispatches on the per-op PTypeTag to
473+ // sign-extend or zero-extend when widening a narrow source to T.
431474#define GENERATE_KERNEL (suffix, Type ) \
432475 extern " C" __global__ void dynamic_dispatch_##suffix(Type *__restrict output, \
433476 uint64_t array_len, \
0 commit comments