Skip to content

Commit 9f7c099

Browse files
committed
refactor(cuda): single-pass DispatchPlan builder
Refactor the dynamic dispatch plan builder to walk the encoding tree exactly once, discovering unfusable subtrees and computing shared memory requirements in the same pass. The result is a 3-variant enum (`Fused`, `PartiallyFused`, `Unfusable`) that replaces the previous `Result<Option<>>` API and eliminates the separate `find_unfusable_nodes` traversal. Shared memory is now validated upfront in `DispatchPlan::new` — before any subtree kernels are executed — so we never pay GPU cost for a plan that will not fit. The plan stages are split into `smem_stages` (fully decoded into persistent shared memory) and `output_stage` (tiled through a scratch region), making the two-phase kernel execution model explicit in the host-side data structures. Shared memory allocation invariants are documented on `FusedPlan`. Signed-off-by: Alexander Droste <alexander.droste@protonmail.com>
1 parent 5369859 commit 9f7c099

4 files changed

Lines changed: 353 additions & 255 deletions

File tree

vortex-cuda/benches/dynamic_dispatch_cuda.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ use vortex_cuda::CudaDeviceBuffer;
4040
use vortex_cuda::CudaExecutionCtx;
4141
use vortex_cuda::CudaSession;
4242
use vortex_cuda::dynamic_dispatch::CudaDispatchPlan;
43+
use vortex_cuda::dynamic_dispatch::DispatchPlan;
4344
use vortex_cuda::dynamic_dispatch::MaterializedPlan;
44-
use vortex_cuda::dynamic_dispatch::UnmaterializedPlan;
4545
use vortex_cuda_macros::cuda_available;
4646
use vortex_cuda_macros::cuda_not_available;
4747

