11// SPDX-License-Identifier: Apache-2.0
22// SPDX-FileCopyrightText: Copyright the Vortex contributors
33
4+ use num_traits:: AsPrimitive ;
45use vortex_buffer:: Buffer ;
56use vortex_buffer:: BufferMut ;
67use vortex_error:: VortexResult ;
78use vortex_error:: vortex_bail;
8- use vortex_error:: vortex_err;
9- use vortex_mask:: AllOr ;
10- use vortex_mask:: Mask ;
119
1210use crate :: ArrayRef ;
1311use crate :: ExecutionCtx ;
@@ -53,20 +51,21 @@ impl CastKernel for Primitive {
5351 } ) ) ;
5452 }
5553
54+ if !values_fit_in ( array, new_ptype, ctx) {
55+ vortex_bail ! (
56+ Compute : "Cannot cast {} to {} — values exceed target range" ,
57+ array. ptype( ) ,
58+ new_ptype,
59+ ) ;
60+ }
61+
5662 // Same-width integers have identical bit representations due to 2's
5763 // complement. If all values fit in the target range, reinterpret with
5864 // no allocation.
5965 if array. ptype ( ) . is_int ( )
6066 && new_ptype. is_int ( )
6167 && array. ptype ( ) . byte_width ( ) == new_ptype. byte_width ( )
6268 {
63- if !values_fit_in ( array, new_ptype, ctx) {
64- vortex_bail ! (
65- Compute : "Cannot cast {} to {} — values exceed target range" ,
66- array. ptype( ) ,
67- new_ptype,
68- ) ;
69- }
7069 // SAFETY: both types are integers with the same size and alignment, and
7170 // min/max confirm all valid values are representable in the target type.
7271 return Ok ( Some ( unsafe {
@@ -79,13 +78,10 @@ impl CastKernel for Primitive {
7978 } ) ) ;
8079 }
8180
82- let mask = array. validity_mask ( ) ;
83-
84- // Otherwise, we need to cast the values one-by-one.
81+ // Otherwise, cast the values element-wise.
8582 Ok ( Some ( match_each_native_ptype ! ( new_ptype, |T | {
8683 match_each_native_ptype!( array. ptype( ) , |F | {
87- PrimitiveArray :: new( cast:: <F , T >( array. as_slice( ) , mask) ?, new_validity)
88- . into_array( )
84+ PrimitiveArray :: new( cast:: <F , T >( array. as_slice( ) ) , new_validity) . into_array( )
8985 } )
9086 } ) ) )
9187 }
@@ -104,30 +100,11 @@ fn values_fit_in(
104100 . is_none_or ( |mm| mm. min . cast ( & target_dtype) . is_ok ( ) && mm. max . cast ( & target_dtype) . is_ok ( ) )
105101}
106102
107- fn cast < F : NativePType , T : NativePType > ( array : & [ F ] , mask : Mask ) -> VortexResult < Buffer < T > > {
108- let try_cast = |src : F | -> VortexResult < T > {
109- T :: from ( src) . ok_or_else ( || vortex_err ! ( Compute : "Failed to cast {} to {:?}" , src, T :: PTYPE ) )
110- } ;
111- match mask. bit_buffer ( ) {
112- AllOr :: None => Ok ( Buffer :: zeroed ( array. len ( ) ) ) ,
113- AllOr :: All => {
114- let mut buffer = BufferMut :: with_capacity ( array. len ( ) ) ;
115- for & src in array {
116- // SAFETY: we've pre-allocated the required capacity
117- unsafe { buffer. push_unchecked ( try_cast ( src) ?) }
118- }
119- Ok ( buffer. freeze ( ) )
120- }
121- AllOr :: Some ( b) => {
122- let mut buffer = BufferMut :: with_capacity ( array. len ( ) ) ;
123- for ( & src, valid) in array. iter ( ) . zip ( b. iter ( ) ) {
124- let dst = if valid { try_cast ( src) ? } else { T :: default ( ) } ;
125- // SAFETY: we've pre-allocated the required capacity
126- unsafe { buffer. push_unchecked ( dst) }
127- }
128- Ok ( buffer. freeze ( ) )
129- }
130- }
103+ /// Caller must ensure all valid values are representable via `values_fit_in`.
104+ /// Out-of-range values at invalid positions are truncated/wrapped by `as`,
105+ /// which is fine because they are masked out by validity.
106+ fn cast < F : NativePType + AsPrimitive < T > , T : NativePType > ( array : & [ F ] ) -> Buffer < T > {
107+ BufferMut :: from_trusted_len_iter ( array. iter ( ) . map ( |& src| src. as_ ( ) ) ) . freeze ( )
131108}
132109
133110#[ cfg( test) ]
@@ -319,6 +296,23 @@ mod test {
319296 Ok ( ( ) )
320297 }
321298
299+ #[ test]
300+ fn cast_u32_to_u8_with_out_of_range_nulls ( ) -> vortex_error:: VortexResult < ( ) > {
301+ let arr = PrimitiveArray :: new (
302+ buffer ! [ 1000u32 , 10u32 , 42u32 ] ,
303+ Validity :: from_iter ( [ false , true , true ] ) ,
304+ ) ;
305+ let casted = arr
306+ . into_array ( )
307+ . cast ( DType :: Primitive ( PType :: U8 , Nullability :: Nullable ) ) ?
308+ . to_primitive ( ) ;
309+ assert_arrays_eq ! (
310+ casted,
311+ PrimitiveArray :: from_option_iter( [ None , Some ( 10u8 ) , Some ( 42 ) ] )
312+ ) ;
313+ Ok ( ( ) )
314+ }
315+
322316 #[ rstest]
323317 #[ case( buffer![ 0u8 , 1 , 2 , 3 , 255 ] . into_array( ) ) ]
324318 #[ case( buffer![ 0u16 , 100 , 1000 , 65535 ] . into_array( ) ) ]
@@ -329,7 +323,9 @@ mod test {
329323 #[ case( buffer![ -1000000i32 , -1 , 0 , 1 , 1000000 ] . into_array( ) ) ]
330324 #[ case( buffer![ -1000000000i64 , -1 , 0 , 1 , 1000000000 ] . into_array( ) ) ]
331325 #[ case( buffer![ 0.0f32 , 1.5 , -2.5 , 100.0 , 1e6 ] . into_array( ) ) ]
326+ #[ case( buffer![ f32 :: NAN , f32 :: INFINITY , f32 :: NEG_INFINITY , 0.0f32 ] . into_array( ) ) ]
332327 #[ case( buffer![ 0.0f64 , 1.5 , -2.5 , 100.0 , 1e12 ] . into_array( ) ) ]
328+ #[ case( buffer![ f64 :: NAN , f64 :: INFINITY , f64 :: NEG_INFINITY , 0.0f64 ] . into_array( ) ) ]
333329 #[ case( PrimitiveArray :: from_option_iter( [ Some ( 1u8 ) , None , Some ( 255 ) , Some ( 0 ) , None ] ) . into_array( ) ) ]
334330 #[ case( PrimitiveArray :: from_option_iter( [ Some ( 1i32 ) , None , Some ( -100 ) , Some ( 0 ) , None ] ) . into_array( ) ) ]
335331 #[ case( buffer![ 42u32 ] . into_array( ) ) ]
0 commit comments