33
44//! Host interface for dynamic CUDA kernel dispatch.
55//!
6- //! An [`UnmaterializedPlan`] walks an encoding tree (e.g., `ALP(FoR(BitPacked))`)
7- //! and flattens it into a linear sequence of stages. Call
8- //! [`materialize`](UnmaterializedPlan::materialize) to copy source buffers to
9- //! the device, producing a [`MaterializedPlan`] ready for kernel launch.
6+ //! [`UnmaterializedPlan::new`] walks an encoding tree (e.g., `ALP(FoR(BitPacked))`)
7+ //! in a single pass and returns one of three variants:
108//!
11- //! For partially-fusable trees, [`find_unfusable_nodes`] identifies nodes
12- //! that need separate kernels, and [`UnmaterializedPlan::new_with_subtree_inputs`] builds a plan
13- //! that incorporates their pre-executed arrays.
14- //!
15- //! Shared memory is dynamically sized at launch time via
16- //! [`UnmaterializedPlan::shared_mem_bytes`].
9+ //! - [`Fused`](UnmaterializedPlan::Fused) — call [`FusedPlan::materialize`].
10+ //! - [`PartiallyFused`](UnmaterializedPlan::PartiallyFused) — execute pending
11+ //! subtrees, then call [`FusedPlan::materialize_with_subtrees`].
12+ //! - [`Unfusable`](UnmaterializedPlan::Unfusable) — fall back to single-kernel dispatch.
1713
1814#![ allow( non_upper_case_globals) ]
1915#![ allow( non_camel_case_types) ]
@@ -47,9 +43,9 @@ use crate::CudaDeviceBuffer;
4743use crate :: executor:: CudaExecutionCtx ;
4844
4945pub ( crate ) mod plan_builder;
46+ pub use plan_builder:: DispatchPlan ;
47+ pub use plan_builder:: FusedPlan ;
5048pub use plan_builder:: MaterializedPlan ;
51- pub use plan_builder:: UnmaterializedPlan ;
52- pub use plan_builder:: find_unfusable_nodes;
5349
5450include ! ( concat!( env!( "OUT_DIR" ) , "/dynamic_dispatch.rs" ) ) ;
5551
@@ -422,6 +418,7 @@ impl MaterializedPlan {
422418
423419#[ cfg( test) ]
424420mod tests {
421+ use super :: * ;
425422 use std:: sync:: Arc ;
426423
427424 use cudarc:: driver:: DevicePtr ;
@@ -446,20 +443,21 @@ mod tests {
446443 use vortex:: encodings:: zigzag:: ZigZagArray ;
447444 use vortex:: error:: VortexExpect ;
448445 use vortex:: error:: VortexResult ;
446+
449447 use vortex:: session:: VortexSession ;
450448
451449 use super :: CudaDispatchPlan ;
450+ use super :: DispatchPlan ;
452451 use super :: SMEM_TILE_SIZE ;
453452 use super :: ScalarOp ;
454453 use super :: SourceOp ;
455454 use super :: Stage ;
456- use super :: UnmaterializedPlan ;
457455 use crate :: CudaBufferExt ;
458456 use crate :: CudaDeviceBuffer ;
459457 use crate :: CudaExecutionCtx ;
460458 use crate :: session:: CudaSession ;
461459
462- fn make_bitpacked_array_u32 ( bit_width : u8 , len : usize ) -> BitPackedArray {
460+ fn bitpacked_array_u32 ( bit_width : u8 , len : usize ) -> BitPackedArray {
463461 let max_val = ( 1u64 << bit_width) . saturating_sub ( 1 ) ;
464462 let values: Vec < u32 > = ( 0 ..len)
465463 . map ( |i| ( ( i as u64 ) % ( max_val + 1 ) ) as u32 )
@@ -469,6 +467,16 @@ mod tests {
469467 . vortex_expect ( "failed to create BitPacked array" )
470468 }
471469
470+ fn dispatch_plan (
471+ array : & vortex:: array:: ArrayRef ,
472+ ctx : & CudaExecutionCtx ,
473+ ) -> VortexResult < MaterializedPlan > {
474+ match DispatchPlan :: new ( array) ? {
475+ DispatchPlan :: Fused ( plan) => plan. materialize ( ctx) ,
476+ _ => vortex_bail ! ( "array encoding not fusable" ) ,
477+ }
478+ }
479+
472480 #[ crate :: test]
473481 fn test_max_scalar_ops ( ) -> VortexResult < ( ) > {
474482 let bit_width: u8 = 6 ;
@@ -481,7 +489,7 @@ mod tests {
481489 . map ( |i| ( ( i as u64 ) % ( max_val + 1 ) ) as u32 + total_reference)
482490 . collect ( ) ;
483491
484- let bitpacked = make_bitpacked_array_u32 ( bit_width, len) ;
492+ let bitpacked = bitpacked_array_u32 ( bit_width, len) ;
485493 let cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) ) ?;
486494 let packed = bitpacked. packed ( ) . clone ( ) ;
487495 let device_input = futures:: executor:: block_on ( cuda_ctx. ensure_on_device ( packed) ) ?;
@@ -669,9 +677,9 @@ mod tests {
669677 . map ( |i| ( ( i as u64 ) % ( max_val + 1 ) ) as u32 )
670678 . collect ( ) ;
671679
672- let bp = make_bitpacked_array_u32 ( bit_width, len) ;
680+ let bp = bitpacked_array_u32 ( bit_width, len) ;
673681 let cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) ) ?;
674- let plan = UnmaterializedPlan :: new ( & bp. into_array ( ) ) ? . materialize ( & cuda_ctx) ?;
682+ let plan = dispatch_plan ( & bp. into_array ( ) , & cuda_ctx) ?;
675683
676684 let actual =
677685 run_dynamic_dispatch_plan ( & cuda_ctx, len, & plan. dispatch_plan , plan. shared_mem_bytes ) ?;
@@ -692,11 +700,11 @@ mod tests {
692700 . collect ( ) ;
693701 let expected: Vec < u32 > = raw. iter ( ) . map ( |& v| v + reference) . collect ( ) ;
694702
695- let bp = make_bitpacked_array_u32 ( bit_width, len) ;
703+ let bp = bitpacked_array_u32 ( bit_width, len) ;
696704 let for_arr = FoRArray :: try_new ( bp. into_array ( ) , Scalar :: from ( reference) ) ?;
697705
698706 let cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) ) ?;
699- let plan = UnmaterializedPlan :: new ( & for_arr. into_array ( ) ) ? . materialize ( & cuda_ctx) ?;
707+ let plan = dispatch_plan ( & for_arr. into_array ( ) , & cuda_ctx) ?;
700708
701709 let actual =
702710 run_dynamic_dispatch_plan ( & cuda_ctx, len, & plan. dispatch_plan , plan. shared_mem_bytes ) ?;
@@ -722,7 +730,7 @@ mod tests {
722730 let re = RunEndArray :: new ( ends_arr, values_arr) ;
723731
724732 let cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) ) ?;
725- let plan = UnmaterializedPlan :: new ( & re. into_array ( ) ) ? . materialize ( & cuda_ctx) ?;
733+ let plan = dispatch_plan ( & re. into_array ( ) , & cuda_ctx) ?;
726734
727735 let actual =
728736 run_dynamic_dispatch_plan ( & cuda_ctx, len, & plan. dispatch_plan , plan. shared_mem_bytes ) ?;
@@ -755,7 +763,7 @@ mod tests {
755763 let dict = DictArray :: try_new ( codes_bp. into_array ( ) , dict_for. into_array ( ) ) ?;
756764
757765 let cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) ) ?;
758- let plan = UnmaterializedPlan :: new ( & dict. into_array ( ) ) ? . materialize ( & cuda_ctx) ?;
766+ let plan = dispatch_plan ( & dict. into_array ( ) , & cuda_ctx) ?;
759767
760768 let actual =
761769 run_dynamic_dispatch_plan ( & cuda_ctx, len, & plan. dispatch_plan , plan. shared_mem_bytes ) ?;
@@ -787,7 +795,7 @@ mod tests {
787795 ) ;
788796
789797 let cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) ) ?;
790- let plan = UnmaterializedPlan :: new ( & tree. into_array ( ) ) ? . materialize ( & cuda_ctx) ?;
798+ let plan = dispatch_plan ( & tree. into_array ( ) , & cuda_ctx) ?;
791799
792800 let actual =
793801 run_dispatch_plan_f32 ( & cuda_ctx, len, & plan. dispatch_plan , plan. shared_mem_bytes ) ?;
@@ -816,7 +824,7 @@ mod tests {
816824 let zz = ZigZagArray :: try_new ( bp. into_array ( ) ) ?;
817825
818826 let cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) ) ?;
819- let plan = UnmaterializedPlan :: new ( & zz. into_array ( ) ) ? . materialize ( & cuda_ctx) ?;
827+ let plan = dispatch_plan ( & zz. into_array ( ) , & cuda_ctx) ?;
820828
821829 let actual =
822830 run_dynamic_dispatch_plan ( & cuda_ctx, len, & plan. dispatch_plan , plan. shared_mem_bytes ) ?;
@@ -845,7 +853,7 @@ mod tests {
845853 let for_arr = FoRArray :: try_new ( re. into_array ( ) , Scalar :: from ( reference) ) ?;
846854
847855 let cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) ) ?;
848- let plan = UnmaterializedPlan :: new ( & for_arr. into_array ( ) ) ? . materialize ( & cuda_ctx) ?;
856+ let plan = dispatch_plan ( & for_arr. into_array ( ) , & cuda_ctx) ?;
849857
850858 let actual =
851859 run_dynamic_dispatch_plan ( & cuda_ctx, len, & plan. dispatch_plan , plan. shared_mem_bytes ) ?;
@@ -874,7 +882,7 @@ mod tests {
874882 let for_arr = FoRArray :: try_new ( dict. into_array ( ) , Scalar :: from ( reference) ) ?;
875883
876884 let cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) ) ?;
877- let plan = UnmaterializedPlan :: new ( & for_arr. into_array ( ) ) ? . materialize ( & cuda_ctx) ?;
885+ let plan = dispatch_plan ( & for_arr. into_array ( ) , & cuda_ctx) ?;
878886
879887 let actual =
880888 run_dynamic_dispatch_plan ( & cuda_ctx, len, & plan. dispatch_plan , plan. shared_mem_bytes ) ?;
@@ -902,7 +910,7 @@ mod tests {
902910 let dict = DictArray :: try_new ( codes_for. into_array ( ) , values_prim. into_array ( ) ) ?;
903911
904912 let cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) ) ?;
905- let plan = UnmaterializedPlan :: new ( & dict. into_array ( ) ) ? . materialize ( & cuda_ctx) ?;
913+ let plan = dispatch_plan ( & dict. into_array ( ) , & cuda_ctx) ?;
906914
907915 let actual =
908916 run_dynamic_dispatch_plan ( & cuda_ctx, len, & plan. dispatch_plan , plan. shared_mem_bytes ) ?;
@@ -927,7 +935,7 @@ mod tests {
927935 let dict = DictArray :: try_new ( codes_bp. into_array ( ) , values_prim. into_array ( ) ) ?;
928936
929937 let cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) ) ?;
930- let plan = UnmaterializedPlan :: new ( & dict. into_array ( ) ) ? . materialize ( & cuda_ctx) ?;
938+ let plan = dispatch_plan ( & dict. into_array ( ) , & cuda_ctx) ?;
931939
932940 let actual =
933941 run_dynamic_dispatch_plan ( & cuda_ctx, len, & plan. dispatch_plan , plan. shared_mem_bytes ) ?;
@@ -946,8 +954,11 @@ mod tests {
946954 let values_prim = PrimitiveArray :: new ( Buffer :: from ( dict_values) , NonNullable ) ;
947955 let dict = DictArray :: try_new ( codes_prim. into_array ( ) , values_prim. into_array ( ) ) ?;
948956
949- // UnmaterializedPlan::new should fail because u8 codes != u32 values in byte width.
950- assert ! ( UnmaterializedPlan :: new( & dict. into_array( ) ) . is_err( ) ) ;
957+ // UnmaterializedPlan::new should return Unfusable because u8 codes != u32 values in byte width.
958+ assert ! ( matches!(
959+ DispatchPlan :: new( & dict. into_array( ) ) ?,
960+ DispatchPlan :: Unfused
961+ ) ) ;
951962
952963 Ok ( ( ) )
953964 }
@@ -961,8 +972,11 @@ mod tests {
961972 let values_arr = PrimitiveArray :: new ( Buffer :: from ( values) , NonNullable ) . into_array ( ) ;
962973 let re = RunEndArray :: new ( ends_arr, values_arr) ;
963974
964- // UnmaterializedPlan::new should fail because u64 ends != i32 values in byte width.
965- assert ! ( UnmaterializedPlan :: new( & re. into_array( ) ) . is_err( ) ) ;
975+ // UnmaterializedPlan::new should return Unfusable because u64 ends != i32 values in byte width.
976+ assert ! ( matches!(
977+ DispatchPlan :: new( & re. into_array( ) ) ?,
978+ DispatchPlan :: Unfused
979+ ) ) ;
966980
967981 Ok ( ( ) )
968982 }
@@ -997,7 +1011,7 @@ mod tests {
9971011 let expected: Vec < u32 > = data[ slice_start..slice_end] . to_vec ( ) ;
9981012
9991013 let cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) ) ?;
1000- let plan = UnmaterializedPlan :: new ( & sliced) ? . materialize ( & cuda_ctx) ?;
1014+ let plan = dispatch_plan ( & sliced, & cuda_ctx) ?;
10011015
10021016 let actual = run_dynamic_dispatch_plan (
10031017 & cuda_ctx,
@@ -1048,7 +1062,7 @@ mod tests {
10481062 let expected: Vec < u32 > = all_decoded[ slice_start..slice_end] . to_vec ( ) ;
10491063
10501064 let cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) ) ?;
1051- let plan = UnmaterializedPlan :: new ( & sliced) ? . materialize ( & cuda_ctx) ?;
1065+ let plan = dispatch_plan ( & sliced, & cuda_ctx) ?;
10521066
10531067 let actual = run_dynamic_dispatch_plan (
10541068 & cuda_ctx,
@@ -1098,7 +1112,7 @@ mod tests {
10981112 . collect ( ) ;
10991113
11001114 let cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) ) ?;
1101- let plan = UnmaterializedPlan :: new ( & sliced) ? . materialize ( & cuda_ctx) ?;
1115+ let plan = dispatch_plan ( & sliced, & cuda_ctx) ?;
11021116
11031117 let actual = run_dynamic_dispatch_plan (
11041118 & cuda_ctx,
@@ -1143,7 +1157,7 @@ mod tests {
11431157 let expected: Vec < u32 > = data[ slice_start..slice_end] . to_vec ( ) ;
11441158
11451159 let cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) ) ?;
1146- let plan = UnmaterializedPlan :: new ( & sliced) ? . materialize ( & cuda_ctx) ?;
1160+ let plan = dispatch_plan ( & sliced, & cuda_ctx) ?;
11471161
11481162 let actual = run_dynamic_dispatch_plan (
11491163 & cuda_ctx,
@@ -1192,7 +1206,7 @@ mod tests {
11921206 let expected: Vec < u32 > = all_decoded[ slice_start..slice_end] . to_vec ( ) ;
11931207
11941208 let cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) ) ?;
1195- let plan = UnmaterializedPlan :: new ( & sliced) ? . materialize ( & cuda_ctx) ?;
1209+ let plan = dispatch_plan ( & sliced, & cuda_ctx) ?;
11961210
11971211 let actual = run_dynamic_dispatch_plan (
11981212 & cuda_ctx,
@@ -1244,7 +1258,7 @@ mod tests {
12441258 let expected: Vec < u32 > = all_decoded[ slice_start..slice_end] . to_vec ( ) ;
12451259
12461260 let cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) ) ?;
1247- let plan = UnmaterializedPlan :: new ( & sliced) ? . materialize ( & cuda_ctx) ?;
1261+ let plan = dispatch_plan ( & sliced, & cuda_ctx) ?;
12481262
12491263 let actual = run_dynamic_dispatch_plan (
12501264 & cuda_ctx,
@@ -1301,7 +1315,7 @@ mod tests {
13011315 let expected: Vec < u32 > = all_decoded[ slice_start..slice_end] . to_vec ( ) ;
13021316
13031317 let cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) ) ?;
1304- let plan = UnmaterializedPlan :: new ( & sliced) ? . materialize ( & cuda_ctx) ?;
1318+ let plan = dispatch_plan ( & sliced, & cuda_ctx) ?;
13051319
13061320 let actual = run_dynamic_dispatch_plan (
13071321 & cuda_ctx,
@@ -1333,7 +1347,7 @@ mod tests {
13331347 let seq = SequenceArray :: try_new_typed ( base, multiplier, Nullability :: NonNullable , len) ?;
13341348
13351349 let cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) ) ?;
1336- let plan = UnmaterializedPlan :: new ( & seq. into_array ( ) ) ? . materialize ( & cuda_ctx) ?;
1350+ let plan = dispatch_plan ( & seq. into_array ( ) , & cuda_ctx) ?;
13371351
13381352 let actual = run_dynamic_dispatch_plan (
13391353 & cuda_ctx,
@@ -1366,7 +1380,7 @@ mod tests {
13661380 let seq = SequenceArray :: try_new_typed ( base, multiplier, Nullability :: NonNullable , len) ?;
13671381
13681382 let cuda_ctx = CudaSession :: create_execution_ctx ( & VortexSession :: empty ( ) ) ?;
1369- let plan = UnmaterializedPlan :: new ( & seq. into_array ( ) ) ? . materialize ( & cuda_ctx) ?;
1383+ let plan = dispatch_plan ( & seq. into_array ( ) , & cuda_ctx) ?;
13701384
13711385 let actual_u32 = run_dynamic_dispatch_plan (
13721386 & cuda_ctx,
0 commit comments