11// SPDX-License-Identifier: Apache-2.0
22// SPDX-FileCopyrightText: Copyright the Vortex contributors
33
4- use std:: mem:: align_of;
5- use std:: mem:: size_of;
6-
74use num_traits:: AsPrimitive ;
85use num_traits:: NumCast ;
9- use vortex_buffer:: BitBuffer ;
106use vortex_buffer:: Buffer ;
117use vortex_buffer:: BufferMut ;
128use vortex_buffer:: lane_ops_indexed:: ReinterpretSink ;
139use vortex_buffer:: lane_ops_indexed:: map_no_validity;
10+ use vortex_buffer:: lane_ops_indexed:: map_no_validity_in_place;
1411use vortex_buffer:: lane_ops_indexed:: try_map_no_validity;
12+ use vortex_buffer:: lane_ops_indexed:: try_map_no_validity_in_place;
1513use vortex_buffer:: lane_ops_indexed:: try_map_with_mask;
1614use vortex_buffer:: lane_ops_indexed:: try_map_with_mask_in_place;
1715use vortex_error:: VortexResult ;
@@ -132,7 +130,6 @@ where
132130 F : NativePType + AsPrimitive < T > ,
133131 T : NativePType ,
134132{
135- let values = array. as_slice :: < F > ( ) ;
136133 let overflow = || {
137134 vortex_err ! (
138135 Compute : "Cannot cast {} to {} — value exceeds target range" ,
@@ -156,138 +153,112 @@ where
156153 let infallible = casts_losslessly_to ( F :: PTYPE , T :: PTYPE )
157154 || cached_values_fit_in ( array, & target_dtype) == Some ( true ) ;
158155
159- // Same-bit-width in-place fast path: when F and T have the same byte width and the
160- // buffer is uniquely owned, mutate in place and transmute the wrapper. Saves the
161- // output allocation. Falls through to the out-of-place path when the buffer is shared
162- // (the common case under the current borrow-based kernel API).
156+ let len = array. len ( ) ;
157+
158+ // Same-bit-width in-place fast path: when F and T have the same byte width, try to take
159+ // unique ownership of the buffer. If successful, each kernel call site below mutates in
160+ // place via `ReinterpretSink` and transmutes the wrapper at the end, saving the output
161+ // allocation. Falls back to the out-of-place path (borrowed slice + fresh buffer) when
162+ // the buffer is shared — the common case under the current borrow-based kernel API.
163163 let same_bit_width = F :: PTYPE . byte_width ( ) == T :: PTYPE . byte_width ( ) ;
164- if same_bit_width
165- && let Ok ( buffer_mut) = array. into_owned ( ) . try_into_buffer_mut :: < F > ( )
166- {
167- return cast_buffer_in_place :: < F , T > ( buffer_mut, array, new_validity, ctx, infallible) ;
168- }
164+ let owned: Option < BufferMut < F > > = if same_bit_width {
165+ array. into_owned ( ) . try_into_buffer_mut :: < F > ( ) . ok ( )
166+ } else {
167+ None
168+ } ;
169+ let values: & [ F ] = array. as_slice :: < F > ( ) ;
169170
170171 if infallible {
171- let mut buffer = BufferMut :: < T > :: with_capacity ( values. len ( ) ) ;
172- // Truncating `as`-cast — safe here because stats prove every valid value fits.
173- // Null lanes' underlying garbage gets truncated/wrapped (harmless: the result
174- // validity bitmap masks them downstream).
175- map_no_validity (
176- values,
177- & mut buffer. spare_capacity_mut ( ) [ ..values. len ( ) ] ,
178- |v| v. as_ ( ) ,
179- ) ;
180- // SAFETY: map_no_validity initializes every lane.
181- unsafe { buffer. set_len ( values. len ( ) ) } ;
182- return Ok ( PrimitiveArray :: new ( buffer. freeze ( ) , new_validity) . into_array ( ) ) ;
172+ // Truncating `as`-cast — safe here because static type analysis or cached stats prove
173+ // every valid value fits. Null lanes' underlying garbage gets truncated/wrapped
174+ // (harmless: the result validity bitmap masks them downstream).
175+ return match owned {
176+ Some ( mut buf) => {
177+ map_no_validity_in_place (
178+ ReinterpretSink :: < F , T > :: new ( buf. as_mut_slice ( ) ) ,
179+ |v : F | v. as_ ( ) ,
180+ ) ;
181+ // SAFETY: same size + alignment for NativePType same-byte-width pairs;
182+ // every F-slot was overwritten with a real `T` bit pattern.
183+ let result: BufferMut < T > = unsafe { buf. transmute :: < T > ( ) } ;
184+ Ok ( PrimitiveArray :: new ( result. freeze ( ) , new_validity) . into_array ( ) )
185+ }
186+ None => {
187+ let mut buffer = BufferMut :: < T > :: with_capacity ( len) ;
188+ map_no_validity ( values, & mut buffer. spare_capacity_mut ( ) [ ..len] , |v| v. as_ ( ) ) ;
189+ // SAFETY: map_no_validity initializes every lane.
190+ unsafe { buffer. set_len ( len) } ;
191+ Ok ( PrimitiveArray :: new ( buffer. freeze ( ) , new_validity) . into_array ( ) )
192+ }
193+ } ;
183194 }
184195
185- let mask = array. validity ( ) ?. execute_mask ( array . len ( ) , ctx) ?;
196+ let mask = array. validity ( ) ?. execute_mask ( len, ctx) ?;
186197
187- let buffer: Buffer < T > = match & mask {
188- Mask :: AllTrue ( _) => {
189- let mut buffer = BufferMut :: < T > :: with_capacity ( values. len ( ) ) ;
190- try_map_no_validity (
191- values,
192- & mut buffer. spare_capacity_mut ( ) [ ..values. len ( ) ] ,
193- |v| <T as NumCast >:: from ( v) ,
198+ let buffer: Buffer < T > = match ( & mask, owned) {
199+ ( Mask :: AllTrue ( _) , Some ( mut buf) ) => {
200+ try_map_no_validity_in_place (
201+ ReinterpretSink :: < F , T > :: new ( buf. as_mut_slice ( ) ) ,
202+ |v : F | <T as NumCast >:: from ( v) ,
194203 )
195204 . map_err ( |_| overflow ( ) ) ?;
205+ // SAFETY: same size + alignment for NativePType same-byte-width pairs;
206+ // every F-slot now holds a `T` bit pattern written by `ReinterpretSink`.
207+ let result: BufferMut < T > = unsafe { buf. transmute :: < T > ( ) } ;
208+ result. freeze ( )
209+ }
210+ ( Mask :: AllTrue ( _) , None ) => {
211+ let mut buffer = BufferMut :: < T > :: with_capacity ( len) ;
212+ try_map_no_validity ( values, & mut buffer. spare_capacity_mut ( ) [ ..len] , |v| {
213+ <T as NumCast >:: from ( v)
214+ } )
215+ . map_err ( |_| overflow ( ) ) ?;
196216 // SAFETY: try_map_no_validity returned Ok, so it initialized every lane.
197- unsafe { buffer. set_len ( values . len ( ) ) } ;
217+ unsafe { buffer. set_len ( len) } ;
198218 buffer. freeze ( )
199219 }
200- Mask :: AllFalse ( _) => BufferMut :: < T > :: zeroed ( values. len ( ) ) . freeze ( ) ,
201- Mask :: Values ( m) => {
202- let mut buffer = BufferMut :: < T > :: with_capacity ( values. len ( ) ) ;
220+ ( Mask :: AllFalse ( _) , Some ( buf) ) => {
221+ // SAFETY: same size + alignment by NativePType same-byte-width invariant.
222+ let mut t_buf: BufferMut < T > = unsafe { buf. transmute :: < T > ( ) } ;
223+ t_buf. as_mut_slice ( ) . fill ( T :: zero ( ) ) ;
224+ t_buf. freeze ( )
225+ }
226+ ( Mask :: AllFalse ( _) , None ) => BufferMut :: < T > :: zeroed ( len) . freeze ( ) ,
227+ ( Mask :: Values ( m) , Some ( mut buf) ) => {
228+ try_map_with_mask_in_place (
229+ ReinterpretSink :: < F , T > :: new ( buf. as_mut_slice ( ) ) ,
230+ m. bit_buffer ( ) ,
231+ |v : F , valid| <T as NumCast >:: from ( v) . or_else ( || ( !valid) . then ( T :: zero) ) ,
232+ )
233+ . map_err ( |_| overflow ( ) ) ?;
234+ // SAFETY: same size + alignment for NativePType same-byte-width pairs;
235+ // every F-slot now holds a `T` bit pattern written by `ReinterpretSink`.
236+ let result: BufferMut < T > = unsafe { buf. transmute :: < T > ( ) } ;
237+ result. freeze ( )
238+ }
239+ ( Mask :: Values ( m) , None ) => {
240+ let mut buffer = BufferMut :: < T > :: with_capacity ( len) ;
203241 try_map_with_mask (
204242 values,
205243 m. bit_buffer ( ) ,
206- & mut buffer. spare_capacity_mut ( ) [ ..values. len ( ) ] ,
207- // Lazy validity: only consult `valid` on the failure branch. For
208- // widening / statically-infallible casts, `NumCast::from` is always
209- // `Some` so the `or_else` is provably dead — LLVM DCEs the validity
210- // path entirely, giving the same codegen as the maskless kernel.
211- // For narrowing, `valid` is only read at lanes that actually
212- // overflowed (a cold check on top of the cast).
244+ & mut buffer. spare_capacity_mut ( ) [ ..len] ,
245+ // Lazy validity: only consult `valid` on the failure branch. For widening /
246+ // statically-infallible casts, `NumCast::from` is always `Some` so the
247+ // `or_else` is provably dead — LLVM DCEs the validity path entirely, giving
248+ // the same codegen as the maskless kernel. For narrowing, `valid` is only
249+ // read at lanes that actually overflowed (a cold check on top of the cast).
213250 |v, valid| <T as NumCast >:: from ( v) . or_else ( || ( !valid) . then ( T :: zero) ) ,
214251 )
215252 . map_err ( |_| overflow ( ) ) ?;
216253 // SAFETY: try_map_with_mask returned Ok, so it initialized every lane.
217- unsafe { buffer. set_len ( values . len ( ) ) } ;
254+ unsafe { buffer. set_len ( len) } ;
218255 buffer. freeze ( )
219256 }
220257 } ;
221258
222259 Ok ( PrimitiveArray :: new ( buffer, new_validity) . into_array ( ) )
223260}
224261
225- /// In-place cast of an owned `BufferMut<F>` to `BufferMut<T>` when `F` and `T` have the
226- /// same byte width. Each slot is read as `F`, converted, and written back as `T`-bits
227- /// using `BufferMut`'s transmute family. Avoids allocating a second output buffer.
228- ///
229- /// The caller has already verified `F::PTYPE.byte_width() == T::PTYPE.byte_width()`.
230- fn cast_buffer_in_place < F , T > (
231- buffer : BufferMut < F > ,
232- array : ArrayView < ' _ , Primitive > ,
233- new_validity : Validity ,
234- ctx : & mut ExecutionCtx ,
235- infallible : bool ,
236- ) -> VortexResult < ArrayRef >
237- where
238- F : NativePType + AsPrimitive < T > ,
239- T : NativePType ,
240- {
241- debug_assert_eq ! ( size_of:: <F >( ) , size_of:: <T >( ) ) ;
242- debug_assert_eq ! ( align_of:: <F >( ) , align_of:: <T >( ) ) ;
243-
244- if infallible {
245- // `map_each_in_place` does the BufferMut<F> → BufferMut<T> transmute internally
246- // (same size + alignment for primitives of equal byte width) and walks each slot
247- // with the closure.
248- let result: BufferMut < T > = buffer. map_each_in_place ( |v : F | v. as_ ( ) ) ;
249- return Ok ( PrimitiveArray :: new ( result. freeze ( ) , new_validity) . into_array ( ) ) ;
250- }
251-
252- let mask = array. validity ( ) ?. execute_mask ( array. len ( ) , ctx) ?;
253- let overflow = || {
254- vortex_err ! (
255- Compute : "Cannot cast {} to {} — value exceeds target range" ,
256- F :: PTYPE , T :: PTYPE ,
257- )
258- } ;
259-
260- // All-null short-circuit: zero out the buffer and skip the conversion loop entirely.
261- if matches ! ( mask, Mask :: AllFalse ( _) ) {
262- // SAFETY: same size + alignment by NativePType same-byte-width invariant.
263- let mut t_buf: BufferMut < T > = unsafe { buffer. transmute :: < T > ( ) } ;
264- t_buf. as_mut_slice ( ) . fill ( T :: zero ( ) ) ;
265- return Ok ( PrimitiveArray :: new ( t_buf. freeze ( ) , new_validity) . into_array ( ) ) ;
266- }
267-
268- let bit_buffer = match & mask {
269- Mask :: AllTrue ( n) => BitBuffer :: new_set ( * n) ,
270- Mask :: AllFalse ( _) => unreachable ! ( "handled above" ) ,
271- Mask :: Values ( m) => m. bit_buffer ( ) . clone ( ) ,
272- } ;
273-
274- let mut buffer = buffer;
275- try_map_with_mask_in_place (
276- ReinterpretSink :: < F , T > :: new ( buffer. as_mut_slice ( ) ) ,
277- & bit_buffer,
278- |f_val : F , valid| -> Option < T > {
279- <T as NumCast >:: from ( f_val) . or_else ( || ( !valid) . then ( T :: zero) )
280- } ,
281- )
282- . map_err ( |_| overflow ( ) ) ?;
283-
284- // SAFETY: same size + alignment for NativePType same-byte-width pairs. Every F-slot
285- // now holds a valid T-bit pattern because `ReinterpretSink::set_unchecked` wrote a
286- // real `T` at every visited lane.
287- let result: BufferMut < T > = unsafe { buffer. transmute :: < T > ( ) } ;
288- Ok ( PrimitiveArray :: new ( result. freeze ( ) , new_validity) . into_array ( ) )
289- }
290-
291262fn reinterpret (
292263 array : ArrayView < ' _ , Primitive > ,
293264 new_ptype : PType ,
0 commit comments