Skip to content

Commit bf1527e

Browse files
authored
fix[gpu]: retain device buffers for dyn dispatch kernel (#7980)
fixes: #7838 (comment) Signed-off-by: Alexander Droste <alexander.droste@protonmail.com>
1 parent 97f21d0 commit bf1527e

1 file changed

Lines changed: 20 additions & 2 deletions

File tree

  • vortex-cuda/src/dynamic_dispatch

vortex-cuda/src/dynamic_dispatch/mod.rs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ use vortex::error::VortexResult;
4343
use vortex::error::vortex_bail;
4444
use vortex::error::vortex_err;
4545

46+
use crate::CudaBufferExt;
4647
use crate::CudaDeviceBuffer;
4748
use crate::executor::CudaExecutionCtx;
4849

@@ -479,8 +480,23 @@ impl MaterializedPlan {
479480
shared_mem_bytes: self.shared_mem_bytes,
480481
};
481482

483+
// The packed dispatch plan stores raw input/patch pointers, so those buffers are not
484+
// passed through `LaunchArgs` as `CudaView`s. Record reads explicitly so their drops are
485+
// ordered after this kernel launch on `stream`. The read records borrow from the views,
486+
// so keep both alive until after the kernel is enqueued.
487+
let device_buffer_views = self
488+
.device_buffers
489+
.iter()
490+
.map(|buffer| buffer.cuda_view::<u8>())
491+
.collect::<VortexResult<Vec<_>>>()?;
492+
let stream = ctx.stream().clone();
493+
let device_buffer_read_records = device_buffer_views
494+
.iter()
495+
.map(|view| view.device_ptr(&stream).1)
496+
.collect::<Vec<_>>();
497+
482498
let output_ptr = output_buf.offset_ptr();
483-
let plan_ptr = device_plan.device_ptr(ctx.stream()).0;
499+
let (plan_ptr, plan_read_record) = device_plan.device_ptr(&stream);
484500
let array_len_u64 = len as u64;
485501

486502
ctx.launch_kernel_config(&cuda_function, config, len, |args| {
@@ -489,6 +505,8 @@ impl MaterializedPlan {
489505
args.arg(&plan_ptr);
490506
})?;
491507

508+
drop((device_buffer_read_records, plan_read_record));
509+
492510
Ok(Canonical::Primitive(PrimitiveArray::from_buffer_handle(
493511
BufferHandle::new_device(output_buf.slice_typed::<T>(0..len)),
494512
output_ptype,
@@ -1980,7 +1998,7 @@ mod tests {
19801998

19811999
#[crate::test]
19822000
async fn alp_slice_device_patches() -> VortexResult<()> {
1983-
// Regression test for https://github.com/vortex-data/vortex/issues/7838.
2001+
// Regression test for https://github.com/vortex-data/vortex/issues/7838#issuecomment-4452796116.
19842002
let mut cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
19852003
let len = 4096;
19862004
let exponents = Exponents { e: 0, f: 0 };

0 commit comments

Comments
 (0)