Skip to content

Commit 7f541db

Browse files
authored
feat(cuda): hybrid GPU dispatch - fuse dyn + standalone kernels (#7127)
Add a hybrid_dispatch module that integrates subtrees with separate kernel dispatch with dynamic dispatch kernels. Subtrees with unsupported encodings (e.g. Zstd) are executed separately and their device buffers are fed back as `LOAD` ops in the fused plan. Note that this implicitly enables filtering via the CUDA CUB filter implementation. Signed-off-by: Alexander Droste <alexander.droste@protonmail.com>
1 parent 1c8667c commit 7f541db

6 files changed

Lines changed: 596 additions & 1 deletion

File tree

vortex-cuda/benches/dynamic_dispatch_cuda.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ const BENCH_ARGS: &[(usize, &str)] = &[
5050
];
5151

5252
/// Launch the dynamic_dispatch kernel and return GPU-timed duration.
53+
///
54+
/// This deliberately does not use `DynamicDispatchPlan::execute` because the
55+
/// benchmark pre-allocates the output buffer and device plan once, then reuses
56+
/// them across iterations.
5357
fn run_timed(
5458
cuda_ctx: &mut CudaExecutionCtx,
5559
array_len: usize,

vortex-cuda/src/dynamic_dispatch/mod.rs

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,27 @@
1919
#![allow(non_snake_case)]
2020
#![allow(clippy::cast_possible_truncation)]
2121

22-
mod plan_builder;
22+
use std::sync::Arc;
23+
24+
use cudarc::driver::DevicePtr;
25+
use cudarc::driver::LaunchConfig;
26+
use cudarc::driver::PushKernelArg;
27+
use vortex::array::Canonical;
28+
use vortex::array::arrays::PrimitiveArray;
29+
use vortex::array::buffer::BufferHandle;
30+
use vortex::array::buffer::DeviceBufferExt;
31+
use vortex::array::match_each_unsigned_integer_ptype;
32+
use vortex::array::validity::Validity;
33+
use vortex::dtype::Nullability;
34+
use vortex::dtype::PType;
35+
use vortex::error::VortexResult;
36+
use vortex::error::vortex_bail;
37+
use vortex::error::vortex_err;
38+
39+
use crate::CudaDeviceBuffer;
40+
use crate::executor::CudaExecutionCtx;
41+
42+
pub(crate) mod plan_builder;
2343
pub use plan_builder::build_plan;
2444

2545
include!(concat!(env!("OUT_DIR"), "/dynamic_dispatch.rs"));
@@ -201,6 +221,85 @@ impl DynamicDispatchPlan {
201221
}
202222
max_end * elem_size
203223
}
224+
225+
/// Allocate output, upload the plan to the device, and launch the
226+
/// `dynamic_dispatch` kernel.
227+
///
228+
/// The CUDA kernels are instantiated for unsigned types only.
229+
/// Encoding transforms (FoR, ZigZag, ALP) are bit-identical
230+
/// regardless of signedness.
231+
///
232+
/// `CudaSlice::drop` enqueues `free` on the stream after kernel execution.
233+
pub fn execute(
234+
self,
235+
output_ptype: PType,
236+
len: usize,
237+
device_buffers: Vec<BufferHandle>,
238+
ctx: &mut CudaExecutionCtx,
239+
) -> VortexResult<Canonical> {
240+
let unsigned_ptype = match output_ptype {
241+
PType::U8 | PType::I8 => PType::U8,
242+
PType::U16 | PType::I16 => PType::U16,
243+
PType::U32 | PType::I32 | PType::F32 => PType::U32,
244+
PType::U64 | PType::I64 => PType::U64,
245+
other => vortex_bail!("dynamic dispatch does not support PType {:?}", other),
246+
};
247+
match_each_unsigned_integer_ptype!(unsigned_ptype, |T| {
248+
self.execute_typed::<T>(output_ptype, len, device_buffers, ctx)
249+
})
250+
}
251+
252+
fn execute_typed<T>(
253+
self,
254+
output_ptype: PType,
255+
len: usize,
256+
device_buffers: Vec<BufferHandle>,
257+
ctx: &mut CudaExecutionCtx,
258+
) -> VortexResult<Canonical>
259+
where
260+
T: cudarc::driver::DeviceRepr + vortex::dtype::NativePType,
261+
{
262+
if len == 0 {
263+
return Ok(Canonical::Primitive(PrimitiveArray::empty::<T>(
264+
Nullability::NonNullable,
265+
)));
266+
}
267+
268+
let output_buf = CudaDeviceBuffer::new(ctx.device_alloc::<T>(len.next_multiple_of(1024))?);
269+
let device_plan = Arc::new(
270+
ctx.stream()
271+
.clone_htod(std::slice::from_ref(&self))
272+
.map_err(|e| vortex_err!("copy plan to device: {e}"))?,
273+
);
274+
275+
let shared_mem_bytes = self.shared_mem_bytes::<T>();
276+
let cuda_function = ctx.load_function("dynamic_dispatch", &[T::PTYPE])?;
277+
let num_blocks = u32::try_from(len.div_ceil(2048))?;
278+
let config = LaunchConfig {
279+
grid_dim: (num_blocks, 1, 1),
280+
block_dim: (64, 1, 1),
281+
shared_mem_bytes,
282+
};
283+
284+
let output_ptr = output_buf.offset_ptr();
285+
let plan_ptr = device_plan.device_ptr(ctx.stream()).0;
286+
let array_len_u64 = len as u64;
287+
288+
ctx.launch_kernel_config(&cuda_function, config, len, |args| {
289+
args.arg(&output_ptr);
290+
args.arg(&array_len_u64);
291+
args.arg(&plan_ptr);
292+
})?;
293+
294+
drop(device_buffers);
295+
drop(device_plan);
296+
297+
Ok(Canonical::Primitive(PrimitiveArray::from_buffer_handle(
298+
BufferHandle::new_device(output_buf.slice_typed::<T>(0..len)),
299+
output_ptype,
300+
Validity::NonNullable,
301+
)))
302+
}
204303
}
205304

