@@ -5,6 +5,9 @@ use num_traits::AsPrimitive;
55use num_traits:: NumCast ;
66use vortex_buffer:: Buffer ;
77use vortex_buffer:: BufferMut ;
8+ use vortex_compute:: lane_kernels:: IndexedSinkExt ;
9+ use vortex_compute:: lane_kernels:: IndexedSourceExt ;
10+ use vortex_compute:: lane_kernels:: ReinterpretSink ;
811use vortex_error:: VortexResult ;
912use vortex_error:: vortex_bail;
1013use vortex_error:: vortex_err;
@@ -102,9 +105,7 @@ impl CastKernel for Primitive {
102105 }
103106}
104107
105- /// Cast values from `F` to `T`. For infallible casts this is a pure pass; for fallible casts
106- /// each valid value goes through a checked `NumCast::from` and the kernel bails if any of them
107- /// overflow `T`. Invalid positions use the wrapping `as` cast since their values are masked out.
108+ /// Cast Primitive values from `F` to `T`.
108109fn cast_values < F , T > (
109110 array : ArrayView < ' _ , Primitive > ,
110111 new_validity : Validity ,
@@ -114,53 +115,99 @@ where
114115 F : NativePType + AsPrimitive < T > ,
115116 T : NativePType ,
116117{
117- let values = array. as_slice :: < F > ( ) ;
118-
119- // Fast path: statically infallible, or cached min/max prove every valid value fits in `T`.
120- // The cached check never triggers a stats computation — if the bounds aren't already known
121- // we fall through to the per-lane loop below.
122- if values_always_fit ( F :: PTYPE , T :: PTYPE ) || values_fit_in ( array, T :: PTYPE , ctx, false ) {
123- return Ok ( PrimitiveArray :: new ( cast :: < F , T > ( values) , new_validity) . into_array ( ) ) ;
124- }
125-
126- // TODO(joe): if the values source and target have the same bit-width we can
127- // mutate in place.
128-
129- // Fallible: invalid lanes are pre-multiplied to zero so the checked cast always succeeds for
130- // them; valid lanes go through `NumCast::from` and the whole cast bails on the first overflow.
131- let mask = array. validity ( ) ?. execute_mask ( array. len ( ) , ctx) ?;
132118 let overflow = || {
133119 vortex_err ! (
134120 Compute : "Cannot cast {} to {} — value exceeds target range" ,
135121 F :: PTYPE , T :: PTYPE ,
136122 )
137123 } ;
138- let buffer: Buffer < T > = match & mask {
139- Mask :: AllTrue ( _) => BufferMut :: try_from_trusted_len_iter (
124+
125+ // Returns `true` if every value of `from` is representable in `to` without loss.
126+ fn casts_losslessly_to ( from : PType , to : PType ) -> bool {
127+ from. least_supertype ( to) == Some ( to)
128+ }
129+
130+ // Skip the fallible kernel when type widening or (cached) min/max prove every value fits.
131+ let target_dtype = DType :: Primitive ( T :: PTYPE , Nullability :: NonNullable ) ;
132+ let infallible = casts_losslessly_to ( F :: PTYPE , T :: PTYPE )
133+ || cached_values_fit_in ( array, & target_dtype) . unwrap_or ( false ) ;
134+
135+ let len = array. len ( ) ;
136+
137+ // If F and T have the same byte width, try to take unique ownership of the buffer.
138+ let same_bit_width = F :: PTYPE . byte_width ( ) == T :: PTYPE . byte_width ( ) ;
139+ let owned: Option < BufferMut < F > > = same_bit_width
140+ . then ( || array. into_owned ( ) . try_into_buffer_mut :: < F > ( ) . ok ( ) )
141+ . flatten ( ) ;
142+ let values: & [ F ] = array. as_slice :: < F > ( ) ;
143+
144+ if infallible {
145+ return match owned {
146+ Some ( mut buf) => {
147+ ReinterpretSink :: < F , T > :: new ( buf. as_mut_slice ( ) ) . map_into_in_place ( |v : F | v. as_ ( ) ) ;
148+ // SAFETY: same size + alignment for NativePType
149+ let result: BufferMut < T > = unsafe { buf. transmute :: < T > ( ) } ;
150+ Ok ( PrimitiveArray :: new ( result. freeze ( ) , new_validity) . into_array ( ) )
151+ }
152+ None => {
153+ let mut buffer = BufferMut :: < T > :: with_capacity ( len) ;
154+ values. map_into ( & mut buffer. spare_capacity_mut ( ) [ ..len] , |v| v. as_ ( ) ) ;
155+ // SAFETY: map_into initializes every lane.
156+ unsafe { buffer. set_len ( len) } ;
157+ Ok ( PrimitiveArray :: new ( buffer. freeze ( ) , new_validity) . into_array ( ) )
158+ }
159+ } ;
160+ }
161+
162+ let mask = array. validity ( ) ?. execute_mask ( len, ctx) ?;
163+
164+ let buffer: Buffer < T > = match ( & mask, owned) {
165+ ( Mask :: AllTrue ( _) , Some ( mut buf) ) => {
166+ ReinterpretSink :: < F , T > :: new ( buf. as_mut_slice ( ) )
167+ . try_map_in_place ( |v : F | <T as NumCast >:: from ( v) )
168+ . map_err ( |_| overflow ( ) ) ?;
169+ // SAFETY: same size + alignment for NativePType
170+ let result: BufferMut < T > = unsafe { buf. transmute :: < T > ( ) } ;
171+ result. freeze ( )
172+ }
173+ ( Mask :: AllTrue ( _) , None ) => {
174+ let mut buffer = BufferMut :: < T > :: with_capacity ( len) ;
175+ values
176+ . try_map_into ( & mut buffer. spare_capacity_mut ( ) [ ..len] , |v| {
177+ <T as NumCast >:: from ( v)
178+ } )
179+ . map_err ( |_| overflow ( ) ) ?;
180+ // SAFETY: initialized every lane.
181+ unsafe { buffer. set_len ( len) } ;
182+ buffer. freeze ( )
183+ }
184+ ( Mask :: AllFalse ( _) , _) => BufferMut :: < T > :: zeroed ( len) . freeze ( ) ,
185+ ( Mask :: Values ( m) , Some ( mut buf) ) => {
186+ ReinterpretSink :: < F , T > :: new ( buf. as_mut_slice ( ) )
187+ . try_map_masked_in_place ( m. bit_buffer ( ) , |v : F | <T as NumCast >:: from ( v) )
188+ . map_err ( |_| overflow ( ) ) ?;
189+ // SAFETY: same size + alignment for NativePType
190+ let result: BufferMut < T > = unsafe { buf. transmute :: < T > ( ) } ;
191+ result. freeze ( )
192+ }
193+ ( Mask :: Values ( m) , None ) => {
194+ let mut buffer = BufferMut :: < T > :: with_capacity ( len) ;
140195 values
141- . iter ( )
142- . map ( |& v| <T as NumCast >:: from ( v) . ok_or_else ( overflow) ) ,
143- ) ?
144- . freeze ( ) ,
145- Mask :: AllFalse ( _) => BufferMut :: < T > :: zeroed ( values. len ( ) ) . freeze ( ) ,
146- Mask :: Values ( m) => BufferMut :: try_from_trusted_len_iter (
147- values. iter ( ) . zip ( m. bit_buffer ( ) . iter ( ) ) . map ( |( & v, valid) | {
148- let factor = if valid { F :: one ( ) } else { F :: zero ( ) } ;
149- <T as NumCast >:: from ( v * factor) . ok_or_else ( overflow)
150- } ) ,
151- ) ?
152- . freeze ( ) ,
196+ . try_map_masked_into (
197+ m. bit_buffer ( ) ,
198+ & mut buffer. spare_capacity_mut ( ) [ ..len] ,
199+ |v| <T as NumCast >:: from ( v) ,
200+ )
201+ . map_err ( |_| overflow ( ) ) ?;
202+ // SAFETY: initialized every lane.
203+ unsafe { buffer. set_len ( len) } ;
204+ buffer. freeze ( )
205+ }
153206 } ;
154207
155208 Ok ( PrimitiveArray :: new ( buffer, new_validity) . into_array ( ) )
156209}
157210
158- /// Out-of-range values at invalid positions are truncated/wrapped by `as`, which is fine because
159- /// they are masked out by validity.
160- fn cast < F : NativePType + AsPrimitive < T > , T : NativePType > ( array : & [ F ] ) -> Buffer < T > {
161- BufferMut :: from_trusted_len_iter ( array. iter ( ) . map ( |& src| src. as_ ( ) ) ) . freeze ( )
162- }
163-
164211fn reinterpret (
165212 array : ArrayView < ' _ , Primitive > ,
166213 new_ptype : PType ,
@@ -178,23 +225,6 @@ fn reinterpret(
178225 . into_array ( )
179226}
180227
181- /// Returns `true` if every value of `src` is guaranteed representable in `target` without
182- /// overflow. Precision may be lost (e.g. large integers cast to `f32`), but the cast can never
183- /// produce an out-of-range result.
184- fn values_always_fit ( src : PType , target : PType ) -> bool {
185- if src == target {
186- return true ;
187- }
188- if src. is_int ( ) && target. is_int ( ) {
189- return target. byte_width ( ) > src. byte_width ( )
190- && ( src. is_unsigned_int ( ) || target. is_signed_int ( ) ) ;
191- }
192- if src. is_float ( ) && target. is_float ( ) {
193- return target. byte_width ( ) > src. byte_width ( ) ;
194- }
195- src. is_int ( ) && matches ! ( target, PType :: F32 | PType :: F64 )
196- }
197-
198228/// Returns `true` if all valid values in `array` are representable as `target_ptype`.
199229///
200230/// Cached min/max statistics are consulted first. If either bound is missing, the function either
0 commit comments