@@ -9,6 +9,7 @@ use arrow_ord::cmp;
99use arrow_ord:: ord:: make_comparator;
1010use arrow_schema:: SortOptions ;
1111use vortex_error:: VortexResult ;
12+ use vortex_error:: vortex_err;
1213
1314use crate :: ArrayRef ;
1415use crate :: Canonical ;
@@ -131,7 +132,7 @@ pub(crate) fn execute_compare(
131132 // Constant-constant fast path
132133 if let ( Some ( lhs_const) , Some ( rhs_const) ) = ( lhs. as_opt :: < Constant > ( ) , rhs. as_opt :: < Constant > ( ) )
133134 {
134- let result = scalar_cmp ( lhs_const. scalar ( ) , rhs_const. scalar ( ) , op) ;
135+ let result = scalar_cmp ( lhs_const. scalar ( ) , rhs_const. scalar ( ) , op) ? ;
135136 return Ok ( ConstantArray :: new ( result, lhs. len ( ) ) . into_array ( ) ) ;
136137 }
137138
@@ -150,7 +151,7 @@ fn arrow_compare_arrays(
150151
151152 // Arrow's vectorized comparison kernels don't support nested types.
152153 // For nested types, fall back to `make_comparator` which does element-wise comparison.
153- let array : BooleanArray = if left. dtype ( ) . is_nested ( ) || right. dtype ( ) . is_nested ( ) {
154+ let arrow_array : BooleanArray = if left. dtype ( ) . is_nested ( ) || right. dtype ( ) . is_nested ( ) {
154155 let rhs = right. to_array ( ) . into_arrow_preferred ( ) ?;
155156 let lhs = left. to_array ( ) . into_arrow ( rhs. data_type ( ) ) ?;
156157
@@ -176,24 +177,36 @@ fn arrow_compare_arrays(
176177 CompareOperator :: Lte => cmp:: lt_eq ( & lhs, & rhs) ?,
177178 }
178179 } ;
179- from_arrow_array_with_len ( & array, left. len ( ) , nullable)
180+
181+ from_arrow_array_with_len ( & arrow_array, left. len ( ) , nullable)
180182}
181183
182- pub fn scalar_cmp ( lhs : & Scalar , rhs : & Scalar , operator : CompareOperator ) -> Scalar {
184+ pub fn scalar_cmp ( lhs : & Scalar , rhs : & Scalar , operator : CompareOperator ) -> VortexResult < Scalar > {
183185 if lhs. is_null ( ) | rhs. is_null ( ) {
184- Scalar :: null ( DType :: Bool ( Nullability :: Nullable ) )
185- } else {
186- let b = match operator {
187- CompareOperator :: Eq => lhs == rhs,
188- CompareOperator :: NotEq => lhs != rhs,
189- CompareOperator :: Gt => lhs > rhs,
190- CompareOperator :: Gte => lhs >= rhs,
191- CompareOperator :: Lt => lhs < rhs,
192- CompareOperator :: Lte => lhs <= rhs,
193- } ;
194-
195- Scalar :: bool ( b, lhs. dtype ( ) . nullability ( ) | rhs. dtype ( ) . nullability ( ) )
186+ return Ok ( Scalar :: null ( DType :: Bool ( Nullability :: Nullable ) ) ) ;
196187 }
188+
189+ let nullability = lhs. dtype ( ) . nullability ( ) | rhs. dtype ( ) . nullability ( ) ;
190+
191+ // We use `partial_cmp` to ensure we do not lose a type mismatch error.
192+ let ordering = lhs. partial_cmp ( rhs) . ok_or_else ( || {
193+ vortex_err ! (
194+ "Cannot compare scalars with incompatible types: {} and {}" ,
195+ lhs. dtype( ) ,
196+ rhs. dtype( )
197+ )
198+ } ) ?;
199+
200+ let b = match operator {
201+ CompareOperator :: Eq => ordering. is_eq ( ) ,
202+ CompareOperator :: NotEq => ordering. is_ne ( ) ,
203+ CompareOperator :: Gt => ordering. is_gt ( ) ,
204+ CompareOperator :: Gte => ordering. is_ge ( ) ,
205+ CompareOperator :: Lt => ordering. is_lt ( ) ,
206+ CompareOperator :: Lte => ordering. is_le ( ) ,
207+ } ;
208+
209+ Ok ( Scalar :: bool ( b, nullability) )
197210}
198211
199212/// Compare two Arrow arrays element-wise using [`make_comparator`].
@@ -251,8 +264,13 @@ mod tests {
251264 use crate :: dtype:: FieldNames ;
252265 use crate :: dtype:: Nullability ;
253266 use crate :: dtype:: PType ;
267+ use crate :: extension:: datetime:: TimeUnit ;
268+ use crate :: extension:: datetime:: Timestamp ;
269+ use crate :: extension:: datetime:: TimestampOptions ;
254270 use crate :: scalar:: Scalar ;
255271 use crate :: scalar_fn:: fns:: binary:: compare:: ConstantArray ;
272+ use crate :: scalar_fn:: fns:: binary:: scalar_cmp;
273+ use crate :: scalar_fn:: fns:: operators:: CompareOperator ;
256274 use crate :: scalar_fn:: fns:: operators:: Operator ;
257275 use crate :: test_harness:: to_int_indices;
258276 use crate :: validity:: Validity ;
@@ -479,6 +497,35 @@ mod tests {
479497 assert_arrays_eq ! ( result, expected) ;
480498 }
481499
500+ /// Regression test: `scalar_cmp` must error when comparing scalars with incompatible
501+ /// extension types (e.g., timestamps with different time units) rather than silently
502+ /// returning a wrong result.
503+ #[ test]
504+ fn scalar_cmp_incompatible_extension_types_errors ( ) {
505+ let ms_scalar = Scalar :: extension :: < Timestamp > (
506+ TimestampOptions {
507+ unit : TimeUnit :: Milliseconds ,
508+ tz : None ,
509+ } ,
510+ Scalar :: from ( 1704067200000i64 ) ,
511+ ) ;
512+ let s_scalar = Scalar :: extension :: < Timestamp > (
513+ TimestampOptions {
514+ unit : TimeUnit :: Seconds ,
515+ tz : None ,
516+ } ,
517+ Scalar :: from ( 1704067200i64 ) ,
518+ ) ;
519+
520+ // Ordering comparisons must error on incompatible types.
521+ assert ! ( scalar_cmp( & ms_scalar, & s_scalar, CompareOperator :: Gt ) . is_err( ) ) ;
522+ assert ! ( scalar_cmp( & ms_scalar, & s_scalar, CompareOperator :: Lt ) . is_err( ) ) ;
523+ assert ! ( scalar_cmp( & ms_scalar, & s_scalar, CompareOperator :: Gte ) . is_err( ) ) ;
524+ assert ! ( scalar_cmp( & ms_scalar, & s_scalar, CompareOperator :: Lte ) . is_err( ) ) ;
525+ assert ! ( scalar_cmp( & ms_scalar, & s_scalar, CompareOperator :: Eq ) . is_err( ) ) ;
526+ assert ! ( scalar_cmp( & ms_scalar, & s_scalar, CompareOperator :: NotEq ) . is_err( ) ) ;
527+ }
528+
482529 #[ test]
483530 fn test_empty_list ( ) {
484531 let list = ListViewArray :: new (
0 commit comments