Skip to content

Commit 6c7cd94

Browse files
authored
keep DropOnSync alive until after dispatching for cuda slices (#6673)
## Summary cuda device pointers come with `SyncOnDrop` that is used to synchronise reads and writes to the underlying cuda buffer. We should keep those alive until after dispatching the read or write work on them. For our single stream case this is fine, but if we were to have multiple streams accessing the same buffer these would be a problem --------- Signed-off-by: Onur Satici <onur@spiraldb.com>
1 parent ea7804c commit 6c7cd94

9 files changed

Lines changed: 169 additions & 83 deletions

File tree

vortex-cuda/benches/dynamic_dispatch_cuda.rs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -51,14 +51,16 @@ const BENCH_ARGS: &[(usize, &str)] = &[
5151
/// Launch the dynamic_dispatch kernel and return GPU-timed duration.
5252
fn run_timed(
5353
cuda_ctx: &mut CudaExecutionCtx,
54-
output_ptr: u64,
5554
array_len: usize,
55+
output_buf: &CudaDeviceBuffer,
5656
device_plan: &Arc<cudarc::driver::CudaSlice<DynamicDispatchPlan>>,
5757
shared_mem_bytes: u32,
5858
) -> VortexResult<Duration> {
5959
let cuda_function = cuda_ctx.load_function("dynamic_dispatch", &["u32"])?;
6060
let array_len_u64 = array_len as u64;
61-
let plan_ptr = device_plan.device_ptr(cuda_ctx.stream()).0;
61+
let output_view = output_buf.as_view::<u32>();
62+
let (output_ptr, record_output) = output_view.device_ptr(cuda_ctx.stream());
63+
let (plan_ptr, record_plan) = device_plan.device_ptr(cuda_ctx.stream());
6264

6365
let stream = cuda_ctx.stream();
6466
let ctx = stream.context();
@@ -86,6 +88,7 @@ fn run_timed(
8688
.launch(config)
8789
.map_err(|e| vortex_err!("kernel launch failed: {e}"))?;
8890
}
91+
drop((record_output, record_plan));
8992

9093
let stream = cuda_ctx.stream();
9194
let ctx = stream.context();
@@ -105,11 +108,10 @@ fn run_timed(
105108
struct BenchRunner {
106109
_plan: DynamicDispatchPlan,
107110
smem_bytes: u32,
108-
output_ptr: u64,
109111
len: usize,
110112
// Keep alive
111-
_device_plan: Arc<cudarc::driver::CudaSlice<DynamicDispatchPlan>>,
112-
_output_buf: CudaDeviceBuffer,
113+
device_plan: Arc<cudarc::driver::CudaSlice<DynamicDispatchPlan>>,
114+
output_buf: CudaDeviceBuffer,
113115
_plan_buffers: Vec<vortex::array::buffer::BufferHandle>,
114116
}
115117

@@ -130,15 +132,13 @@ impl BenchRunner {
130132
.device_alloc::<u32>(len.next_multiple_of(1024))
131133
.expect("alloc output");
132134
let output_buf = CudaDeviceBuffer::new(output_slice);
133-
let output_ptr = output_buf.as_view::<u32>().device_ptr(cuda_ctx.stream()).0;
134135

135136
Self {
136137
_plan: plan,
137138
smem_bytes,
138-
output_ptr,
139139
len,
140-
_device_plan: device_plan,
141-
_output_buf: output_buf,
140+
device_plan,
141+
output_buf,
142142
_plan_buffers: plan_buffers,
143143
}
144144
}
@@ -147,9 +147,9 @@ impl BenchRunner {
147147
cuda_ctx.stream().synchronize().unwrap();
148148
run_timed(
149149
cuda_ctx,
150-
self.output_ptr,
151150
self.len,
152-
&self._device_plan,
151+
&self.output_buf,
152+
&self.device_plan,
153153
self.smem_bytes,
154154
)
155155
.unwrap()

vortex-cuda/benches/filter_cuda.rs

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -86,26 +86,33 @@ async fn run_filter_timed<T: CubFilterable + cudarc::driver::DeviceRepr>(
8686

8787
// Get raw pointers
8888
let stream_ptr = stream.cu_stream() as cudaStream_t;
89-
let d_input_ptr = d_input.device_ptr(stream).0 as *const T;
90-
let d_bitmask_ptr = d_bitmask.device_ptr(stream).0 as *const u8;
91-
let d_output_ptr = d_output.device_ptr_mut(stream).0 as *mut T;
92-
let d_temp_ptr = d_temp.device_ptr_mut(stream).0 as *mut c_void;
93-
let d_num_selected_ptr = d_num_selected.device_ptr_mut(stream).0 as *mut i64;
89+
let (d_input_ptr, record_d_input) = d_input.device_ptr(stream);
90+
let (d_bitmask_ptr, record_d_bitmask) = d_bitmask.device_ptr(stream);
91+
let (d_output_ptr, record_d_output) = d_output.device_ptr_mut(stream);
92+
let (d_temp_ptr, record_d_temp) = d_temp.device_ptr_mut(stream);
93+
let (d_num_selected_ptr, record_d_num_selected) = d_num_selected.device_ptr_mut(stream);
9494

9595
unsafe {
9696
T::filter_bitmask(
97-
d_temp_ptr,
97+
d_temp_ptr as *mut c_void,
9898
temp_bytes,
99-
d_input_ptr,
100-
d_bitmask_ptr,
99+
d_input_ptr as *const T,
100+
d_bitmask_ptr as *const u8,
101101
0, // bit_offset
102-
d_output_ptr,
103-
d_num_selected_ptr,
102+
d_output_ptr as *mut T,
103+
d_num_selected_ptr as *mut i64,
104104
num_items,
105105
stream_ptr,
106106
)
107107
.map_err(|e| vortex_err!("Filter kernel execution failed: {}", e))?;
108108
}
109+
drop((
110+
record_d_input,
111+
record_d_bitmask,
112+
record_d_output,
113+
record_d_temp,
114+
record_d_num_selected,
115+
));
109116

110117
let end_event = ctx
111118
.new_event(Some(CUevent_flags::CU_EVENT_BLOCKING_SYNC))

vortex-cuda/benches/throughput_cuda.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ fn transfer_mix_timed(
5151
.device_alloc::<u32>((output_bytes / size_of::<u32>()).max(1))
5252
.unwrap();
5353

54-
let src_ptr = dtod_src.device_ptr(&in_stream).0;
55-
let dst_ptr = dtod_dst.device_ptr_mut(&in_stream).0;
56-
let memset_ptr = memset_dst.device_ptr_mut(&out_stream).0;
54+
let (src_ptr, record_src) = dtod_src.device_ptr(&in_stream);
55+
let (dst_ptr, record_dst) = dtod_dst.device_ptr_mut(&in_stream);
56+
let (memset_ptr, record_memset) = memset_dst.device_ptr_mut(&out_stream);
5757

5858
in_stream.synchronize().unwrap();
5959
out_stream.synchronize().unwrap();
@@ -76,6 +76,7 @@ fn transfer_mix_timed(
7676
.unwrap();
7777
}
7878
}
79+
drop((record_src, record_dst, record_memset));
7980

8081
let end_in = in_stream
8182
.record_event(Some(CUevent_flags::CU_EVENT_BLOCKING_SYNC))

vortex-cuda/benches/zstd_cuda.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,22 +78,28 @@ async fn execute_zstd_kernel(
7878
.record(stream)
7979
.map_err(|e| vortex_err!("Failed to record start event: {:?}", e))?;
8080

81+
let (device_actual_sizes_ptr, record_actual_sizes) =
82+
exec.device_actual_sizes.device_ptr_mut(stream);
83+
let (nvcomp_temp_buffer_ptr, record_temp) = exec.nvcomp_temp_buffer.device_ptr_mut(stream);
84+
let (device_statuses_ptr, record_statuses) = exec.device_statuses.device_ptr_mut(stream);
85+
8186
// Launch the kernel
8287
unsafe {
8388
nvcomp_zstd::decompress_async(
8489
exec.frame_ptrs_ptr as _,
8590
exec.frame_sizes_ptr as _,
8691
exec.output_sizes_ptr as _,
87-
exec.device_actual_sizes.device_ptr_mut(stream).0 as _,
92+
device_actual_sizes_ptr as _,
8893
exec.num_frames,
89-
exec.nvcomp_temp_buffer.device_ptr_mut(stream).0 as _,
94+
nvcomp_temp_buffer_ptr as _,
9095
exec.nvcomp_temp_buffer_size,
9196
exec.output_ptrs_ptr as _,
92-
exec.device_statuses.device_ptr_mut(stream).0 as _,
97+
device_statuses_ptr as _,
9398
stream.cu_stream().cast(),
9499
)
95100
.map_err(|e| vortex_err!("nvcomp decompress_async failed: {}", e))?;
96101
}
102+
drop((record_actual_sizes, record_temp, record_statuses));
97103

98104
let end_event = ctx
99105
.new_event(Some(CUevent_flags::CU_EVENT_BLOCKING_SYNC))

vortex-cuda/src/dynamic_dispatch/mod.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ mod tests {
316316
data: &[u32],
317317
) -> VortexResult<(u64, Arc<cudarc::driver::CudaSlice<u32>>)> {
318318
let device_buf = Arc::new(cuda_ctx.stream().clone_htod(data).expect("htod"));
319-
let ptr = device_buf.device_ptr(cuda_ctx.stream()).0;
319+
let (ptr, _) = device_buf.device_ptr(cuda_ctx.stream());
320320
Ok((ptr, device_buf))
321321
}
322322

@@ -372,15 +372,16 @@ mod tests {
372372
.device_alloc::<u32>(output_len)
373373
.vortex_expect("alloc output");
374374
let output_buf = CudaDeviceBuffer::new(output_slice);
375-
let output_ptr = output_buf.as_view::<u32>().device_ptr(cuda_ctx.stream()).0;
375+
let output_view = output_buf.as_view::<u32>();
376+
let (output_ptr, record_output) = output_view.device_ptr(cuda_ctx.stream());
376377

377378
let device_plan = Arc::new(
378379
cuda_ctx
379380
.stream()
380381
.clone_htod(std::slice::from_ref(plan))
381382
.expect("copy plan to device"),
382383
);
383-
let plan_ptr = device_plan.device_ptr(cuda_ctx.stream()).0;
384+
let (plan_ptr, record_plan) = device_plan.device_ptr(cuda_ctx.stream());
384385
let array_len_u64 = output_len as u64;
385386

386387
cuda_ctx.stream().synchronize().expect("sync");
@@ -402,6 +403,7 @@ mod tests {
402403
unsafe {
403404
launch_builder.launch(config).expect("kernel launch");
404405
}
406+
drop((record_output, record_plan));
405407

406408
Ok(cuda_ctx
407409
.stream()

vortex-cuda/src/kernel/encodings/zstd.rs

Lines changed: 56 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -122,16 +122,15 @@ pub async fn zstd_kernel_prepare(
122122
// Device pointers for all compressed frames.
123123
let frame_ptrs = device_frame_handles
124124
.iter()
125-
.map(|handle| {
126-
handle
127-
.cuda_view::<u8>()
128-
.map(|view| view.device_ptr(ctx.stream()).0)
129-
})
125+
.map(|handle| handle.cuda_device_ptr())
130126
.collect::<VortexResult<Vec<_>>>()?;
131127

132128
// Build output_ptrs from output base pointer + offsets.
133129
let output_ptrs = {
134-
let base_ptr = device_output.device_ptr(ctx.stream()).0;
130+
// We only need the allocation address here to build pointer metadata.
131+
// The actual device write is tracked by `record_device_output` around
132+
// `decompress_async`, so this guard can be dropped immediately.
133+
let (base_ptr, _) = device_output.device_ptr(ctx.stream());
135134
output_sizes
136135
.iter()
137136
.scan(0u64, |offset, &size| {
@@ -155,16 +154,10 @@ pub async fn zstd_kernel_prepare(
155154
let device_statuses: CudaSlice<nvcompStatus_t> = ctx.device_alloc(num_frames)?;
156155
let nvcomp_temp_buffer: CudaSlice<u8> = ctx.device_alloc(nvcomp_temp_buffer_size)?;
157156

158-
macro_rules! device_ptr {
159-
($handle:expr, $type:ty) => {
160-
$handle.cuda_view::<$type>()?.device_ptr(ctx.stream()).0
161-
};
162-
}
163-
164-
let frame_ptrs_ptr = device_ptr!(frame_ptrs_handle, u64);
165-
let frame_sizes_ptr = device_ptr!(frame_sizes_handle, usize);
166-
let output_sizes_ptr = device_ptr!(output_sizes_handle, usize);
167-
let output_ptrs_ptr = device_ptr!(output_ptrs_handle, u64);
157+
let frame_ptrs_ptr = frame_ptrs_handle.cuda_device_ptr()?;
158+
let frame_sizes_ptr = frame_sizes_handle.cuda_device_ptr()?;
159+
let output_sizes_ptr = output_sizes_handle.cuda_device_ptr()?;
160+
let output_ptrs_ptr = output_ptrs_handle.cuda_device_ptr()?;
168161

169162
// Return device pointers and handles to keep device memory alive
170163
Ok(ZstdKernelPrep {
@@ -252,25 +245,65 @@ async fn decode_zstd(array: ZstdArray, ctx: &mut CudaExecutionCtx) -> VortexResu
252245
let mut exec = zstd_kernel_prepare(frames, &metadata, ctx).await?;
253246

254247
let stream = ctx.stream();
248+
let frame_views = exec
249+
.device_frame_handles
250+
.iter()
251+
.map(|handle| handle.cuda_view::<u8>())
252+
.collect::<VortexResult<Vec<_>>>()?;
253+
let mut frame_ptr_records = Vec::with_capacity(frame_views.len());
254+
for view in &frame_views {
255+
let (_frame_ptr, record_frame_ptr) = view.device_ptr(stream);
256+
frame_ptr_records.push(record_frame_ptr);
257+
}
258+
259+
let frame_ptrs_view = exec.frame_ptrs_handle.cuda_view::<u64>()?;
260+
let frame_sizes_view = exec.frame_sizes_handle.cuda_view::<usize>()?;
261+
let output_sizes_view = exec.output_sizes_handle.cuda_view::<usize>()?;
262+
let output_ptrs_view = exec.output_ptrs_handle.cuda_view::<u64>()?;
263+
264+
let (frame_ptrs_ptr, record_frame_ptrs) = frame_ptrs_view.device_ptr(stream);
265+
let (frame_sizes_ptr, record_frame_sizes) = frame_sizes_view.device_ptr(stream);
266+
let (output_sizes_ptr, record_output_sizes) = output_sizes_view.device_ptr(stream);
267+
let (output_ptrs_ptr, record_output_ptrs) = output_ptrs_view.device_ptr(stream);
268+
269+
// Track writes to the output allocation at the actual enqueue point.
270+
// This guard intentionally outlives the pointer-metadata construction above.
271+
let (_device_output_ptr, record_device_output) = exec.device_output.device_ptr_mut(stream);
272+
let (device_actual_sizes_ptr, record_actual_sizes) =
273+
exec.device_actual_sizes.device_ptr_mut(stream);
274+
let (nvcomp_temp_buffer_ptr, record_temp) = exec.nvcomp_temp_buffer.device_ptr_mut(stream);
275+
let (device_statuses_ptr, record_statuses) = exec.device_statuses.device_ptr_mut(stream);
255276

256277
ctx.launch_external(n_rows, || {
257278
// SAFETY: zstd_kernel_prepare makes sure to return valid kernel params.
258279
unsafe {
259280
nvcomp_zstd::decompress_async(
260-
exec.frame_ptrs_ptr as _,
261-
exec.frame_sizes_ptr as _,
262-
exec.output_sizes_ptr as _,
263-
exec.device_actual_sizes.device_ptr_mut(stream).0 as _,
281+
frame_ptrs_ptr as _,
282+
frame_sizes_ptr as _,
283+
output_sizes_ptr as _,
284+
device_actual_sizes_ptr as _,
264285
exec.num_frames,
265-
exec.nvcomp_temp_buffer.device_ptr_mut(stream).0 as _,
286+
nvcomp_temp_buffer_ptr as _,
266287
exec.nvcomp_temp_buffer_size,
267-
exec.output_ptrs_ptr as _,
268-
exec.device_statuses.device_ptr_mut(stream).0 as _,
288+
output_ptrs_ptr as _,
289+
device_statuses_ptr as _,
269290
stream.cu_stream().cast(),
270291
)
271292
.map_err(|e| vortex_err!("nvcomp decompress_async failed: {}", e))
272293
}
273294
})?;
295+
drop(frame_ptr_records);
296+
drop(frame_views);
297+
drop((
298+
record_frame_ptrs,
299+
record_frame_sizes,
300+
record_output_sizes,
301+
record_output_ptrs,
302+
record_device_output,
303+
record_actual_sizes,
304+
record_temp,
305+
record_statuses,
306+
));
274307

275308
// Unconditionally copy back to the host as Zstd arrays are fully
276309
// self-contained. They neither have any parent or child encodings.

0 commit comments

Comments
 (0)