Skip to content

Commit b519e78

Browse files
committed
perf[vortex-array]: use from_trusted_len_iter in primitive casts
Some of our scan profiles show 10% of scan cpu time is spent in integer widening casts (nullable dictionary codes). This commit simplifies primitive casts by hoisting a lot of hot loop branching logic. Specifically, this commit relies on values_fit_in to verify representability so that we can avoid a potential validity and error check in the hot loop. Additionally from_trusted_len_iter lets the destination BufferMut optimize the actual cast instead of using push_unchecked for each element. Signed-off-by: Alfonso Subiotto Marques <alfonso.subiotto@polarsignals.com>
1 parent fbfa072 commit b519e78

2 files changed

Lines changed: 38 additions & 39 deletions

File tree

vortex-array/benches/cast_primitive.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use vortex_array::builtins::ArrayBuiltins;
99
use vortex_array::dtype::DType;
1010
use vortex_array::dtype::Nullability;
1111
use vortex_array::dtype::PType;
12+
use vortex_array::expr::stats::Stat;
1213

1314
fn main() {
1415
divan::main();
@@ -28,6 +29,8 @@ fn cast_u16_to_u32(bencher: Bencher) {
2829
}
2930
}))
3031
.into_array();
32+
// Pre-compute min/max so values_fit_in is a cache hit during the benchmark.
33+
arr.statistics().compute_all(&[Stat::Min, Stat::Max]).ok();
3134
bencher.with_inputs(|| arr.clone()).bench_refs(|a| {
3235
#[expect(clippy::unwrap_used)]
3336
a.cast(DType::Primitive(PType::U32, Nullability::Nullable))

vortex-array/src/arrays/primitive/compute/cast.rs

Lines changed: 35 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,6 @@ use vortex_buffer::Buffer;
55
use vortex_buffer::BufferMut;
66
use vortex_error::VortexResult;
77
use vortex_error::vortex_bail;
8-
use vortex_error::vortex_err;
9-
use vortex_mask::AllOr;
10-
use vortex_mask::Mask;
118

129
use crate::ArrayRef;
1310
use crate::ExecutionCtx;
@@ -53,20 +50,21 @@ impl CastKernel for Primitive {
5350
}));
5451
}
5552

53+
if !values_fit_in(array, new_ptype, ctx) {
54+
vortex_bail!(
55+
Compute: "Cannot cast {} to {} — values exceed target range",
56+
array.ptype(),
57+
new_ptype,
58+
);
59+
}
60+
5661
// Same-width integers have identical bit representations due to 2's
5762
// complement. If all values fit in the target range, reinterpret with
5863
// no allocation.
5964
if array.ptype().is_int()
6065
&& new_ptype.is_int()
6166
&& array.ptype().byte_width() == new_ptype.byte_width()
6267
{
63-
if !values_fit_in(array, new_ptype, ctx) {
64-
vortex_bail!(
65-
Compute: "Cannot cast {} to {} — values exceed target range",
66-
array.ptype(),
67-
new_ptype,
68-
);
69-
}
7068
// SAFETY: both types are integers with the same size and alignment, and
7169
// min/max confirm all valid values are representable in the target type.
7270
return Ok(Some(unsafe {
@@ -79,13 +77,10 @@ impl CastKernel for Primitive {
7977
}));
8078
}
8179

82-
let mask = array.validity_mask();
83-
84-
// Otherwise, we need to cast the values one-by-one.
80+
// Otherwise, cast the values element-wise.
8581
Ok(Some(match_each_native_ptype!(new_ptype, |T| {
8682
match_each_native_ptype!(array.ptype(), |F| {
87-
PrimitiveArray::new(cast::<F, T>(array.as_slice(), mask)?, new_validity)
88-
.into_array()
83+
PrimitiveArray::new(cast::<F, T>(array.as_slice()), new_validity).into_array()
8984
})
9085
})))
9186
}
@@ -104,30 +99,12 @@ fn values_fit_in(
10499
.is_none_or(|mm| mm.min.cast(&target_dtype).is_ok() && mm.max.cast(&target_dtype).is_ok())
105100
}
106101