@@ -123,13 +123,15 @@ struct BenchRunner {
123123

124124
impl BenchRunner {
125125
fn new(array: &vortex::array::ArrayRef, len: usize, cuda_ctx: &CudaExecutionCtx) -> Self {
126+
let plan = match DispatchPlan::new(array).vortex_expect("build_dyn_dispatch_plan") {
127+
DispatchPlan::Fused(plan) => plan,
128+
_ => panic!("encoding not fusable"),
129+
};
126130
let MaterializedPlan {
127131
dispatch_plan,
128132
device_buffers,
129133
shared_mem_bytes,
130-
} = UnmaterializedPlan::new(array)
131-
.and_then(|p| p.materialize(cuda_ctx))
132-
.vortex_expect("build_dyn_dispatch_plan");
134+
} = plan.materialize(cuda_ctx).vortex_expect("materialize plan");
133135

134136
let device_plan = Arc::new(
135137
cuda_ctx

vortex-cuda/src/dynamic_dispatch/mod.rs

Lines changed: 54 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,13 @@
33

44
//! Host interface for dynamic CUDA kernel dispatch.
55
//!
6-
//! An [`UnmaterializedPlan`] walks an encoding tree (e.g., `ALP(FoR(BitPacked))`)
7-
//! and flattens it into a linear sequence of stages. Call
8-
//! [`materialize`](UnmaterializedPlan::materialize) to copy source buffers to
9-
//! the device, producing a [`MaterializedPlan`] ready for kernel launch.
6+
//! [`UnmaterializedPlan::new`] walks an encoding tree (e.g., `ALP(FoR(BitPacked))`)
7+
//! in a single pass and returns one of three variants:
108
//!
11-
//! For partially-fusable trees, [`find_unfusable_nodes`] identifies nodes
12-
//! that need separate kernels, and [`UnmaterializedPlan::new_with_subtree_inputs`] builds a plan
13-
//! that incorporates their pre-executed arrays.
14-
//!
15-
//! Shared memory is dynamically sized at launch time via
16-
//! [`UnmaterializedPlan::shared_mem_bytes`].
9+
//! - [`Fused`](UnmaterializedPlan::Fused) — call [`FusedPlan::materialize`].
10+
//! - [`PartiallyFused`](UnmaterializedPlan::PartiallyFused) — execute pending
11+
//! subtrees, then call [`FusedPlan::materialize_with_subtrees`].
12+
//! - [`Unfusable`](UnmaterializedPlan::Unfusable) — fall back to single-kernel dispatch.
1713
1814
#![allow(non_upper_case_globals)]
1915
#![allow(non_camel_case_types)]
@@ -47,9 +43,9 @@ use crate::CudaDeviceBuffer;
4743
use crate::executor::CudaExecutionCtx;
4844

4945
pub(crate) mod plan_builder;
46+
pub use plan_builder::DispatchPlan;
47+
pub use plan_builder::FusedPlan;
5048
pub use plan_builder::MaterializedPlan;
51-
pub use plan_builder::UnmaterializedPlan;
52-
pub use plan_builder::find_unfusable_nodes;
5349

5450
include!(concat!(env!("OUT_DIR"), "/dynamic_dispatch.rs"));
5551

@@ -422,6 +418,7 @@ impl MaterializedPlan {
422418

423419
#[cfg(test)]
424420
mod tests {
421+
use super::*;
425422
use std::sync::Arc;
426423

427424
use cudarc::driver::DevicePtr;
@@ -446,20 +443,21 @@ mod tests {
446443
use vortex::encodings::zigzag::ZigZagArray;
447444
use vortex::error::VortexExpect;
448445
use vortex::error::VortexResult;
446+
449447
use vortex::session::VortexSession;
450448

451449
use super::CudaDispatchPlan;
450+
use super::DispatchPlan;
452451
use super::SMEM_TILE_SIZE;
453452
use super::ScalarOp;
454453
use super::SourceOp;
455454
use super::Stage;
456-
use super::UnmaterializedPlan;
457455
use crate::CudaBufferExt;
458456
use crate::CudaDeviceBuffer;
459457
use crate::CudaExecutionCtx;
460458
use crate::session::CudaSession;
461459

462-
fn make_bitpacked_array_u32(bit_width: u8, len: usize) -> BitPackedArray {
460+
fn bitpacked_array_u32(bit_width: u8, len: usize) -> BitPackedArray {
463461
let max_val = (1u64 << bit_width).saturating_sub(1);
464462
let values: Vec<u32> = (0..len)
465463
.map(|i| ((i as u64) % (max_val + 1)) as u32)
@@ -469,6 +467,16 @@ mod tests {
469467
.vortex_expect("failed to create BitPacked array")
470468
}
471469

470+
fn dispatch_plan(
471+
array: &vortex::array::ArrayRef,
472+
ctx: &CudaExecutionCtx,
473+
) -> VortexResult<MaterializedPlan> {
474+
match DispatchPlan::new(array)? {
475+
DispatchPlan::Fused(plan) => plan.materialize(ctx),
476+
_ => vortex_bail!("array encoding not fusable"),
477+
}
478+
}
479+
472480
#[crate::test]
473481
fn test_max_scalar_ops() -> VortexResult<()> {
474482
let bit_width: u8 = 6;
@@ -481,7 +489,7 @@ mod tests {
481489
.map(|i| ((i as u64) % (max_val + 1)) as u32 + total_reference)
482490
.collect();
483491

484-
let bitpacked = make_bitpacked_array_u32(bit_width, len);
492+
let bitpacked = bitpacked_array_u32(bit_width, len);
485493
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
486494
let packed = bitpacked.packed().clone();
487495
let device_input = futures::executor::block_on(cuda_ctx.ensure_on_device(packed))?;
@@ -669,9 +677,9 @@ mod tests {
669677
.map(|i| ((i as u64) % (max_val + 1)) as u32)
670678
.collect();
671679

672-
let bp = make_bitpacked_array_u32(bit_width, len);
680+
let bp = bitpacked_array_u32(bit_width, len);
673681
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
674-
let plan = UnmaterializedPlan::new(&bp.into_array())?.materialize(&cuda_ctx)?;
682+
let plan = dispatch_plan(&bp.into_array(), &cuda_ctx)?;
675683

676684
let actual =
677685
run_dynamic_dispatch_plan(&cuda_ctx, len, &plan.dispatch_plan, plan.shared_mem_bytes)?;
@@ -692,11 +700,11 @@ mod tests {
692700
.collect();
693701
let expected: Vec<u32> = raw.iter().map(|&v| v + reference).collect();
694702

695-
let bp = make_bitpacked_array_u32(bit_width, len);
703+
let bp = bitpacked_array_u32(bit_width, len);
696704
let for_arr = FoRArray::try_new(bp.into_array(), Scalar::from(reference))?;
697705

698706
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
699-
let plan = UnmaterializedPlan::new(&for_arr.into_array())?.materialize(&cuda_ctx)?;
707+
let plan = dispatch_plan(&for_arr.into_array(), &cuda_ctx)?;
700708

701709
let actual =
702710
run_dynamic_dispatch_plan(&cuda_ctx, len, &plan.dispatch_plan, plan.shared_mem_bytes)?;
@@ -722,7 +730,7 @@ mod tests {
722730
let re = RunEndArray::new(ends_arr, values_arr);
723731

724732
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
725-
let plan = UnmaterializedPlan::new(&re.into_array())?.materialize(&cuda_ctx)?;
733+
let plan = dispatch_plan(&re.into_array(), &cuda_ctx)?;
726734

727735
let actual =
728736
run_dynamic_dispatch_plan(&cuda_ctx, len, &plan.dispatch_plan, plan.shared_mem_bytes)?;
@@ -755,7 +763,7 @@ mod tests {
755763
let dict = DictArray::try_new(codes_bp.into_array(), dict_for.into_array())?;
756764

757765
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
758-
let plan = UnmaterializedPlan::new(&dict.into_array())?.materialize(&cuda_ctx)?;
766+
let plan = dispatch_plan(&dict.into_array(), &cuda_ctx)?;
759767

760768
let actual =
761769
run_dynamic_dispatch_plan(&cuda_ctx, len, &plan.dispatch_plan, plan.shared_mem_bytes)?;
@@ -787,7 +795,7 @@ mod tests {
787795
);
788796

789797
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
790-
let plan = UnmaterializedPlan::new(&tree.into_array())?.materialize(&cuda_ctx)?;
798+
let plan = dispatch_plan(&tree.into_array(), &cuda_ctx)?;
791799

792800
let actual =
793801
run_dispatch_plan_f32(&cuda_ctx, len, &plan.dispatch_plan, plan.shared_mem_bytes)?;
@@ -816,7 +824,7 @@ mod tests {
816824
let zz = ZigZagArray::try_new(bp.into_array())?;
817825

818826
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
819-
let plan = UnmaterializedPlan::new(&zz.into_array())?.materialize(&cuda_ctx)?;
827+
let plan = dispatch_plan(&zz.into_array(), &cuda_ctx)?;
820828

821829
let actual =
822830
run_dynamic_dispatch_plan(&cuda_ctx, len, &plan.dispatch_plan, plan.shared_mem_bytes)?;
@@ -845,7 +853,7 @@ mod tests {
845853
let for_arr = FoRArray::try_new(re.into_array(), Scalar::from(reference))?;
846854

847855
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
848-
let plan = UnmaterializedPlan::new(&for_arr.into_array())?.materialize(&cuda_ctx)?;
856+
let plan = dispatch_plan(&for_arr.into_array(), &cuda_ctx)?;
849857

850858
let actual =
851859
run_dynamic_dispatch_plan(&cuda_ctx, len, &plan.dispatch_plan, plan.shared_mem_bytes)?;
@@ -874,7 +882,7 @@ mod tests {
874882
let for_arr = FoRArray::try_new(dict.into_array(), Scalar::from(reference))?;
875883

876884
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
877-
let plan = UnmaterializedPlan::new(&for_arr.into_array())?.materialize(&cuda_ctx)?;
885+
let plan = dispatch_plan(&for_arr.into_array(), &cuda_ctx)?;
878886

879887
let actual =
880888
run_dynamic_dispatch_plan(&cuda_ctx, len, &plan.dispatch_plan, plan.shared_mem_bytes)?;
@@ -902,7 +910,7 @@ mod tests {
902910
let dict = DictArray::try_new(codes_for.into_array(), values_prim.into_array())?;
903911

904912
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
905-
let plan = UnmaterializedPlan::new(&dict.into_array())?.materialize(&cuda_ctx)?;
913+
let plan = dispatch_plan(&dict.into_array(), &cuda_ctx)?;
906914

907915
let actual =
908916
run_dynamic_dispatch_plan(&cuda_ctx, len, &plan.dispatch_plan, plan.shared_mem_bytes)?;
@@ -927,7 +935,7 @@ mod tests {
927935
let dict = DictArray::try_new(codes_bp.into_array(), values_prim.into_array())?;
928936

929937
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
930-
let plan = UnmaterializedPlan::new(&dict.into_array())?.materialize(&cuda_ctx)?;
938+
let plan = dispatch_plan(&dict.into_array(), &cuda_ctx)?;
931939

932940
let actual =
933941
run_dynamic_dispatch_plan(&cuda_ctx, len, &plan.dispatch_plan, plan.shared_mem_bytes)?;
@@ -946,8 +954,11 @@ mod tests {
946954
let values_prim = PrimitiveArray::new(Buffer::from(dict_values), NonNullable);
947955
let dict = DictArray::try_new(codes_prim.into_array(), values_prim.into_array())?;
948956

949-
// UnmaterializedPlan::new should fail because u8 codes != u32 values in byte width.
950-
assert!(UnmaterializedPlan::new(&dict.into_array()).is_err());
957+
// UnmaterializedPlan::new should return Unfusable because u8 codes != u32 values in byte width.
958+
assert!(matches!(
959+
DispatchPlan::new(&dict.into_array())?,
960+
DispatchPlan::Unfused
961+
));
951962

952963
Ok(())
953964
}
@@ -961,8 +972,11 @@ mod tests {
961972
let values_arr = PrimitiveArray::new(Buffer::from(values), NonNullable).into_array();
962973
let re = RunEndArray::new(ends_arr, values_arr);
963974

964-
// UnmaterializedPlan::new should fail because u64 ends != i32 values in byte width.
965-
assert!(UnmaterializedPlan::new(&re.into_array()).is_err());
975+
// UnmaterializedPlan::new should return Unfusable because u64 ends != i32 values in byte width.
976+
assert!(matches!(
977+
DispatchPlan::new(&re.into_array())?,
978+
DispatchPlan::Unfused
979+
));
966980

967981
Ok(())
968982
}
@@ -997,7 +1011,7 @@ mod tests {
9971011
let expected: Vec<u32> = data[slice_start..slice_end].to_vec();
9981012

9991013
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
1000-
let plan = UnmaterializedPlan::new(&sliced)?.materialize(&cuda_ctx)?;
1014+
let plan = dispatch_plan(&sliced, &cuda_ctx)?;
10011015

10021016
let actual = run_dynamic_dispatch_plan(
10031017
&cuda_ctx,
@@ -1048,7 +1062,7 @@ mod tests {
10481062
let expected: Vec<u32> = all_decoded[slice_start..slice_end].to_vec();
10491063

10501064
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
1051-
let plan = UnmaterializedPlan::new(&sliced)?.materialize(&cuda_ctx)?;
1065+
let plan = dispatch_plan(&sliced, &cuda_ctx)?;
10521066

10531067
let actual = run_dynamic_dispatch_plan(
10541068
&cuda_ctx,
@@ -1098,7 +1112,7 @@ mod tests {
10981112
.collect();
10991113

11001114
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
1101-
let plan = UnmaterializedPlan::new(&sliced)?.materialize(&cuda_ctx)?;
1115+
let plan = dispatch_plan(&sliced, &cuda_ctx)?;
11021116

11031117
let actual = run_dynamic_dispatch_plan(
11041118
&cuda_ctx,
@@ -1143,7 +1157,7 @@ mod tests {
11431157
let expected: Vec<u32> = data[slice_start..slice_end].to_vec();
11441158

11451159
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
1146-
let plan = UnmaterializedPlan::new(&sliced)?.materialize(&cuda_ctx)?;
1160+
let plan = dispatch_plan(&sliced, &cuda_ctx)?;
11471161

11481162
let actual = run_dynamic_dispatch_plan(
11491163
&cuda_ctx,
@@ -1192,7 +1206,7 @@ mod tests {
11921206
let expected: Vec<u32> = all_decoded[slice_start..slice_end].to_vec();
11931207

11941208
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
1195-
let plan = UnmaterializedPlan::new(&sliced)?.materialize(&cuda_ctx)?;
1209+
let plan = dispatch_plan(&sliced, &cuda_ctx)?;
11961210

11971211
let actual = run_dynamic_dispatch_plan(
11981212
&cuda_ctx,
@@ -1244,7 +1258,7 @@ mod tests {
12441258
let expected: Vec<u32> = all_decoded[slice_start..slice_end].to_vec();
12451259

12461260
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
1247-
let plan = UnmaterializedPlan::new(&sliced)?.materialize(&cuda_ctx)?;
1261+
let plan = dispatch_plan(&sliced, &cuda_ctx)?;
12481262

12491263
let actual = run_dynamic_dispatch_plan(
12501264
&cuda_ctx,
@@ -1301,7 +1315,7 @@ mod tests {
13011315
let expected: Vec<u32> = all_decoded[slice_start..slice_end].to_vec();
13021316

13031317
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
1304-
let plan = UnmaterializedPlan::new(&sliced)?.materialize(&cuda_ctx)?;
1318+
let plan = dispatch_plan(&sliced, &cuda_ctx)?;
13051319

13061320
let actual = run_dynamic_dispatch_plan(
13071321
&cuda_ctx,
@@ -1333,7 +1347,7 @@ mod tests {
13331347
let seq = SequenceArray::try_new_typed(base, multiplier, Nullability::NonNullable, len)?;
13341348

13351349
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
1336-
let plan = UnmaterializedPlan::new(&seq.into_array())?.materialize(&cuda_ctx)?;
1350+
let plan = dispatch_plan(&seq.into_array(), &cuda_ctx)?;
13371351

13381352
let actual = run_dynamic_dispatch_plan(
13391353
&cuda_ctx,
@@ -1366,7 +1380,7 @@ mod tests {
13661380
let seq = SequenceArray::try_new_typed(base, multiplier, Nullability::NonNullable, len)?;
13671381

13681382
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
1369-
let plan = UnmaterializedPlan::new(&seq.into_array())?.materialize(&cuda_ctx)?;
1383+
let plan = dispatch_plan(&seq.into_array(), &cuda_ctx)?;
13701384

13711385
let actual_u32 = run_dynamic_dispatch_plan(
13721386
&cuda_ctx,

0 commit comments

Comments
 (0)