Skip to content

Commit cb9b138

Browse files
authored
Remove PartialOrd implementation for ScalarValue (#7742)
## Summary This might help with #7699, but it might not. I made the changes though so here we go. Removes the `PartialOrd` implementation for `ScalarValue`, ensuring that we are only able to compare `Scalar`s which carry a `DType`. ## API Changes Removes the `PartialOrd` impl. ## Testing Existing tests should suffice since this just moves some logic around. Signed-off-by: Connor Tsui <connor.tsui20@gmail.com>
1 parent deb7de0 commit cb9b138

3 files changed

Lines changed: 112 additions & 30 deletions

File tree

vortex-array/public-api.lock

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14368,10 +14368,6 @@ impl core::cmp::PartialEq for vortex_array::scalar::ScalarValue
1436814368

1436914369
pub fn vortex_array::scalar::ScalarValue::eq(&self, &vortex_array::scalar::ScalarValue) -> bool
1437014370

14371-
impl core::cmp::PartialOrd for vortex_array::scalar::ScalarValue
14372-
14373-
pub fn vortex_array::scalar::ScalarValue::partial_cmp(&self, &Self) -> core::option::Option<core::cmp::Ordering>
14374-
1437514371
impl core::convert::From<&[u8]> for vortex_array::scalar::ScalarValue
1437614372

1437714373
pub fn vortex_array::scalar::ScalarValue::from(&[u8]) -> Self

vortex-array/src/scalar/scalar_impl.rs

Lines changed: 112 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use vortex_error::vortex_panic;
1414
use crate::dtype::DType;
1515
use crate::dtype::NativeDType;
1616
use crate::dtype::PType;
17+
use crate::dtype::StructFields;
1718
use crate::scalar::Scalar;
1819
use crate::scalar::ScalarValue;
1920

@@ -263,6 +264,16 @@ impl Scalar {
263264
}
264265
}
265266

267+
/// We implement `Hash` manually to be consistent with `PartialEq`. Since we ignore nullability in
268+
/// equality comparisons, we must also ignore it when hashing to maintain the invariant that equal
269+
/// values have equal hashes.
270+
impl Hash for Scalar {
271+
fn hash<H: Hasher>(&self, state: &mut H) {
272+
self.dtype.as_nonnullable().hash(state);
273+
self.value.hash(state);
274+
}
275+
}
276+
266277
/// We implement `PartialEq` manually because we want to ignore nullability when comparing scalars.
267278
/// Two scalars with the same value but different nullability should be considered equal.
268279
///
@@ -288,7 +299,14 @@ impl PartialOrd for Scalar {
288299
/// - Non-null values are compared according to their natural ordering
289300
///
290301
/// # Examples
291-
/// ```ignore
302+
///
303+
/// ```
304+
/// use std::cmp::Ordering;
305+
/// use vortex_array::dtype::DType;
306+
/// use vortex_array::dtype::Nullability;
307+
/// use vortex_array::dtype::PType;
308+
/// use vortex_array::scalar::Scalar;
309+
///
292310
/// // Same types compare successfully
293311
/// let a = Scalar::primitive(10i32, Nullability::NonNullable);
294312
/// let b = Scalar::primitive(20i32, Nullability::NonNullable);
@@ -308,16 +326,101 @@ impl PartialOrd for Scalar {
308326
if !self.dtype().eq_ignore_nullability(other.dtype()) {
309327
return None;
310328
}
311-
self.value().partial_cmp(&other.value())
329+
330+
partial_cmp_scalar_values(self.dtype(), self.value(), other.value())
312331
}
313332
}
314333

