Skip to content

Commit 0beb0fb

Browse files
committed
fix: dyn dispatch requires same bit width types
Signed-off-by: Alexander Droste <alexander.droste@protonmail.com>
1 parent ec2c602 commit 0beb0fb

2 files changed

Lines changed: 76 additions & 2 deletions

File tree

vortex-cuda/src/dynamic_dispatch/mod.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -803,6 +803,54 @@ mod tests {
803803
Ok(())
804804
}
805805

806+
/// Dict with mismatched code/value ptypes (u8 codes, u32 values) must be
807+
/// rejected by the plan builder so the single-kernel `DictExecutor` handles
808+
/// it instead. The dynamic dispatch kernel operates on a single type T, so
809+
/// loading narrower codes as wider elements would read out of bounds.
810+
#[crate::test]
811+
fn test_dict_mismatched_ptypes_rejected() -> VortexResult<()> {
812+
let dict_values: Vec<u32> = vec![100, 200, 300, 400];
813+
let len = 3000;
814+
let codes: Vec<u8> = (0..len).map(|i| (i % dict_values.len()) as u8).collect();
815+
816+
let codes_prim = PrimitiveArray::new(Buffer::from(codes), NonNullable);
817+
let values_prim = PrimitiveArray::new(Buffer::from(dict_values), NonNullable);
818+
let dict = DictArray::try_new(codes_prim.into_array(), values_prim.into_array())?;
819+
820+
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
821+
// build_plan should fail because u8 codes != u32 values in byte width.
822+
assert!(
823+
build_plan(&dict.into_array(), &cuda_ctx).is_err(),
824+
"Dict with mismatched code/value ptypes should not be fused"
825+
);
826+
827+
Ok(())
828+
}
829+
830+
/// RunEnd with mismatched ends/values ptypes (u64 ends, i32 values) must be
831+
/// rejected by the plan builder so the single-kernel `RunEndExecutor`
832+
/// handles it instead. The dynamic dispatch kernel loads both ends and
833+
/// values into shared memory as the single uniform type T; mismatched byte
834+
/// widths would misinterpret the raw data.
835+
#[crate::test]
836+
fn test_runend_mismatched_ptypes_rejected() -> VortexResult<()> {
837+
let ends: Vec<u64> = vec![1000, 2000, 3000];
838+
let values: Vec<i32> = vec![10, 20, 30];
839+
840+
let ends_arr = PrimitiveArray::new(Buffer::from(ends), NonNullable).into_array();
841+
let values_arr = PrimitiveArray::new(Buffer::from(values), NonNullable).into_array();
842+
let re = RunEndArray::new(ends_arr, values_arr);
843+
844+
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
845+
// build_plan should fail because u64 ends != i32 values in byte width.
846+
assert!(
847+
build_plan(&re.into_array(), &cuda_ctx).is_err(),
848+
"RunEnd with mismatched ends/values ptypes should not be fused"
849+
);
850+
851+
Ok(())
852+
}
853+
806854
#[rstest]
807855
#[case(0, 1024)]
808856
#[case(0, 3000)]

vortex-cuda/src/dynamic_dispatch/plan_builder.rs

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,10 +178,36 @@ fn is_dyn_dispatch_compatible(array: &ArrayRef) -> bool {
178178
}
179179
return false;
180180
}
181+
if id == Dict::ID {
182+
if let Ok(a) = array.clone().try_into::<Dict>() {
183+
// As of now the dict dyn dispatch kernel requires
184+
// codes and values to have the same byte width.
185+
return match (
186+
PType::try_from(a.values().dtype()),
187+
PType::try_from(a.codes().dtype()),
188+
) {
189+
(Ok(values), Ok(codes)) => values.byte_width() == codes.byte_width(),
190+
_ => false,
191+
};
192+
}
193+
return false;
194+
}
195+
if id == RunEnd::ID {
196+
if let Ok(a) = array.clone().try_into::<RunEnd>() {
197+
// As of now the run-end dyn dispatch kernel requires
198+
// ends and values to have the same byte width.
199+
return match (
200+
PType::try_from(a.ends().dtype()),
201+
PType::try_from(a.values().dtype()),
202+
) {
203+
(Ok(e), Ok(v)) => e.byte_width() == v.byte_width(),
204+
_ => false,
205+
};
206+
}
207+
return false;
208+
}
181209
id == FoR::ID
182210
|| id == ZigZag::ID
183-
|| id == Dict::ID
184-
|| id == RunEnd::ID
185211
|| id == Primitive::ID
186212
|| id == Slice::ID
187213
|| id == Sequence::ID

0 commit comments

Comments
 (0)