|
39 | 39 | // |
40 | 40 | // ## Mixed-width support |
41 | 41 | // |
42 | | -// LOAD sources from pending subtrees may have a narrower type than the |
43 | | -// output (e.g. u8 dict codes in a u32 plan). load_element() widens |
44 | | -// to T via static_cast — no separate widen kernel or smem intermediate. |
| 42 | +// Dict codes, RunEnd ends, and other child arrays may have a narrower |
| 43 | +// element type than the output T. Two mechanisms handle this: |
| 44 | +// |
| 45 | +// LOAD load_element() dispatches on the per-stage PTypeTag to |
| 46 | +// read at the source's native width and static_cast to T. |
| 47 | +// BITUNPACK bitunpack_typed() unpacks at the source's native width, |
| 48 | +// then widens to T in-place via a backward scan |
| 49 | +// (widen_inplace). The smem region is pre-allocated at |
| 50 | +// max(source_width, T) bytes per element by the Rust plan |
| 51 | +// builder, so the widen never overflows. |
45 | 52 |
|
46 | 53 | #include <assert.h> |
47 | 54 | #include <cuda.h> |
@@ -203,6 +210,20 @@ scatter_patches_chunk(const GPUPatches &patches, T *__restrict out, uint32_t chu |
203 | 210 | // Source ops |
204 | 211 | // ═══════════════════════════════════════════════════════════════════════════ |
205 | 212 |
|
| 213 | +/// Widen U-sized elements in shared memory to T-sized, in-place. |
| 214 | +/// Backward scan ensures no unread element is overwritten since |
| 215 | +/// sizeof(T) >= sizeof(U) guarantees the write at index i touches |
| 216 | +/// only bytes beyond those of src[i]. |
| 217 | +template <typename T, typename U> |
| 218 | +__device__ inline void widen_inplace(T *dst, uint32_t len) { |
| 219 | + if constexpr (sizeof(T) <= sizeof(U)) return; |
| 220 | + const U *src = reinterpret_cast<const U *>(dst); |
| 221 | + for (int32_t i = static_cast<int32_t>(len) - 1 - threadIdx.x; i >= 0; i -= blockDim.x) { |
| 222 | + dst[i] = static_cast<T>(src[i]); |
| 223 | + } |
| 224 | + __syncthreads(); |
| 225 | +} |
| 226 | + |
206 | 227 | /// FastLanes cooperative unpack — all threads in the block scatter-write |
207 | 228 | /// decoded elements into `dst`. Caller must issue __syncthreads() before |
208 | 229 | /// any thread reads from `dst`. |
@@ -236,6 +257,68 @@ __device__ inline void bitunpack(const T *__restrict packed, |
236 | 257 | } |
237 | 258 | } |
238 | 259 |
|
| 260 | +/// Dispatch bitunpack at the source's native element width, then widen |
| 261 | +/// to T in-place so all downstream scalar ops and smem consumers see |
| 262 | +/// T-sized elements. Falls back to the direct `bitunpack<T>` path when |
| 263 | +/// the source ptype already matches T. Issues __syncthreads() before |
| 264 | +/// returning on all paths. |
| 265 | +/// |
| 266 | +/// Accepts explicit chunk_start / chunk_len so it works for both input |
| 267 | +/// stages (full decode with chunk_start=0, chunk_len=stage.len) and |
| 268 | +/// the output stage (tiled with varying chunk_start / chunk_len). |
| 269 | +template <typename T> |
| 270 | +__device__ inline void bitunpack_typed(T *__restrict dst, |
| 271 | + const void *__restrict packed, |
| 272 | + uint64_t chunk_start, |
| 273 | + uint32_t chunk_len, |
| 274 | + const struct SourceOp &src, |
| 275 | + PTypeTag source_ptype) { |
| 276 | + // Fast path: source width matches T — no widening needed. |
| 277 | + if (ptype_byte_width(source_ptype) == sizeof(T)) { |
| 278 | + bitunpack<T>(reinterpret_cast<const T *>(packed), dst, chunk_start, chunk_len, src); |
| 279 | + __syncthreads(); |
| 280 | + return; |
| 281 | + } |
| 282 | + |
| 283 | + // Compute total elements written by bitunpack (including alignment |
| 284 | + // padding) so widen_inplace covers the full scratch region. |
| 285 | + const uint32_t elem_off = src.params.bitunpack.element_offset; |
| 286 | + const uint32_t dst_off = (chunk_start + elem_off) % FL_CHUNK; |
| 287 | + const uint32_t n_chunks = (chunk_len + dst_off + FL_CHUNK - 1) / FL_CHUNK; |
| 288 | + const uint32_t total_elems = n_chunks * FL_CHUNK; |
| 289 | + |
| 290 | + // Narrow source: unpack at native width, then widen to T. |
| 291 | + switch (source_ptype) { |
| 292 | + case PTYPE_U8: |
| 293 | + case PTYPE_I8: { |
| 294 | + auto *narrow = reinterpret_cast<uint8_t *>(dst); |
| 295 | + bitunpack<uint8_t>(reinterpret_cast<const uint8_t *>(packed), narrow, chunk_start, chunk_len, src); |
| 296 | + __syncthreads(); |
| 297 | + widen_inplace<T, uint8_t>(dst, total_elems); |
| 298 | + break; |
| 299 | + } |
| 300 | + case PTYPE_U16: |
| 301 | + case PTYPE_I16: { |
| 302 | + auto *narrow = reinterpret_cast<uint16_t *>(dst); |
| 303 | + bitunpack<uint16_t>(reinterpret_cast<const uint16_t *>(packed), narrow, chunk_start, chunk_len, src); |
| 304 | + __syncthreads(); |
| 305 | + widen_inplace<T, uint16_t>(dst, total_elems); |
| 306 | + break; |
| 307 | + } |
| 308 | + case PTYPE_U32: |
| 309 | + case PTYPE_I32: |
| 310 | + case PTYPE_F32: { |
| 311 | + auto *narrow = reinterpret_cast<uint32_t *>(dst); |
| 312 | + bitunpack<uint32_t>(reinterpret_cast<const uint32_t *>(packed), narrow, chunk_start, chunk_len, src); |
| 313 | + __syncthreads(); |
| 314 | + widen_inplace<T, uint32_t>(dst, total_elems); |
| 315 | + break; |
| 316 | + } |
| 317 | + default: |
| 318 | + __builtin_unreachable(); |
| 319 | + } |
| 320 | +} |
| 321 | + |
239 | 322 | /// Read N values from a source op into `out`. |
240 | 323 | /// |
241 | 324 | /// Dispatches on `src.op_code` to handle each encoding: |
@@ -354,16 +437,14 @@ __device__ void execute_output_stage(T *__restrict output, |
354 | 437 | if (src.op_code == SourceOp::BITUNPACK) { |
355 | 438 | chunk_len = bitunpack_tile_len(stage, block_len, elem_idx); |
356 | 439 | T *scratch = reinterpret_cast<T *>(smem + stage.smem_byte_offset); |
357 | | - bitunpack<T>(reinterpret_cast<const T *>(stage.input_ptr), |
358 | | - scratch, |
359 | | - block_start + elem_idx, |
360 | | - chunk_len, |
361 | | - src); |
| 440 | + bitunpack_typed<T>(scratch, |
| 441 | + reinterpret_cast<const void *>(stage.input_ptr), |
| 442 | + block_start + elem_idx, |
| 443 | + chunk_len, |
| 444 | + src, |
| 445 | + ptype); |
362 | 446 | const uint32_t align = (block_start + elem_idx + src.params.bitunpack.element_offset) % FL_CHUNK; |
363 | 447 | smem_src = scratch + align; |
364 | | - // Write barrier: all threads finished bitunpack (and any |
365 | | - // patches), safe to read from scratch. |
366 | | - __syncthreads(); |
367 | 448 | } else { |
368 | 449 | chunk_len = block_len; |
369 | 450 | } |
@@ -438,11 +519,12 @@ __device__ void execute_input_stage(const Stage &stage, char *__restrict smem) { |
438 | 519 | const auto &src = stage.source; |
439 | 520 |
|
440 | 521 | if (src.op_code == SourceOp::BITUNPACK) { |
441 | | - T *raw_smem = smem_out; |
442 | | - bitunpack<T>(reinterpret_cast<const T *>(stage.input_ptr), smem_out, 0, stage.len, src); |
443 | | - // Write barrier: cooperative bitunpack finished, safe to read |
444 | | - // decoded elements below. |
445 | | - __syncthreads(); |
| 522 | + bitunpack_typed<T>(smem_out, |
| 523 | + reinterpret_cast<const void *>(stage.input_ptr), |
| 524 | + 0, |
| 525 | + stage.len, |
| 526 | + src, |
| 527 | + stage.source_ptype); |
446 | 528 |
|
447 | 529 | smem_out += src.params.bitunpack.element_offset % SMEM_TILE_SIZE; |
448 | 530 |
|
|
0 commit comments