22// SPDX-FileCopyrightText: Copyright the Vortex contributors
33
44use std:: fmt:: Debug ;
5- use std:: sync:: Arc ;
65
76use async_trait:: async_trait;
87use cudarc:: driver:: DeviceRepr ;
8+ use cudarc:: driver:: LaunchConfig ;
99use cudarc:: driver:: PushKernelArg ;
1010use tracing:: instrument;
1111use vortex:: array:: ArrayRef ;
1212use vortex:: array:: Canonical ;
1313use vortex:: array:: arrays:: PrimitiveArray ;
1414use vortex:: array:: arrays:: primitive:: PrimitiveDataParts ;
1515use vortex:: array:: buffer:: BufferHandle ;
16- use vortex:: array:: match_each_unsigned_integer_ptype ;
16+ use vortex:: array:: buffer :: DeviceBufferExt ;
1717use vortex:: dtype:: NativePType ;
1818use vortex:: encodings:: alp:: ALP ;
1919use vortex:: encodings:: alp:: ALPArray ;
@@ -30,7 +30,8 @@ use crate::CudaDeviceBuffer;
3030use crate :: executor:: CudaArrayExt ;
3131use crate :: executor:: CudaExecute ;
3232use crate :: executor:: CudaExecutionCtx ;
33- use crate :: kernel:: patches:: execute_patches;
33+ use crate :: kernel:: patches:: build_gpu_patches;
34+ use crate :: kernel:: patches:: types:: load_patches;
3435
3536/// CUDA decoder for ALP (Adaptive Lossless floating-Point) decompression.
3637#[ derive( Debug ) ]
@@ -54,6 +55,13 @@ impl CudaExecute for ALPExecutor {
5455 }
5556}
5657
58+ /// Thread count per block, matching the strategy used by `bit_unpack`:
59+ /// 16 threads (64 elements each) for 64-bit output widths, otherwise 32.
60+ const fn alp_thread_count < A > ( ) -> u32 {
61+ if size_of :: < A > ( ) == 8 { 16 } else { 32 }
62+ }
63+
64+ #[ instrument( skip_all) ]
5765async fn decode_alp < A > ( array : ALPArray , ctx : & mut CudaExecutionCtx ) -> VortexResult < Canonical >
5866where
5967 A : ALPFloat + NativePType + DeviceRepr + Send + Sync + ' static ,
@@ -67,50 +75,69 @@ where
6775 let f: A = A :: F10 [ exponents. f as usize ] ;
6876 let e: A = A :: IF10 [ exponents. e as usize ] ;
6977
70- // Execute child and copy to device
78+ // Execute child and copy to device.
7179 let canonical = array. encoded ( ) . clone ( ) . execute_cuda ( ctx) . await ?;
7280 let primitive = canonical. into_primitive ( ) ;
7381 let PrimitiveDataParts {
7482 buffer, validity, ..
7583 } = primitive. into_data_parts ( ) ;
7684
7785 let device_input = ctx. ensure_on_device ( buffer) . await ?;
78-
79- // Get CUDA view of input
8086 let input_view = device_input. cuda_view :: < A :: ALPInt > ( ) ?;
8187
82- // Allocate output buffer
83- let output_slice = ctx. device_alloc :: < A > ( array_len) ?;
88+ // Allocate output rounded up to a full chunk: the fused kernel writes a
89+ // whole 1024-element chunk per block, and we slice off any padding below.
90+ let output_slice = ctx. device_alloc :: < A > ( array_len. next_multiple_of ( 1024 ) ) ?;
8491 let output_buf = CudaDeviceBuffer :: new ( output_slice) ;
8592 let output_view = output_buf. as_view :: < A > ( ) ;
8693
87- let array_len_u64 = array_len as u64 ;
88-
89- // Load kernel function
90- let kernel_ptypes = [ A :: ALPInt :: PTYPE , A :: PTYPE ] ;
91- let cuda_function = ctx. load_function ( "alp" , & kernel_ptypes) ?;
94+ // Patch validity does not need to be scattered: the ALP encoder strips null
95+ // positions from the exception list, so patches only exist at valid
96+ // positions. load_patches additionally rejects patches without
97+ // chunk_offsets (required by the fused kernel's PatchesCursor).
98+ let device_patches = if let Some ( patches) = array. patches ( ) {
99+ Some ( load_patches ( & patches, ctx) . await ?)
100+ } else {
101+ None
102+ } ;
103+ let patches_arg = build_gpu_patches ( device_patches. as_ref ( ) ) ?;
104+
105+ // Load the kernel: alp_{enc}_{float}_{threads}t
106+ let thread_count = alp_thread_count :: < A > ( ) ;
107+ let thread_suffix = format ! ( "{thread_count}t" ) ;
108+ let enc_suffix = A :: ALPInt :: PTYPE . to_string ( ) ;
109+ let float_suffix = A :: PTYPE . to_string ( ) ;
110+ let cuda_function = ctx. load_function_with_suffixes (
111+ "alp" ,
112+ & [
113+ enc_suffix. as_str ( ) ,
114+ float_suffix. as_str ( ) ,
115+ thread_suffix. as_str ( ) ,
116+ ] ,
117+ ) ?;
118+
119+ let num_blocks = u32:: try_from ( array_len. div_ceil ( 1024 ) ) ?;
120+ let config = LaunchConfig {
121+ grid_dim : ( num_blocks, 1 , 1 ) ,
122+ block_dim : ( thread_count, 1 , 1 ) ,
123+ shared_mem_bytes : 0 ,
124+ } ;
92125
93- ctx. launch_kernel ( & cuda_function, array_len, |args| {
126+ let array_len_u64 = array_len as u64 ;
127+ ctx. launch_kernel_config ( & cuda_function, config, array_len, |args| {
94128 args. arg ( & input_view)
95129 . arg ( & output_view)
96130 . arg ( & f)
97131 . arg ( & e)
98- . arg ( & array_len_u64) ;
132+ . arg ( & array_len_u64)
133+ . arg ( & patches_arg) ;
99134 } ) ?;
100135
101- // Check if there are any patches to decode here. Patch validity does not
102- // need to be scattered: the ALP encoder strips null positions from the
103- // exception list, so patches only exist at valid positions. execute_patches
104- // additionally guards against nullable patch values at runtime.
105- let output_buf = if let Some ( patches) = array. patches ( ) {
106- match_each_unsigned_integer_ptype ! ( patches. indices_ptype( ) ?, |I | {
107- execute_patches:: <A , I >( patches. clone( ) , output_buf, ctx) . await ?
108- } )
109- } else {
110- output_buf
111- } ;
136+ // Synchronize so the device patches buffers remain alive for the kernel.
137+ ctx. synchronize_stream ( ) ?;
138+ drop ( device_patches) ;
112139
113- let output_handle = BufferHandle :: new_device ( Arc :: new ( output_buf ) ) ;
140+ let output_handle = BufferHandle :: new_device ( output_buf . slice_typed :: < A > ( 0 ..array_len ) ) ;
114141 Ok ( Canonical :: Primitive ( PrimitiveArray :: from_buffer_handle (
115142 output_handle,
116143 A :: PTYPE ,
@@ -257,4 +284,88 @@ mod tests {
257284 assert_arrays_eq ! ( cpu_result, gpu_result) ;
258285 Ok ( ( ) )
259286 }
287+
288+ /// Multi-chunk ALP (> 1024 elements) with patches scattered across chunks.
289+ /// Exercises the fused kernel's per-block patches cursor math when more
290+ /// than one block is launched.
291+ #[ crate :: test]
292+ async fn test_cuda_alp_multi_chunk_with_patches ( ) -> VortexResult < ( ) > {
293+ let mut cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) )
294+ . vortex_expect ( "failed to create execution context" ) ;
295+
296+ // 3072 values (3 chunks). Inject exceptions (values ALP can't encode
297+ // losslessly) at a handful of positions spread across chunks.
298+ let mut values: Vec < f32 > = Vec :: with_capacity ( 3072 ) ;
299+ for i in 0u32 ..3072 {
300+ if matches ! ( i, 0 | 100 | 1023 | 1024 | 2000 | 3071 ) {
301+ values. push ( 1.0_f32 / 7.0 + i as f32 ) ;
302+ } else {
303+ values. push ( i as f32 ) ;
304+ }
305+ }
306+ let prim = PrimitiveArray :: new ( Buffer :: from ( values) , Validity :: NonNullable ) ;
307+ let alp_array = alp_encode (
308+ prim. as_view ( ) ,
309+ None ,
310+ & mut LEGACY_SESSION . create_execution_ctx ( ) ,
311+ ) ?;
312+ assert ! (
313+ alp_array. patches( ) . is_some( ) ,
314+ "expected patches from ALP exceptions"
315+ ) ;
316+
317+ let cpu_result = crate :: canonicalize_cpu ( alp_array. clone ( ) ) ?. into_array ( ) ;
318+
319+ let gpu_result = alp_array
320+ . into_array ( )
321+ . execute_cuda ( & mut cuda_ctx)
322+ . await ?
323+ . into_host ( )
324+ . await ?
325+ . into_array ( ) ;
326+
327+ assert_arrays_eq ! ( cpu_result, gpu_result) ;
328+ Ok ( ( ) )
329+ }
330+
331+ /// Tail-chunk bounds check: an array whose length is not a multiple of
332+ /// 1024 forces the kernel's tail-block path to bounds-check its decode
333+ /// loop. Includes a patch in the tail.
334+ #[ crate :: test]
335+ async fn test_cuda_alp_partial_tail_chunk ( ) -> VortexResult < ( ) > {
336+ let mut cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) )
337+ . vortex_expect ( "failed to create execution context" ) ;
338+
339+ let mut values: Vec < f64 > = Vec :: with_capacity ( 1500 ) ;
340+ for i in 0u32 ..1500 {
341+ if i == 1400 {
342+ values. push ( 1.0_f64 / 3.0 ) ;
343+ } else {
344+ values. push ( i as f64 ) ;
345+ }
346+ }
347+ let prim = PrimitiveArray :: new ( Buffer :: from ( values) , Validity :: NonNullable ) ;
348+ let alp_array = alp_encode (
349+ prim. as_view ( ) ,
350+ None ,
351+ & mut LEGACY_SESSION . create_execution_ctx ( ) ,
352+ ) ?;
353+ assert ! (
354+ alp_array. patches( ) . is_some( ) ,
355+ "expected patches from ALP exceptions"
356+ ) ;
357+
358+ let cpu_result = crate :: canonicalize_cpu ( alp_array. clone ( ) ) ?. into_array ( ) ;
359+
360+ let gpu_result = alp_array
361+ . into_array ( )
362+ . execute_cuda ( & mut cuda_ctx)
363+ . await ?
364+ . into_host ( )
365+ . await ?
366+ . into_array ( ) ;
367+
368+ assert_arrays_eq ! ( cpu_result, gpu_result) ;
369+ Ok ( ( ) )
370+ }
260371}
0 commit comments