Skip to content

Commit 799fdf1

Browse files
committed
fix[gpu]: retain device buffers for dyn dispatch kernel
Signed-off-by: Alexander Droste <alexander.droste@protonmail.com>
1 parent d45538e commit 799fdf1

2 files changed

Lines changed: 22 additions & 5 deletions

File tree

vortex-cuda/src/dynamic_dispatch/mod.rs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ use vortex::error::VortexResult;
4343
use vortex::error::vortex_bail;
4444
use vortex::error::vortex_err;
4545

46+
use crate::CudaBufferExt;
4647
use crate::CudaDeviceBuffer;
4748
use crate::executor::CudaExecutionCtx;
4849

@@ -479,8 +480,23 @@ impl MaterializedPlan {
479480
shared_mem_bytes: self.shared_mem_bytes,
480481
};
481482

483+
// The packed dispatch plan stores raw input/patch pointers, so those buffers are not
484+
// passed through `LaunchArgs` as `CudaView`s. Record reads explicitly so their drops are
485+
// ordered after this kernel launch on `stream`. The read records borrow from the views,
486+
// so keep both alive until after the kernel is enqueued.
487+
let device_buffer_views = self
488+
.device_buffers
489+
.iter()
490+
.map(|buffer| buffer.cuda_view::<u8>())
491+
.collect::<VortexResult<Vec<_>>>()?;
492+
let stream = ctx.stream().clone();
493+
let device_buffer_read_records = device_buffer_views
494+
.iter()
495+
.map(|view| view.device_ptr(&stream).1)
496+
.collect::<Vec<_>>();
497+
482498
let output_ptr = output_buf.offset_ptr();
483-
let plan_ptr = device_plan.device_ptr(ctx.stream()).0;
499+
let (plan_ptr, plan_read_record) = device_plan.device_ptr(&stream);
484500
let array_len_u64 = len as u64;
485501

486502
ctx.launch_kernel_config(&cuda_function, config, len, |args| {
@@ -489,6 +505,8 @@ impl MaterializedPlan {
489505
args.arg(&plan_ptr);
490506
})?;
491507

508+
drop((device_buffer_read_records, plan_read_record));
509+
492510
Ok(Canonical::Primitive(PrimitiveArray::from_buffer_handle(
493511
BufferHandle::new_device(output_buf.slice_typed::<T>(0..len)),
494512
output_ptype,
@@ -1980,7 +1998,7 @@ mod tests {
19801998

19811999
#[crate::test]
19822000
async fn alp_slice_device_patches() -> VortexResult<()> {
1983-
// Regression test for https://github.com/vortex-data/vortex/issues/7838.
2001+
// Regression test for https://github.com/vortex-data/vortex/issues/7838#issuecomment-4452796116.
19842002
let mut cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
19852003
let len = 4096;
19862004
let exponents = Exponents { e: 0, f: 0 };

vortex-cuda/src/kernel/patches/types.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ pub(crate) fn slice_device_patches(
220220

221221
#[cfg(test)]
222222
mod tests {
223+
use vortex::array::buffer::BufferHandle;
223224
use vortex::array::validity::Validity::NonNullable;
224225
use vortex::buffer::Buffer;
225226
use vortex::session::VortexSession;
@@ -329,9 +330,7 @@ mod tests {
329330
slice_device_patches(&patches, 1024..3000, &mut device_patches);
330331

331332
let actual = PrimitiveArray::from_buffer_handle(
332-
vortex::array::buffer::BufferHandle::new_host(
333-
device_patches.chunk_offsets.to_host().await,
334-
),
333+
BufferHandle::new_host(device_patches.chunk_offsets.to_host().await),
335334
device_patches.chunk_offset_ptype,
336335
NonNullable,
337336
)

0 commit comments

Comments
 (0)