Skip to content

Commit 56b0731

Browse files
authored
Revert #7603 (#7613)
This reverts commit 09ae129. <!-- Thank you for submitting a pull request! We appreciate your time and effort. Please make sure to provide enough information so that we can review your pull request. The Summary and Testing sections below contain guidance on what to include. --> ## Summary <!-- If this PR is related to a tracked effort, please link to the relevant issue here (e.g., `Closes: #123`). Otherwise, feel free to ignore / delete this. In this section, please: 1. Explain the rationale for this change. 2. Summarize the changes included in this PR. A general rule of thumb is that larger PRs should have larger summaries. If there are a lot of changes, please help us review the code by explaining what was changed and why. If there is an issue or discussion attached, there is no need to duplicate all the details, but clarity is always preferred over brevity. --> Closes: #000 <!-- ## API Changes Uncomment this section if there are any user-facing changes. Consider whether the change affects users in one of the following ways: 1. Breaks public APIs in some way. 2. Changes the underlying behavior of one of the engine integrations. 3. Should some documentation be updated to reflect this change? If a public API is changed in a breaking manner, make sure to add the appropriate label. You can run `./scripts/public-api.sh` locally to see if there are any public API changes (and this also runs in our CI). --> ## Testing <!-- Please describe how this change was tested. Here are some common categories for testing in Vortex: 1. Verifying existing behavior is maintained. 2. Verifying new behavior and functionality works correctly. 3. Serialization compatibility (backwards and forwards) should be maintained or explicitly broken. -->
1 parent 09ae129 commit 56b0731

4 files changed

Lines changed: 133 additions & 387 deletions

File tree

vortex-cuda/kernels/src/dynamic_dispatch.cu

Lines changed: 16 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,9 @@
3939
//
4040
// ## Mixed-width support
4141
//
42-
// Dict codes, RunEnd ends, and other child arrays may have a narrower
43-
// element type than the output T. Two mechanisms handle this:
44-
//
45-
// LOAD load_element() dispatches on the per-stage PTypeTag to
46-
// read at the source's native width and static_cast to T.
47-
// BITUNPACK bitunpack_typed() unpacks at the source's native width,
48-
// then widens to T in-place via a backward scan
49-
// (widen_inplace). The smem region is pre-allocated at
50-
// max(source_width, T) bytes per element by the Rust plan
51-
// builder, so the widen never overflows.
42+
// LOAD sources from pending subtrees may have a narrower type than the
43+
// output (e.g. u8 dict codes in a u32 plan). load_element() widens
44+
// to T via static_cast — no separate widen kernel or smem intermediate.
5245

5346
#include <assert.h>
5447
#include <cuda.h>
@@ -210,22 +203,6 @@ scatter_patches_chunk(const GPUPatches &patches, T *__restrict out, uint32_t chu
210203
// Source ops
211204
// ═══════════════════════════════════════════════════════════════════════════
212205

213-
/// Widen U-sized elements in shared memory to T-sized, in-place.
214-
/// Backward scan ensures no unread element is overwritten since
215-
/// sizeof(T) >= sizeof(U) guarantees the write at index i touches
216-
/// only bytes beyond those of src[i].
217-
template <typename T, typename U>
218-
__device__ inline void widen_inplace(T *dst, uint32_t len) {
219-
if constexpr (sizeof(T) <= sizeof(U)) {
220-
return;
221-
}
222-
const U *src = reinterpret_cast<const U *>(dst);
223-
for (int32_t i = static_cast<int32_t>(len) - 1 - threadIdx.x; i >= 0; i -= blockDim.x) {
224-
dst[i] = static_cast<T>(src[i]);
225-
}
226-
__syncthreads();
227-
}
228-
229206
/// FastLanes cooperative unpack — all threads in the block scatter-write
230207
/// decoded elements into `dst`. Caller must issue __syncthreads() before
231208
/// any thread reads from `dst`.
@@ -259,68 +236,6 @@ __device__ inline void bitunpack(const T *__restrict packed,
259236
}
260237
}
261238

