diff --git a/fuzz/src/array/compare.rs b/fuzz/src/array/compare.rs index d16c5c19107..0f639548a3c 100644 --- a/fuzz/src/array/compare.rs +++ b/fuzz/src/array/compare.rs @@ -134,11 +134,12 @@ pub fn compare_canonical_array( let scalar_vals: Vec = (0..array.len()) .map(|i| array.scalar_at(i).vortex_expect("scalar_at")) .collect(); - BoolArray::from_iter( - scalar_vals - .iter() - .map(|v| scalar_cmp(v, value, operator).as_bool().value()), - ) + BoolArray::from_iter(scalar_vals.iter().map(|v| { + scalar_cmp(v, value, operator) + .vortex_expect("tried to compare different typed scalars") + .as_bool() + .value() + })) .into_array() } d @ (DType::Null | DType::Extension(_) | DType::Variant(_)) => { diff --git a/vortex-array/public-api.lock b/vortex-array/public-api.lock index 353f92b4d44..9219bb75cbb 100644 --- a/vortex-array/public-api.lock +++ b/vortex-array/public-api.lock @@ -17096,7 +17096,7 @@ pub fn vortex_array::scalar_fn::fns::binary::compare_nested_arrow_arrays(lhs: &d pub fn vortex_array::scalar_fn::fns::binary::or_kleene(lhs: &vortex_array::ArrayRef, rhs: &vortex_array::ArrayRef) -> vortex_error::VortexResult -pub fn vortex_array::scalar_fn::fns::binary::scalar_cmp(lhs: &vortex_array::scalar::Scalar, rhs: &vortex_array::scalar::Scalar, operator: vortex_array::scalar_fn::fns::operators::CompareOperator) -> vortex_array::scalar::Scalar +pub fn vortex_array::scalar_fn::fns::binary::scalar_cmp(lhs: &vortex_array::scalar::Scalar, rhs: &vortex_array::scalar::Scalar, operator: vortex_array::scalar_fn::fns::operators::CompareOperator) -> vortex_error::VortexResult pub mod vortex_array::scalar_fn::fns::case_when diff --git a/vortex-array/src/scalar/scalar_impl.rs b/vortex-array/src/scalar/scalar_impl.rs index f8c3da8ec57..edb102fe384 100644 --- a/vortex-array/src/scalar/scalar_impl.rs +++ b/vortex-array/src/scalar/scalar_impl.rs @@ -265,6 +265,10 @@ impl Scalar { /// We implement `PartialEq` manually because we want to ignore nullability when comparing scalars. /// Two scalars with the same value but different nullability should be considered equal. +/// +/// Note that this has **different** behavior than the [`PartialOrd`] implementation since the +/// [`PartialOrd`] returns `None` if the types are different, whereas this `PartialEq` +/// implementation simply returns `false`. impl PartialEq for Scalar { fn eq(&self, other: &Self) -> bool { self.dtype.eq_ignore_nullability(&other.dtype) && self.value == other.value diff --git a/vortex-array/src/scalar_fn/fns/binary/compare.rs b/vortex-array/src/scalar_fn/fns/binary/compare.rs index 81af2c5f868..7c31ee6aab2 100644 --- a/vortex-array/src/scalar_fn/fns/binary/compare.rs +++ b/vortex-array/src/scalar_fn/fns/binary/compare.rs @@ -9,6 +9,7 @@ use arrow_ord::cmp; use arrow_ord::ord::make_comparator; use arrow_schema::SortOptions; use vortex_error::VortexResult; +use vortex_error::vortex_err; use crate::ArrayRef; use crate::Canonical; @@ -131,7 +132,7 @@ pub(crate) fn execute_compare( // Constant-constant fast path if let (Some(lhs_const), Some(rhs_const)) = (lhs.as_opt::(), rhs.as_opt::()) { - let result = scalar_cmp(lhs_const.scalar(), rhs_const.scalar(), op); + let result = scalar_cmp(lhs_const.scalar(), rhs_const.scalar(), op)?; return Ok(ConstantArray::new(result, lhs.len()).into_array()); } @@ -150,7 +151,7 @@ fn arrow_compare_arrays( // Arrow's vectorized comparison kernels don't support nested types. // For nested types, fall back to `make_comparator` which does element-wise comparison. - let array: BooleanArray = if left.dtype().is_nested() || right.dtype().is_nested() { + let arrow_array: BooleanArray = if left.dtype().is_nested() || right.dtype().is_nested() { let rhs = right.to_array().into_arrow_preferred()?; let lhs = left.to_array().into_arrow(rhs.data_type())?; @@ -176,24 +177,36 @@ fn arrow_compare_arrays( CompareOperator::Lte => cmp::lt_eq(&lhs, &rhs)?, } }; - from_arrow_array_with_len(&array, left.len(), nullable) + + from_arrow_array_with_len(&arrow_array, left.len(), nullable) } -pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: CompareOperator) -> Scalar { +pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: CompareOperator) -> VortexResult { if lhs.is_null() | rhs.is_null() { - Scalar::null(DType::Bool(Nullability::Nullable)) - } else { - let b = match operator { - CompareOperator::Eq => lhs == rhs, - CompareOperator::NotEq => lhs != rhs, - CompareOperator::Gt => lhs > rhs, - CompareOperator::Gte => lhs >= rhs, - CompareOperator::Lt => lhs < rhs, - CompareOperator::Lte => lhs <= rhs, - }; - - Scalar::bool(b, lhs.dtype().nullability() | rhs.dtype().nullability()) + return Ok(Scalar::null(DType::Bool(Nullability::Nullable))); } + + let nullability = lhs.dtype().nullability() | rhs.dtype().nullability(); + + // We use `partial_cmp` to ensure we do not lose a type mismatch error. + let ordering = lhs.partial_cmp(rhs).ok_or_else(|| { + vortex_err!( + "Cannot compare scalars with incompatible types: {} and {}", + lhs.dtype(), + rhs.dtype() + ) + })?; + + let b = match operator { + CompareOperator::Eq => ordering.is_eq(), + CompareOperator::NotEq => ordering.is_ne(), + CompareOperator::Gt => ordering.is_gt(), + CompareOperator::Gte => ordering.is_ge(), + CompareOperator::Lt => ordering.is_lt(), + CompareOperator::Lte => ordering.is_le(), + }; + + Ok(Scalar::bool(b, nullability)) } /// Compare two Arrow arrays element-wise using [`make_comparator`]. @@ -251,8 +264,13 @@ mod tests { use crate::dtype::FieldNames; use crate::dtype::Nullability; use crate::dtype::PType; + use crate::extension::datetime::TimeUnit; + use crate::extension::datetime::Timestamp; + use crate::extension::datetime::TimestampOptions; use crate::scalar::Scalar; use crate::scalar_fn::fns::binary::compare::ConstantArray; + use crate::scalar_fn::fns::binary::scalar_cmp; + use crate::scalar_fn::fns::operators::CompareOperator; use crate::scalar_fn::fns::operators::Operator; use crate::test_harness::to_int_indices; use crate::validity::Validity; @@ -479,6 +497,35 @@ mod tests { assert_arrays_eq!(result, expected); } + /// Regression test: `scalar_cmp` must error when comparing scalars with incompatible + /// extension types (e.g., timestamps with different time units) rather than silently + /// returning a wrong result. + #[test] + fn scalar_cmp_incompatible_extension_types_errors() { + let ms_scalar = Scalar::extension::( + TimestampOptions { + unit: TimeUnit::Milliseconds, + tz: None, + }, + Scalar::from(1704067200000i64), + ); + let s_scalar = Scalar::extension::( + TimestampOptions { + unit: TimeUnit::Seconds, + tz: None, + }, + Scalar::from(1704067200i64), + ); + + // Ordering comparisons must error on incompatible types. + assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Gt).is_err()); + assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Lt).is_err()); + assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Gte).is_err()); + assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Lte).is_err()); + assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Eq).is_err()); + assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::NotEq).is_err()); + } + #[test] fn test_empty_list() { let list = ListViewArray::new( diff --git a/vortex-file/src/tests.rs b/vortex-file/src/tests.rs index c13b41e9311..e477146e5cd 100644 --- a/vortex-file/src/tests.rs +++ b/vortex-file/src/tests.rs @@ -1597,7 +1597,7 @@ async fn test_writer_with_statistics() -> VortexResult<()> { } #[tokio::test] -async fn main_test() -> Result<(), Box> { +async fn timestamp_unit_mismatch() -> Result<(), Box> { // Write file with MILLISECONDS timestamps let ts_array = PrimitiveArray::from_iter(vec![1704067200000i64, 1704153600000, 1704240000000]) .into_array(); @@ -1621,6 +1621,62 @@ async fn main_test() -> Result<(), Box> { )), ); + let mut stream = SESSION + .open_options() + .open_buffer(buf)? + .scan()? + .with_filter(filter_expr) + .into_array_stream()?; + + let result = stream.try_next().await; + + assert!(result.is_err()); + + Ok(()) +} + +/// Regression test: filtering a milliseconds timestamp column with a seconds scalar should +/// always error, regardless of how the internal children of `DateTimePartsArray` are encoded. +/// +/// This test forces `ConstantArray` encoding for the seconds/subseconds children by using a +/// compressor with Dict excluded (which triggers distinct-value computation, letting +/// `ConstantScheme` win for `[0, 0, 0]`). The scanner should still detect the time unit +/// mismatch and error, not silently return wrong results. +#[tokio::test] +async fn timestamp_unit_mismatch_errors_with_constant_children() +-> Result<(), Box> { + // Build a compressor where ConstantScheme wins for [0, 0, 0] by including Dict + // (which enables distinct-value computation). + let compressor = vortex_btrblocks::BtrBlocksCompressor::default(); + + // Write file with MILLISECONDS timestamps using this compressor. + let ts_array = PrimitiveArray::from_iter(vec![1704067200000i64, 1704153600000, 1704240000000]) + .into_array(); + let temporal = TemporalArray::new_timestamp(ts_array, TimeUnit::Milliseconds, None); + + let strategy = crate::strategy::WriteStrategyBuilder::default() + .with_compressor(compressor) + .build(); + + let mut buf = ByteBufferMut::empty(); + SESSION + .write_options() + .with_strategy(strategy) + .write(&mut buf, temporal.into_array().to_array_stream()) + .await?; + + // Read with SECONDS filter scalar — should error due to time unit mismatch. + let filter_expr = gt( + root(), + lit(Scalar::extension::( + TimestampOptions { + unit: TimeUnit::Seconds, + tz: None, + }, + Scalar::from(1704153600i64), + )), + ); + let stream = SESSION .open_options() .open_buffer(buf)? @@ -1630,7 +1686,13 @@ async fn main_test() -> Result<(), Box> { let results = stream.try_collect::>().await; - assert!(results.is_err()); + assert!( + results.is_err(), + "Expected error from timestamp unit mismatch (ms vs s), but got {} results. \ + This indicates the scanner silently applied the filter incorrectly when \ + DateTimePartsArray children use ConstantArray encoding.", + results.unwrap().len() + ); Ok(()) }