Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions fuzz/src/array/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,12 @@ pub fn compare_canonical_array(
let scalar_vals: Vec<Scalar> = (0..array.len())
.map(|i| array.scalar_at(i).vortex_expect("scalar_at"))
.collect();
BoolArray::from_iter(
scalar_vals
.iter()
.map(|v| scalar_cmp(v, value, operator).as_bool().value()),
)
BoolArray::from_iter(scalar_vals.iter().map(|v| {
scalar_cmp(v, value, operator)
.vortex_expect("tried to compare different typed scalars")
.as_bool()
.value()
}))
.into_array()
}
d @ (DType::Null | DType::Extension(_) | DType::Variant(_)) => {
Expand Down
2 changes: 1 addition & 1 deletion vortex-array/public-api.lock
Original file line number Diff line number Diff line change
Expand Up @@ -17096,7 +17096,7 @@ pub fn vortex_array::scalar_fn::fns::binary::compare_nested_arrow_arrays(lhs: &d

pub fn vortex_array::scalar_fn::fns::binary::or_kleene(lhs: &vortex_array::ArrayRef, rhs: &vortex_array::ArrayRef) -> vortex_error::VortexResult<vortex_array::ArrayRef>

pub fn vortex_array::scalar_fn::fns::binary::scalar_cmp(lhs: &vortex_array::scalar::Scalar, rhs: &vortex_array::scalar::Scalar, operator: vortex_array::scalar_fn::fns::operators::CompareOperator) -> vortex_array::scalar::Scalar
pub fn vortex_array::scalar_fn::fns::binary::scalar_cmp(lhs: &vortex_array::scalar::Scalar, rhs: &vortex_array::scalar::Scalar, operator: vortex_array::scalar_fn::fns::operators::CompareOperator) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>

pub mod vortex_array::scalar_fn::fns::case_when

Expand Down
4 changes: 4 additions & 0 deletions vortex-array/src/scalar/scalar_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,10 @@ impl Scalar {

/// We implement `PartialEq` manually because we want to ignore nullability when comparing scalars.
/// Two scalars with the same value but different nullability should be considered equal.
///
/// Note that this has **different** behavior than the [`PartialOrd`] implementation since the
/// [`PartialOrd`] returns `None` if the types are different, whereas this `PartialEq`
/// implementation simply returns `false`.
impl PartialEq for Scalar {
fn eq(&self, other: &Self) -> bool {
self.dtype.eq_ignore_nullability(&other.dtype) && self.value == other.value
Expand Down
79 changes: 63 additions & 16 deletions vortex-array/src/scalar_fn/fns/binary/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use arrow_ord::cmp;
use arrow_ord::ord::make_comparator;
use arrow_schema::SortOptions;
use vortex_error::VortexResult;
use vortex_error::vortex_err;

use crate::ArrayRef;
use crate::Canonical;
Expand Down Expand Up @@ -131,7 +132,7 @@ pub(crate) fn execute_compare(
// Constant-constant fast path
if let (Some(lhs_const), Some(rhs_const)) = (lhs.as_opt::<Constant>(), rhs.as_opt::<Constant>())
{
let result = scalar_cmp(lhs_const.scalar(), rhs_const.scalar(), op);
let result = scalar_cmp(lhs_const.scalar(), rhs_const.scalar(), op)?;
return Ok(ConstantArray::new(result, lhs.len()).into_array());
}

Expand All @@ -150,7 +151,7 @@ fn arrow_compare_arrays(

// Arrow's vectorized comparison kernels don't support nested types.
// For nested types, fall back to `make_comparator` which does element-wise comparison.
let array: BooleanArray = if left.dtype().is_nested() || right.dtype().is_nested() {
let arrow_array: BooleanArray = if left.dtype().is_nested() || right.dtype().is_nested() {
let rhs = right.to_array().into_arrow_preferred()?;
let lhs = left.to_array().into_arrow(rhs.data_type())?;

Expand All @@ -176,24 +177,36 @@ fn arrow_compare_arrays(
CompareOperator::Lte => cmp::lt_eq(&lhs, &rhs)?,
}
};
from_arrow_array_with_len(&array, left.len(), nullable)

from_arrow_array_with_len(&arrow_array, left.len(), nullable)
}

pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: CompareOperator) -> Scalar {
pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: CompareOperator) -> VortexResult<Scalar> {
Comment thread
joseph-isaacs marked this conversation as resolved.
if lhs.is_null() | rhs.is_null() {
Scalar::null(DType::Bool(Nullability::Nullable))
} else {
let b = match operator {
CompareOperator::Eq => lhs == rhs,
CompareOperator::NotEq => lhs != rhs,
CompareOperator::Gt => lhs > rhs,
CompareOperator::Gte => lhs >= rhs,
CompareOperator::Lt => lhs < rhs,
CompareOperator::Lte => lhs <= rhs,
};

Scalar::bool(b, lhs.dtype().nullability() | rhs.dtype().nullability())
return Ok(Scalar::null(DType::Bool(Nullability::Nullable)));
}

let nullability = lhs.dtype().nullability() | rhs.dtype().nullability();

// We use `partial_cmp` to ensure we do not lose a type mismatch error.
let ordering = lhs.partial_cmp(rhs).ok_or_else(|| {
vortex_err!(
"Cannot compare scalars with incompatible types: {} and {}",
lhs.dtype(),
rhs.dtype()
)
})?;

let b = match operator {
CompareOperator::Eq => ordering.is_eq(),
CompareOperator::NotEq => ordering.is_ne(),
CompareOperator::Gt => ordering.is_gt(),
CompareOperator::Gte => ordering.is_ge(),
CompareOperator::Lt => ordering.is_lt(),
CompareOperator::Lte => ordering.is_le(),
};

Ok(Scalar::bool(b, nullability))
}

/// Compare two Arrow arrays element-wise using [`make_comparator`].
Expand Down Expand Up @@ -251,8 +264,13 @@ mod tests {
use crate::dtype::FieldNames;
use crate::dtype::Nullability;
use crate::dtype::PType;
use crate::extension::datetime::TimeUnit;
use crate::extension::datetime::Timestamp;
use crate::extension::datetime::TimestampOptions;
use crate::scalar::Scalar;
use crate::scalar_fn::fns::binary::compare::ConstantArray;
use crate::scalar_fn::fns::binary::scalar_cmp;
use crate::scalar_fn::fns::operators::CompareOperator;
use crate::scalar_fn::fns::operators::Operator;
use crate::test_harness::to_int_indices;
use crate::validity::Validity;
Expand Down Expand Up @@ -479,6 +497,35 @@ mod tests {
assert_arrays_eq!(result, expected);
}

/// Regression test: `scalar_cmp` must error when comparing scalars with incompatible
/// extension types (e.g., timestamps with different time units) rather than silently
/// returning a wrong result.
#[test]
fn scalar_cmp_incompatible_extension_types_errors() {
let ms_scalar = Scalar::extension::<Timestamp>(
TimestampOptions {
unit: TimeUnit::Milliseconds,
tz: None,
},
Scalar::from(1704067200000i64),
);
let s_scalar = Scalar::extension::<Timestamp>(
TimestampOptions {
unit: TimeUnit::Seconds,
tz: None,
},
Scalar::from(1704067200i64),
);

// Ordering comparisons must error on incompatible types.
assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Gt).is_err());
assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Lt).is_err());
assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Gte).is_err());
assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Lte).is_err());
assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Eq).is_err());
assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::NotEq).is_err());
}

