Skip to content

Commit d0d400c

Browse files
authored
Use AllNonDistinct in assert_arrays_eq and implement it for variant types (#8546)
AllNonDistinct is exactly the semantics we want, it should be faster and in case of failure we run the slow method to produce precise error message Signed-off-by: Robert Kruszewski <github@robertk.io>
1 parent f41dc23 commit d0d400c

9 files changed

Lines changed: 233 additions & 6 deletions

File tree

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use std::fmt::Debug;
5+
6+
use vortex_array::ArrayRef;
7+
use vortex_array::ExecutionCtx;
8+
use vortex_array::aggregate_fn::AggregateFnRef;
9+
use vortex_array::aggregate_fn::fns::all_non_distinct::AllNonDistinct;
10+
use vortex_array::aggregate_fn::fns::all_non_distinct::all_non_distinct;
11+
use vortex_array::aggregate_fn::kernels::DynAggregateKernel;
12+
use vortex_array::arrays::Struct;
13+
use vortex_array::arrays::struct_::StructArrayExt;
14+
use vortex_array::dtype::Nullability;
15+
use vortex_array::scalar::Scalar;
16+
use vortex_error::VortexResult;
17+
18+
use crate::ParquetVariant;
19+
use crate::ParquetVariantArrayExt;
20+
21+
/// Lets `AllNonDistinct` compare two `ParquetVariant` arrays without canonicalizing them.
22+
///
23+
/// `AllNonDistinct` accumulates over a `Struct{lhs, rhs}` batch, so this kernel is registered for
24+
/// the struct encoding and inspects the two children. When both are `ParquetVariant`, we compare
25+
/// the typed (`typed_value`) arrays if both sides are shredded, and fall back to the raw `value`
26+
/// arrays otherwise. Comparing these child arrays directly avoids re-canonicalizing the variant
27+
/// (which would recurse through the `Variant` canonical form).
28+
#[derive(Debug)]
29+
pub struct AllNonDistinctParquetVariant;
30+
31+
impl DynAggregateKernel for AllNonDistinctParquetVariant {
32+
fn aggregate(
33+
&self,
34+
aggregate_fn: &AggregateFnRef,
35+
batch: &ArrayRef,
36+
ctx: &mut ExecutionCtx,
37+
) -> VortexResult<Option<Scalar>> {
38+
if !aggregate_fn.is::<AllNonDistinct>() {
39+
return Ok(None);
40+
}
41+
42+
let Some(batch) = batch.as_opt::<Struct>() else {
43+
return Ok(None);
44+
};
45+
let lhs = batch.unmasked_field(0);
46+
let rhs = batch.unmasked_field(1);
47+
let (Some(lhs), Some(rhs)) = (
48+
lhs.as_opt::<ParquetVariant>(),
49+
rhs.as_opt::<ParquetVariant>(),
50+
) else {
51+
return Ok(None);
52+
};
53+
54+
let typed_identical = match (lhs.typed_value_array(), rhs.typed_value_array()) {
55+
(Some(lhs_typed), Some(rhs_typed)) => {
56+
if lhs_typed.dtype().eq_ignore_nullability(rhs_typed.dtype()) {
57+
all_non_distinct(lhs_typed, rhs_typed, ctx)?
58+
} else {
59+
return Ok(None);
60+
}
61+
}
62+
_ => true,
63+
};
64+
65+
if typed_identical {
66+
let values_identical = match (lhs.value_array(), rhs.value_array()) {
67+
(Some(lhs_value), Some(rhs_value)) => all_non_distinct(lhs_value, rhs_value, ctx)?,
68+
(None, None) => true,
69+
// Mixed shredding layouts: let the generic canonical path handle it.
70+
_ => return Ok(None),
71+
};
72+
Ok(Some(Scalar::bool(
73+
values_identical,
74+
Nullability::NonNullable,
75+
)))
76+
} else {
77+
Ok(Some(Scalar::bool(false, Nullability::NonNullable)))
78+
}
79+
}
80+
}
81+
82+
#[cfg(test)]
83+
mod tests {
84+
use std::sync::LazyLock;
85+
86+
use vortex_array::ArrayRef;
87+
use vortex_array::IntoArray;
88+
use vortex_array::VortexSessionExecute;
89+
use vortex_array::aggregate_fn::fns::all_non_distinct::all_non_distinct;
90+
use vortex_array::arrays::VarBinViewArray;
91+
use vortex_array::validity::Validity;
92+
use vortex_buffer::buffer;
93+
use vortex_error::VortexResult;
94+
use vortex_session::VortexSession;
95+
96+
use crate::ParquetVariant;
97+
98+
static SESSION: LazyLock<VortexSession> = LazyLock::new(|| {
99+
let session = vortex_array::array_session();
100+
crate::initialize(&session);
101+
session
102+
});
103+
104+
/// Non-nullable, minimally-valid metadata column of `len` rows.
105+
fn metadata(len: usize) -> ArrayRef {
106+
VarBinViewArray::from_iter_bin(vec![b"\x01\x00"; len]).into_array()
107+
}
108+
109+
/// Non-nullable binary `value` column.
110+
fn binary<T: AsRef<[u8]>>(values: impl IntoIterator<Item = T>) -> ArrayRef {
111+
VarBinViewArray::from_iter_bin(values).into_array()
112+
}
113+
114+
fn parquet_variant(
115+
len: usize,
116+
value: Option<ArrayRef>,
117+
typed_value: Option<ArrayRef>,
118+
) -> VortexResult<ArrayRef> {
119+
Ok(
120+
ParquetVariant::try_new(Validity::NonNullable, metadata(len), value, typed_value)?
121+
.into_array(),
122+
)
123+
}
124+
125+
#[test]
126+
fn all_non_distinct_matches_equal_unshredded() -> VortexResult<()> {
127+
let lhs = parquet_variant(2, Some(binary([b"\x10", b"\x11"])), None)?;
128+
let rhs = parquet_variant(2, Some(binary([b"\x10", b"\x11"])), None)?;
129+
let mut ctx = SESSION.create_execution_ctx();
130+
assert!(all_non_distinct(&lhs, &rhs, &mut ctx)?);
131+
Ok(())
132+
}
133+
134+
#[test]
135+
fn all_non_distinct_detects_distinct_unshredded() -> VortexResult<()> {
136+
let lhs = parquet_variant(2, Some(binary([b"\x10", b"\x11"])), None)?;
137+
let rhs = parquet_variant(2, Some(binary([b"\x10", b"\x12"])), None)?;
138+
let mut ctx = SESSION.create_execution_ctx();
139+
assert!(!all_non_distinct(&lhs, &rhs, &mut ctx)?);
140+
Ok(())
141+
}
142+
143+
#[test]
144+
fn all_non_distinct_matches_equal_value_and_typed() -> VortexResult<()> {
145+
let typed = || buffer![1i32, 2].into_array();
146+
let lhs = parquet_variant(2, Some(binary([b"\x10", b"\x11"])), Some(typed()))?;
147+
let rhs = parquet_variant(2, Some(binary([b"\x10", b"\x11"])), Some(typed()))?;
148+
let mut ctx = SESSION.create_execution_ctx();
149+
assert!(all_non_distinct(&lhs, &rhs, &mut ctx)?);
150+
Ok(())
151+
}
152+
153+
#[test]
154+
fn all_non_distinct_empty_is_true() -> VortexResult<()> {
155+
let lhs = parquet_variant(0, Some(binary(Vec::<&[u8]>::new())), None)?;
156+
let rhs = parquet_variant(0, Some(binary(Vec::<&[u8]>::new())), None)?;
157+
let mut ctx = SESSION.create_execution_ctx();
158+
assert!(all_non_distinct(&lhs, &rhs, &mut ctx)?);
159+
Ok(())
160+
}
161+
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
mod allnondistinct;
5+
6+
pub use allnondistinct::*;

encodings/parquet-variant/src/kernel.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,14 @@ use vortex_array::ArrayVTable;
2020
use vortex_array::ArrayView;
2121
use vortex_array::ExecutionCtx;
2222
use vortex_array::IntoArray;
23+
use vortex_array::aggregate_fn::AggregateFnVTable;
24+
use vortex_array::aggregate_fn::fns::all_non_distinct::AllNonDistinct;
25+
use vortex_array::aggregate_fn::session::AggregateFnSessionExt;
2326
use vortex_array::arrays::Dict;
2427
use vortex_array::arrays::Extension;
2528
use vortex_array::arrays::Filter;
2629
use vortex_array::arrays::Slice;
30+
use vortex_array::arrays::Struct;
2731
use vortex_array::arrays::dict::TakeExecute;
2832
use vortex_array::arrays::dict::TakeExecuteAdaptor;
2933
use vortex_array::arrays::extension::ExtensionArrayExt;
@@ -52,6 +56,7 @@ use vortex_session::VortexSession;
5256

5357
use crate::ParquetVariant;
5458
use crate::ParquetVariantArrayExt;
59+
use crate::compute::AllNonDistinctParquetVariant;
5560

5661
pub(crate) fn initialize(session: &VortexSession) {
5762
let kernels = session.kernels();
@@ -76,6 +81,12 @@ pub(crate) fn initialize(session: &VortexSession) {
7681
Extension,
7782
JsonExtensionToVariantKernel,
7883
);
84+
let aggregates = session.aggregate_fns();
85+
aggregates.register_aggregate_kernel(
86+
Struct.id(),
87+
Some(AllNonDistinct.id()),
88+
&AllNonDistinctParquetVariant,
89+
);
7990
}
8091

8192
#[derive(Default, Debug)]

encodings/parquet-variant/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
2727
mod array;
2828
mod arrow;
29+
mod compute;
2930
#[cfg(test)]
3031
mod json_to_variant_tests;
3132
mod kernel;

vortex-array/src/aggregate_fn/fns/all_non_distinct/mod.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ mod struct_;
1212
#[cfg(test)]
1313
mod tests;
1414
mod varbin;
15+
mod variant;
1516

1617
use std::sync::LazyLock;
1718

@@ -38,6 +39,7 @@ use crate::aggregate_fn::AggregateFnId;
3839
use crate::aggregate_fn::AggregateFnVTable;
3940
use crate::aggregate_fn::DynAccumulator;
4041
use crate::aggregate_fn::EmptyOptions;
42+
use crate::aggregate_fn::fns::all_non_distinct::variant::check_variant_identical;
4143
use crate::arrays::StructArray;
4244
use crate::arrays::struct_::StructArrayExt;
4345
use crate::dtype::DType;
@@ -264,8 +266,8 @@ fn check_canonical_identical(
264266
(Canonical::Extension(lhs), Canonical::Extension(rhs)) => {
265267
check_extension_identical(lhs, rhs, ctx)
266268
}
267-
(Canonical::Variant(_), _) | (_, Canonical::Variant(_)) => {
268-
vortex_bail!("Variant arrays don't support AllNonDistinct")
269+
(Canonical::Variant(lhs), Canonical::Variant(rhs)) => {
270+
check_variant_identical(lhs, rhs, ctx)
269271
}
270272
_ => Err(vortex_err!(
271273
"Canonical type mismatch in AllNonDistinct: {:?} vs {:?}",

vortex-array/src/aggregate_fn/fns/all_non_distinct/primitive.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// SPDX-License-Identifier: Apache-2.0
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
use vortex_array::dtype::NativePType;
45
use vortex_error::VortexResult;
56

67
use crate::arrays::primitive::PrimitiveArrayExt;
@@ -12,6 +13,10 @@ where
1213
R: PrimitiveArrayExt,
1314
{
1415
match_each_native_ptype!(lhs.ptype(), |P| {
15-
Ok(lhs.as_slice::<P>() == rhs.as_slice::<P>())
16+
Ok(lhs
17+
.as_slice::<P>()
18+
.iter()
19+
.zip(rhs.as_slice::<P>())
20+
.all(|(l, r)| l.is_eq(*r)))
1621
})
1722
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use vortex_error::VortexResult;
5+
6+
use crate::ExecutionCtx;
7+
use crate::arrays::VariantArray;
8+
9+
/// Checks whether two canonical variant arrays are element-wise non-distinct.
10+
///
11+
/// Variant values cannot be routed back through [`all_non_distinct`]: canonicalizing a variant
12+
/// value array yields another canonical variant (with no shredded tree), which would recurse
13+
/// forever. The generic fallback therefore compares logical variant scalars row-by-row. Encodings
14+
/// that can compare their typed/value children more cheaply (e.g. `ParquetVariant`) register an
15+
/// aggregate kernel that intercepts the comparison before it reaches this fallback.
16+
///
17+
/// [`all_non_distinct`]: super::all_non_distinct
18+
pub(super) fn check_variant_identical(
19+
lhs: &VariantArray,
20+
rhs: &VariantArray,
21+
ctx: &mut ExecutionCtx,
22+
) -> VortexResult<bool> {
23+
if lhs.len() != rhs.len() {
24+
return Ok(false);
25+
}
26+
for idx in 0..lhs.len() {
27+
if lhs.execute_scalar(idx, ctx)? != rhs.execute_scalar(idx, ctx)? {
28+
return Ok(false);
29+
}
30+
}
31+
Ok(true)
32+
}

vortex-array/src/arrays/assertions.rs

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use crate::ArrayRef;
1010
use crate::ExecutionCtx;
1111
use crate::IntoArray;
1212
use crate::RecursiveCanonical;
13+
use crate::aggregate_fn::fns::all_non_distinct::all_non_distinct;
1314

1415
fn format_indices<I: IntoIterator<Item = usize>>(indices: I) -> impl Display {
1516
indices.into_iter().format(",")
@@ -117,17 +118,23 @@ macro_rules! assert_arrays_eq {
117118
pub fn assert_arrays_eq_impl(left: &ArrayRef, right: &ArrayRef, ctx: &mut ExecutionCtx) {
118119
let executed = execute_to_canonical(left.clone(), ctx);
119120

120-
let left_right = find_mismatched_indices(left, right, ctx);
121-
let executed_right = find_mismatched_indices(&executed, right, ctx);
121+
let left_right_the_same =
122+
all_non_distinct(left, right, ctx).vortex_expect("failed to compare left and right");
123+
let executed_right_the_same = all_non_distinct(&executed, right, ctx)
124+
.vortex_expect("failed to compare executed left and right");
125+
126+
if !left_right_the_same || !executed_right_the_same {
127+
let left_right = find_mismatched_indices(left, right, ctx);
122128

123-
if !left_right.is_empty() || !executed_right.is_empty() {
124129
let mut msg = String::new();
125130
if !left_right.is_empty() {
126131
msg.push_str(&format!(
127132
"\n left != right at indices: {}",
128133
format_indices(left_right)
129134
));
130135
}
136+
137+
let executed_right = find_mismatched_indices(&executed, right, ctx);
131138
if !executed_right.is_empty() {
132139
msg.push_str(&format!(
133140
"\n executed != right at indices: {}",

vortex-ipc/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ pub mod stream;
1919
mod test {
2020
use std::sync::LazyLock;
2121

22+
use vortex_array::aggregate_fn::session::AggregateFnSession;
2223
use vortex_array::dtype::session::DTypeSession;
2324
use vortex_array::optimizer::kernels::KernelSession;
2425
use vortex_array::session::ArraySession;
@@ -29,6 +30,7 @@ mod test {
2930
.with::<DTypeSession>()
3031
.with::<ArraySession>()
3132
.with::<KernelSession>()
33+
.with::<AggregateFnSession>()
3234
.build()
3335
});
3436
}

0 commit comments

Comments
 (0)