206305
#[cfg(test)]

vortex-cuda/src/dynamic_dispatch/plan_builder.rs

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
//! to the device, computes shared memory offsets, and produces a plan that the
88
//! dynamic dispatch kernel can execute in a single launch.
99
10+
use std::sync::Arc;
11+
1012
use futures::executor::block_on;
1113
use vortex::array::ArrayRef;
1214
use vortex::array::DynArray;
@@ -102,11 +104,30 @@ pub fn build_plan(
102104
array: &ArrayRef,
103105
ctx: &CudaExecutionCtx,
104106
) -> VortexResult<(DynamicDispatchPlan, Vec<BufferHandle>)> {
107+
build_plan_with_subtrees(array, ctx, &[])
108+
}
109+
110+
/// Build a [`DynamicDispatchPlan`] with subtrees run as separate
111+
/// kernels that provide device buffers as inputs integrated via `LOAD`.
112+
pub fn build_plan_with_subtrees(
113+
array: &ArrayRef,
114+
ctx: &CudaExecutionCtx,
115+
subtree_inputs: &[(ArrayRef, BufferHandle)],
116+
) -> VortexResult<(DynamicDispatchPlan, Vec<BufferHandle>)> {
117+
let sub_map = subtree_inputs
118+
.iter()
119+
.map(|(arr, handle)| {
120+
let ptr = handle.cuda_device_ptr()?;
121+
Ok((Arc::as_ptr(arr) as *const () as usize, ptr))
122+
})
123+
.collect::<VortexResult<Vec<_>>>()?;
124+
105125
let mut state = PlanBuilderState {
106126
ctx,
107127
stages: Vec::new(),
108128
smem_cursor: 0,
109129
device_buffers: Vec::new(),
130+
subtree_inputs: sub_map,
110131
};
111132

112133
let pipeline = state.walk(array.clone())?;
@@ -129,6 +150,88 @@ pub fn build_plan(
129150
Ok((DynamicDispatchPlan::new(state.stages), state.device_buffers))
130151
}
131152

153+
/// Walk the encoding tree and find subtrees that cannot be fused into a
154+
/// dynamic-dispatch plan. The root of each subtree has a node that cannot
155+
/// be fused.
156+
///
157+
/// Returns an empty vec if the root itself cannot be fused.
158+
pub fn find_subtrees(array: &ArrayRef) -> Vec<ArrayRef> {
159+
if !is_dyn_dispatch_compatible(array) {
160+
return Vec::new();
161+
}
162+
let mut out = Vec::new();
163+
collect_subtrees(array, &mut out);
164+
out
165+
}
166+
167+
/// Checks whether the encoding of an array can be fused into a dynamic-dispatch plan.
168+
fn is_dyn_dispatch_compatible(array: &ArrayRef) -> bool {
169+
let id = array.encoding_id();
170+
if id == ALP::ID {
171+
if let Ok(a) = array.clone().try_into::<ALP>() {
172+
return a.patches().is_none() && a.dtype().as_ptype() == PType::F32;
173+
}
174+
return false;
175+
}
176+
if id == BitPacked::ID {
177+
if let Ok(a) = array.clone().try_into::<BitPacked>() {
178+
return a.patches().is_none();
179+
}
180+
return false;
181+
}
182+
id == FoR::ID
183+
|| id == ZigZag::ID
184+
|| id == Dict::ID
185+
|| id == RunEnd::ID
186+
|| id == Primitive::ID
187+
|| id == Slice::ID
188+
|| id == Sequence::ID
189+
}
190+
191+
/// Walk the children of a dynamic dispatch compatible root node. Any child
192+
/// that is not dyn dispatch compatible is recorded as a subtree that must be
193+
/// executed separately.
194+
fn collect_subtrees(array: &ArrayRef, out: &mut Vec<ArrayRef>) {
195+
let id = array.encoding_id();
196+
197+
fn visit_child(child: &ArrayRef, out: &mut Vec<ArrayRef>) {
198+
if is_dyn_dispatch_compatible(child) {
199+
collect_subtrees(child, out);
200+
} else {
201+
out.push(child.clone());
202+
}
203+
}
204+
205+
if id == FoR::ID {
206+
if let Ok(a) = array.clone().try_into::<FoR>() {
207+
visit_child(a.encoded(), out);
208+
}
209+
} else if id == ZigZag::ID {
210+
if let Ok(a) = array.clone().try_into::<ZigZag>() {
211+
visit_child(a.encoded(), out);
212+
}
213+
} else if id == ALP::ID {
214+
if let Ok(a) = array.clone().try_into::<ALP>() {
215+
visit_child(a.encoded(), out);
216+
}
217+
} else if id == Slice::ID {
218+
if let Some(a) = array.as_opt::<Slice>() {
219+
visit_child(a.child(), out);
220+
}
221+
} else if id == Dict::ID
222+
&& let Ok(a) = array.clone().try_into::<Dict>()
223+
{
224+
visit_child(a.values(), out);
225+
visit_child(a.codes(), out);
226+
} else if id == RunEnd::ID
227+
&& let Ok(a) = array.clone().try_into::<RunEnd>()
228+
{
229+
visit_child(a.ends(), out);
230+
visit_child(a.values(), out);
231+
}
232+
// BitPacked, Primitive, Sequence — leaves, no children.
233+
}
234+
132235
/// Internal mutable state for the recursive tree walk.
133236
struct PlanBuilderState<'a> {
134237
ctx: &'a CudaExecutionCtx,
@@ -138,11 +241,30 @@ struct PlanBuilderState<'a> {
138241
smem_cursor: u32,
139242
/// Device buffers to keep alive.
140243
device_buffers: Vec<BufferHandle>,
244+
/// Pre-executed subtree outputs injected as `LOAD` sources: `(identity, device_ptr)`.
245+
subtree_inputs: Vec<(usize, u64)>,
141246
}
142247