#[test]
fn test_empty_list() {
let list = ListViewArray::new(
Expand Down
66 changes: 64 additions & 2 deletions vortex-file/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1597,7 +1597,7 @@ async fn test_writer_with_statistics() -> VortexResult<()> {
}

#[tokio::test]
async fn main_test() -> Result<(), Box<dyn std::error::Error>> {
async fn timestamp_unit_mismatch() -> Result<(), Box<dyn std::error::Error>> {
// Write file with MILLISECONDS timestamps
let ts_array = PrimitiveArray::from_iter(vec![1704067200000i64, 1704153600000, 1704240000000])
.into_array();
Expand All @@ -1621,6 +1621,62 @@ async fn main_test() -> Result<(), Box<dyn std::error::Error>> {
)),
);

let mut stream = SESSION
.open_options()
.open_buffer(buf)?
.scan()?
.with_filter(filter_expr)
.into_array_stream()?;

let result = stream.try_next().await;

assert!(result.is_err());

Ok(())
}

/// Regression test: filtering a milliseconds timestamp column with a seconds scalar should
/// always error, regardless of how the internal children of `DateTimePartsArray` are encoded.
///
/// This test forces `ConstantArray` encoding for the seconds/subseconds children by using a
/// compressor with Dict excluded (which triggers distinct-value computation, letting
/// `ConstantScheme` win for `[0, 0, 0]`). The scanner should still detect the time unit
/// mismatch and error, not silently return wrong results.
#[tokio::test]
async fn timestamp_unit_mismatch_errors_with_constant_children()
-> Result<(), Box<dyn std::error::Error>> {
// Build a compressor where ConstantScheme wins for [0, 0, 0] by including Dict
// (which enables distinct-value computation).
let compressor = vortex_btrblocks::BtrBlocksCompressor::default();

// Write file with MILLISECONDS timestamps using this compressor.
let ts_array = PrimitiveArray::from_iter(vec![1704067200000i64, 1704153600000, 1704240000000])
.into_array();
let temporal = TemporalArray::new_timestamp(ts_array, TimeUnit::Milliseconds, None);

let strategy = crate::strategy::WriteStrategyBuilder::default()
.with_compressor(compressor)
.build();

let mut buf = ByteBufferMut::empty();
SESSION
.write_options()
.with_strategy(strategy)
.write(&mut buf, temporal.into_array().to_array_stream())
.await?;

// Read with SECONDS filter scalar — should error due to time unit mismatch.
let filter_expr = gt(
root(),
lit(Scalar::extension::<Timestamp>(
TimestampOptions {
unit: TimeUnit::Seconds,
tz: None,
},
Scalar::from(1704153600i64),
)),
);

let stream = SESSION
.open_options()
.open_buffer(buf)?
Expand All @@ -1630,7 +1686,13 @@ async fn main_test() -> Result<(), Box<dyn std::error::Error>> {

let results = stream.try_collect::<Vec<_>>().await;

assert!(results.is_err());
assert!(
results.is_err(),
"Expected error from timestamp unit mismatch (ms vs s), but got {} results. \
This indicates the scanner silently applied the filter incorrectly when \
DateTimePartsArray children use ConstantArray encoding.",
results.unwrap().len()
);

Ok(())
}
Loading