315-
/// We implement `Hash` manually to be consistent with `PartialEq`. Since we ignore nullability
316-
/// in equality comparisons, we must also ignore it when hashing to maintain the invariant that
317-
/// equal values have equal hashes.
318-
impl Hash for Scalar {
319-
fn hash<H: Hasher>(&self, state: &mut H) {
320-
self.dtype.as_nonnullable().hash(state);
321-
self.value.hash(state);
334+
/// Compare two optional scalar values using `dtype` for nested tuple interpretation.
335+
fn partial_cmp_scalar_values(
336+
dtype: &DType,
337+
lhs: Option<&ScalarValue>,
338+
rhs: Option<&ScalarValue>,
339+
) -> Option<Ordering> {
340+
match (lhs, rhs) {
341+
(None, None) => Some(Ordering::Equal),
342+
(None, Some(_)) => Some(Ordering::Less),
343+
(Some(_), None) => Some(Ordering::Greater),
344+
(Some(lhs), Some(rhs)) => partial_cmp_non_null_scalar_values(dtype, lhs, rhs),
345+
}
346+
}
347+
348+
/// Compare two non-null scalar values, consulting `dtype` only for tuple-backed values.
349+
fn partial_cmp_non_null_scalar_values(
350+
dtype: &DType,
351+
lhs: &ScalarValue,
352+
rhs: &ScalarValue,
353+
) -> Option<Ordering> {
354+
// `Scalar::validate` guarantees that a scalar's value matches its dtype. Most of the scalar
355+
// value variants have only 1 method of comparison, regardless of the dtype.
356+
match (lhs, rhs) {
357+
(ScalarValue::Bool(lhs), ScalarValue::Bool(rhs)) => lhs.partial_cmp(rhs),
358+
(ScalarValue::Primitive(lhs), ScalarValue::Primitive(rhs)) => lhs.partial_cmp(rhs),
359+
(ScalarValue::Decimal(lhs), ScalarValue::Decimal(rhs)) => lhs.partial_cmp(rhs),
360+
(ScalarValue::Utf8(lhs), ScalarValue::Utf8(rhs)) => lhs.partial_cmp(rhs),
361+
(ScalarValue::Binary(lhs), ScalarValue::Binary(rhs)) => lhs.partial_cmp(rhs),
362+
// `Tuple` is the exception here. Since it backs lists, fixed-size lists, and structs, we
363+
// need the dtype to know whether children share one element dtype or use per-field dtypes.
364+
(ScalarValue::Tuple(lhs), ScalarValue::Tuple(rhs)) => {
365+
partial_cmp_tuple_values(dtype, lhs, rhs)
366+
}
367+
// Variant values can have a different dtype in each row, so it doesn't make sense to
368+
// compare them.
369+
(ScalarValue::Variant(_), ScalarValue::Variant(_)) => None,
370+
_ => None,
371+
}
372+
}
373+
374+
/// Compare tuple values according to the list, fixed-size list, or struct dtype layout.
375+
fn partial_cmp_tuple_values(
376+
dtype: &DType,
377+
lhs: &[Option<ScalarValue>],
378+
rhs: &[Option<ScalarValue>],
379+
) -> Option<Ordering> {
380+
match dtype {
381+
DType::List(element_dtype, _) | DType::FixedSizeList(element_dtype, ..) => {
382+
partial_cmp_list_values(element_dtype, lhs, rhs)
383+
}
384+
DType::Struct(fields, _) => partial_cmp_struct_values(fields, lhs, rhs),
385+
DType::Extension(ext_dtype) => {
386+
partial_cmp_tuple_values(ext_dtype.storage_dtype(), lhs, rhs)
387+
}
388+
_ => None,
389+
}
390+
}
391+
392+
/// Compare list tuple values using the shared element dtype for each element.
393+
fn partial_cmp_list_values(
394+
element_dtype: &DType,
395+
lhs: &[Option<ScalarValue>],
396+
rhs: &[Option<ScalarValue>],
397+
) -> Option<Ordering> {
398+
for (lhs, rhs) in lhs.iter().zip(rhs.iter()) {
399+
match partial_cmp_scalar_values(element_dtype, lhs.as_ref(), rhs.as_ref())? {
400+
Ordering::Equal => continue,
401+
ordering => return Some(ordering),
402+
}
403+
}
404+
405+
Some(lhs.len().cmp(&rhs.len()))
406+
}
407+
408+
/// Compare struct tuple values using each field's dtype in field order.
409+
fn partial_cmp_struct_values(
410+
fields: &StructFields,
411+
lhs: &[Option<ScalarValue>],
412+
rhs: &[Option<ScalarValue>],
413+
) -> Option<Ordering> {
414+
if lhs.len() != fields.nfields() || rhs.len() != fields.nfields() {
415+
return None;
322416
}
417+
418+
for ((field_dtype, lhs), rhs) in fields.fields().zip(lhs.iter()).zip(rhs.iter()) {
419+
match partial_cmp_scalar_values(&field_dtype, lhs.as_ref(), rhs.as_ref())? {
420+
Ordering::Equal => continue,
421+
ordering => return Some(ordering),
422+
}
423+
}
424+
425+
Some(Ordering::Equal)
323426
}

vortex-array/src/scalar/scalar_value.rs

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
//! Core [`ScalarValue`] type definition.
55
6-
use std::cmp::Ordering;
76
use std::fmt::Display;
87
use std::fmt::Formatter;
98

@@ -111,22 +110,6 @@ impl ScalarValue {
111110
}
112111
}
113112

114-
impl PartialOrd for ScalarValue {
115-
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
116-
match (self, other) {
117-
(ScalarValue::Bool(a), ScalarValue::Bool(b)) => a.partial_cmp(b),
118-
(ScalarValue::Primitive(a), ScalarValue::Primitive(b)) => a.partial_cmp(b),
119-
(ScalarValue::Decimal(a), ScalarValue::Decimal(b)) => a.partial_cmp(b),
120-
(ScalarValue::Utf8(a), ScalarValue::Utf8(b)) => a.partial_cmp(b),
121-
(ScalarValue::Binary(a), ScalarValue::Binary(b)) => a.partial_cmp(b),
122-
(ScalarValue::Tuple(a), ScalarValue::Tuple(b)) => a.partial_cmp(b),
123-
(ScalarValue::Variant(a), ScalarValue::Variant(b)) => a.partial_cmp(b),
124-
// (ScalarValue::Extension(a), ScalarValue::Extension(b)) => a.partial_cmp(b),
125-
_ => None,
126-
}
127-
}
128-
}
129-
130113
impl Display for ScalarValue {
131114
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
132115
match self {

0 commit comments

Comments
 (0)