Skip to content

Commit 1e6e8a5

Browse files
committed
chore: add patches_ptr to BitunpackParams and AlpParams
Structural plumbing for per-op exception patches in the fused dynamic dispatch kernel. Adds PackedPatchesHeader and kernel helpers (patch_fl_chunk, patch_all_fl_chunks) but does not yet populate patches_ptr - all constructors initialize it to 0. Signed-off-by: Alexander Droste <alexander.droste@protonmail.com>
1 parent 4135209 commit 1e6e8a5

5 files changed

Lines changed: 131 additions & 9 deletions

File tree

vortex-cuda/kernels/src/dynamic_dispatch.cu

Lines changed: 94 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,8 @@ __device__ inline void bitunpack(const T *__restrict packed,
162162
uint64_t chunk_start,
163163
uint32_t chunk_len,
164164
const struct SourceOp &src) {
165-
constexpr uint32_t T_BITS = sizeof(T) * 8;
166-
constexpr uint32_t FL_CHUNK = 1024;
167-
constexpr uint32_t LANES = FL_CHUNK / T_BITS;
168165
const uint32_t bw = src.params.bitunpack.bit_width;
169-
const uint32_t words_per_block = LANES * bw;
166+
const uint32_t words_per_block = FL_LANES<T> * bw;
170167
const uint32_t elem_off = src.params.bitunpack.element_offset;
171168
const uint32_t dst_off = (chunk_start + elem_off) % FL_CHUNK;
172169
const uint64_t first_block = (chunk_start + elem_off) / FL_CHUNK;
@@ -177,12 +174,86 @@ __device__ inline void bitunpack(const T *__restrict packed,
177174
for (uint32_t c = 0; c < n_chunks; ++c) {
178175
const T *src_chunk = packed + (first_block + c) * words_per_block;
179176
T *chunk_dst = dst + c * FL_CHUNK;
180-
for (uint32_t lane = threadIdx.x; lane < LANES; lane += blockDim.x) {
177+
for (uint32_t lane = threadIdx.x; lane < FL_LANES<T>; lane += blockDim.x) {
181178
bit_unpack_lane<T>(src_chunk, chunk_dst, 0, lane, bw);
182179
}
183180
}
184181
}
185182

183+
// ═══════════════════════════════════════════════════════════════════════════
184+
// Patches
185+
// ═══════════════════════════════════════════════════════════════════════════
186+
187+
/// Parsed view into a packed patches buffer (the fused-dispatch counterpart
188+
/// of GPUPatches, which is used by the standalone per-bitwidth kernels).
189+
/// Each op with patches gets its own contiguous device allocation holding
190+
/// lane_offsets, indices, and values, referenced by a single uint64_t pointer
191+
/// (patches_ptr in BitunpackParams / AlpParams); see PackedPatchesHeader in
192+
/// patches.h for the layout.
193+
template <typename T>
194+
struct PackedPatchesView {
195+
const uint32_t *lane_offsets;
196+
uint32_t num_lane_offsets;
197+
const uint16_t *indices;
198+
const T *values;
199+
};
200+
201+
/// Parse a packed patches buffer into its component arrays.
202+
template <typename T>
203+
__device__ inline PackedPatchesView<T> parse_patches(uint64_t patches_ptr) {
204+
const uint8_t *base = reinterpret_cast<const uint8_t *>(patches_ptr);
205+
const auto *header = reinterpret_cast<const PackedPatchesHeader *>(base);
206+
return {
207+
reinterpret_cast<const uint32_t *>(base + sizeof(PackedPatchesHeader)),
208+
static_cast<uint32_t>((header->indices_byte_offset - sizeof(PackedPatchesHeader)) / sizeof(uint32_t)),
209+
reinterpret_cast<const uint16_t *>(base + header->indices_byte_offset),
210+
reinterpret_cast<const T *>(base + header->values_byte_offset),
211+
};
212+
}
213+
214+
/// Apply source patches for a single FL chunk.
215+
///
216+
/// Overwrites patched positions in `out` and issues __syncthreads().
217+
template <typename T>
218+
__device__ inline void patch_fl_chunk(uint64_t patches_ptr, T *__restrict out, uint32_t fl_chunk) {
219+
const auto patches = parse_patches<T>(patches_ptr);
220+
221+
for (uint32_t lane = threadIdx.x; lane < FL_LANES<T>; lane += blockDim.x) {
222+
auto slot = fl_chunk * FL_LANES<T> + lane;
223+
assert(slot + 1 < patches.num_lane_offsets);
224+
auto start = patches.lane_offsets[slot];
225+
auto end = patches.lane_offsets[slot + 1];
226+
for (auto i = start; i < end; ++i) {
227+
out[patches.indices[i]] = patches.values[i];
228+
}
229+
}
230+
__syncthreads();
231+
}
232+
233+
/// Apply source patches for all FL chunks in a contiguous region.
234+
/// Overwrites patched positions in `out` and issues __syncthreads().
235+
template <typename T>
236+
__device__ inline void
237+
patch_all_fl_chunks(uint64_t patches_ptr, T *__restrict out, uint32_t stage_len, uint32_t element_offset) {
238+
const auto patches = parse_patches<T>(patches_ptr);
239+
240+
const uint32_t first_chunk = element_offset / FL_CHUNK;
241+
const uint32_t n_chunks = (stage_len + (element_offset % FL_CHUNK) + FL_CHUNK - 1) / FL_CHUNK;
242+
for (uint32_t c = 0; c < n_chunks; ++c) {
243+
T *chunk_base = out + c * FL_CHUNK;
244+
for (uint32_t lane = threadIdx.x; lane < FL_LANES<T>; lane += blockDim.x) {
245+
auto slot = (first_chunk + c) * FL_LANES<T> + lane;
246+
assert(slot + 1 < patches.num_lane_offsets);
247+
auto start = patches.lane_offsets[slot];
248+
auto end = patches.lane_offsets[slot + 1];
249+
for (auto i = start; i < end; ++i) {
250+
chunk_base[patches.indices[i]] = patches.values[i];
251+
}
252+
}
253+
}
254+
__syncthreads();
255+
}
256+
186257
/// Read N values from a source op into `out`.
187258
///
188259
/// Dispatches on `src.op_code` to handle each encoding:
@@ -313,11 +384,17 @@ __device__ void execute_output_stage(T *__restrict output,
313384
block_start + elem_idx,
314385
chunk_len,
315386
src);
316-
constexpr uint32_t FL_CHUNK = 1024; // FastLanes chunk size
317387
const uint32_t align = (block_start + elem_idx + src.params.bitunpack.element_offset) % FL_CHUNK;
318388
smem_src = scratch + align;
319389
// Write barrier: all threads finished bitunpack, safe to read from scratch.
320390
__syncthreads();
391+
392+
// Overwrite patched positions in the decoded scratch buffer.
393+
if (src.params.bitunpack.patches_ptr != 0) {
394+
const uint32_t fl_chunk = static_cast<uint32_t>(
395+
(block_start + elem_idx + src.params.bitunpack.element_offset) / FL_CHUNK);
396+
patch_fl_chunk<T>(src.params.bitunpack.patches_ptr, scratch, fl_chunk);
397+
}
321398
} else {
322399
chunk_len = block_len;
323400
}
@@ -392,12 +469,22 @@ __device__ void execute_input_stage(const Stage &stage, char *__restrict smem) {
392469
const auto &src = stage.source;
393470

394471
if (src.op_code == SourceOp::BITUNPACK) {
472+
T *raw_smem = smem_out;
395473
bitunpack<T>(reinterpret_cast<const T *>(stage.input_ptr), smem_out, 0, stage.len, src);
396-
smem_out += src.params.bitunpack.element_offset % SMEM_TILE_SIZE;
397474
// Write barrier: cooperative bitunpack finished, safe to read
398475
// decoded elements in the scalar-op loop below.
399476
__syncthreads();
400477

478+
// Overwrite exception positions in the decoded buffer with patch values.
479+
if (src.params.bitunpack.patches_ptr != 0) {
480+
patch_all_fl_chunks<T>(src.params.bitunpack.patches_ptr,
481+
raw_smem,
482+
stage.len,
483+
src.params.bitunpack.element_offset);
484+
}
485+
486+
smem_out += src.params.bitunpack.element_offset % SMEM_TILE_SIZE;
487+
401488
if (stage.num_scalar_ops > 0) {
402489
for (uint32_t i = threadIdx.x; i < stage.len; i += blockDim.x) {
403490
T val = smem_out[i];

vortex-cuda/kernels/src/dynamic_dispatch.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#pragma once
3131

3232
#include <stdint.h>
33+
#include "patches.h"
3334

3435
/// Compact tag identifying a Vortex PType for GPU dispatch.
3536
///
@@ -108,6 +109,7 @@ union SourceParams {
108109
struct BitunpackParams {
109110
uint8_t bit_width;
110111
uint32_t element_offset; // Sub-byte offset
112+
uint64_t patches_ptr; // device pointer to packed patches buffer (0 = none)
111113
} bitunpack;
112114

113115
/// Copy from global to shared memory.
@@ -157,6 +159,7 @@ union ScalarParams {
157159
struct AlpParams {
158160
float f;
159161
float e;
162+
uint64_t patches_ptr; // device pointer to packed patches buffer (0 = none)
160163
} alp;
161164

162165
/// Dictionary gather: use current value as index into decoded values in smem.

vortex-cuda/kernels/src/fastlanes_common.cuh

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,25 @@
88
// FastLanes ordering array
99
__constant__ int FL_ORDER[] = {0, 4, 2, 6, 1, 5, 3, 7};
1010

11+
// FastLanes organises every 1024-element vector into a transposed layout
12+
// of FL_LANES columns × (1024 / FL_LANES) rows. Each column is a "lane"
13+
// that can be processed independently of every other lane, which is what
14+
// makes all FastLanes encodings (FFOR, DELTA, RLE, ALP, …) fully
15+
// data-parallel. One CUDA thread or one CPU SIMD lane handles one
16+
// FastLanes lane.
17+
//
18+
// Paper: https://ir.cwi.nl/pub/35881/35881.pdf
19+
// Repo: https://github.com/cwida/FastLanes
20+
21+
/// FastLanes chunk size in elements.
22+
constexpr uint32_t FL_CHUNK = 1024;
23+
24+
/// Number of FastLanes lanes for element type T (32 for ≤32-bit, 16 for 64-bit).
25+
template <typename T>
26+
constexpr uint32_t FL_LANES = (sizeof(T) < 8) ? 32 : 16;
27+
1128
// Compute the index in the FastLanes layout
1229
#define INDEX(row, lane) (FL_ORDER[row / 8] * 16 + (row % 8) * 128 + lane)
1330

1431
// Create a mask with 'width' bits set
15-
#define MASK(T, width) (((T)1 << width) - 1)
32+
#define MASK(T, width) (((T)1 << width) - 1)

vortex-cuda/kernels/src/patches.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,16 @@
99
extern "C" {
1010
#endif
1111

12+
/// Header at the start of a packed patches buffer.
13+
///
14+
/// Layout: [PackedPatchesHeader | lane_offsets (u32, N+1 sentinel) | indices (u16) | padding | values (V)]
15+
///
16+
/// A `patches_ptr` of 0 signals no patches.
17+
struct PackedPatchesHeader {
18+
uint32_t indices_byte_offset; // absolute byte offset from buffer start to indices
19+
uint32_t values_byte_offset; // absolute byte offset from buffer start to values
20+
};
21+
1222
/// Type tag for chunk_offsets pointer.
1323
typedef enum { CO_U8 = 0, CO_U16 = 1, CO_U32 = 2, CO_U64 = 3 } ChunkOffsetType;
1424

vortex-cuda/src/dynamic_dispatch/mod.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ impl SourceOp {
313313
bitunpack: SourceParams_BitunpackParams {
314314
bit_width,
315315
element_offset: u32::from(element_offset),
316+
patches_ptr: 0,
316317
},
317318
},
318319
}
@@ -393,7 +394,11 @@ impl ScalarOp {
393394
op_code: ScalarOp_ScalarOpCode_ALP,
394395
output_ptype: PTypeTag_PTYPE_F32,
395396
params: ScalarParams {
396-
alp: ScalarParams_AlpParams { f, e },
397+
alp: ScalarParams_AlpParams {
398+
f,
399+
e,
400+
patches_ptr: 0,
401+
},
397402
},
398403
}
399404
}

0 commit comments

Comments
 (0)