Skip to content

Commit a968853

Browse files
committed
f
Signed-off-by: Joe Isaacs <joe.isaacs@live.co.uk>
1 parent 99fd5d9 commit a968853

2 files changed

Lines changed: 362 additions & 281 deletions

File tree

vortex-array/src/arrays/primitive/compute/cast.rs

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
use std::mem::align_of;
5+
use std::mem::size_of;
6+
47
use num_traits::AsPrimitive;
58
use num_traits::NumCast;
9+
use vortex_buffer::BitBuffer;
610
use vortex_buffer::Buffer;
711
use vortex_buffer::BufferMut;
12+
use vortex_buffer::lane_ops_indexed::ReinterpretSink;
813
use vortex_buffer::lane_ops_indexed::map_no_validity;
914
use vortex_buffer::lane_ops_indexed::try_map_no_validity;
1015
use vortex_buffer::lane_ops_indexed::try_map_with_mask;
16+
use vortex_buffer::lane_ops_indexed::try_map_with_mask_in_place;
1117
use vortex_error::VortexResult;
1218
use vortex_error::vortex_bail;
1319
use vortex_error::vortex_err;
@@ -147,9 +153,21 @@ where
147153
// Skip the fallible kernel when the conversion is infallible by type alone (widening) or
148154
// when cached min/max prove every value fits in `T`.
149155
let target_dtype = DType::Primitive(T::PTYPE, Nullability::NonNullable);
150-
if casts_losslessly_to(F::PTYPE, T::PTYPE)
151-
|| cached_values_fit_in(array, &target_dtype) == Some(true)
156+
let infallible = casts_losslessly_to(F::PTYPE, T::PTYPE)
157+
|| cached_values_fit_in(array, &target_dtype) == Some(true);
158+
159+
// Same-bit-width in-place fast path: when F and T have the same byte width and the
160+
// buffer is uniquely owned, mutate in place and transmute the wrapper. Saves the
161+
// output allocation. Falls through to the out-of-place path when the buffer is shared
162+
// (the common case under the current borrow-based kernel API).
163+
let same_bit_width = F::PTYPE.byte_width() == T::PTYPE.byte_width();
164+
if same_bit_width
165+
&& let Ok(buffer_mut) = array.into_owned().try_into_buffer_mut::<F>()
152166
{
167+
return cast_buffer_in_place::<F, T>(buffer_mut, array, new_validity, ctx, infallible);
168+
}
169+
170+
if infallible {
153171
let mut buffer = BufferMut::<T>::with_capacity(values.len());
154172
// Truncating `as`-cast — safe here because stats prove every valid value fits.
155173
// Null lanes' underlying garbage gets truncated/wrapped (harmless: the result
@@ -204,6 +222,72 @@ where
204222
Ok(PrimitiveArray::new(buffer, new_validity).into_array())
205223
}
206224

225+
/// In-place cast of an owned `BufferMut<F>` to `BufferMut<T>` when `F` and `T` have the
226+
/// same byte width. Each slot is read as `F`, converted, and written back as `T`-bits
227+
/// using `BufferMut`'s transmute family. Avoids allocating a second output buffer.
228+
///
229+
/// The caller has already verified `F::PTYPE.byte_width() == T::PTYPE.byte_width()`.
230+
fn cast_buffer_in_place<F, T>(
231+
buffer: BufferMut<F>,
232+
array: ArrayView<'_, Primitive>,
233+
new_validity: Validity,
234+
ctx: &mut ExecutionCtx,
235+
infallible: bool,
236+
) -> VortexResult<ArrayRef>
237+
where
238+
F: NativePType + AsPrimitive<T>,
239+
T: NativePType,
240+
{
241+
debug_assert_eq!(size_of::<F>(), size_of::<T>());
242+
debug_assert_eq!(align_of::<F>(), align_of::<T>());
243+
244+
if infallible {
245+
// `map_each_in_place` does the BufferMut<F> → BufferMut<T> transmute internally
246+
// (same size + alignment for primitives of equal byte width) and walks each slot
247+
// with the closure.
248+
let result: BufferMut<T> = buffer.map_each_in_place(|v: F| v.as_());
249+
return Ok(PrimitiveArray::new(result.freeze(), new_validity).into_array());
250+
}
251+
252+
let mask = array.validity()?.execute_mask(array.len(), ctx)?;
253+
let overflow = || {
254+
vortex_err!(
255+
Compute: "Cannot cast {} to {} — value exceeds target range",
256+
F::PTYPE, T::PTYPE,
257+
)
258+
};
259+
260+
// All-null short-circuit: zero out the buffer and skip the conversion loop entirely.
261+
if matches!(mask, Mask::AllFalse(_)) {
262+
// SAFETY: same size + alignment by NativePType same-byte-width invariant.
263+
let mut t_buf: BufferMut<T> = unsafe { buffer.transmute::<T>() };
264+
t_buf.as_mut_slice().fill(T::zero());
265+
return Ok(PrimitiveArray::new(t_buf.freeze(), new_validity).into_array());
266+
}
267+
268+
let bit_buffer = match &mask {
269+
Mask::AllTrue(n) => BitBuffer::new_set(*n),
270+
Mask::AllFalse(_) => unreachable!("handled above"),
271+
Mask::Values(m) => m.bit_buffer().clone(),
272+
};
273+
274+
let mut buffer = buffer;
275+
try_map_with_mask_in_place(
276+
ReinterpretSink::<F, T>::new(buffer.as_mut_slice()),
277+
&bit_buffer,
278+
|f_val: F, valid| -> Option<T> {
279+
<T as NumCast>::from(f_val).or_else(|| (!valid).then(T::zero))
280+
},
281+
)
282+
.map_err(|_| overflow())?;
283+
284+
// SAFETY: same size + alignment for NativePType same-byte-width pairs. Every F-slot
285+
// now holds a valid T-bit pattern because `ReinterpretSink::set_unchecked` wrote a
286+
// real `T` at every visited lane.
287+
let result: BufferMut<T> = unsafe { buffer.transmute::<T>() };
288+
Ok(PrimitiveArray::new(result.freeze(), new_validity).into_array())
289+
}
290+
207291
fn reinterpret(
208292
array: ArrayView<'_, Primitive>,
209293
new_ptype: PType,

0 commit comments

Comments
 (0)