Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 115 additions & 21 deletions vortex-array/src/arrays/primitive/compute/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
use vortex_buffer::Buffer;
use vortex_buffer::BufferMut;
use vortex_error::VortexResult;
use vortex_error::vortex_bail;
use vortex_error::vortex_err;
use vortex_mask::AllOr;
use vortex_mask::Mask;
Expand All @@ -13,8 +14,11 @@ use crate::ExecutionCtx;
use crate::IntoArray;
use crate::arrays::Primitive;
use crate::arrays::PrimitiveArray;
use crate::compute;
use crate::dtype::DType;
use crate::dtype::NativePType;
use crate::dtype::Nullability;
use crate::dtype::PType;
use crate::match_each_native_ptype;
use crate::scalar_fn::fns::cast::CastKernel;
use crate::vtable::ValidityHelper;
Expand All @@ -36,7 +40,7 @@ impl CastKernel for Primitive {
.clone()
.cast_nullability(new_nullability, array.len())?;

// If the bit width is the same, we can short-circuit and simply update the validity
// Same ptype: zero-copy, just update validity.
if array.ptype() == new_ptype {
// SAFETY: validity and data buffer still have same length
return Ok(Some(unsafe {
Expand All @@ -49,9 +53,35 @@ impl CastKernel for Primitive {
}));
}

// Same-width integers have identical bit representations due to 2's
// complement. If all values fit in the target range, reinterpret with
// no allocation.
if array.ptype().is_int()
&& new_ptype.is_int()
&& array.ptype().byte_width() == new_ptype.byte_width()
{
if !values_fit_in(array, new_ptype) {
vortex_bail!(
Compute: "Cannot cast {} to {} — values exceed target range",
array.ptype(),
new_ptype,
);
}
// SAFETY: both types are integers with the same size and alignment, and
// min/max confirm all valid values are representable in the target type.
return Ok(Some(unsafe {
PrimitiveArray::new_unchecked_from_handle(
array.buffer_handle().clone(),
new_ptype,
new_validity,
)
.into_array()
}));
}

let mask = array.validity_mask()?;

// Otherwise, we need to cast the values one-by-one
// Otherwise, we need to cast the values one-by-one.
Ok(Some(match_each_native_ptype!(new_ptype, |T| {
match_each_native_ptype!(array.ptype(), |F| {
PrimitiveArray::new(cast::<F, T>(array.as_slice(), mask)?, new_validity)
Expand All @@ -61,34 +91,35 @@ impl CastKernel for Primitive {
}
}

/// Returns `true` if all valid values in `array` are representable as `target_ptype`.
fn values_fit_in(array: &PrimitiveArray, target_ptype: PType) -> bool {
let target_dtype = DType::Primitive(target_ptype, Nullability::NonNullable);
compute::min_max(&array.clone().into_array())
Comment thread
0ax1 marked this conversation as resolved.
.ok()
.flatten()
.is_none_or(|mm| mm.min.cast(&target_dtype).is_ok() && mm.max.cast(&target_dtype).is_ok())
}

fn cast<F: NativePType, T: NativePType>(array: &[F], mask: Mask) -> VortexResult<Buffer<T>> {
let try_cast = |src: F| -> VortexResult<T> {
T::from(src).ok_or_else(|| vortex_err!(Compute: "Failed to cast {} to {:?}", src, T::PTYPE))
};
match mask.bit_buffer() {
AllOr::None => Ok(Buffer::zeroed(array.len())),
AllOr::All => {
let mut buffer = BufferMut::with_capacity(array.len());
for item in array {
let item = T::from(*item).ok_or_else(
|| vortex_err!(Compute: "Failed to cast {} to {:?}", item, T::PTYPE),
)?;
for &src in array {
// SAFETY: we've pre-allocated the required capacity
unsafe { buffer.push_unchecked(item) }
unsafe { buffer.push_unchecked(try_cast(src)?) }
}
Ok(buffer.freeze())
}
AllOr::None => Ok(Buffer::zeroed(array.len())),
AllOr::Some(b) => {
// TODO(robert): Depending on density of the buffer might be better to prefill Buffer and only write valid values
let mut buffer = BufferMut::with_capacity(array.len());
for (item, valid) in array.iter().zip(b.iter()) {
if valid {
let item = T::from(*item).ok_or_else(
|| vortex_err!(Compute: "Failed to cast {} to {:?}", item, T::PTYPE),
)?;
// SAFETY: we've pre-allocated the required capacity
unsafe { buffer.push_unchecked(item) }
} else {
// SAFETY: we've pre-allocated the required capacity
unsafe { buffer.push_unchecked(T::default()) }
}
for (&src, valid) in array.iter().zip(b.iter()) {
let dst = if valid { try_cast(src)? } else { T::default() };
// SAFETY: we've pre-allocated the required capacity
unsafe { buffer.push_unchecked(dst) }
}
Ok(buffer.freeze())
}
Expand Down Expand Up @@ -183,7 +214,7 @@ mod test {
.and_then(|a| a.to_canonical().map(|c| c.into_array()))
.unwrap_err();
assert!(matches!(error, VortexError::Compute(..)));
assert!(error.to_string().contains("Failed to cast -1 to U32"));
assert!(error.to_string().contains("values exceed target range"));
}

#[test]
Expand Down Expand Up @@ -223,6 +254,69 @@ mod test {
);
}

/// Same-width integer cast where all values fit: should reinterpret the
/// buffer without allocation (pointer identity).
#[test]
fn cast_same_width_int_reinterprets_buffer() -> vortex_error::VortexResult<()> {
let src = PrimitiveArray::from_iter([0u32, 10, 100]);
let src_ptr = src.as_slice::<u32>().as_ptr();

let dst = src.into_array().cast(PType::I32.into())?.to_primitive();
let dst_ptr = dst.as_slice::<i32>().as_ptr();

// Zero-copy: the data pointer should be identical.
assert_eq!(src_ptr as usize, dst_ptr as usize);
assert_arrays_eq!(dst, PrimitiveArray::from_iter([0i32, 10, 100]));
Ok(())
}

/// Same-width integer cast where values don't fit: should fall through
/// to the allocating path and produce an error.
#[test]
fn cast_same_width_int_out_of_range_errors() {
let arr = buffer![u32::MAX].into_array();
let err = arr
.cast(PType::I32.into())
.and_then(|a| a.to_canonical().map(|c| c.into_array()))
.unwrap_err();
assert!(matches!(err, VortexError::Compute(..)));
}

/// All-null array cast between same-width types should succeed without
/// touching the buffer contents.
#[test]
fn cast_same_width_all_null() -> vortex_error::VortexResult<()> {
let arr = PrimitiveArray::new(buffer![0xFFu8, 0xFF], Validity::AllInvalid);
let casted = arr
.into_array()
.cast(DType::Primitive(PType::I8, Nullability::Nullable))?
.to_primitive();
assert_eq!(casted.len(), 2);
assert!(matches!(casted.validity(), Validity::AllInvalid));
Ok(())
}

/// Same-width integer cast with nullable values: out-of-range nulls should
/// not prevent the cast from succeeding.
#[test]
fn cast_same_width_int_nullable_with_out_of_range_nulls() -> vortex_error::VortexResult<()> {
// The null position holds u32::MAX which doesn't fit in i32, but it's
// masked as invalid so the cast should still succeed via reinterpret.
let arr = PrimitiveArray::new(
buffer![u32::MAX, 0u32, 42u32],
Validity::from_iter([false, true, true]),
);
let casted = arr
.into_array()
.cast(DType::Primitive(PType::I32, Nullability::Nullable))?
.to_primitive();
assert_arrays_eq!(
casted,
PrimitiveArray::from_option_iter([None, Some(0i32), Some(42)])
);
Ok(())
}

#[rstest]
#[case(buffer![0u8, 1, 2, 3, 255].into_array())]
#[case(buffer![0u16, 100, 1000, 65535].into_array())]
Expand Down
Loading