@@ -14,6 +14,7 @@ use vortex_error::vortex_panic;
1414use crate :: dtype:: DType ;
1515use crate :: dtype:: NativeDType ;
1616use crate :: dtype:: PType ;
17+ use crate :: dtype:: StructFields ;
1718use crate :: scalar:: Scalar ;
1819use crate :: scalar:: ScalarValue ;
1920
@@ -263,6 +264,16 @@ impl Scalar {
263264 }
264265}
265266
267+ /// We implement `Hash` manually to be consistent with `PartialEq`. Since we ignore nullability in
268+ /// equality comparisons, we must also ignore it when hashing to maintain the invariant that equal
269+ /// values have equal hashes.
270+ impl Hash for Scalar {
271+ fn hash < H : Hasher > ( & self , state : & mut H ) {
272+ self . dtype . as_nonnullable ( ) . hash ( state) ;
273+ self . value . hash ( state) ;
274+ }
275+ }
276+
266277/// We implement `PartialEq` manually because we want to ignore nullability when comparing scalars.
267278/// Two scalars with the same value but different nullability should be considered equal.
268279///
@@ -288,7 +299,14 @@ impl PartialOrd for Scalar {
288299 /// - Non-null values are compared according to their natural ordering
289300 ///
290301 /// # Examples
291- /// ```ignore
302+ ///
303+ /// ```
304+ /// use std::cmp::Ordering;
305+ /// use vortex_array::dtype::DType;
306+ /// use vortex_array::dtype::Nullability;
307+ /// use vortex_array::dtype::PType;
308+ /// use vortex_array::scalar::Scalar;
309+ ///
292310 /// // Same types compare successfully
293311 /// let a = Scalar::primitive(10i32, Nullability::NonNullable);
294312 /// let b = Scalar::primitive(20i32, Nullability::NonNullable);
@@ -308,16 +326,101 @@ impl PartialOrd for Scalar {
308326 if !self . dtype ( ) . eq_ignore_nullability ( other. dtype ( ) ) {
309327 return None ;
310328 }
311- self . value ( ) . partial_cmp ( & other. value ( ) )
329+
330+ partial_cmp_scalar_values ( self . dtype ( ) , self . value ( ) , other. value ( ) )
312331 }
313332}
314333
315- /// We implement `Hash` manually to be consistent with `PartialEq`. Since we ignore nullability
316- /// in equality comparisons, we must also ignore it when hashing to maintain the invariant that
317- /// equal values have equal hashes.
318- impl Hash for Scalar {
319- fn hash < H : Hasher > ( & self , state : & mut H ) {
320- self . dtype . as_nonnullable ( ) . hash ( state) ;
321- self . value . hash ( state) ;
334+ /// Compare two optional scalar values using `dtype` for nested tuple interpretation.
335+ fn partial_cmp_scalar_values (
336+ dtype : & DType ,
337+ lhs : Option < & ScalarValue > ,
338+ rhs : Option < & ScalarValue > ,
339+ ) -> Option < Ordering > {
340+ match ( lhs, rhs) {
341+ ( None , None ) => Some ( Ordering :: Equal ) ,
342+ ( None , Some ( _) ) => Some ( Ordering :: Less ) ,
343+ ( Some ( _) , None ) => Some ( Ordering :: Greater ) ,
344+ ( Some ( lhs) , Some ( rhs) ) => partial_cmp_non_null_scalar_values ( dtype, lhs, rhs) ,
345+ }
346+ }
347+
348+ /// Compare two non-null scalar values, consulting `dtype` only for tuple-backed values.
349+ fn partial_cmp_non_null_scalar_values (
350+ dtype : & DType ,
351+ lhs : & ScalarValue ,
352+ rhs : & ScalarValue ,
353+ ) -> Option < Ordering > {
354+ // `Scalar::validate` guarantees that a scalar's value matches its dtype. Most of the scalar
355+ // value variants have only 1 method of comparison, regardless of the dtype.
356+ match ( lhs, rhs) {
357+ ( ScalarValue :: Bool ( lhs) , ScalarValue :: Bool ( rhs) ) => lhs. partial_cmp ( rhs) ,
358+ ( ScalarValue :: Primitive ( lhs) , ScalarValue :: Primitive ( rhs) ) => lhs. partial_cmp ( rhs) ,
359+ ( ScalarValue :: Decimal ( lhs) , ScalarValue :: Decimal ( rhs) ) => lhs. partial_cmp ( rhs) ,
360+ ( ScalarValue :: Utf8 ( lhs) , ScalarValue :: Utf8 ( rhs) ) => lhs. partial_cmp ( rhs) ,
361+ ( ScalarValue :: Binary ( lhs) , ScalarValue :: Binary ( rhs) ) => lhs. partial_cmp ( rhs) ,
362+ // `Tuple` is the exception here. Since it backs lists, fixed-size lists, and structs, we
363+ // need the dtype to know whether children share one element dtype or use per-field dtypes.
364+ ( ScalarValue :: Tuple ( lhs) , ScalarValue :: Tuple ( rhs) ) => {
365+ partial_cmp_tuple_values ( dtype, lhs, rhs)
366+ }
367+ // Variant values can have a different dtype in each row, so it doesn't make sense to
368+ // compare them.
369+ ( ScalarValue :: Variant ( _) , ScalarValue :: Variant ( _) ) => None ,
370+ _ => None ,
371+ }
372+ }
373+
374+ /// Compare tuple values according to the list, fixed-size list, or struct dtype layout.
375+ fn partial_cmp_tuple_values (
376+ dtype : & DType ,
377+ lhs : & [ Option < ScalarValue > ] ,
378+ rhs : & [ Option < ScalarValue > ] ,
379+ ) -> Option < Ordering > {
380+ match dtype {
381+ DType :: List ( element_dtype, _) | DType :: FixedSizeList ( element_dtype, ..) => {
382+ partial_cmp_list_values ( element_dtype, lhs, rhs)
383+ }
384+ DType :: Struct ( fields, _) => partial_cmp_struct_values ( fields, lhs, rhs) ,
385+ DType :: Extension ( ext_dtype) => {
386+ partial_cmp_tuple_values ( ext_dtype. storage_dtype ( ) , lhs, rhs)
387+ }
388+ _ => None ,
389+ }
390+ }
391+
392+ /// Compare list tuple values using the shared element dtype for each element.
393+ fn partial_cmp_list_values (
394+ element_dtype : & DType ,
395+ lhs : & [ Option < ScalarValue > ] ,
396+ rhs : & [ Option < ScalarValue > ] ,
397+ ) -> Option < Ordering > {
398+ for ( lhs, rhs) in lhs. iter ( ) . zip ( rhs. iter ( ) ) {
399+ match partial_cmp_scalar_values ( element_dtype, lhs. as_ref ( ) , rhs. as_ref ( ) ) ? {
400+ Ordering :: Equal => continue ,
401+ ordering => return Some ( ordering) ,
402+ }
403+ }
404+
405+ Some ( lhs. len ( ) . cmp ( & rhs. len ( ) ) )
406+ }
407+
408+ /// Compare struct tuple values using each field's dtype in field order.
409+ fn partial_cmp_struct_values (
410+ fields : & StructFields ,
411+ lhs : & [ Option < ScalarValue > ] ,
412+ rhs : & [ Option < ScalarValue > ] ,
413+ ) -> Option < Ordering > {
414+ if lhs. len ( ) != fields. nfields ( ) || rhs. len ( ) != fields. nfields ( ) {
415+ return None ;
322416 }
417+
418+ for ( ( field_dtype, lhs) , rhs) in fields. fields ( ) . zip ( lhs. iter ( ) ) . zip ( rhs. iter ( ) ) {
419+ match partial_cmp_scalar_values ( & field_dtype, lhs. as_ref ( ) , rhs. as_ref ( ) ) ? {
420+ Ordering :: Equal => continue ,
421+ ordering => return Some ( ordering) ,
422+ }
423+ }
424+
425+ Some ( Ordering :: Equal )
323426}
0 commit comments