Skip to content

Commit c4b9949

Browse files
Fix scalar partial ordering comparison (#6999)
## Summary When we compare `Scalar`s, we have to be careful that we do not use `<` or `>` if we do not know that the scalars have different types. If they have different types, the result will not panic or raise an error, it will just return `false` for `>`. This changes the `scalar_cmp` function inside the comparison execution to use the `partial_cmp` method directly, where we map the optional result to an error and raise it. ## Testing Adds 2 regression tests, the first one is the one I stumbled on and the second is a more targeted one ## Unresolved Questions - Why is this function public? - Are there any other places that we do scalar comparison with the built-in `>` operators? If so, we may need to fix them. --------- Signed-off-by: Connor Tsui <connor.tsui20@gmail.com> Co-authored-by: Joe Isaacs <joe.isaacs@live.co.uk>
1 parent 6d6faf3 commit c4b9949

5 files changed

Lines changed: 138 additions & 24 deletions

File tree

fuzz/src/array/compare.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,12 @@ pub fn compare_canonical_array(
134134
let scalar_vals: Vec<Scalar> = (0..array.len())
135135
.map(|i| array.scalar_at(i).vortex_expect("scalar_at"))
136136
.collect();
137-
BoolArray::from_iter(
138-
scalar_vals
139-
.iter()
140-
.map(|v| scalar_cmp(v, value, operator).as_bool().value()),
141-
)
137+
BoolArray::from_iter(scalar_vals.iter().map(|v| {
138+
scalar_cmp(v, value, operator)
139+
.vortex_expect("tried to compare different typed scalars")
140+
.as_bool()
141+
.value()
142+
}))
142143
.into_array()
143144
}
144145
d @ (DType::Null | DType::Extension(_) | DType::Variant(_)) => {

vortex-array/public-api.lock

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16984,7 +16984,7 @@ pub fn vortex_array::scalar_fn::fns::binary::compare_nested_arrow_arrays(lhs: &d
1698416984

1698516985
pub fn vortex_array::scalar_fn::fns::binary::or_kleene(lhs: &vortex_array::ArrayRef, rhs: &vortex_array::ArrayRef) -> vortex_error::VortexResult<vortex_array::ArrayRef>
1698616986

16987-
pub fn vortex_array::scalar_fn::fns::binary::scalar_cmp(lhs: &vortex_array::scalar::Scalar, rhs: &vortex_array::scalar::Scalar, operator: vortex_array::scalar_fn::fns::operators::CompareOperator) -> vortex_array::scalar::Scalar
16987+
pub fn vortex_array::scalar_fn::fns::binary::scalar_cmp(lhs: &vortex_array::scalar::Scalar, rhs: &vortex_array::scalar::Scalar, operator: vortex_array::scalar_fn::fns::operators::CompareOperator) -> vortex_error::VortexResult<vortex_array::scalar::Scalar>
1698816988

1698916989
pub mod vortex_array::scalar_fn::fns::case_when
1699016990

vortex-array/src/scalar/scalar_impl.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,10 @@ impl Scalar {
265265

266266
/// We implement `PartialEq` manually because we want to ignore nullability when comparing scalars.
267267
/// Two scalars with the same value but different nullability should be considered equal.
268+
///
269+
/// Note that this has **different** behavior than the [`PartialOrd`] implementation since the
270+
/// [`PartialOrd`] returns `None` if the types are different, whereas this `PartialEq`
271+
/// implementation simply returns `false`.
268272
impl PartialEq for Scalar {
269273
fn eq(&self, other: &Self) -> bool {
270274
self.dtype.eq_ignore_nullability(&other.dtype) && self.value == other.value

vortex-array/src/scalar_fn/fns/binary/compare.rs

Lines changed: 63 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use arrow_ord::cmp;
99
use arrow_ord::ord::make_comparator;
1010
use arrow_schema::SortOptions;
1111
use vortex_error::VortexResult;
12+
use vortex_error::vortex_err;
1213

1314
use crate::ArrayRef;
1415
use crate::Canonical;
@@ -131,7 +132,7 @@ pub(crate) fn execute_compare(
131132
// Constant-constant fast path
132133
if let (Some(lhs_const), Some(rhs_const)) = (lhs.as_opt::<Constant>(), rhs.as_opt::<Constant>())
133134
{
134-
let result = scalar_cmp(lhs_const.scalar(), rhs_const.scalar(), op);
135+
let result = scalar_cmp(lhs_const.scalar(), rhs_const.scalar(), op)?;
135136
return Ok(ConstantArray::new(result, lhs.len()).into_array());
136137
}
137138

@@ -150,7 +151,7 @@ fn arrow_compare_arrays(
150151

151152
// Arrow's vectorized comparison kernels don't support nested types.
152153
// For nested types, fall back to `make_comparator` which does element-wise comparison.
153-
let array: BooleanArray = if left.dtype().is_nested() || right.dtype().is_nested() {
154+
let arrow_array: BooleanArray = if left.dtype().is_nested() || right.dtype().is_nested() {
154155
let rhs = right.to_array().into_arrow_preferred()?;
155156
let lhs = left.to_array().into_arrow(rhs.data_type())?;
156157

@@ -176,24 +177,36 @@ fn arrow_compare_arrays(
176177
CompareOperator::Lte => cmp::lt_eq(&lhs, &rhs)?,
177178
}
178179
};
179-
from_arrow_array_with_len(&array, left.len(), nullable)
180+
181+
from_arrow_array_with_len(&arrow_array, left.len(), nullable)
180182
}
181183

182-
pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: CompareOperator) -> Scalar {
184+
pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: CompareOperator) -> VortexResult<Scalar> {
183185
if lhs.is_null() | rhs.is_null() {
184-
Scalar::null(DType::Bool(Nullability::Nullable))
185-
} else {
186-
let b = match operator {
187-
CompareOperator::Eq => lhs == rhs,
188-
CompareOperator::NotEq => lhs != rhs,
189-
CompareOperator::Gt => lhs > rhs,
190-
CompareOperator::Gte => lhs >= rhs,
191-
CompareOperator::Lt => lhs < rhs,
192-
CompareOperator::Lte => lhs <= rhs,
193-
};
194-
195-
Scalar::bool(b, lhs.dtype().nullability() | rhs.dtype().nullability())
186+
return Ok(Scalar::null(DType::Bool(Nullability::Nullable)));
196187
}
188+
189+
let nullability = lhs.dtype().nullability() | rhs.dtype().nullability();
190+
191+
// We use `partial_cmp` to ensure we do not lose a type mismatch error.
192+
let ordering = lhs.partial_cmp(rhs).ok_or_else(|| {
193+
vortex_err!(
194+
"Cannot compare scalars with incompatible types: {} and {}",
195+
lhs.dtype(),
196+
rhs.dtype()
197+
)
198+
})?;
199+
200+
let b = match operator {
201+
CompareOperator::Eq => ordering.is_eq(),
202+
CompareOperator::NotEq => ordering.is_ne(),
203+
CompareOperator::Gt => ordering.is_gt(),
204+
CompareOperator::Gte => ordering.is_ge(),
205+
CompareOperator::Lt => ordering.is_lt(),
206+
CompareOperator::Lte => ordering.is_le(),
207+
};
208+
209+
Ok(Scalar::bool(b, nullability))
197210
}
198211

199212
/// Compare two Arrow arrays element-wise using [`make_comparator`].
@@ -251,8 +264,13 @@ mod tests {
251264
use crate::dtype::FieldNames;
252265
use crate::dtype::Nullability;
253266
use crate::dtype::PType;
267+
use crate::extension::datetime::TimeUnit;
268+
use crate::extension::datetime::Timestamp;
269+
use crate::extension::datetime::TimestampOptions;
254270
use crate::scalar::Scalar;
255271
use crate::scalar_fn::fns::binary::compare::ConstantArray;
272+
use crate::scalar_fn::fns::binary::scalar_cmp;
273+
use crate::scalar_fn::fns::operators::CompareOperator;
256274
use crate::scalar_fn::fns::operators::Operator;
257275
use crate::test_harness::to_int_indices;
258276
use crate::validity::Validity;
@@ -479,6 +497,35 @@ mod tests {
479497
assert_arrays_eq!(result, expected);
480498
}
481499

500+
/// Regression test: `scalar_cmp` must error when comparing scalars with incompatible
501+
/// extension types (e.g., timestamps with different time units) rather than silently
502+
/// returning a wrong result.
503+
#[test]
504+
fn scalar_cmp_incompatible_extension_types_errors() {
505+
let ms_scalar = Scalar::extension::<Timestamp>(
506+
TimestampOptions {
507+
unit: TimeUnit::Milliseconds,
508+
tz: None,
509+
},
510+
Scalar::from(1704067200000i64),
511+
);
512+
let s_scalar = Scalar::extension::<Timestamp>(
513+
TimestampOptions {
514+
unit: TimeUnit::Seconds,
515+
tz: None,
516+
},
517+
Scalar::from(1704067200i64),
518+
);
519+
520+
// Ordering comparisons must error on incompatible types.
521+
assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Gt).is_err());
522+
assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Lt).is_err());
523+
assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Gte).is_err());
524+
assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Lte).is_err());
525+
assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Eq).is_err());
526+
assert!(scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::NotEq).is_err());
527+
}
528+
482529
#[test]
483530
fn test_empty_list() {
484531
let list = ListViewArray::new(

vortex-file/src/tests.rs

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1597,7 +1597,7 @@ async fn test_writer_with_statistics() -> VortexResult<()> {
15971597
}
15981598

15991599
#[tokio::test]
1600-
async fn main_test() -> Result<(), Box<dyn std::error::Error>> {
1600+
async fn timestamp_unit_mismatch() -> Result<(), Box<dyn std::error::Error>> {
16011601
// Write file with MILLISECONDS timestamps
16021602
let ts_array = PrimitiveArray::from_iter(vec![1704067200000i64, 1704153600000, 1704240000000])
16031603
.into_array();
@@ -1621,6 +1621,62 @@ async fn main_test() -> Result<(), Box<dyn std::error::Error>> {
16211621
)),
16221622
);
16231623

1624+
let mut stream = SESSION
1625+
.open_options()
1626+
.open_buffer(buf)?
1627+
.scan()?
1628+
.with_filter(filter_expr)
1629+
.into_array_stream()?;
1630+
1631+
let result = stream.try_next().await;
1632+
1633+
assert!(result.is_err());
1634+
1635+
Ok(())
1636+
}
1637+
1638+
/// Regression test: filtering a milliseconds timestamp column with a seconds scalar should
1639+
/// always error, regardless of how the internal children of `DateTimePartsArray` are encoded.
1640+
///
1641+
/// This test forces `ConstantArray` encoding for the seconds/subseconds children by using a
1642+
/// compressor with Dict excluded (which triggers distinct-value computation, letting
1643+
/// `ConstantScheme` win for `[0, 0, 0]`). The scanner should still detect the time unit
1644+
/// mismatch and error, not silently return wrong results.
1645+
#[tokio::test]
1646+
async fn timestamp_unit_mismatch_errors_with_constant_children()
1647+
-> Result<(), Box<dyn std::error::Error>> {
1648+
// Build a compressor where ConstantScheme wins for [0, 0, 0] by including Dict
1649+
// (which enables distinct-value computation).
1650+
let compressor = vortex_btrblocks::BtrBlocksCompressor::default();
1651+
1652+
// Write file with MILLISECONDS timestamps using this compressor.
1653+
let ts_array = PrimitiveArray::from_iter(vec![1704067200000i64, 1704153600000, 1704240000000])
1654+
.into_array();
1655+
let temporal = TemporalArray::new_timestamp(ts_array, TimeUnit::Milliseconds, None);
1656+
1657+
let strategy = crate::strategy::WriteStrategyBuilder::default()
1658+
.with_compressor(compressor)
1659+
.build();
1660+
1661+
let mut buf = ByteBufferMut::empty();
1662+
SESSION
1663+
.write_options()
1664+
.with_strategy(strategy)
1665+
.write(&mut buf, temporal.into_array().to_array_stream())
1666+
.await?;
1667+
1668+
// Read with SECONDS filter scalar — should error due to time unit mismatch.
1669+
let filter_expr = gt(
1670+
root(),
1671+
lit(Scalar::extension::<Timestamp>(
1672+
TimestampOptions {
1673+
unit: TimeUnit::Seconds,
1674+
tz: None,
1675+
},
1676+
Scalar::from(1704153600i64),
1677+
)),
1678+
);
1679+
16241680
let stream = SESSION
16251681
.open_options()
16261682
.open_buffer(buf)?
@@ -1630,7 +1686,13 @@ async fn main_test() -> Result<(), Box<dyn std::error::Error>> {
16301686

16311687
let results = stream.try_collect::<Vec<_>>().await;
16321688

1633-
assert!(results.is_err());
1689+
assert!(
1690+
results.is_err(),
1691+
"Expected error from timestamp unit mismatch (ms vs s), but got {} results. \
1692+
This indicates the scanner silently applied the filter incorrectly when \
1693+
DateTimePartsArray children use ConstantArray encoding.",
1694+
results.unwrap().len()
1695+
);
16341696

16351697
Ok(())
16361698
}

0 commit comments

Comments
 (0)