Skip to content

Commit ee29169

Browse files
authored
cuda ctx bind to thread before unsafe (#6681)
set the cuda context thread locals before unsafe cudarc calls. These lower level calls do not ensure the thread local cuda context some calls expect. cudarc's safe methods do this internally. This PR also wraps the zstd buffers to use launch_external, matching other kernels --------- Signed-off-by: Onur Satici <onur@spiraldb.com>
1 parent a13a0ed commit ee29169

File tree

3 files changed

+39
-16
lines changed

3 files changed

+39
-16
lines changed

vortex-cuda/src/device_buffer.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,11 @@ impl DeviceBuffer for CudaDeviceBuffer {
244244
ByteBufferMut::with_capacity_aligned(self.len, alignment);
245245
let len = self.len;
246246

247+
stream
248+
.context()
249+
.bind_to_thread()
250+
.map_err(|e| vortex_err!("Failed to bind CUDA context: {}", e))?;
251+
247252
// SAFETY: We pass a valid pointer to a buffer with sufficient capacity.
248253
// `cuMemcpyDtoHAsync_v2` fully initializes the memory.
249254
unsafe {

vortex-cuda/src/kernel/encodings/zstd_buffers.rs

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@ async fn decode_zstd_buffers(
106106
let mut device_statuses: CudaSlice<nvcompStatus_t> = ctx.device_alloc(plan.num_frames())?;
107107
let mut nvcomp_temp_buffer: CudaSlice<u8> = ctx.device_alloc(nvcomp_temp_buffer_size)?;
108108
let stream = ctx.stream();
109-
110109
let frame_ptrs_view = frame_ptrs_handle.cuda_view::<u64>()?;
111110
let frame_sizes_view = frame_sizes_handle.cuda_view::<usize>()?;
112111
let output_sizes_view = output_sizes_handle.cuda_view::<usize>()?;
@@ -123,21 +122,25 @@ async fn decode_zstd_buffers(
123122
let (device_actual_sizes_ptr, record_actual_sizes) = device_actual_sizes.device_ptr_mut(stream);
124123
let (nvcomp_temp_buffer_ptr, record_temp) = nvcomp_temp_buffer.device_ptr_mut(stream);
125124
let (device_statuses_ptr, record_statuses) = device_statuses.device_ptr_mut(stream);
126-
unsafe {
127-
nvcomp_zstd::decompress_async(
128-
frame_ptrs_ptr as _,
129-
frame_sizes_ptr as _,
130-
output_sizes_ptr as _,
131-
device_actual_sizes_ptr as _,
132-
plan.num_frames(),
133-
nvcomp_temp_buffer_ptr as _,
134-
nvcomp_temp_buffer_size,
135-
output_ptrs_ptr as _,
136-
device_statuses_ptr as _,
137-
stream.cu_stream().cast(),
138-
)
139-
.map_err(|e| vortex_err!("nvcomp decompress_async failed: {}", e))?;
140-
}
125+
126+
ctx.launch_external(plan.output_size_total(), || {
127+
// SAFETY: Pointer and size parameters are derived from validated decode plan inputs.
128+
unsafe {
129+
nvcomp_zstd::decompress_async(
130+
frame_ptrs_ptr as _,
131+
frame_sizes_ptr as _,
132+
output_sizes_ptr as _,
133+
device_actual_sizes_ptr as _,
134+
plan.num_frames(),
135+
nvcomp_temp_buffer_ptr as _,
136+
nvcomp_temp_buffer_size,
137+
output_ptrs_ptr as _,
138+
device_statuses_ptr as _,
139+
stream.cu_stream().cast(),
140+
)
141+
.map_err(|e| vortex_err!("nvcomp decompress_async failed: {}", e))
142+
}
143+
})?;
141144
drop(frame_ptr_records);
142145
drop(frame_views);
143146
drop((

vortex-cuda/src/stream.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,16 @@ impl VortexCudaStream {
6868
let mut cuda_slice: CudaSlice<T> = self.device_alloc(host_slice.len())?;
6969
let (device_ptr, record_write) = cuda_slice.device_ptr_mut(&self.0);
7070

71+
// calling the unsafe memcpy_htod_async expects the cuda context thread local
72+
// to be set. To avoid invalid context error from the cuda call we set it
73+
// explicitly here.
74+
// TODO(os): wrap calling unsafe cudarc functions with something that binds always
75+
// so we don't forget
76+
self.0
77+
.context()
78+
.bind_to_thread()
79+
.map_err(|e| vortex_err!("Failed to bind CUDA context: {}", e))?;
80+
7181
unsafe {
7282
memcpy_htod_async(device_ptr, host_slice, self.0.cu_stream())
7383
.map_err(|e| vortex_err!("Failed to schedule async copy to device: {}", e))?;
@@ -127,6 +137,11 @@ fn register_stream_callback(stream: &CudaStream) -> VortexResult<kanal::AsyncRec
127137

128138
let tx_ptr = Box::into_raw(Box::new(tx));
129139

140+
stream
141+
.context()
142+
.bind_to_thread()
143+
.map_err(|e| vortex_err!("Failed to bind CUDA context: {}", e))?;
144+
130145
/// Called from CUDA driver thread when all preceding work on the stream completes.
131146
unsafe extern "C" fn callback(user_data: *mut std::ffi::c_void) {
132147
// SAFETY: The memory of `tx` is manually managed has not been freed

0 commit comments

Comments
 (0)