107-
fn cast<F: NativePType, T: NativePType>(array: &[F], mask: Mask) -> VortexResult<Buffer<T>> {
108-
let try_cast = |src: F| -> VortexResult<T> {
109-
T::from(src).ok_or_else(|| vortex_err!(Compute: "Failed to cast {} to {:?}", src, T::PTYPE))
110-
};
111-
match mask.bit_buffer() {
112-
AllOr::None => Ok(Buffer::zeroed(array.len())),
113-
AllOr::All => {
114-
let mut buffer = BufferMut::with_capacity(array.len());
115-
for &src in array {
116-
// SAFETY: we've pre-allocated the required capacity
117-
unsafe { buffer.push_unchecked(try_cast(src)?) }
118-
}
119-
Ok(buffer.freeze())
120-
}
121-
AllOr::Some(b) => {
122-
let mut buffer = BufferMut::with_capacity(array.len());
123-
for (&src, valid) in array.iter().zip(b.iter()) {
124-
let dst = if valid { try_cast(src)? } else { T::default() };
125-
// SAFETY: we've pre-allocated the required capacity
126-
unsafe { buffer.push_unchecked(dst) }
127-
}
128-
Ok(buffer.freeze())
129-
}
130-
}
102+
/// Caller must ensure all valid values are representable via `values_fit_in`;
103+
/// `unwrap_or_default` only fires at invalid positions where the physical
104+
/// value is out of range.
105+
fn cast<F: NativePType, T: NativePType>(array: &[F]) -> Buffer<T> {
106+
BufferMut::from_trusted_len_iter(array.iter().map(|&src| T::from(src).unwrap_or_default()))
107+
.freeze()
131108
}
132109

133110
#[cfg(test)]
@@ -319,6 +296,23 @@ mod test {
319296
Ok(())
320297
}
321298

299+
#[test]
300+
fn cast_u32_to_u8_with_out_of_range_nulls() -> vortex_error::VortexResult<()> {
301+
let arr = PrimitiveArray::new(
302+
buffer![1000u32, 10u32, 42u32],
303+
Validity::from_iter([false, true, true]),
304+
);
305+
let casted = arr
306+
.into_array()
307+
.cast(DType::Primitive(PType::U8, Nullability::Nullable))?
308+
.to_primitive();
309+
assert_arrays_eq!(
310+
casted,
311+
PrimitiveArray::from_option_iter([None, Some(10u8), Some(42)])
312+
);
313+
Ok(())
314+
}
315+
322316
#[rstest]
323317
#[case(buffer![0u8, 1, 2, 3, 255].into_array())]
324318
#[case(buffer![0u16, 100, 1000, 65535].into_array())]
@@ -329,7 +323,9 @@ mod test {
329323
#[case(buffer![-1000000i32, -1, 0, 1, 1000000].into_array())]
330324
#[case(buffer![-1000000000i64, -1, 0, 1, 1000000000].into_array())]
331325
#[case(buffer![0.0f32, 1.5, -2.5, 100.0, 1e6].into_array())]
326+
#[case(buffer![f32::NAN, f32::INFINITY, f32::NEG_INFINITY, 0.0f32].into_array())]
332327
#[case(buffer![0.0f64, 1.5, -2.5, 100.0, 1e12].into_array())]
328+
#[case(buffer![f64::NAN, f64::INFINITY, f64::NEG_INFINITY, 0.0f64].into_array())]
333329
#[case(PrimitiveArray::from_option_iter([Some(1u8), None, Some(255), Some(0), None]).into_array())]
334330
#[case(PrimitiveArray::from_option_iter([Some(1i32), None, Some(-100), Some(0), None]).into_array())]
335331
#[case(buffer![42u32].into_array())]

0 commit comments

Comments
 (0)