143248
impl PlanBuilderState<'_> {
249+
/// If `array` matches a pre-executed subtree input, return a `LOAD` pipeline pointing at its device buffer.
250+
fn find_subtree(&self, array: &ArrayRef) -> Option<Pipeline> {
251+
let subtree_id = Arc::as_ptr(array) as *const () as usize;
252+
self.subtree_inputs
253+
.iter()
254+
.find(|(id, _)| *id == subtree_id)
255+
.map(|(_, ptr)| Pipeline {
256+
source: SourceOp::load(),
257+
scalar_ops: vec![],
258+
input_ptr: *ptr,
259+
})
260+
}
261+
144262
/// Recursively walk the encoding tree.
145263
fn walk(&mut self, array: ArrayRef) -> VortexResult<Pipeline> {
264+
if let Some(pipeline) = self.find_subtree(&array) {
265+
return Ok(pipeline);
266+
}
267+
146268
let id = array.encoding_id();
147269

148270
if id == BitPacked::ID {

vortex-cuda/src/executor.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ use vortex::error::vortex_err;
3030

3131
use crate::CudaSession;
3232
use crate::ExportDeviceArray;
33+
use crate::hybrid_dispatch;
3334
use crate::kernel::DefaultLaunchStrategy;
3435
use crate::kernel::LaunchStrategy;
3536
use crate::kernel::LaunchStrategyExt;
@@ -265,6 +266,11 @@ impl CudaExecutionCtx {
265266
self.ctx.session()
266267
}
267268

269+
/// Returns a reference to the CUDA session.
270+
pub(crate) fn cuda_session(&self) -> &CudaSession {
271+
&self.cuda_session
272+
}
273+
268274
/// Get a handle to the exporter that can convert arrays into `ArrowDeviceArray`.
269275
pub fn exporter(&self) -> &Arc<dyn ExportDeviceArray> {
270276
self.cuda_session.export_device_array()
@@ -364,6 +370,19 @@ impl CudaArrayExt for ArrayRef {
364370
return self.execute(&mut ctx.ctx);
365371
}
366372

373+
// Try to fuse the encoding tree (or parts of it) into dynamic-dispatch
374+
// kernel launches. See hybrid_dispatch module docs for details.
375+
match hybrid_dispatch::try_dyn_dispatch(&self, ctx).await {
376+
Ok(canonical) => return Ok(canonical),
377+
Err(e) => {
378+
trace!(
379+
encoding = %self.encoding_id(),
380+
error = %e,
381+
"Hybrid dispatch not applicable, trying registered single kernel"
382+
);
383+
}
384+
}
385+
367386
let Some(support) = ctx.cuda_session.kernel(&self.encoding_id()) else {
368387
debug!(
369388
encoding = %self.encoding_id(),

0 commit comments

Comments
 (0)