diff --git a/vortex-array/src/arrays/primitive/compute/cast.rs b/vortex-array/src/arrays/primitive/compute/cast.rs index 932d120d50f..8f8e1683e45 100644 --- a/vortex-array/src/arrays/primitive/compute/cast.rs +++ b/vortex-array/src/arrays/primitive/compute/cast.rs @@ -4,6 +4,7 @@ use vortex_buffer::Buffer; use vortex_buffer::BufferMut; use vortex_error::VortexResult; +use vortex_error::vortex_bail; use vortex_error::vortex_err; use vortex_mask::AllOr; use vortex_mask::Mask; @@ -13,8 +14,11 @@ use crate::ExecutionCtx; use crate::IntoArray; use crate::arrays::Primitive; use crate::arrays::PrimitiveArray; +use crate::compute; use crate::dtype::DType; use crate::dtype::NativePType; +use crate::dtype::Nullability; +use crate::dtype::PType; use crate::match_each_native_ptype; use crate::scalar_fn::fns::cast::CastKernel; use crate::vtable::ValidityHelper; @@ -36,7 +40,7 @@ impl CastKernel for Primitive { .clone() .cast_nullability(new_nullability, array.len())?; - // If the bit width is the same, we can short-circuit and simply update the validity + // Same ptype: zero-copy, just update validity. if array.ptype() == new_ptype { // SAFETY: validity and data buffer still have same length return Ok(Some(unsafe { @@ -49,9 +53,35 @@ impl CastKernel for Primitive { })); } + // Same-width integers have identical bit representations due to 2's + // complement. If all values fit in the target range, reinterpret with + // no allocation. + if array.ptype().is_int() + && new_ptype.is_int() + && array.ptype().byte_width() == new_ptype.byte_width() + { + if !values_fit_in(array, new_ptype) { + vortex_bail!( + Compute: "Cannot cast {} to {} — values exceed target range", + array.ptype(), + new_ptype, + ); + } + // SAFETY: both types are integers with the same size and alignment, and + // min/max confirm all valid values are representable in the target type. + return Ok(Some(unsafe { + PrimitiveArray::new_unchecked_from_handle( + array.buffer_handle().clone(), + new_ptype, + new_validity, + ) + .into_array() + })); + } + let mask = array.validity_mask()?; - // Otherwise, we need to cast the values one-by-one + // Otherwise, we need to cast the values one-by-one. Ok(Some(match_each_native_ptype!(new_ptype, |T| { match_each_native_ptype!(array.ptype(), |F| { PrimitiveArray::new(cast::(array.as_slice(), mask)?, new_validity) @@ -61,34 +91,35 @@ impl CastKernel for Primitive { } } +/// Returns `true` if all valid values in `array` are representable as `target_ptype`. +fn values_fit_in(array: &PrimitiveArray, target_ptype: PType) -> bool { + let target_dtype = DType::Primitive(target_ptype, Nullability::NonNullable); + compute::min_max(&array.clone().into_array()) + .ok() + .flatten() + .is_none_or(|mm| mm.min.cast(&target_dtype).is_ok() && mm.max.cast(&target_dtype).is_ok()) +} + fn cast(array: &[F], mask: Mask) -> VortexResult> { + let try_cast = |src: F| -> VortexResult { + T::from(src).ok_or_else(|| vortex_err!(Compute: "Failed to cast {} to {:?}", src, T::PTYPE)) + }; match mask.bit_buffer() { + AllOr::None => Ok(Buffer::zeroed(array.len())), AllOr::All => { let mut buffer = BufferMut::with_capacity(array.len()); - for item in array { - let item = T::from(*item).ok_or_else( - || vortex_err!(Compute: "Failed to cast {} to {:?}", item, T::PTYPE), - )?; + for &src in array { // SAFETY: we've pre-allocated the required capacity - unsafe { buffer.push_unchecked(item) } + unsafe { buffer.push_unchecked(try_cast(src)?) } } Ok(buffer.freeze()) } - AllOr::None => Ok(Buffer::zeroed(array.len())), AllOr::Some(b) => { - // TODO(robert): Depending on density of the buffer might be better to prefill Buffer and only write valid values let mut buffer = BufferMut::with_capacity(array.len()); - for (item, valid) in array.iter().zip(b.iter()) { - if valid { - let item = T::from(*item).ok_or_else( - || vortex_err!(Compute: "Failed to cast {} to {:?}", item, T::PTYPE), - )?; - // SAFETY: we've pre-allocated the required capacity - unsafe { buffer.push_unchecked(item) } - } else { - // SAFETY: we've pre-allocated the required capacity - unsafe { buffer.push_unchecked(T::default()) } - } + for (&src, valid) in array.iter().zip(b.iter()) { + let dst = if valid { try_cast(src)? } else { T::default() }; + // SAFETY: we've pre-allocated the required capacity + unsafe { buffer.push_unchecked(dst) } } Ok(buffer.freeze()) } @@ -183,7 +214,7 @@ mod test { .and_then(|a| a.to_canonical().map(|c| c.into_array())) .unwrap_err(); assert!(matches!(error, VortexError::Compute(..))); - assert!(error.to_string().contains("Failed to cast -1 to U32")); + assert!(error.to_string().contains("values exceed target range")); } #[test] @@ -223,6 +254,69 @@ mod test { ); } + /// Same-width integer cast where all values fit: should reinterpret the + /// buffer without allocation (pointer identity). + #[test] + fn cast_same_width_int_reinterprets_buffer() -> vortex_error::VortexResult<()> { + let src = PrimitiveArray::from_iter([0u32, 10, 100]); + let src_ptr = src.as_slice::().as_ptr(); + + let dst = src.into_array().cast(PType::I32.into())?.to_primitive(); + let dst_ptr = dst.as_slice::().as_ptr(); + + // Zero-copy: the data pointer should be identical. + assert_eq!(src_ptr as usize, dst_ptr as usize); + assert_arrays_eq!(dst, PrimitiveArray::from_iter([0i32, 10, 100])); + Ok(()) + } + + /// Same-width integer cast where values don't fit: should fall through + /// to the allocating path and produce an error. + #[test] + fn cast_same_width_int_out_of_range_errors() { + let arr = buffer![u32::MAX].into_array(); + let err = arr + .cast(PType::I32.into()) + .and_then(|a| a.to_canonical().map(|c| c.into_array())) + .unwrap_err(); + assert!(matches!(err, VortexError::Compute(..))); + } + + /// All-null array cast between same-width types should succeed without + /// touching the buffer contents. + #[test] + fn cast_same_width_all_null() -> vortex_error::VortexResult<()> { + let arr = PrimitiveArray::new(buffer![0xFFu8, 0xFF], Validity::AllInvalid); + let casted = arr + .into_array() + .cast(DType::Primitive(PType::I8, Nullability::Nullable))? + .to_primitive(); + assert_eq!(casted.len(), 2); + assert!(matches!(casted.validity(), Validity::AllInvalid)); + Ok(()) + } + + /// Same-width integer cast with nullable values: out-of-range nulls should + /// not prevent the cast from succeeding. + #[test] + fn cast_same_width_int_nullable_with_out_of_range_nulls() -> vortex_error::VortexResult<()> { + // The null position holds u32::MAX which doesn't fit in i32, but it's + // masked as invalid so the cast should still succeed via reinterpret. + let arr = PrimitiveArray::new( + buffer![u32::MAX, 0u32, 42u32], + Validity::from_iter([false, true, true]), + ); + let casted = arr + .into_array() + .cast(DType::Primitive(PType::I32, Nullability::Nullable))? + .to_primitive(); + assert_arrays_eq!( + casted, + PrimitiveArray::from_option_iter([None, Some(0i32), Some(42)]) + ); + Ok(()) + } + #[rstest] #[case(buffer![0u8, 1, 2, 3, 255].into_array())] #[case(buffer![0u16, 100, 1000, 65535].into_array())]