Skip to content

Commit bd93340

Browse files
authored
GPU kernel for sorted patches with chunk_offsets (#7440)
1 parent ad087b6 commit bd93340

File tree

14 files changed

+409
-449
lines changed

14 files changed

+409
-449
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vortex-cuda/Cargo.toml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ cudarc = { workspace = true, features = ["f16"] }
3030
futures = { workspace = true, features = ["executor"] }
3131
itertools = { workspace = true }
3232
kanal = { workspace = true }
33+
num-traits = { workspace = true }
3334
object_store = { workspace = true, features = ["fs"] }
3435
parking_lot = { workspace = true }
3536
prost = { workspace = true }
@@ -89,7 +90,3 @@ harness = false
8990
[[bench]]
9091
name = "throughput_cuda"
9192
harness = false
92-
93-
[[bench]]
94-
name = "transpose_patches"
95-
harness = false

vortex-cuda/benches/transpose_patches.rs

Lines changed: 0 additions & 81 deletions
This file was deleted.

vortex-cuda/kernels/src/bit_unpack_16.cu

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,28 @@
55
template <int BW>
66
__device__ void _bit_unpack_16_device(const uint16_t *__restrict in, uint16_t *__restrict out, uint16_t reference, int thread_idx, GPUPatches& patches) {
77
__shared__ uint16_t shared_out[1024];
8+
9+
// Step 1: Unpack into shared memory
810
#pragma unroll
911
for (int i = 0; i < 2; i++) {
1012
_bit_unpack_16_lane<BW>(in, shared_out, reference, thread_idx * 2 + i);
1113
}
1214
__syncwarp();
15+
16+
// Step 2: Apply patches to shared memory in parallel
1317
PatchesCursor<uint16_t> cursor(patches, blockIdx.x, thread_idx, 32);
1418
auto patch = cursor.next();
19+
while (patch.index != 1024) {
20+
shared_out[patch.index] = patch.value;
21+
patch = cursor.next();
22+
}
23+
__syncwarp();
24+
25+
// Step 3: Copy to global memory
26+
#pragma unroll
1527
for (int i = 0; i < 32; i++) {
1628
auto idx = i * 32 + thread_idx;
17-
if (idx == patch.index) {
18-
out[idx] = patch.value;
19-
patch = cursor.next();
20-
} else {
21-
out[idx] = shared_out[idx];
22-
}
29+
out[idx] = shared_out[idx];
2330
}
2431
}
2532

vortex-cuda/kernels/src/bit_unpack_32.cu

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,28 @@
55
template <int BW>
66
__device__ void _bit_unpack_32_device(const uint32_t *__restrict in, uint32_t *__restrict out, uint32_t reference, int thread_idx, GPUPatches& patches) {
77
__shared__ uint32_t shared_out[1024];
8+
9+
// Step 1: Unpack into shared memory
810
#pragma unroll
911
for (int i = 0; i < 1; i++) {
1012
_bit_unpack_32_lane<BW>(in, shared_out, reference, thread_idx * 1 + i);
1113
}
1214
__syncwarp();
15+
16+
// Step 2: Apply patches to shared memory in parallel
1317
PatchesCursor<uint32_t> cursor(patches, blockIdx.x, thread_idx, 32);
1418
auto patch = cursor.next();
19+
while (patch.index != 1024) {
20+
shared_out[patch.index] = patch.value;
21+
patch = cursor.next();
22+
}
23+
__syncwarp();
24+
25+
// Step 3: Copy to global memory
26+
#pragma unroll
1527
for (int i = 0; i < 32; i++) {
1628
auto idx = i * 32 + thread_idx;
17-
if (idx == patch.index) {
18-
out[idx] = patch.value;
19-
patch = cursor.next();
20-
} else {
21-
out[idx] = shared_out[idx];
22-
}
29+
out[idx] = shared_out[idx];
2330
}
2431
}
2532

vortex-cuda/kernels/src/bit_unpack_64.cu

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,28 @@
55
template <int BW>
66
__device__ void _bit_unpack_64_device(const uint64_t *__restrict in, uint64_t *__restrict out, uint64_t reference, int thread_idx, GPUPatches& patches) {
77
__shared__ uint64_t shared_out[1024];
8+
9+
// Step 1: Unpack into shared memory
810
#pragma unroll
911
for (int i = 0; i < 1; i++) {
1012
_bit_unpack_64_lane<BW>(in, shared_out, reference, thread_idx * 1 + i);
1113
}
1214
__syncwarp();
15+
16+
// Step 2: Apply patches to shared memory in parallel
1317
PatchesCursor<uint64_t> cursor(patches, blockIdx.x, thread_idx, 16);
1418
auto patch = cursor.next();
19+
while (patch.index != 1024) {
20+
shared_out[patch.index] = patch.value;
21+
patch = cursor.next();
22+
}
23+
__syncwarp();
24+
25+
// Step 3: Copy to global memory
26+
#pragma unroll
1527
for (int i = 0; i < 64; i++) {
1628
auto idx = i * 16 + thread_idx;
17-
if (idx == patch.index) {
18-
out[idx] = patch.value;
19-
patch = cursor.next();
20-
} else {
21-
out[idx] = shared_out[idx];
22-
}
29+
out[idx] = shared_out[idx];
2330
}
2431
}
2532

vortex-cuda/kernels/src/bit_unpack_8.cu

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,28 @@
55
template <int BW>
66
__device__ void _bit_unpack_8_device(const uint8_t *__restrict in, uint8_t *__restrict out, uint8_t reference, int thread_idx, GPUPatches& patches) {
77
__shared__ uint8_t shared_out[1024];
8+
9+
// Step 1: Unpack into shared memory
810
#pragma unroll
911
for (int i = 0; i < 4; i++) {
1012
_bit_unpack_8_lane<BW>(in, shared_out, reference, thread_idx * 4 + i);
1113
}
1214
__syncwarp();
15+
16+
// Step 2: Apply patches to shared memory in parallel
1317
PatchesCursor<uint8_t> cursor(patches, blockIdx.x, thread_idx, 32);
1418
auto patch = cursor.next();
19+
while (patch.index != 1024) {
20+
shared_out[patch.index] = patch.value;
21+
patch = cursor.next();
22+
}
23+
__syncwarp();
24+
25+
// Step 3: Copy to global memory
26+
#pragma unroll
1527
for (int i = 0; i < 32; i++) {
1628
auto idx = i * 32 + thread_idx;
17-
if (idx == patch.index) {
18-
out[idx] = patch.value;
19-
patch = cursor.next();
20-
} else {
21-
out[idx] = shared_out[idx];
22-
}
29+
out[idx] = shared_out[idx];
2330
}
2431
}
2532

vortex-cuda/kernels/src/patches.cuh

Lines changed: 69 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,21 @@
55

66
#include "patches.h"
77

8+
/// Load a chunk offset value, dispatching on the runtime type.
9+
__device__ inline uint32_t load_chunk_offset(const GPUPatches &patches, uint32_t idx) {
10+
switch (patches.chunk_offset_type) {
11+
case CO_U8:
12+
return reinterpret_cast<const uint8_t *>(patches.chunk_offsets)[idx];
13+
case CO_U16:
14+
return reinterpret_cast<const uint16_t *>(patches.chunk_offsets)[idx];
15+
case CO_U32:
16+
return reinterpret_cast<const uint32_t *>(patches.chunk_offsets)[idx];
17+
case CO_U64:
18+
return static_cast<uint32_t>(reinterpret_cast<const uint64_t *>(patches.chunk_offsets)[idx]);
19+
}
20+
return 0;
21+
}
22+
823
/// A single patch: a within-chunk index and its replacement value.
924
/// A sentinel patch has index == 1024, which can never match a valid
1025
/// within-chunk position (0–1023).
@@ -14,54 +29,87 @@ struct Patch {
1429
T value;
1530
};
1631

17-
/// Cursor for iterating over a single lane's patches within a chunk.
32+
/// Cursor for iterating over a thread's portion of patches within a chunk.
1833
///
19-
/// Usage in the generated merge-loop:
34+
/// Patches are divided evenly among threads. Each thread applies its patches
35+
/// to shared memory, then all threads sync and copy to global memory.
36+
///
37+
/// Usage in the generated kernel:
2038
///
2139
/// PatchesCursor<uint32_t> cursor(patches, blockIdx.x, thread_idx, 32);
2240
/// auto patch = cursor.next();
23-
/// for (int i = 0; i < 32; i++) {
24-
/// auto idx = i * 32 + thread_idx;
25-
/// if (idx == patch.index) {
26-
/// out[idx] = patch.value;
27-
/// patch = cursor.next();
28-
/// } else {
29-
/// out[idx] = shared_out[idx];
30-
/// }
41+
/// while (patch.index != 1024) {
42+
/// shared_out[patch.index] = patch.value;
43+
/// patch = cursor.next();
3144
/// }
3245
template <typename T>
3346
class PatchesCursor {
3447
public:
35-
/// Construct a cursor positioned at the patches for the given (chunk, lane).
36-
/// n_lanes is a compile-time constant emitted by the code generator (16 or 32).
37-
__device__ PatchesCursor(const GPUPatches &patches, uint32_t chunk, uint32_t lane, uint32_t n_lanes) {
38-
if (patches.lane_offsets == nullptr) {
48+
/// Construct a cursor for this thread's portion of patches in the chunk.
49+
__device__
50+
PatchesCursor(const GPUPatches &patches, uint32_t chunk, uint32_t thread_idx, uint32_t n_threads) {
51+
if (patches.chunk_offsets == nullptr) {
3952
indices = nullptr;
4053
values = nullptr;
4154
remaining = 0;
4255
return;
4356
}
44-
auto slot = chunk * n_lanes + lane;
45-
auto start = patches.lane_offsets[slot];
46-
remaining = patches.lane_offsets[slot + 1] - start;
57+
58+
// mirrors the logic from vortex-array/src/arrays/primitive/array/patch.rs
59+
60+
// Compute base_offset from the first chunk offset.
61+
uint32_t base_offset = load_chunk_offset(patches, 0);
62+
63+
uint32_t patches_start_idx = load_chunk_offset(patches, chunk) - base_offset;
64+
patches_start_idx -= min(patches_start_idx, patches.offset_within_chunk);
65+
66+
// calculate the ending index.
67+
uint32_t patches_end_idx;
68+
if ((chunk + 1) < patches.n_chunks) {
69+
patches_end_idx = load_chunk_offset(patches, chunk + 1) - base_offset;
70+
// if this is the end of times, we should drop it out here...
71+
patches_end_idx -= min(patches_end_idx, patches.offset_within_chunk);
72+
} else {
73+
patches_end_idx = patches.num_patches;
74+
}
75+
76+
// calculate how many patches are in the chunk
77+
uint32_t num_patches = patches_end_idx - patches_start_idx;
78+
79+
// Divide patches among threads (ceil division)
80+
uint32_t patches_per_thread = (num_patches + n_threads - 1) / n_threads;
81+
uint32_t my_start = min(thread_idx * patches_per_thread, num_patches);
82+
uint32_t my_end = min((thread_idx + 1) * patches_per_thread, num_patches);
83+
84+
uint32_t start = patches_start_idx + my_start;
85+
remaining = my_end - my_start;
4786
indices = patches.indices + start;
4887
values = reinterpret_cast<const T *>(patches.values) + start;
88+
89+
// The iterator returns indices relative to the start of the chunk.
90+
// `chunk_base` is the index of the first element within a chunk, accounting
91+
// for the slice offset.
92+
chunk_base = chunk * 1024 + patches.offset;
93+
chunk_base -= min(chunk_base, patches.offset % 1024);
4994
}
5095

51-
/// Return the current patch and advance, or a sentinel {1024, 0} if exhausted.
96+
/// Return the current patch (with within-chunk index) and advance,
97+
/// or a sentinel {1024, 0} if exhausted.
5298
__device__ Patch<T> next() {
5399
if (remaining == 0) {
54100
return {1024, T {}};
55101
}
56-
Patch<T> patch = {*indices, *values};
102+
uint16_t within_chunk = static_cast<uint16_t>(*indices - chunk_base);
103+
Patch<T> patch = {within_chunk, *values};
57104
indices++;
58105
values++;
59106
remaining--;
60107
return patch;
61108
}
62109

63110
private:
64-
const uint16_t *indices;
111+
const uint32_t *indices;
65112
const T *values;
66113
uint8_t remaining;
67-
};
114+
uint32_t chunk_base;
115+
};

0 commit comments

Comments
 (0)