Skip to content

Commit 173f664

Browse files
authored
bool -> Primitive casts (#8621)
Support of bool -> Primitive casts. This is a subtask for pushing expressions of form "WHERE prefix(...) > 0" to Duckdb. Signed-off-by: Mikhail Kot <mikhail@spiraldb.com>
1 parent 00c39e8 commit 173f664

1 file changed

Lines changed: 47 additions & 6 deletions

File tree

  • vortex-array/src/arrays/bool/compute

vortex-array/src/arrays/bool/compute/cast.rs

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
use num_traits::One;
5+
use num_traits::Zero;
6+
use vortex_buffer::BufferMut;
47
use vortex_error::VortexResult;
58

69
use crate::ArrayRef;
@@ -9,8 +12,10 @@ use crate::IntoArray;
912
use crate::array::ArrayView;
1013
use crate::arrays::Bool;
1114
use crate::arrays::BoolArray;
15+
use crate::arrays::PrimitiveArray;
1216
use crate::arrays::bool::BoolArrayExt;
1317
use crate::dtype::DType;
18+
use crate::match_each_native_ptype;
1419
use crate::scalar_fn::fns::cast::CastKernel;
1520
use crate::scalar_fn::fns::cast::CastReduce;
1621

@@ -38,17 +43,34 @@ impl CastKernel for Bool {
3843
dtype: &DType,
3944
ctx: &mut ExecutionCtx,
4045
) -> VortexResult<Option<ArrayRef>> {
41-
if !dtype.is_boolean() {
42-
return Ok(None);
46+
if dtype.is_boolean() {
47+
let new_validity =
48+
array
49+
.validity()?
50+
.cast_nullability(dtype.nullability(), array.len(), ctx)?;
51+
return Ok(Some(
52+
BoolArray::new(array.to_bit_buffer(), new_validity).into_array(),
53+
));
4354
}
4455

56+
let DType::Primitive(new_ptype, new_nullability) = dtype else {
57+
return Ok(None);
58+
};
59+
4560
let new_validity =
4661
array
4762
.validity()?
48-
.cast_nullability(dtype.nullability(), array.len(), ctx)?;
49-
Ok(Some(
50-
BoolArray::new(array.to_bit_buffer(), new_validity).into_array(),
51-
))
63+
.cast_nullability(*new_nullability, array.len(), ctx)?;
64+
65+
let bits = array.to_bit_buffer();
66+
let len = bits.len();
67+
68+
Ok(Some(match_each_native_ptype!(*new_ptype, |T| {
69+
let (one, zero) = (<T as One>::one(), <T as Zero>::zero());
70+
let mut buffer = BufferMut::<T>::with_capacity(len);
71+
buffer.extend(bits.iter().map(|v| if v { one } else { zero }));
72+
PrimitiveArray::new(buffer.freeze(), new_validity).into_array()
73+
})))
5274
}
5375
}
5476

@@ -67,6 +89,7 @@ mod tests {
6789
use crate::compute::conformance::cast::test_cast_conformance;
6890
use crate::dtype::DType;
6991
use crate::dtype::Nullability;
92+
use crate::dtype::PType;
7093

7194
static SESSION: LazyLock<VortexSession> = LazyLock::new(crate::array_session);
7295

@@ -102,4 +125,22 @@ mod tests {
102125
fn test_cast_bool_conformance(#[case] array: BoolArray) {
103126
test_cast_conformance(&array.into_array());
104127
}
128+
129+
#[rstest]
130+
#[case(PType::I8)]
131+
#[case(PType::I32)]
132+
#[case(PType::I64)]
133+
#[case(PType::U8)]
134+
#[case(PType::U64)]
135+
#[case(PType::F32)]
136+
#[case(PType::F64)]
137+
fn cast_bool_to_primitive(#[case] target: PType) {
138+
let mut ctx = SESSION.create_execution_ctx();
139+
let arr = BoolArray::from_iter(vec![true, false, true]).into_array();
140+
let out = arr
141+
.cast(DType::Primitive(target, Nullability::NonNullable))
142+
.unwrap();
143+
let out = out.execute::<Canonical>(&mut ctx).unwrap().into_array();
144+
assert_eq!(out.len(), 3);
145+
}
105146
}

0 commit comments

Comments
 (0)