Skip to content

Commit 1b97331

Browse files
committed
refactor: merge upload_patches, uniform match pattern for source/scalar patches
Signed-off-by: Alexander Droste <alexander.droste@protonmail.com>
1 parent 40df862 commit 1b97331

3 files changed

Lines changed: 40 additions & 74 deletions

File tree

vortex-cuda/src/dynamic_dispatch/plan_builder.rs

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,7 @@ use super::ptype_to_tag;
4848
use super::tag_to_ptype;
4949
use crate::CudaBufferExt;
5050
use crate::CudaExecutionCtx;
51-
use crate::kernel::DevicePatches;
52-
use crate::kernel::load_patches_sync;
53-
use crate::kernel::upload_gpu_patches;
51+
use crate::kernel::upload_patches;
5452

5553
/// A plan whose source buffers have been copied to the device, ready for kernel launch.
5654
pub struct MaterializedPlan {
@@ -386,37 +384,26 @@ impl FusedPlan {
386384
for (stage, smem_byte_offset, len) in &self.stages {
387385
let mut source = stage.source;
388386

389-
// Upload BitPacked patches as a GPUPatches struct if present.
387+
// Upload source patches (e.g. BitPacked exceptions).
390388
if let Some(patches) = &stage.source_patches {
391-
let device_patches = load_patches_sync(patches, ctx)?;
392-
let (gpu_buf, ptr) = upload_gpu_patches(&device_patches, ctx)?;
393-
source.params.bitunpack.patches_ptr = ptr;
394-
// Keep the underlying data buffers and the GPUPatches struct alive.
395-
let DevicePatches {
396-
chunk_offsets,
397-
indices,
398-
values,
399-
..
400-
} = device_patches;
401-
device_buffers.extend([chunk_offsets, indices, values, gpu_buf]);
389+
let (ptr, bufs) = upload_patches(patches, ctx)?;
390+
match source.op_code {
391+
SourceOp_SourceOpCode_BITUNPACK => source.params.bitunpack.patches_ptr = ptr,
392+
_ => unreachable!("patches on unsupported source op"),
393+
}
394+
device_buffers.extend(bufs);
402395
}
403396

404397
// Upload patches for each scalar op that carries them.
405398
let mut scalar_ops: Vec<ScalarOp> = Vec::with_capacity(stage.scalar_ops.len());
406399
for (mut op, patches) in stage.scalar_ops.clone() {
407400
if let Some(patches) = &patches {
408-
let device_patches = load_patches_sync(patches, ctx)?;
409-
let (gpu_buf, ptr) = upload_gpu_patches(&device_patches, ctx)?;
410-
if op.op_code == ScalarOp_ScalarOpCode_ALP {
411-
op.params.alp.patches_ptr = ptr;
401+
let (ptr, bufs) = upload_patches(&patches, ctx)?;
402+
match op.op_code {
403+
ScalarOp_ScalarOpCode_ALP => op.params.alp.patches_ptr = ptr,
404+
_ => unreachable!("patches on unsupported scalar op"),
412405
}
413-
let DevicePatches {
414-
chunk_offsets,
415-
indices,
416-
values,
417-
..
418-
} = device_patches;
419-
device_buffers.extend([chunk_offsets, indices, values, gpu_buf]);
406+
device_buffers.extend(bufs);
420407
}
421408
scalar_ops.push(op);
422409
}

vortex-cuda/src/kernel/mod.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,7 @@ pub use encodings::ZstdKernelPrep;
3434
pub use encodings::zstd_kernel_prepare;
3535
pub(crate) use encodings::*;
3636
pub(crate) use filter::FilterExecutor;
37-
pub(crate) use patches::types::DevicePatches;
38-
pub(crate) use patches::types::load_patches_sync;
39-
pub(crate) use patches::types::upload_gpu_patches;
37+
pub(crate) use patches::types::upload_patches;
4038
pub(crate) use slice::SliceExecutor;
4139

4240
use crate::CudaKernelEvents;

vortex-cuda/src/kernel/patches/types.rs

Lines changed: 26 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -137,15 +137,19 @@ pub(crate) fn ptype_to_chunk_offset_type(ptype: PType) -> VortexResult<ChunkOffs
137137
///
138138
/// Canonicalization is done on the CPU via [`LEGACY_SESSION`], then the
139139
/// resulting host buffers are uploaded to the device.
140+
/// Canonicalize patches on the CPU, upload data buffers and a [`GPUPatches`]
141+
/// struct to the device in one step. Returns the device pointer to the
142+
/// `GPUPatches` struct and a vec of buffer handles that must be kept alive
143+
/// for the duration of the kernel launch.
140144
///
141145
/// # Errors
142146
///
143147
/// If the patches do not have `chunk_offsets`.
144148
#[allow(deprecated)]
145-
pub(crate) fn load_patches_sync(
149+
pub(crate) fn upload_patches(
146150
patches: &Patches,
147151
ctx: &CudaExecutionCtx,
148-
) -> VortexResult<DevicePatches> {
152+
) -> VortexResult<(u64, Vec<BufferHandle>)> {
149153
let offset = patches.offset();
150154
let offset_within_chunk = patches.offset_within_chunk().unwrap_or_default();
151155

@@ -159,6 +163,7 @@ pub(crate) fn load_patches_sync(
159163
// Canonicalize chunk_offsets on the CPU
160164
let co_canonical = co.clone().execute::<PrimitiveArray>(&mut exec_ctx)?;
161165
let chunk_offset_ptype = co_canonical.ptype();
166+
let n_chunks = co_canonical.len();
162167
let chunk_offsets = co_canonical.buffer_handle().clone();
163168

164169
// Canonicalize indices and convert to u32
@@ -187,52 +192,29 @@ pub(crate) fn load_patches_sync(
187192
.execute::<PrimitiveArray>(&mut exec_ctx)?;
188193
let values = values_prim.buffer_handle().clone();
189194

190-
// Upload all buffers to the device
195+
// Upload data buffers to the device
191196
let chunk_offsets = ctx.ensure_on_device_sync(chunk_offsets)?;
192197
let indices = ctx.ensure_on_device_sync(indices)?;
193198
let values = ctx.ensure_on_device_sync(values)?;
194199

195-
let num_patches = patches.num_patches();
196-
// n_chunks must match the chunk_offsets array length, not array_len / 1024.
197-
// When patches are sliced, chunk_offsets is sliced to only include chunks
198-
// overlapping the slice range — matching the CPU's patch_chunk which uses
199-
// chunk_offsets_slice.len().
200-
let n_chunks = co_canonical.len();
201-
202-
Ok(DevicePatches {
203-
chunk_offsets,
204-
chunk_offset_ptype,
205-
indices,
206-
values,
207-
offset,
208-
offset_within_chunk,
209-
num_patches,
210-
n_chunks,
211-
})
212-
}
213-
214-
/// Upload a [`GPUPatches`] struct to the device, returning the buffer handle
215-
/// (which must be kept alive) and the device pointer to the struct.
216-
///
217-
/// The caller must also keep the [`DevicePatches`] alive for the duration of
218-
/// the kernel launch, since the `GPUPatches` struct contains device pointers
219-
/// into the individual buffers owned by `DevicePatches`.
220-
pub(crate) fn upload_gpu_patches(
221-
device_patches: &DevicePatches,
222-
ctx: &CudaExecutionCtx,
223-
) -> VortexResult<(BufferHandle, u64)> {
200+
// Build the GPUPatches C struct from device pointers.
224201
// Zero-initialize to avoid uninitialized padding bytes.
225202
let mut gpu_patches: GPUPatches = unsafe { std::mem::zeroed() };
226-
gpu_patches.chunk_offsets = device_patches.chunk_offsets.cuda_device_ptr()? as _;
227-
gpu_patches.chunk_offset_type = ptype_to_chunk_offset_type(device_patches.chunk_offset_ptype)?;
228-
gpu_patches.indices = device_patches.indices.cuda_device_ptr()? as _;
229-
gpu_patches.values = device_patches.values.cuda_device_ptr()? as _;
203+
gpu_patches.chunk_offsets = chunk_offsets.cuda_device_ptr()? as _;
204+
gpu_patches.chunk_offset_type = ptype_to_chunk_offset_type(chunk_offset_ptype)?;
205+
gpu_patches.indices = indices.cuda_device_ptr()? as _;
206+
gpu_patches.values = values.cuda_device_ptr()? as _;
207+
let num_patches = patches.num_patches();
230208
#[expect(clippy::cast_possible_truncation)]
231209
{
232-
gpu_patches.offset = device_patches.offset as u32;
233-
gpu_patches.offset_within_chunk = device_patches.offset_within_chunk as u32;
234-
gpu_patches.num_patches = device_patches.num_patches as u32;
235-
gpu_patches.n_chunks = device_patches.n_chunks as u32;
210+
gpu_patches.offset = offset as u32;
211+
gpu_patches.offset_within_chunk = offset_within_chunk as u32;
212+
gpu_patches.num_patches = num_patches as u32;
213+
// n_chunks must match the chunk_offsets array length, not array_len / 1024.
214+
// When patches are sliced, chunk_offsets is sliced to only include chunks
215+
// overlapping the slice range — matching the CPU's patch_chunk which uses
216+
// chunk_offsets_slice.len().
217+
gpu_patches.n_chunks = n_chunks as u32;
236218
}
237219

238220
// Serialize the repr(C) struct to bytes and upload to the device.
@@ -242,14 +224,13 @@ pub(crate) fn upload_gpu_patches(
242224
size_of::<GPUPatches>(),
243225
)
244226
};
245-
246227
let mut buf =
247228
ByteBufferMut::with_capacity_aligned(size_of::<GPUPatches>(), Alignment::of::<u64>());
248229
buf.extend_from_slice(bytes);
249-
let host_buf = BufferHandle::new_host(buf.freeze());
250-
let device_buf = ctx.ensure_on_device_sync(host_buf)?;
251-
let ptr = device_buf.cuda_device_ptr()?;
252-
Ok((device_buf, ptr))
230+
let gpu_buf = ctx.ensure_on_device_sync(BufferHandle::new_host(buf.freeze()))?;
231+
let ptr = gpu_buf.cuda_device_ptr()?;
232+
233+
Ok((ptr, vec![chunk_offsets, indices, values, gpu_buf]))
253234
}
254235

255236
#[cfg(test)]

0 commit comments

Comments
 (0)