55
66#include " patches.h"
77
8+ // / Load a chunk offset value, dispatching on the runtime type.
9+ __device__ inline uint32_t load_chunk_offset (const GPUPatches &patches, uint32_t idx) {
10+ switch (patches.chunk_offset_type ) {
11+ case CO_U8:
12+ return reinterpret_cast <const uint8_t *>(patches.chunk_offsets )[idx];
13+ case CO_U16:
14+ return reinterpret_cast <const uint16_t *>(patches.chunk_offsets )[idx];
15+ case CO_U32:
16+ return reinterpret_cast <const uint32_t *>(patches.chunk_offsets )[idx];
17+ case CO_U64:
18+ return static_cast <uint32_t >(reinterpret_cast <const uint64_t *>(patches.chunk_offsets )[idx]);
19+ }
20+ return 0 ;
21+ }
22+
823// / A single patch: a within-chunk index and its replacement value.
924// / A sentinel patch has index == 1024, which can never match a valid
1025// / within-chunk position (0–1023).
@@ -14,54 +29,87 @@ struct Patch {
1429 T value;
1530};
1631
17- // / Cursor for iterating over a single lane's patches within a chunk.
32+ // / Cursor for iterating over a thread's portion of patches within a chunk.
1833// /
19- // / Usage in the generated merge-loop:
34+ // / Patches are divided evenly among threads. Each thread applies its patches
35+ // / to shared memory, then all threads sync and copy to global memory.
36+ // /
37+ // / Usage in the generated kernel:
2038// /
2139// / PatchesCursor<uint32_t> cursor(patches, blockIdx.x, thread_idx, 32);
2240// / auto patch = cursor.next();
23- // / for (int i = 0; i < 32; i++) {
24- // / auto idx = i * 32 + thread_idx;
25- // / if (idx == patch.index) {
26- // / out[idx] = patch.value;
27- // / patch = cursor.next();
28- // / } else {
29- // / out[idx] = shared_out[idx];
30- // / }
41+ // / while (patch.index != 1024) {
42+ // / shared_out[patch.index] = patch.value;
43+ // / patch = cursor.next();
3144// / }
3245template <typename T>
3346class PatchesCursor {
3447public:
35- // / Construct a cursor positioned at the patches for the given ( chunk, lane) .
36- // / n_lanes is a compile-time constant emitted by the code generator (16 or 32).
37- __device__ PatchesCursor (const GPUPatches &patches, uint32_t chunk, uint32_t lane , uint32_t n_lanes ) {
38- if (patches.lane_offsets == nullptr ) {
48+ // / Construct a cursor for this thread's portion of patches in the chunk.
49+ __device__
50+ PatchesCursor (const GPUPatches &patches, uint32_t chunk, uint32_t thread_idx , uint32_t n_threads ) {
51+ if (patches.chunk_offsets == nullptr ) {
3952 indices = nullptr ;
4053 values = nullptr ;
4154 remaining = 0 ;
4255 return ;
4356 }
44- auto slot = chunk * n_lanes + lane;
45- auto start = patches.lane_offsets [slot];
46- remaining = patches.lane_offsets [slot + 1 ] - start;
57+
58+ // mirrors the logic from vortex-array/src/arrays/primitive/array/patch.rs
59+
60+ // Compute base_offset from the first chunk offset.
61+ uint32_t base_offset = load_chunk_offset (patches, 0 );
62+
63+ uint32_t patches_start_idx = load_chunk_offset (patches, chunk) - base_offset;
64+ patches_start_idx -= min (patches_start_idx, patches.offset_within_chunk );
65+
66+ // calculate the ending index.
67+ uint32_t patches_end_idx;
68+ if ((chunk + 1 ) < patches.n_chunks ) {
69+ patches_end_idx = load_chunk_offset (patches, chunk + 1 ) - base_offset;
70+ // if this is the end of times, we should drop it out here...
71+ patches_end_idx -= min (patches_end_idx, patches.offset_within_chunk );
72+ } else {
73+ patches_end_idx = patches.num_patches ;
74+ }
75+
76+ // calculate how many patches are in the chunk
77+ uint32_t num_patches = patches_end_idx - patches_start_idx;
78+
79+ // Divide patches among threads (ceil division)
80+ uint32_t patches_per_thread = (num_patches + n_threads - 1 ) / n_threads;
81+ uint32_t my_start = min (thread_idx * patches_per_thread, num_patches);
82+ uint32_t my_end = min ((thread_idx + 1 ) * patches_per_thread, num_patches);
83+
84+ uint32_t start = patches_start_idx + my_start;
85+ remaining = my_end - my_start;
4786 indices = patches.indices + start;
4887 values = reinterpret_cast <const T *>(patches.values ) + start;
88+
89+ // The iterator returns indices relative to the start of the chunk.
90+ // `chunk_base` is the index of the first element within a chunk, accounting
91+ // for the slice offset.
92+ chunk_base = chunk * 1024 + patches.offset ;
93+ chunk_base -= min (chunk_base, patches.offset % 1024 );
4994 }
5095
51- // / Return the current patch and advance, or a sentinel {1024, 0} if exhausted.
96+ // / Return the current patch (with within-chunk index) and advance,
97+ // / or a sentinel {1024, 0} if exhausted.
5298 __device__ Patch<T> next () {
5399 if (remaining == 0 ) {
54100 return {1024 , T {}};
55101 }
56- Patch<T> patch = {*indices, *values};
102+ uint16_t within_chunk = static_cast <uint16_t >(*indices - chunk_base);
103+ Patch<T> patch = {within_chunk, *values};
57104 indices++;
58105 values++;
59106 remaining--;
60107 return patch;
61108 }
62109
63110private:
64- const uint16_t *indices;
111+ const uint32_t *indices;
65112 const T *values;
66113 uint8_t remaining;
67- };
114+ uint32_t chunk_base;
115+ };
0 commit comments