@@ -13,8 +13,11 @@ use crate::ExecutionCtx;
1313use crate :: IntoArray ;
1414use crate :: arrays:: Primitive ;
1515use crate :: arrays:: PrimitiveArray ;
16+ use crate :: compute;
1617use crate :: dtype:: DType ;
1718use crate :: dtype:: NativePType ;
19+ use crate :: dtype:: Nullability ;
20+ use crate :: dtype:: PType ;
1821use crate :: match_each_native_ptype;
1922use crate :: scalar_fn:: fns:: cast:: CastKernel ;
2023use crate :: vtable:: ValidityHelper ;
@@ -36,7 +39,7 @@ impl CastKernel for Primitive {
3639 . clone ( )
3740 . cast_nullability ( new_nullability, array. len ( ) ) ?;
3841
39- // If the bit width is the same, we can short-circuit and simply update the validity
42+ // Same ptype: zero-copy, just update validity.
4043 if array. ptype ( ) == new_ptype {
4144 // SAFETY: validity and data buffer still have same length
4245 return Ok ( Some ( unsafe {
@@ -49,9 +52,29 @@ impl CastKernel for Primitive {
4952 } ) ) ;
5053 }
5154
55+ // Same-width integers have identical bit representations due to 2's
56+ // complement. If all values fit in the target range, reinterpret with
57+ // no allocation.
58+ if array. ptype ( ) . is_int ( )
59+ && new_ptype. is_int ( )
60+ && array. ptype ( ) . byte_width ( ) == new_ptype. byte_width ( )
61+ && values_fit_in ( array, new_ptype)
62+ {
63+ // SAFETY: both types are integers with the same size and alignment, and
64+ // min/max confirm all valid values are representable in the target type.
65+ return Ok ( Some ( unsafe {
66+ PrimitiveArray :: new_unchecked_from_handle (
67+ array. buffer_handle ( ) . clone ( ) ,
68+ new_ptype,
69+ new_validity,
70+ )
71+ . into_array ( )
72+ } ) ) ;
73+ }
74+
5275 let mask = array. validity_mask ( ) ?;
5376
54- // Otherwise, we need to cast the values one-by-one
77+ // Otherwise, we need to cast the values one-by-one.
5578 Ok ( Some ( match_each_native_ptype ! ( new_ptype, |T | {
5679 match_each_native_ptype!( array. ptype( ) , |F | {
5780 PrimitiveArray :: new( cast:: <F , T >( array. as_slice( ) , mask) ?, new_validity)
@@ -61,34 +84,35 @@ impl CastKernel for Primitive {
6184 }
6285}
6386
87+ /// Returns `true` if all valid values in `array` are representable as `target_ptype`.
88+ fn values_fit_in ( array : & PrimitiveArray , target_ptype : PType ) -> bool {
89+ let target_dtype = DType :: Primitive ( target_ptype, Nullability :: NonNullable ) ;
90+ compute:: min_max ( & array. clone ( ) . into_array ( ) )
91+ . ok ( )
92+ . flatten ( )
93+ . is_none_or ( |mm| mm. min . cast ( & target_dtype) . is_ok ( ) && mm. max . cast ( & target_dtype) . is_ok ( ) )
94+ }
95+
6496fn cast < F : NativePType , T : NativePType > ( array : & [ F ] , mask : Mask ) -> VortexResult < Buffer < T > > {
97+ let try_cast = |src : F | -> VortexResult < T > {
98+ T :: from ( src) . ok_or_else ( || vortex_err ! ( Compute : "Failed to cast {} to {:?}" , src, T :: PTYPE ) )
99+ } ;
65100 match mask. bit_buffer ( ) {
101+ AllOr :: None => Ok ( Buffer :: zeroed ( array. len ( ) ) ) ,
66102 AllOr :: All => {
67103 let mut buffer = BufferMut :: with_capacity ( array. len ( ) ) ;
68- for item in array {
69- let item = T :: from ( * item) . ok_or_else (
70- || vortex_err ! ( Compute : "Failed to cast {} to {:?}" , item, T :: PTYPE ) ,
71- ) ?;
104+ for & src in array {
72105 // SAFETY: we've pre-allocated the required capacity
73- unsafe { buffer. push_unchecked ( item ) }
106+ unsafe { buffer. push_unchecked ( try_cast ( src ) ? ) }
74107 }
75108 Ok ( buffer. freeze ( ) )
76109 }
77- AllOr :: None => Ok ( Buffer :: zeroed ( array. len ( ) ) ) ,
78110 AllOr :: Some ( b) => {
79- // TODO(robert): Depending on density of the buffer might be better to prefill Buffer and only write valid values
80111 let mut buffer = BufferMut :: with_capacity ( array. len ( ) ) ;
81- for ( item, valid) in array. iter ( ) . zip ( b. iter ( ) ) {
82- if valid {
83- let item = T :: from ( * item) . ok_or_else (
84- || vortex_err ! ( Compute : "Failed to cast {} to {:?}" , item, T :: PTYPE ) ,
85- ) ?;
86- // SAFETY: we've pre-allocated the required capacity
87- unsafe { buffer. push_unchecked ( item) }
88- } else {
89- // SAFETY: we've pre-allocated the required capacity
90- unsafe { buffer. push_unchecked ( T :: default ( ) ) }
91- }
112+ for ( & src, valid) in array. iter ( ) . zip ( b. iter ( ) ) {
113+ let dst = if valid { try_cast ( src) ? } else { T :: default ( ) } ;
114+ // SAFETY: we've pre-allocated the required capacity
115+ unsafe { buffer. push_unchecked ( dst) }
92116 }
93117 Ok ( buffer. freeze ( ) )
94118 }
@@ -223,6 +247,69 @@ mod test {
223247 ) ;
224248 }
225249
250+ /// Same-width integer cast where all values fit: should reinterpret the
251+ /// buffer without allocation (pointer identity).
252+ #[ test]
253+ fn cast_same_width_int_reinterprets_buffer ( ) -> vortex_error:: VortexResult < ( ) > {
254+ let src = PrimitiveArray :: from_iter ( [ 0u32 , 10 , 100 ] ) ;
255+ let src_ptr = src. as_slice :: < u32 > ( ) . as_ptr ( ) ;
256+
257+ let dst = src. into_array ( ) . cast ( PType :: I32 . into ( ) ) ?. to_primitive ( ) ;
258+ let dst_ptr = dst. as_slice :: < i32 > ( ) . as_ptr ( ) ;
259+
260+ // Zero-copy: the data pointer should be identical.
261+ assert_eq ! ( src_ptr as usize , dst_ptr as usize ) ;
262+ assert_arrays_eq ! ( dst, PrimitiveArray :: from_iter( [ 0i32 , 10 , 100 ] ) ) ;
263+ Ok ( ( ) )
264+ }
265+
266+ /// Same-width integer cast where values don't fit: should fall through
267+ /// to the allocating path and produce an error.
268+ #[ test]
269+ fn cast_same_width_int_out_of_range_errors ( ) {
270+ let arr = buffer ! [ u32 :: MAX ] . into_array ( ) ;
271+ let err = arr
272+ . cast ( PType :: I32 . into ( ) )
273+ . and_then ( |a| a. to_canonical ( ) . map ( |c| c. into_array ( ) ) )
274+ . unwrap_err ( ) ;
275+ assert ! ( matches!( err, VortexError :: Compute ( ..) ) ) ;
276+ }
277+
278+ /// All-null array cast between same-width types should succeed without
279+ /// touching the buffer contents.
280+ #[ test]
281+ fn cast_same_width_all_null ( ) -> vortex_error:: VortexResult < ( ) > {
282+ let arr = PrimitiveArray :: new ( buffer ! [ 0xFFu8 , 0xFF ] , Validity :: AllInvalid ) ;
283+ let casted = arr
284+ . into_array ( )
285+ . cast ( DType :: Primitive ( PType :: I8 , Nullability :: Nullable ) ) ?
286+ . to_primitive ( ) ;
287+ assert_eq ! ( casted. len( ) , 2 ) ;
288+ assert ! ( matches!( casted. validity( ) , Validity :: AllInvalid ) ) ;
289+ Ok ( ( ) )
290+ }
291+
292+ /// Same-width integer cast with nullable values: out-of-range nulls should
293+ /// not prevent the cast from succeeding.
294+ #[ test]
295+ fn cast_same_width_int_nullable_with_out_of_range_nulls ( ) -> vortex_error:: VortexResult < ( ) > {
296+ // The null position holds u32::MAX which doesn't fit in i32, but it's
297+ // masked as invalid so the cast should still succeed via reinterpret.
298+ let arr = PrimitiveArray :: new (
299+ buffer ! [ u32 :: MAX , 0u32 , 42u32 ] ,
300+ Validity :: from_iter ( [ false , true , true ] ) ,
301+ ) ;
302+ let casted = arr
303+ . into_array ( )
304+ . cast ( DType :: Primitive ( PType :: I32 , Nullability :: Nullable ) ) ?
305+ . to_primitive ( ) ;
306+ assert_arrays_eq ! (
307+ casted,
308+ PrimitiveArray :: from_option_iter( [ None , Some ( 0i32 ) , Some ( 42 ) ] )
309+ ) ;
310+ Ok ( ( ) )
311+ }
312+
226313 #[ rstest]
227314 #[ case( buffer![ 0u8 , 1 , 2 , 3 , 255 ] . into_array( ) ) ]
228315 #[ case( buffer![ 0u16 , 100 , 1000 , 65535 ] . into_array( ) ) ]
0 commit comments