Skip to content

Commit b4c28ab

Browse files
committed
fix scalar partial ordering comparison
Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent b921999 commit b4c28ab

4 files changed

Lines changed: 155 additions & 22 deletions

File tree

fuzz/src/array/compare.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,11 +134,12 @@ pub fn compare_canonical_array(
134134
let scalar_vals: Vec<Scalar> = (0..array.len())
135135
.map(|i| array.scalar_at(i).vortex_expect("scalar_at"))
136136
.collect();
137-
BoolArray::from_iter(
138-
scalar_vals
139-
.iter()
140-
.map(|v| scalar_cmp(v, value, operator).as_bool().value()),
141-
)
137+
BoolArray::from_iter(scalar_vals.iter().map(|v| {
138+
scalar_cmp(v, value, operator)
139+
.vortex_expect("tried to compare different typed scalars")
140+
.as_bool()
141+
.value()
142+
}))
142143
.into_array()
143144
}
144145
d @ (DType::Null | DType::Extension(_)) => {

vortex-array/src/scalar/scalar_impl.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,10 @@ impl Scalar {
263263

264264
/// We implement `PartialEq` manually because we want to ignore nullability when comparing scalars.
265265
/// Two scalars with the same value but different nullability should be considered equal.
266+
///
267+
/// Note that this has **different** behavior than the [`PartialOrd`] implementation since the
268+
/// [`PartialOrd`] returns `None` if the types are different, whereas this `PartialEq`
269+
/// implementation simply returns `false`.
266270
impl PartialEq for Scalar {
267271
fn eq(&self, other: &Self) -> bool {
268272
self.dtype.eq_ignore_nullability(&other.dtype) && self.value == other.value

vortex-array/src/scalar_fn/fns/binary/compare.rs

Lines changed: 81 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ pub(crate) fn execute_compare(
131131
// Constant-constant fast path
132132
if let (Some(lhs_const), Some(rhs_const)) = (lhs.as_opt::<Constant>(), rhs.as_opt::<Constant>())
133133
{
134-
let result = scalar_cmp(lhs_const.scalar(), rhs_const.scalar(), op);
134+
let result = scalar_cmp(lhs_const.scalar(), rhs_const.scalar(), op)?;
135135
return Ok(ConstantArray::new(result, lhs.len()).into_array());
136136
}
137137

@@ -150,7 +150,7 @@ fn arrow_compare_arrays(
150150

151151
// Arrow's vectorized comparison kernels don't support nested types.
152152
// For nested types, fall back to `make_comparator` which does element-wise comparison.
153-
let array: BooleanArray = if left.dtype().is_nested() || right.dtype().is_nested() {
153+
let arrow_array: BooleanArray = if left.dtype().is_nested() || right.dtype().is_nested() {
154154
let rhs = right.to_array().into_arrow_preferred()?;
155155
let lhs = left.to_array().into_arrow(rhs.data_type())?;
156156

@@ -176,24 +176,43 @@ fn arrow_compare_arrays(
176176
CompareOperator::Lte => cmp::lt_eq(&lhs, &rhs)?,
177177
}
178178
};
179-
from_arrow_array_with_len(&array, left.len(), nullable)
179+
180+
from_arrow_array_with_len(&arrow_array, left.len(), nullable)
180181
}
181182

182-
pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: CompareOperator) -> Scalar {
183+
pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: CompareOperator) -> VortexResult<Scalar> {
183184
if lhs.is_null() | rhs.is_null() {
184-
Scalar::null(DType::Bool(Nullability::Nullable))
185-
} else {
186-
let b = match operator {
187-
CompareOperator::Eq => lhs == rhs,
188-
CompareOperator::NotEq => lhs != rhs,
189-
CompareOperator::Gt => lhs > rhs,
190-
CompareOperator::Gte => lhs >= rhs,
191-
CompareOperator::Lt => lhs < rhs,
192-
CompareOperator::Lte => lhs <= rhs,
193-
};
185+
return Ok(Scalar::null(DType::Bool(Nullability::Nullable)));
186+
}
187+
188+
let nullability = lhs.dtype().nullability() | rhs.dtype().nullability();
194189

195-
Scalar::bool(b, lhs.dtype().nullability() | rhs.dtype().nullability())
190+
// Equality and inequality can be determined without `partial_cmp`.
191+
match operator {
192+
CompareOperator::Eq => return Ok(Scalar::bool(lhs == rhs, nullability)),
193+
CompareOperator::NotEq => return Ok(Scalar::bool(lhs != rhs, nullability)),
194+
_ => {}
196195
}
196+
197+
// We do this instead of `<` and `>` to ensure we do not lose a type mismatch error.
198+
let ordering = lhs.partial_cmp(rhs).ok_or_else(|| {
199+
vortex_error::vortex_err!(
200+
"Cannot compare scalars with incompatible types: {} and {}",
201+
lhs.dtype(),
202+
rhs.dtype()
203+
)
204+
})?;
205+
206+
let b = match operator {
207+
CompareOperator::Gt => ordering.is_gt(),
208+
CompareOperator::Gte => ordering.is_ge(),
209+
CompareOperator::Lt => ordering.is_lt(),
210+
CompareOperator::Lte => ordering.is_le(),
211+
// Already handled above.
212+
CompareOperator::Eq | CompareOperator::NotEq => unreachable!(),
213+
};
214+
215+
Ok(Scalar::bool(b, nullability))
197216
}
198217

199218
/// Compare two Arrow arrays element-wise using [`make_comparator`].
@@ -251,8 +270,12 @@ mod tests {
251270
use crate::dtype::FieldNames;
252271
use crate::dtype::Nullability;
253272
use crate::dtype::PType;
273+
use crate::extension::datetime::TimeUnit;
274+
use crate::extension::datetime::Timestamp;
275+
use crate::extension::datetime::TimestampOptions;
254276
use crate::scalar::Scalar;
255277
use crate::scalar_fn::fns::binary::compare::ConstantArray;
278+
use crate::scalar_fn::fns::operators::CompareOperator;
256279
use crate::scalar_fn::fns::operators::Operator;
257280
use crate::test_harness::to_int_indices;
258281
use crate::validity::Validity;
@@ -479,6 +502,49 @@ mod tests {
479502
assert_arrays_eq!(result, expected);
480503
}
481504

505+
/// Regression test: `scalar_cmp` must error when comparing scalars with incompatible
506+
/// extension types (e.g., timestamps with different time units) rather than silently
507+
/// returning a wrong result.
508+
#[test]
509+
fn scalar_cmp_incompatible_extension_types_errors() {
510+
let ms_scalar = Scalar::extension::<Timestamp>(
511+
TimestampOptions {
512+
unit: TimeUnit::Milliseconds,
513+
tz: None,
514+
},
515+
Scalar::from(1704067200000i64),
516+
);
517+
let s_scalar = Scalar::extension::<Timestamp>(
518+
TimestampOptions {
519+
unit: TimeUnit::Seconds,
520+
tz: None,
521+
},
522+
Scalar::from(1704067200i64),
523+
);
524+
525+
// Ordering comparisons must error on incompatible types.
526+
assert!(super::scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Gt).is_err());
527+
assert!(super::scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Lt).is_err());
528+
assert!(super::scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Gte).is_err());
529+
assert!(super::scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Lte).is_err());
530+
531+
// Equality comparisons should succeed (and return false since the types differ).
532+
assert_eq!(
533+
super::scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::Eq)
534+
.unwrap()
535+
.as_bool()
536+
.value(),
537+
Some(false),
538+
);
539+
assert_eq!(
540+
super::scalar_cmp(&ms_scalar, &s_scalar, CompareOperator::NotEq)
541+
.unwrap()
542+
.as_bool()
543+
.value(),
544+
Some(true),
545+
);
546+
}
547+
482548
#[test]
483549
fn test_empty_list() {
484550
let list = ListViewArray::new(

vortex-file/src/tests.rs

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1597,7 +1597,7 @@ async fn test_writer_with_statistics() -> VortexResult<()> {
15971597
}
15981598

15991599
#[tokio::test]
1600-
async fn main_test() -> Result<(), Box<dyn std::error::Error>> {
1600+
async fn timestamp_unit_mismatch() -> Result<(), Box<dyn std::error::Error>> {
16011601
// Write file with MILLISECONDS timestamps
16021602
let ts_array = PrimitiveArray::from_iter(vec![1704067200000i64, 1704153600000, 1704240000000])
16031603
.into_array();
@@ -1621,6 +1621,62 @@ async fn main_test() -> Result<(), Box<dyn std::error::Error>> {
16211621
)),
16221622
);
16231623

1624+
let mut stream = SESSION
1625+
.open_options()
1626+
.open_buffer(buf)?
1627+
.scan()?
1628+
.with_filter(filter_expr)
1629+
.into_array_stream()?;
1630+
1631+
let result = stream.try_next().await;
1632+
1633+
assert!(result.is_err());
1634+
1635+
Ok(())
1636+
}
1637+
1638+
/// Regression test: filtering a milliseconds timestamp column with a seconds scalar should
1639+
/// always error, regardless of how the internal children of `DateTimePartsArray` are encoded.
1640+
///
1641+
/// This test forces `ConstantArray` encoding for the seconds/subseconds children by using a
1642+
/// compressor with Dict excluded (which triggers distinct-value computation, letting
1643+
/// `ConstantScheme` win for `[0, 0, 0]`). The scanner should still detect the time unit
1644+
/// mismatch and error, not silently return wrong results.
1645+
#[tokio::test]
1646+
async fn timestamp_unit_mismatch_errors_with_constant_children()
1647+
-> Result<(), Box<dyn std::error::Error>> {
1648+
// Build a compressor where ConstantScheme wins for [0, 0, 0] by including Dict
1649+
// (which enables distinct-value computation).
1650+
let compressor = vortex_btrblocks::BtrBlocksCompressor::default();
1651+
1652+
// Write file with MILLISECONDS timestamps using this compressor.
1653+
let ts_array = PrimitiveArray::from_iter(vec![1704067200000i64, 1704153600000, 1704240000000])
1654+
.into_array();
1655+
let temporal = TemporalArray::new_timestamp(ts_array, TimeUnit::Milliseconds, None);
1656+
1657+
let strategy = crate::strategy::WriteStrategyBuilder::default()
1658+
.with_compressor(compressor)
1659+
.build();
1660+
1661+
let mut buf = ByteBufferMut::empty();
1662+
SESSION
1663+
.write_options()
1664+
.with_strategy(strategy)
1665+
.write(&mut buf, temporal.into_array().to_array_stream())
1666+
.await?;
1667+
1668+
// Read with SECONDS filter scalar — should error due to time unit mismatch.
1669+
let filter_expr = gt(
1670+
root(),
1671+
lit(Scalar::extension::<Timestamp>(
1672+
TimestampOptions {
1673+
unit: TimeUnit::Seconds,
1674+
tz: None,
1675+
},
1676+
Scalar::from(1704153600i64),
1677+
)),
1678+
);
1679+
16241680
let stream = SESSION
16251681
.open_options()
16261682
.open_buffer(buf)?
@@ -1630,7 +1686,13 @@ async fn main_test() -> Result<(), Box<dyn std::error::Error>> {
16301686

16311687
let results = stream.try_collect::<Vec<_>>().await;
16321688

1633-
assert!(results.is_err());
1689+
assert!(
1690+
results.is_err(),
1691+
"Expected error from timestamp unit mismatch (ms vs s), but got {} results. \
1692+
This indicates the scanner silently applied the filter incorrectly when \
1693+
DateTimePartsArray children use ConstantArray encoding.",
1694+
results.unwrap().len()
1695+
);
16341696

16351697
Ok(())
16361698
}

0 commit comments

Comments
 (0)