11// SPDX-License-Identifier: Apache-2.0
22// SPDX-FileCopyrightText: Copyright the Vortex contributors
33
4+ use num_traits:: One ;
5+ use num_traits:: Zero ;
6+ use vortex_buffer:: BufferMut ;
47use vortex_error:: VortexResult ;
58
69use crate :: ArrayRef ;
@@ -9,8 +12,10 @@ use crate::IntoArray;
912use crate :: array:: ArrayView ;
1013use crate :: arrays:: Bool ;
1114use crate :: arrays:: BoolArray ;
15+ use crate :: arrays:: PrimitiveArray ;
1216use crate :: arrays:: bool:: BoolArrayExt ;
1317use crate :: dtype:: DType ;
18+ use crate :: match_each_native_ptype;
1419use crate :: scalar_fn:: fns:: cast:: CastKernel ;
1520use crate :: scalar_fn:: fns:: cast:: CastReduce ;
1621
@@ -38,17 +43,34 @@ impl CastKernel for Bool {
3843 dtype : & DType ,
3944 ctx : & mut ExecutionCtx ,
4045 ) -> VortexResult < Option < ArrayRef > > {
41- if !dtype. is_boolean ( ) {
42- return Ok ( None ) ;
46+ if dtype. is_boolean ( ) {
47+ let new_validity =
48+ array
49+ . validity ( ) ?
50+ . cast_nullability ( dtype. nullability ( ) , array. len ( ) , ctx) ?;
51+ return Ok ( Some (
52+ BoolArray :: new ( array. to_bit_buffer ( ) , new_validity) . into_array ( ) ,
53+ ) ) ;
4354 }
4455
56+ let DType :: Primitive ( new_ptype, new_nullability) = dtype else {
57+ return Ok ( None ) ;
58+ } ;
59+
4560 let new_validity =
4661 array
4762 . validity ( ) ?
48- . cast_nullability ( dtype. nullability ( ) , array. len ( ) , ctx) ?;
49- Ok ( Some (
50- BoolArray :: new ( array. to_bit_buffer ( ) , new_validity) . into_array ( ) ,
51- ) )
63+ . cast_nullability ( * new_nullability, array. len ( ) , ctx) ?;
64+
65+ let bits = array. to_bit_buffer ( ) ;
66+ let len = bits. len ( ) ;
67+
68+ Ok ( Some ( match_each_native_ptype ! ( * new_ptype, |T | {
69+ let ( one, zero) = ( <T as One >:: one( ) , <T as Zero >:: zero( ) ) ;
70+ let mut buffer = BufferMut :: <T >:: with_capacity( len) ;
71+ buffer. extend( bits. iter( ) . map( |v| if v { one } else { zero } ) ) ;
72+ PrimitiveArray :: new( buffer. freeze( ) , new_validity) . into_array( )
73+ } ) ) )
5274 }
5375}
5476
@@ -67,6 +89,7 @@ mod tests {
6789 use crate :: compute:: conformance:: cast:: test_cast_conformance;
6890 use crate :: dtype:: DType ;
6991 use crate :: dtype:: Nullability ;
92+ use crate :: dtype:: PType ;
7093
7194 static SESSION : LazyLock < VortexSession > = LazyLock :: new ( crate :: array_session) ;
7295
@@ -102,4 +125,22 @@ mod tests {
102125 fn test_cast_bool_conformance ( #[ case] array : BoolArray ) {
103126 test_cast_conformance ( & array. into_array ( ) ) ;
104127 }
128+
129+ #[ rstest]
130+ #[ case( PType :: I8 ) ]
131+ #[ case( PType :: I32 ) ]
132+ #[ case( PType :: I64 ) ]
133+ #[ case( PType :: U8 ) ]
134+ #[ case( PType :: U64 ) ]
135+ #[ case( PType :: F32 ) ]
136+ #[ case( PType :: F64 ) ]
137+ fn cast_bool_to_primitive ( #[ case] target : PType ) {
138+ let mut ctx = SESSION . create_execution_ctx ( ) ;
139+ let arr = BoolArray :: from_iter ( vec ! [ true , false , true ] ) . into_array ( ) ;
140+ let out = arr
141+ . cast ( DType :: Primitive ( target, Nullability :: NonNullable ) )
142+ . unwrap ( ) ;
143+ let out = out. execute :: < Canonical > ( & mut ctx) . unwrap ( ) . into_array ( ) ;
144+ assert_eq ! ( out. len( ) , 3 ) ;
145+ }
105146}
0 commit comments