@@ -5,9 +5,6 @@ use vortex_buffer::Buffer;
55use vortex_buffer:: BufferMut ;
66use vortex_error:: VortexResult ;
77use vortex_error:: vortex_bail;
8- use vortex_error:: vortex_err;
9- use vortex_mask:: AllOr ;
10- use vortex_mask:: Mask ;
118
129use crate :: ArrayRef ;
1310use crate :: ExecutionCtx ;
@@ -53,20 +50,21 @@ impl CastKernel for Primitive {
5350 } ) ) ;
5451 }
5552
53+ if !values_fit_in ( array, new_ptype, ctx) {
54+ vortex_bail ! (
55+ Compute : "Cannot cast {} to {} — values exceed target range" ,
56+ array. ptype( ) ,
57+ new_ptype,
58+ ) ;
59+ }
60+
5661 // Same-width integers have identical bit representations due to 2's
5762 // complement. If all values fit in the target range, reinterpret with
5863 // no allocation.
5964 if array. ptype ( ) . is_int ( )
6065 && new_ptype. is_int ( )
6166 && array. ptype ( ) . byte_width ( ) == new_ptype. byte_width ( )
6267 {
63- if !values_fit_in ( array, new_ptype, ctx) {
64- vortex_bail ! (
65- Compute : "Cannot cast {} to {} — values exceed target range" ,
66- array. ptype( ) ,
67- new_ptype,
68- ) ;
69- }
7068 // SAFETY: both types are integers with the same size and alignment, and
7169 // min/max confirm all valid values are representable in the target type.
7270 return Ok ( Some ( unsafe {
@@ -79,13 +77,10 @@ impl CastKernel for Primitive {
7977 } ) ) ;
8078 }
8179
82- let mask = array. validity_mask ( ) ;
83-
84- // Otherwise, we need to cast the values one-by-one.
80+ // Otherwise, cast the values element-wise.
8581 Ok ( Some ( match_each_native_ptype ! ( new_ptype, |T | {
8682 match_each_native_ptype!( array. ptype( ) , |F | {
87- PrimitiveArray :: new( cast:: <F , T >( array. as_slice( ) , mask) ?, new_validity)
88- . into_array( )
83+ PrimitiveArray :: new( cast:: <F , T >( array. as_slice( ) ) , new_validity) . into_array( )
8984 } )
9085 } ) ) )
9186 }
@@ -104,30 +99,12 @@ fn values_fit_in(
10499 . is_none_or ( |mm| mm. min . cast ( & target_dtype) . is_ok ( ) && mm. max . cast ( & target_dtype) . is_ok ( ) )
105100}
106101
107- fn cast < F : NativePType , T : NativePType > ( array : & [ F ] , mask : Mask ) -> VortexResult < Buffer < T > > {
108- let try_cast = |src : F | -> VortexResult < T > {
109- T :: from ( src) . ok_or_else ( || vortex_err ! ( Compute : "Failed to cast {} to {:?}" , src, T :: PTYPE ) )
110- } ;
111- match mask. bit_buffer ( ) {
112- AllOr :: None => Ok ( Buffer :: zeroed ( array. len ( ) ) ) ,
113- AllOr :: All => {
114- let mut buffer = BufferMut :: with_capacity ( array. len ( ) ) ;
115- for & src in array {
116- // SAFETY: we've pre-allocated the required capacity
117- unsafe { buffer. push_unchecked ( try_cast ( src) ?) }
118- }
119- Ok ( buffer. freeze ( ) )
120- }
121- AllOr :: Some ( b) => {
122- let mut buffer = BufferMut :: with_capacity ( array. len ( ) ) ;
123- for ( & src, valid) in array. iter ( ) . zip ( b. iter ( ) ) {
124- let dst = if valid { try_cast ( src) ? } else { T :: default ( ) } ;
125- // SAFETY: we've pre-allocated the required capacity
126- unsafe { buffer. push_unchecked ( dst) }
127- }
128- Ok ( buffer. freeze ( ) )
129- }
130- }
102+ /// Caller must ensure all valid values are representable via `values_fit_in`;
103+ /// `unwrap_or_default` only fires at invalid positions where the physical
104+ /// value is out of range.
105+ fn cast < F : NativePType , T : NativePType > ( array : & [ F ] ) -> Buffer < T > {
106+ BufferMut :: from_trusted_len_iter ( array. iter ( ) . map ( |& src| T :: from ( src) . unwrap_or_default ( ) ) )
107+ . freeze ( )
131108}
132109
133110#[ cfg( test) ]
@@ -329,7 +306,9 @@ mod test {
329306 #[ case( buffer![ -1000000i32 , -1 , 0 , 1 , 1000000 ] . into_array( ) ) ]
330307 #[ case( buffer![ -1000000000i64 , -1 , 0 , 1 , 1000000000 ] . into_array( ) ) ]
331308 #[ case( buffer![ 0.0f32 , 1.5 , -2.5 , 100.0 , 1e6 ] . into_array( ) ) ]
309+ #[ case( buffer![ f32 :: NAN , f32 :: INFINITY , f32 :: NEG_INFINITY , 0.0f32 ] . into_array( ) ) ]
332310 #[ case( buffer![ 0.0f64 , 1.5 , -2.5 , 100.0 , 1e12 ] . into_array( ) ) ]
311+ #[ case( buffer![ f64 :: NAN , f64 :: INFINITY , f64 :: NEG_INFINITY , 0.0f64 ] . into_array( ) ) ]
333312 #[ case( PrimitiveArray :: from_option_iter( [ Some ( 1u8 ) , None , Some ( 255 ) , Some ( 0 ) , None ] ) . into_array( ) ) ]
334313 #[ case( PrimitiveArray :: from_option_iter( [ Some ( 1i32 ) , None , Some ( -100 ) , Some ( 0 ) , None ] ) . into_array( ) ) ]
335314 #[ case( buffer![ 42u32 ] . into_array( ) ) ]
0 commit comments