|
39 | 39 | // |
40 | 40 | // ## Mixed-width support |
41 | 41 | // |
42 | | -// LOAD sources from pending subtrees may have a narrower type than the |
43 | | -// output (e.g. u8 dict codes in a u32 plan). load_element() widens |
44 | | -// to T via static_cast — no separate widen kernel or smem intermediate. |
| 42 | +// Dict codes, RunEnd ends, and other child arrays may have a narrower |
| 43 | +// element type than the output T. Two mechanisms handle this: |
| 44 | +// |
| 45 | +// LOAD load_element() dispatches on the per-stage PTypeTag to |
| 46 | +// read at the source's native width and static_cast to T. |
| 47 | +// BITUNPACK bitunpack_typed() unpacks at the source's native width, |
| 48 | +// then widens to T in-place via a backward scan |
| 49 | +// (widen_inplace). The smem region is pre-allocated at |
| 50 | +// max(source_width, T) bytes per element by the Rust plan |
| 51 | +// builder, so the widen never overflows. |
45 | 52 |
|
46 | 53 | #include <assert.h> |
47 | 54 | #include <cuda.h> |
@@ -203,6 +210,22 @@ scatter_patches_chunk(const GPUPatches &patches, T *__restrict out, uint32_t chu |
203 | 210 | // Source ops |
204 | 211 | // ═══════════════════════════════════════════════════════════════════════════ |
205 | 212 |
|
| 213 | +/// Widen U-sized elements in shared memory to T-sized, in-place. |
| 214 | +/// Backward scan ensures no unread element is overwritten since |
| 215 | +/// sizeof(T) >= sizeof(U) guarantees the write at index i touches |
| 216 | +/// only bytes beyond those of src[i]. |
| 217 | +template <typename T, typename U> |
| 218 | +__device__ inline void widen_inplace(T *dst, uint32_t len) { |
| 219 | + if constexpr (sizeof(T) <= sizeof(U)) { |
| 220 | + return; |
| 221 | + } |
| 222 | + const U *src = reinterpret_cast<const U *>(dst); |
| 223 | + for (int32_t i = static_cast<int32_t>(len) - 1 - threadIdx.x; i >= 0; i -= blockDim.x) { |
| 224 | + dst[i] = static_cast<T>(src[i]); |
| 225 | + } |
| 226 | + __syncthreads(); |
| 227 | +} |
| 228 | + |
206 | 229 | /// FastLanes cooperative unpack — all threads in the block scatter-write |
207 | 230 | /// decoded elements into `dst`. Caller must issue __syncthreads() before |
208 | 231 | /// any thread reads from `dst`. |
@@ -236,6 +259,68 @@ __device__ inline void bitunpack(const T *__restrict packed, |
236 | 259 | } |
237 | 260 | } |
238 | 261 |
|
| 262 | +/// Dispatch bitunpack at the source's native element width, then widen |
| 263 | +/// to T in-place so all downstream scalar ops and smem consumers see |
| 264 | +/// T-sized elements. Falls back to the direct `bitunpack<T>` path when |
| 265 | +/// the source ptype already matches T. Issues __syncthreads() before |
| 266 | +/// returning on all paths. |
| 267 | +/// |
| 268 | +/// Accepts explicit chunk_start / chunk_len so it works for both input |
| 269 | +/// stages (full decode with chunk_start=0, chunk_len=stage.len) and |
| 270 | +/// the output stage (tiled with varying chunk_start / chunk_len). |
| 271 | +template <typename T> |
| 272 | +__device__ inline void bitunpack_typed(T *__restrict dst, |
| 273 | + const void *__restrict packed, |
| 274 | + uint64_t chunk_start, |
| 275 | + uint32_t chunk_len, |
| 276 | + const struct SourceOp &src, |
| 277 | + PTypeTag source_ptype) { |
| 278 | + // Fast path: source width matches T — no widening needed. |
| 279 | + if (ptype_byte_width(source_ptype) == sizeof(T)) { |
| 280 | + bitunpack<T>(reinterpret_cast<const T *>(packed), dst, chunk_start, chunk_len, src); |
| 281 | + __syncthreads(); |
| 282 | + return; |
| 283 | + } |
| 284 | + |
| 285 | + // Compute total elements written by bitunpack (including alignment |
| 286 | + // padding) so widen_inplace covers the full scratch region. |
| 287 | + const uint32_t elem_off = src.params.bitunpack.element_offset; |
| 288 | + const uint32_t dst_off = (chunk_start + elem_off) % FL_CHUNK; |
| 289 | + const uint32_t n_chunks = (chunk_len + dst_off + FL_CHUNK - 1) / FL_CHUNK; |
| 290 | + const uint32_t total_elems = n_chunks * FL_CHUNK; |
| 291 | + |
| 292 | + // Narrow source: unpack at native width, then widen to T. |
| 293 | + switch (source_ptype) { |
| 294 | + case PTYPE_U8: |
| 295 | + case PTYPE_I8: { |
| 296 | + auto *narrow = reinterpret_cast<uint8_t *>(dst); |
| 297 | + bitunpack<uint8_t>(reinterpret_cast<const uint8_t *>(packed), narrow, chunk_start, chunk_len, src); |
| 298 | + __syncthreads(); |
| 299 | + widen_inplace<T, uint8_t>(dst, total_elems); |
| 300 | + break; |
| 301 | + } |
| 302 | + case PTYPE_U16: |
| 303 | + case PTYPE_I16: { |
| 304 | + auto *narrow = reinterpret_cast<uint16_t *>(dst); |
| 305 | + bitunpack<uint16_t>(reinterpret_cast<const uint16_t *>(packed), narrow, chunk_start, chunk_len, src); |
| 306 | + __syncthreads(); |
| 307 | + widen_inplace<T, uint16_t>(dst, total_elems); |
| 308 | + break; |
| 309 | + } |
| 310 | + case PTYPE_U32: |
| 311 | + case PTYPE_I32: |
| 312 | + case PTYPE_F32: { |
| 313 | + auto *narrow = reinterpret_cast<uint32_t *>(dst); |
| 314 | + bitunpack<uint32_t>(reinterpret_cast<const uint32_t *>(packed), narrow, chunk_start, chunk_len, src); |
| 315 | + __syncthreads(); |
| 316 | + widen_inplace<T, uint32_t>(dst, total_elems); |
| 317 | + break; |
| 318 | + } |
| 319 | + default: |
| 320 | + __builtin_unreachable(); |
| 321 | + } |
| 322 | +} |
| 323 | + |
239 | 324 | /// Read N values from a source op into `out`. |
240 | 325 | /// |
241 | 326 | /// Dispatches on `src.op_code` to handle each encoding: |
@@ -354,16 +439,14 @@ __device__ void execute_output_stage(T *__restrict output, |
354 | 439 | if (src.op_code == SourceOp::BITUNPACK) { |
355 | 440 | chunk_len = bitunpack_tile_len(stage, block_len, elem_idx); |
356 | 441 | T *scratch = reinterpret_cast<T *>(smem + stage.smem_byte_offset); |
357 | | - bitunpack<T>(reinterpret_cast<const T *>(stage.input_ptr), |
358 | | - scratch, |
359 | | - block_start + elem_idx, |
360 | | - chunk_len, |
361 | | - src); |
| 442 | + bitunpack_typed<T>(scratch, |
| 443 | + reinterpret_cast<const void *>(stage.input_ptr), |
| 444 | + block_start + elem_idx, |
| 445 | + chunk_len, |
| 446 | + src, |
| 447 | + ptype); |
362 | 448 | const uint32_t align = (block_start + elem_idx + src.params.bitunpack.element_offset) % FL_CHUNK; |
363 | 449 | smem_src = scratch + align; |
364 | | - // Write barrier: all threads finished bitunpack (and any |
365 | | - // patches), safe to read from scratch. |
366 | | - __syncthreads(); |
367 | 450 | } else { |
368 | 451 | chunk_len = block_len; |
369 | 452 | } |
@@ -438,11 +521,12 @@ __device__ void execute_input_stage(const Stage &stage, char *__restrict smem) { |
438 | 521 | const auto &src = stage.source; |
439 | 522 |
|
440 | 523 | if (src.op_code == SourceOp::BITUNPACK) { |
441 | | - T *raw_smem = smem_out; |
442 | | - bitunpack<T>(reinterpret_cast<const T *>(stage.input_ptr), smem_out, 0, stage.len, src); |
443 | | - // Write barrier: cooperative bitunpack finished, safe to read |
444 | | - // decoded elements below. |
445 | | - __syncthreads(); |
| 524 | + bitunpack_typed<T>(smem_out, |
| 525 | + reinterpret_cast<const void *>(stage.input_ptr), |
| 526 | + 0, |
| 527 | + stage.len, |
| 528 | + src, |
| 529 | + stage.source_ptype); |
446 | 530 |
|
447 | 531 | smem_out += src.params.bitunpack.element_offset % SMEM_TILE_SIZE; |
448 | 532 |
|
|
0 commit comments