262-
/// Dispatch bitunpack at the source's native element width, then widen
263-
/// to T in-place so all downstream scalar ops and smem consumers see
264-
/// T-sized elements. Falls back to the direct `bitunpack<T>` path when
265-
/// the source ptype already matches T. Issues __syncthreads() before
266-
/// returning on all paths.
267-
///
268-
/// Accepts explicit chunk_start / chunk_len so it works for both input
269-
/// stages (full decode with chunk_start=0, chunk_len=stage.len) and
270-
/// the output stage (tiled with varying chunk_start / chunk_len).
271-
template <typename T>
272-
__device__ inline void bitunpack_typed(T *__restrict dst,
273-
const void *__restrict packed,
274-
uint64_t chunk_start,
275-
uint32_t chunk_len,
276-
const struct SourceOp &src,
277-
PTypeTag source_ptype) {
278-
// Fast path: source width matches T — no widening needed.
279-
if (ptype_byte_width(source_ptype) == sizeof(T)) {
280-
bitunpack<T>(reinterpret_cast<const T *>(packed), dst, chunk_start, chunk_len, src);
281-
__syncthreads();
282-
return;
283-
}
284-
285-
// Compute total elements written by bitunpack (including alignment
286-
// padding) so widen_inplace covers the full scratch region.
287-
const uint32_t elem_off = src.params.bitunpack.element_offset;
288-
const uint32_t dst_off = (chunk_start + elem_off) % FL_CHUNK;
289-
const uint32_t n_chunks = (chunk_len + dst_off + FL_CHUNK - 1) / FL_CHUNK;
290-
const uint32_t total_elems = n_chunks * FL_CHUNK;
291-
292-
// Narrow source: unpack at native width, then widen to T.
293-
switch (source_ptype) {
294-
case PTYPE_U8:
295-
case PTYPE_I8: {
296-
auto *narrow = reinterpret_cast<uint8_t *>(dst);
297-
bitunpack<uint8_t>(reinterpret_cast<const uint8_t *>(packed), narrow, chunk_start, chunk_len, src);
298-
__syncthreads();
299-
widen_inplace<T, uint8_t>(dst, total_elems);
300-
break;
301-
}
302-
case PTYPE_U16:
303-
case PTYPE_I16: {
304-
auto *narrow = reinterpret_cast<uint16_t *>(dst);
305-
bitunpack<uint16_t>(reinterpret_cast<const uint16_t *>(packed), narrow, chunk_start, chunk_len, src);
306-
__syncthreads();
307-
widen_inplace<T, uint16_t>(dst, total_elems);
308-
break;
309-
}
310-
case PTYPE_U32:
311-
case PTYPE_I32:
312-
case PTYPE_F32: {
313-
auto *narrow = reinterpret_cast<uint32_t *>(dst);
314-
bitunpack<uint32_t>(reinterpret_cast<const uint32_t *>(packed), narrow, chunk_start, chunk_len, src);
315-
__syncthreads();
316-
widen_inplace<T, uint32_t>(dst, total_elems);
317-
break;
318-
}
319-
default:
320-
__builtin_unreachable();
321-
}
322-
}
323-
324239
/// Read N values from a source op into `out`.
325240
///
326241
/// Dispatches on `src.op_code` to handle each encoding:
@@ -439,14 +354,16 @@ __device__ void execute_output_stage(T *__restrict output,
439354
if (src.op_code == SourceOp::BITUNPACK) {
440355
chunk_len = bitunpack_tile_len(stage, block_len, elem_idx);
441356
T *scratch = reinterpret_cast<T *>(smem + stage.smem_byte_offset);
442-
bitunpack_typed<T>(scratch,
443-
reinterpret_cast<const void *>(stage.input_ptr),
444-
block_start + elem_idx,
445-
chunk_len,
446-
src,
447-
ptype);
357+
bitunpack<T>(reinterpret_cast<const T *>(stage.input_ptr),
358+
scratch,
359+
block_start + elem_idx,
360+
chunk_len,
361+
src);
448362
const uint32_t align = (block_start + elem_idx + src.params.bitunpack.element_offset) % FL_CHUNK;
449363
smem_src = scratch + align;
364+
// Write barrier: all threads finished bitunpack (and any
365+
// patches), safe to read from scratch.
366+
__syncthreads();
450367
} else {
451368
chunk_len = block_len;
452369
}
@@ -521,12 +438,11 @@ __device__ void execute_input_stage(const Stage &stage, char *__restrict smem) {
521438
const auto &src = stage.source;
522439

523440
if (src.op_code == SourceOp::BITUNPACK) {
524-
bitunpack_typed<T>(smem_out,
525-
reinterpret_cast<const void *>(stage.input_ptr),
526-
0,
527-
stage.len,
528-
src,
529-
stage.source_ptype);
441+
T *raw_smem = smem_out;
442+
bitunpack<T>(reinterpret_cast<const T *>(stage.input_ptr), smem_out, 0, stage.len, src);
443+
// Write barrier: cooperative bitunpack finished, safe to read
444+
// decoded elements below.
445+
__syncthreads();
530446

531447
smem_out += src.params.bitunpack.element_offset % SMEM_TILE_SIZE;
532448

vortex-cuda/kernels/src/dynamic_dispatch.h

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -78,27 +78,6 @@ PTYPE_HOST_DEVICE constexpr PTypeTag ptype_to_unsigned(PTypeTag tag) {
7878
return tag;
7979
}
8080
}
81-
82-
PTYPE_HOST_DEVICE constexpr uint8_t ptype_byte_width(PTypeTag tag) {
83-
switch (tag) {
84-
case PTYPE_U8:
85-
case PTYPE_I8:
86-
return 1;
87-
case PTYPE_U16:
88-
case PTYPE_I16:
89-
return 2;
90-
case PTYPE_U32:
91-
case PTYPE_I32:
92-
case PTYPE_F32:
93-
return 4;
94-
case PTYPE_U64:
95-
case PTYPE_I64:
96-
case PTYPE_F64:
97-
return 8;
98-
default:
99-
return 0;
100-
}
101-
}
10281
#endif
10382

10483
/// Number of threads per CUDA block.

0 commit comments

Comments
 (0)