@@ -137,15 +137,19 @@ pub(crate) fn ptype_to_chunk_offset_type(ptype: PType) -> VortexResult<ChunkOffs
137137///
138138/// Canonicalization is done on the CPU via [`LEGACY_SESSION`], then the
139139/// resulting host buffers are uploaded to the device.
140+ /// Canonicalize patches on the CPU, upload data buffers and a [`GPUPatches`]
141+ /// struct to the device in one step. Returns the device pointer to the
142+ /// `GPUPatches` struct and a vec of buffer handles that must be kept alive
143+ /// for the duration of the kernel launch.
140144///
141145/// # Errors
142146///
143147/// If the patches do not have `chunk_offsets`.
144148#[ allow( deprecated) ]
145- pub ( crate ) fn load_patches_sync (
149+ pub ( crate ) fn upload_patches (
146150 patches : & Patches ,
147151 ctx : & CudaExecutionCtx ,
148- ) -> VortexResult < DevicePatches > {
152+ ) -> VortexResult < ( u64 , Vec < BufferHandle > ) > {
149153 let offset = patches. offset ( ) ;
150154 let offset_within_chunk = patches. offset_within_chunk ( ) . unwrap_or_default ( ) ;
151155
@@ -159,6 +163,7 @@ pub(crate) fn load_patches_sync(
159163 // Canonicalize chunk_offsets on the CPU
160164 let co_canonical = co. clone ( ) . execute :: < PrimitiveArray > ( & mut exec_ctx) ?;
161165 let chunk_offset_ptype = co_canonical. ptype ( ) ;
166+ let n_chunks = co_canonical. len ( ) ;
162167 let chunk_offsets = co_canonical. buffer_handle ( ) . clone ( ) ;
163168
164169 // Canonicalize indices and convert to u32
@@ -187,52 +192,29 @@ pub(crate) fn load_patches_sync(
187192 . execute :: < PrimitiveArray > ( & mut exec_ctx) ?;
188193 let values = values_prim. buffer_handle ( ) . clone ( ) ;
189194
190- // Upload all buffers to the device
195+ // Upload data buffers to the device
191196 let chunk_offsets = ctx. ensure_on_device_sync ( chunk_offsets) ?;
192197 let indices = ctx. ensure_on_device_sync ( indices) ?;
193198 let values = ctx. ensure_on_device_sync ( values) ?;
194199
195- let num_patches = patches. num_patches ( ) ;
196- // n_chunks must match the chunk_offsets array length, not array_len / 1024.
197- // When patches are sliced, chunk_offsets is sliced to only include chunks
198- // overlapping the slice range — matching the CPU's patch_chunk which uses
199- // chunk_offsets_slice.len().
200- let n_chunks = co_canonical. len ( ) ;
201-
202- Ok ( DevicePatches {
203- chunk_offsets,
204- chunk_offset_ptype,
205- indices,
206- values,
207- offset,
208- offset_within_chunk,
209- num_patches,
210- n_chunks,
211- } )
212- }
213-
214- /// Upload a [`GPUPatches`] struct to the device, returning the buffer handle
215- /// (which must be kept alive) and the device pointer to the struct.
216- ///
217- /// The caller must also keep the [`DevicePatches`] alive for the duration of
218- /// the kernel launch, since the `GPUPatches` struct contains device pointers
219- /// into the individual buffers owned by `DevicePatches`.
220- pub ( crate ) fn upload_gpu_patches (
221- device_patches : & DevicePatches ,
222- ctx : & CudaExecutionCtx ,
223- ) -> VortexResult < ( BufferHandle , u64 ) > {
200+ // Build the GPUPatches C struct from device pointers.
224201 // Zero-initialize to avoid uninitialized padding bytes.
225202 let mut gpu_patches: GPUPatches = unsafe { std:: mem:: zeroed ( ) } ;
226- gpu_patches. chunk_offsets = device_patches. chunk_offsets . cuda_device_ptr ( ) ? as _ ;
227- gpu_patches. chunk_offset_type = ptype_to_chunk_offset_type ( device_patches. chunk_offset_ptype ) ?;
228- gpu_patches. indices = device_patches. indices . cuda_device_ptr ( ) ? as _ ;
229- gpu_patches. values = device_patches. values . cuda_device_ptr ( ) ? as _ ;
203+ gpu_patches. chunk_offsets = chunk_offsets. cuda_device_ptr ( ) ? as _ ;
204+ gpu_patches. chunk_offset_type = ptype_to_chunk_offset_type ( chunk_offset_ptype) ?;
205+ gpu_patches. indices = indices. cuda_device_ptr ( ) ? as _ ;
206+ gpu_patches. values = values. cuda_device_ptr ( ) ? as _ ;
207+ let num_patches = patches. num_patches ( ) ;
230208 #[ expect( clippy:: cast_possible_truncation) ]
231209 {
232- gpu_patches. offset = device_patches. offset as u32 ;
233- gpu_patches. offset_within_chunk = device_patches. offset_within_chunk as u32 ;
234- gpu_patches. num_patches = device_patches. num_patches as u32 ;
235- gpu_patches. n_chunks = device_patches. n_chunks as u32 ;
210+ gpu_patches. offset = offset as u32 ;
211+ gpu_patches. offset_within_chunk = offset_within_chunk as u32 ;
212+ gpu_patches. num_patches = num_patches as u32 ;
213+ // n_chunks must match the chunk_offsets array length, not array_len / 1024.
214+ // When patches are sliced, chunk_offsets is sliced to only include chunks
215+ // overlapping the slice range — matching the CPU's patch_chunk which uses
216+ // chunk_offsets_slice.len().
217+ gpu_patches. n_chunks = n_chunks as u32 ;
236218 }
237219
238220 // Serialize the repr(C) struct to bytes and upload to the device.
@@ -242,14 +224,13 @@ pub(crate) fn upload_gpu_patches(
242224 size_of :: < GPUPatches > ( ) ,
243225 )
244226 } ;
245-
246227 let mut buf =
247228 ByteBufferMut :: with_capacity_aligned ( size_of :: < GPUPatches > ( ) , Alignment :: of :: < u64 > ( ) ) ;
248229 buf. extend_from_slice ( bytes) ;
249- let host_buf = BufferHandle :: new_host ( buf. freeze ( ) ) ;
250- let device_buf = ctx . ensure_on_device_sync ( host_buf ) ?;
251- let ptr = device_buf . cuda_device_ptr ( ) ? ;
252- Ok ( ( device_buf , ptr ) )
230+ let gpu_buf = ctx . ensure_on_device_sync ( BufferHandle :: new_host ( buf. freeze ( ) ) ) ? ;
231+ let ptr = gpu_buf . cuda_device_ptr ( ) ?;
232+
233+ Ok ( ( ptr , vec ! [ chunk_offsets , indices , values , gpu_buf ] ) )
253234}
254235
255236#[ cfg( test) ]
0 commit comments