Skip to content

Commit 7852ece

Browse files
authored
fix: dyn dispatch requires same bit width types (#7164)
Signed-off-by: Alexander Droste <alexander.droste@protonmail.com>
1 parent b66cfcb commit 7852ece

2 files changed

Lines changed: 68 additions & 2 deletions

File tree

vortex-cuda/src/dynamic_dispatch/mod.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -803,6 +803,39 @@ mod tests {
803803
Ok(())
804804
}
805805

806+
#[crate::test]
807+
fn test_dict_mismatched_ptypes_rejected() -> VortexResult<()> {
808+
let dict_values: Vec<u32> = vec![100, 200, 300, 400];
809+
let len = 3000;
810+
let codes: Vec<u8> = (0..len).map(|i| (i % dict_values.len()) as u8).collect();
811+
812+
let codes_prim = PrimitiveArray::new(Buffer::from(codes), NonNullable);
813+
let values_prim = PrimitiveArray::new(Buffer::from(dict_values), NonNullable);
814+
let dict = DictArray::try_new(codes_prim.into_array(), values_prim.into_array())?;
815+
816+
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
817+
// build_plan should fail because u8 codes != u32 values in byte width.
818+
assert!(build_plan(&dict.into_array(), &cuda_ctx).is_err());
819+
820+
Ok(())
821+
}
822+
823+
#[crate::test]
824+
fn test_runend_mismatched_ptypes_rejected() -> VortexResult<()> {
825+
let ends: Vec<u64> = vec![1000, 2000, 3000];
826+
let values: Vec<i32> = vec![10, 20, 30];
827+
828+
let ends_arr = PrimitiveArray::new(Buffer::from(ends), NonNullable).into_array();
829+
let values_arr = PrimitiveArray::new(Buffer::from(values), NonNullable).into_array();
830+
let re = RunEndArray::new(ends_arr, values_arr);
831+
832+
let cuda_ctx = CudaSession::create_execution_ctx(&VortexSession::empty())?;
833+
// build_plan should fail because u64 ends != i32 values in byte width.
834+
assert!(build_plan(&re.into_array(), &cuda_ctx).is_err());
835+
836+
Ok(())
837+
}
838+
806839
#[rstest]
807840
#[case(0, 1024)]
808841
#[case(0, 3000)]

vortex-cuda/src/dynamic_dispatch/plan_builder.rs

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,10 +178,36 @@ fn is_dyn_dispatch_compatible(array: &ArrayRef) -> bool {
178178
}
179179
return false;
180180
}
181+
if id == Dict::ID {
182+
if let Ok(a) = array.clone().try_into::<Dict>() {
183+
// As of now the dict dyn dispatch kernel requires
184+
// codes and values to have the same byte width.
185+
return match (
186+
PType::try_from(a.values().dtype()),
187+
PType::try_from(a.codes().dtype()),
188+
) {
189+
(Ok(values), Ok(codes)) => values.byte_width() == codes.byte_width(),
190+
_ => false,
191+
};
192+
}
193+
return false;
194+
}
195+
if id == RunEnd::ID {
196+
if let Ok(a) = array.clone().try_into::<RunEnd>() {
197+
// As of now the run-end dyn dispatch kernel requires
198+
// ends and values to have the same byte width.
199+
return match (
200+
PType::try_from(a.ends().dtype()),
201+
PType::try_from(a.values().dtype()),
202+
) {
203+
(Ok(e), Ok(v)) => e.byte_width() == v.byte_width(),
204+
_ => false,
205+
};
206+
}
207+
return false;
208+
}
181209
id == FoR::ID
182210
|| id == ZigZag::ID
183-
|| id == Dict::ID
184-
|| id == RunEnd::ID
185211
|| id == Primitive::ID
186212
|| id == Slice::ID
187213
|| id == Sequence::ID
@@ -264,6 +290,13 @@ impl PlanBuilderState<'_> {
264290
return Ok(pipeline);
265291
}
266292

293+
if !is_dyn_dispatch_compatible(&array) {
294+
vortex_bail!(
295+
"Encoding {:?} is not compatible with the dynamic dispatch plan builder",
296+
array.encoding_id()
297+
);
298+
}
299+
267300
let id = array.encoding_id();
268301

269302
if id == BitPacked::ID {

0 commit comments

Comments
 (0)