|
1 | 1 | // SPDX-License-Identifier: Apache-2.0 |
2 | 2 | // SPDX-FileCopyrightText: Copyright the Vortex contributors |
3 | 3 |
|
| 4 | +use std::mem::align_of; |
| 5 | +use std::mem::size_of; |
| 6 | + |
4 | 7 | use num_traits::AsPrimitive; |
5 | 8 | use num_traits::NumCast; |
| 9 | +use vortex_buffer::BitBuffer; |
6 | 10 | use vortex_buffer::Buffer; |
7 | 11 | use vortex_buffer::BufferMut; |
| 12 | +use vortex_buffer::lane_ops_indexed::ReinterpretSink; |
8 | 13 | use vortex_buffer::lane_ops_indexed::map_no_validity; |
9 | 14 | use vortex_buffer::lane_ops_indexed::try_map_no_validity; |
10 | 15 | use vortex_buffer::lane_ops_indexed::try_map_with_mask; |
| 16 | +use vortex_buffer::lane_ops_indexed::try_map_with_mask_in_place; |
11 | 17 | use vortex_error::VortexResult; |
12 | 18 | use vortex_error::vortex_bail; |
13 | 19 | use vortex_error::vortex_err; |
@@ -147,9 +153,21 @@ where |
147 | 153 | // Skip the fallible kernel when the conversion is infallible by type alone (widening) or |
148 | 154 | // when cached min/max prove every value fits in `T`. |
149 | 155 | let target_dtype = DType::Primitive(T::PTYPE, Nullability::NonNullable); |
150 | | - if casts_losslessly_to(F::PTYPE, T::PTYPE) |
151 | | - || cached_values_fit_in(array, &target_dtype) == Some(true) |
| 156 | + let infallible = casts_losslessly_to(F::PTYPE, T::PTYPE) |
| 157 | + || cached_values_fit_in(array, &target_dtype) == Some(true); |
| 158 | + |
| 159 | + // Same-bit-width in-place fast path: when F and T have the same byte width and the |
| 160 | + // buffer is uniquely owned, mutate in place and transmute the wrapper. Saves the |
| 161 | + // output allocation. Falls through to the out-of-place path when the buffer is shared |
| 162 | + // (the common case under the current borrow-based kernel API). |
| 163 | + let same_bit_width = F::PTYPE.byte_width() == T::PTYPE.byte_width(); |
| 164 | + if same_bit_width |
| 165 | + && let Ok(buffer_mut) = array.into_owned().try_into_buffer_mut::<F>() |
152 | 166 | { |
| 167 | + return cast_buffer_in_place::<F, T>(buffer_mut, array, new_validity, ctx, infallible); |
| 168 | + } |
| 169 | + |
| 170 | + if infallible { |
153 | 171 | let mut buffer = BufferMut::<T>::with_capacity(values.len()); |
154 | 172 | // Truncating `as`-cast — safe here because stats prove every valid value fits. |
155 | 173 | // Null lanes' underlying garbage gets truncated/wrapped (harmless: the result |
@@ -204,6 +222,72 @@ where |
204 | 222 | Ok(PrimitiveArray::new(buffer, new_validity).into_array()) |
205 | 223 | } |
206 | 224 |
|
| 225 | +/// In-place cast of an owned `BufferMut<F>` to `BufferMut<T>` when `F` and `T` have the |
| 226 | +/// same byte width. Each slot is read as `F`, converted, and written back as `T`-bits |
| 227 | +/// using `BufferMut`'s transmute family. Avoids allocating a second output buffer. |
| 228 | +/// |
| 229 | +/// The caller has already verified `F::PTYPE.byte_width() == T::PTYPE.byte_width()`. |
| 230 | +fn cast_buffer_in_place<F, T>( |
| 231 | + buffer: BufferMut<F>, |
| 232 | + array: ArrayView<'_, Primitive>, |
| 233 | + new_validity: Validity, |
| 234 | + ctx: &mut ExecutionCtx, |
| 235 | + infallible: bool, |
| 236 | +) -> VortexResult<ArrayRef> |
| 237 | +where |
| 238 | + F: NativePType + AsPrimitive<T>, |
| 239 | + T: NativePType, |
| 240 | +{ |
| 241 | + debug_assert_eq!(size_of::<F>(), size_of::<T>()); |
| 242 | + debug_assert_eq!(align_of::<F>(), align_of::<T>()); |
| 243 | + |
| 244 | + if infallible { |
| 245 | + // `map_each_in_place` does the BufferMut<F> → BufferMut<T> transmute internally |
| 246 | + // (same size + alignment for primitives of equal byte width) and walks each slot |
| 247 | + // with the closure. |
| 248 | + let result: BufferMut<T> = buffer.map_each_in_place(|v: F| v.as_()); |
| 249 | + return Ok(PrimitiveArray::new(result.freeze(), new_validity).into_array()); |
| 250 | + } |
| 251 | + |
| 252 | + let mask = array.validity()?.execute_mask(array.len(), ctx)?; |
| 253 | + let overflow = || { |
| 254 | + vortex_err!( |
| 255 | + Compute: "Cannot cast {} to {} — value exceeds target range", |
| 256 | + F::PTYPE, T::PTYPE, |
| 257 | + ) |
| 258 | + }; |
| 259 | + |
| 260 | + // All-null short-circuit: zero out the buffer and skip the conversion loop entirely. |
| 261 | + if matches!(mask, Mask::AllFalse(_)) { |
| 262 | + // SAFETY: same size + alignment by NativePType same-byte-width invariant. |
| 263 | + let mut t_buf: BufferMut<T> = unsafe { buffer.transmute::<T>() }; |
| 264 | + t_buf.as_mut_slice().fill(T::zero()); |
| 265 | + return Ok(PrimitiveArray::new(t_buf.freeze(), new_validity).into_array()); |
| 266 | + } |
| 267 | + |
| 268 | + let bit_buffer = match &mask { |
| 269 | + Mask::AllTrue(n) => BitBuffer::new_set(*n), |
| 270 | + Mask::AllFalse(_) => unreachable!("handled above"), |
| 271 | + Mask::Values(m) => m.bit_buffer().clone(), |
| 272 | + }; |
| 273 | + |
| 274 | + let mut buffer = buffer; |
| 275 | + try_map_with_mask_in_place( |
| 276 | + ReinterpretSink::<F, T>::new(buffer.as_mut_slice()), |
| 277 | + &bit_buffer, |
| 278 | + |f_val: F, valid| -> Option<T> { |
| 279 | + <T as NumCast>::from(f_val).or_else(|| (!valid).then(T::zero)) |
| 280 | + }, |
| 281 | + ) |
| 282 | + .map_err(|_| overflow())?; |
| 283 | + |
| 284 | + // SAFETY: same size + alignment for NativePType same-byte-width pairs. Every F-slot |
| 285 | + // now holds a valid T-bit pattern because `ReinterpretSink::set_unchecked` wrote a |
| 286 | + // real `T` at every visited lane. |
| 287 | + let result: BufferMut<T> = unsafe { buffer.transmute::<T>() }; |
| 288 | + Ok(PrimitiveArray::new(result.freeze(), new_validity).into_array()) |
| 289 | +} |
| 290 | + |
207 | 291 | fn reinterpret( |
208 | 292 | array: ArrayView<'_, Primitive>, |
209 | 293 | new_ptype: PType, |
|
0 commit comments