Skip to content

Commit 507a37b

Browse files
committed
perf: skip allocation for prim cast if possible
Signed-off-by: Alexander Droste <alexander.droste@protonmail.com>
1 parent 876813b commit 507a37b

1 file changed

Lines changed: 107 additions & 20 deletions

File tree

  • vortex-array/src/arrays/primitive/compute

vortex-array/src/arrays/primitive/compute/cast.rs

Lines changed: 107 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,11 @@ use crate::ExecutionCtx;
1313
use crate::IntoArray;
1414
use crate::arrays::Primitive;
1515
use crate::arrays::PrimitiveArray;
16+
use crate::compute;
1617
use crate::dtype::DType;
1718
use crate::dtype::NativePType;
19+
use crate::dtype::Nullability;
20+
use crate::dtype::PType;
1821
use crate::match_each_native_ptype;
1922
use crate::scalar_fn::fns::cast::CastKernel;
2023
use crate::vtable::ValidityHelper;
@@ -36,7 +39,7 @@ impl CastKernel for Primitive {
3639
.clone()
3740
.cast_nullability(new_nullability, array.len())?;
3841

39-
// If the bit width is the same, we can short-circuit and simply update the validity
42+
// Same ptype: zero-copy, just update validity.
4043
if array.ptype() == new_ptype {
4144
// SAFETY: validity and data buffer still have same length
4245
return Ok(Some(unsafe {
@@ -49,9 +52,29 @@ impl CastKernel for Primitive {
4952
}));
5053
}
5154

55+
// Same-width integers have identical bit representations due to 2's
56+
// complement. If all values fit in the target range, reinterpret with
57+
// no allocation.
58+
if array.ptype().is_int()
59+
&& new_ptype.is_int()
60+
&& array.ptype().byte_width() == new_ptype.byte_width()
61+
&& values_fit_in(array, new_ptype)
62+
{
63+
// SAFETY: both types are integers with the same size and alignment, and
64+
// min/max confirm all valid values are representable in the target type.
65+
return Ok(Some(unsafe {
66+
PrimitiveArray::new_unchecked_from_handle(
67+
array.buffer_handle().clone(),
68+
new_ptype,
69+
new_validity,
70+
)
71+
.into_array()
72+
}));
73+
}
74+
5275
let mask = array.validity_mask()?;
5376

54-
// Otherwise, we need to cast the values one-by-one
77+
// Otherwise, we need to cast the values one-by-one.
5578
Ok(Some(match_each_native_ptype!(new_ptype, |T| {
5679
match_each_native_ptype!(array.ptype(), |F| {
5780
PrimitiveArray::new(cast::<F, T>(array.as_slice(), mask)?, new_validity)
@@ -61,34 +84,35 @@ impl CastKernel for Primitive {
6184
}
6285
}
6386

87+
/// Returns `true` if all valid values in `array` are representable as `target_ptype`.
88+
fn values_fit_in(array: &PrimitiveArray, target_ptype: PType) -> bool {
89+
let target_dtype = DType::Primitive(target_ptype, Nullability::NonNullable);
90+
compute::min_max(&array.clone().into_array())
91+
.ok()
92+
.flatten()
93+
.is_none_or(|mm| mm.min.cast(&target_dtype).is_ok() && mm.max.cast(&target_dtype).is_ok())
94+
}
95+
6496
fn cast<F: NativePType, T: NativePType>(array: &[F], mask: Mask) -> VortexResult<Buffer<T>> {
97+
let try_cast = |src: F| -> VortexResult<T> {
98+
T::from(src).ok_or_else(|| vortex_err!(Compute: "Failed to cast {} to {:?}", src, T::PTYPE))
99+
};
65100
match mask.bit_buffer() {
101+
AllOr::None => Ok(Buffer::zeroed(array.len())),
66102
AllOr::All => {
67103
let mut buffer = BufferMut::with_capacity(array.len());
68-
for item in array {
69-
let item = T::from(*item).ok_or_else(
70-
|| vortex_err!(Compute: "Failed to cast {} to {:?}", item, T::PTYPE),
71-
)?;
104+
for &src in array {
72105
// SAFETY: we've pre-allocated the required capacity
73-
unsafe { buffer.push_unchecked(item) }
106+
unsafe { buffer.push_unchecked(try_cast(src)?) }
74107
}
75108
Ok(buffer.freeze())
76109
}
77-
AllOr::None => Ok(Buffer::zeroed(array.len())),
78110
AllOr::Some(b) => {
79-
// TODO(robert): Depending on density of the buffer might be better to prefill Buffer and only write valid values
80111
let mut buffer = BufferMut::with_capacity(array.len());
81-
for (item, valid) in array.iter().zip(b.iter()) {
82-
if valid {
83-
let item = T::from(*item).ok_or_else(
84-
|| vortex_err!(Compute: "Failed to cast {} to {:?}", item, T::PTYPE),
85-
)?;
86-
// SAFETY: we've pre-allocated the required capacity
87-
unsafe { buffer.push_unchecked(item) }
88-
} else {
89-
// SAFETY: we've pre-allocated the required capacity
90-
unsafe { buffer.push_unchecked(T::default()) }
91-
}
112+
for (&src, valid) in array.iter().zip(b.iter()) {
113+
let dst = if valid { try_cast(src)? } else { T::default() };
114+
// SAFETY: we've pre-allocated the required capacity
115+
unsafe { buffer.push_unchecked(dst) }
92116
}
93117
Ok(buffer.freeze())
94118
}
@@ -223,6 +247,69 @@ mod test {
223247
);
224248
}
225249

250+
/// Same-width integer cast where all values fit: should reinterpret the
251+
/// buffer without allocation (pointer identity).
252+
#[test]
253+
fn cast_same_width_int_reinterprets_buffer() -> vortex_error::VortexResult<()> {
254+
let src = PrimitiveArray::from_iter([0u32, 10, 100]);
255+
let src_ptr = src.as_slice::<u32>().as_ptr();
256+
257+
let dst = src.into_array().cast(PType::I32.into())?.to_primitive();
258+
let dst_ptr = dst.as_slice::<i32>().as_ptr();
259+
260+
// Zero-copy: the data pointer should be identical.
261+
assert_eq!(src_ptr as usize, dst_ptr as usize);
262+
assert_arrays_eq!(dst, PrimitiveArray::from_iter([0i32, 10, 100]));
263+
Ok(())
264+
}
265+
266+
/// Same-width integer cast where values don't fit: should fall through
267+
/// to the allocating path and produce an error.
268+
#[test]
269+
fn cast_same_width_int_out_of_range_errors() {
270+
let arr = buffer![u32::MAX].into_array();
271+
let err = arr
272+
.cast(PType::I32.into())
273+
.and_then(|a| a.to_canonical().map(|c| c.into_array()))
274+
.unwrap_err();
275+
assert!(matches!(err, VortexError::Compute(..)));
276+
}
277+
278+
/// All-null array cast between same-width types should succeed without
279+
/// touching the buffer contents.
280+
#[test]
281+
fn cast_same_width_all_null() -> vortex_error::VortexResult<()> {
282+
let arr = PrimitiveArray::new(buffer![0xFFu8, 0xFF], Validity::AllInvalid);
283+
let casted = arr
284+
.into_array()
285+
.cast(DType::Primitive(PType::I8, Nullability::Nullable))?
286+
.to_primitive();
287+
assert_eq!(casted.len(), 2);
288+
assert!(matches!(casted.validity(), Validity::AllInvalid));
289+
Ok(())
290+
}
291+
292+
/// Same-width integer cast with nullable values: out-of-range nulls should
293+
/// not prevent the cast from succeeding.
294+
#[test]
295+
fn cast_same_width_int_nullable_with_out_of_range_nulls() -> vortex_error::VortexResult<()> {
296+
// The null position holds u32::MAX which doesn't fit in i32, but it's
297+
// masked as invalid so the cast should still succeed via reinterpret.
298+
let arr = PrimitiveArray::new(
299+
buffer![u32::MAX, 0u32, 42u32],
300+
Validity::from_iter([false, true, true]),
301+
);
302+
let casted = arr
303+
.into_array()
304+
.cast(DType::Primitive(PType::I32, Nullability::Nullable))?
305+
.to_primitive();
306+
assert_arrays_eq!(
307+
casted,
308+
PrimitiveArray::from_option_iter([None, Some(0i32), Some(42)])
309+
);
310+
Ok(())
311+
}
312+
226313
#[rstest]
227314
#[case(buffer![0u8, 1, 2, 3, 255].into_array())]
228315
#[case(buffer![0u16, 100, 1000, 65535].into_array())]

0 commit comments

Comments
 (0)