Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 110 additions & 56 deletions datafusion/physical-plan/src/joins/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ pub use crate::joins::{JoinOn, JoinOnRef};
use arrow::array::{
Array, ArrowPrimitiveType, BooleanBufferBuilder, NativeAdapter, PrimitiveArray,
RecordBatch, RecordBatchOptions, UInt32Array, UInt32Builder, UInt64Array,
builder::UInt64Builder, downcast_array, new_null_array,
builder::UInt64Builder, downcast_array, make_comparator, new_null_array,
};
use arrow::array::{
ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array,
Expand All @@ -53,13 +53,11 @@ use arrow::array::{
TimestampNanosecondArray, TimestampSecondArray, UInt8Array, UInt16Array,
};
use arrow::buffer::{BooleanBuffer, NullBuffer};
use arrow::compute::kernels::cmp::eq;
use arrow::compute::{self, FilterBuilder, and, take};
use arrow::compute::{self, take};
use arrow::datatypes::{
ArrowNativeType, Field, Schema, SchemaBuilder, UInt32Type, UInt64Type,
};
use arrow_ord::cmp::not_distinct;
use arrow_schema::{ArrowError, DataType, SortOptions, TimeUnit};
use arrow_schema::{DataType, SortOptions, TimeUnit};
use datafusion_common::cast::as_boolean_array;
use datafusion_common::hash_utils::RandomState;
use datafusion_common::hash_utils::create_hashes;
Expand All @@ -68,7 +66,6 @@ use datafusion_common::{
DataFusionError, JoinSide, JoinType, NullEquality, Result, SharedResult,
not_impl_err, plan_err,
};
use datafusion_expr::Operator;
use datafusion_expr::interval_arithmetic::Interval;
use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::utils::collect_columns;
Expand All @@ -77,7 +74,6 @@ use datafusion_physical_expr::{
add_offset_to_physical_sort_exprs,
};

use datafusion_physical_expr_common::datum::compare_op_for_nested;
use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays;
use futures::future::{BoxFuture, Shared};
use futures::{FutureExt, ready};
Expand Down Expand Up @@ -1767,59 +1763,63 @@ pub(super) fn equal_rows_arr(
right_arrays: &[ArrayRef],
null_equality: NullEquality,
) -> Result<(UInt64Array, UInt32Array)> {
let mut iter = left_arrays.iter().zip(right_arrays.iter());

let Some((first_left, first_right)) = iter.next() else {
return Ok((Vec::<u64>::new().into(), Vec::<u32>::new().into()));
};

let arr_left = take(first_left.as_ref(), indices_left, None)?;
let arr_right = take(first_right.as_ref(), indices_right, None)?;

let mut equal: BooleanArray = eq_dyn_null(&arr_left, &arr_right, null_equality)?;

// Use map and try_fold to iterate over the remaining pairs of arrays.
// In each iteration, take is used on the pair of arrays and their equality is determined.
// The results are then folded (combined) using the and function to get a final equality result.
equal = iter
.map(|(left, right)| {
let arr_left = take(left.as_ref(), indices_left, None)?;
let arr_right = take(right.as_ref(), indices_right, None)?;
eq_dyn_null(arr_left.as_ref(), arr_right.as_ref(), null_equality)
})
.try_fold(equal, |acc, equal2| and(&acc, &equal2?))?;

let filter_builder = FilterBuilder::new(&equal).optimize().build();
let num_indices = indices_left.len();
if num_indices == 0 || left_arrays.is_empty() {
return Ok((
UInt64Array::from(Vec::<u64>::new()),
UInt32Array::from(Vec::<u32>::new()),
));
}

let left_filtered = filter_builder.filter(indices_left)?;
let right_filtered = filter_builder.filter(indices_right)?;
let mut comparators = Vec::with_capacity(left_arrays.len());
for (left, right) in left_arrays.iter().zip(right_arrays.iter()) {
comparators.push(make_comparator(
left.as_ref(),
right.as_ref(),
SortOptions::default(),
)?);
}

Ok((
downcast_array(left_filtered.as_ref()),
downcast_array(right_filtered.as_ref()),
))
}
let mut left_builder = UInt64Builder::with_capacity(num_indices);
let mut right_builder = UInt32Builder::with_capacity(num_indices);

for i in 0..num_indices {
let left_idx = indices_left.value(i) as usize;
let right_idx = indices_right.value(i) as usize;

let mut is_equal = true;
for (col_idx, cmp) in comparators.iter().enumerate() {
let left_arr = left_arrays.get(col_idx).unwrap();
let right_arr = right_arrays.get(col_idx).unwrap();
let left_null = left_arr.data_type() == &DataType::Null || left_arr.is_null(left_idx);
let right_null = right_arr.data_type() == &DataType::Null || right_arr.is_null(right_idx);

match (null_equality, left_null, right_null) {
(NullEquality::NullEqualsNull, true, true) => continue, // Nulls match
(NullEquality::NullEqualsNothing, true, _) | (NullEquality::NullEqualsNothing, _, true) => {
is_equal = false; // Nulls never match
break;
}
(_, true, false) | (_, false, true) => {
is_equal = false; // Different null states
break;
}
(_, false, false) => {
if cmp(left_idx, right_idx) != Ordering::Equal {
is_equal = false;
break;
}
}
}
}

// version of eq_dyn supporting equality on null arrays
fn eq_dyn_null(
left: &dyn Array,
right: &dyn Array,
null_equality: NullEquality,
) -> Result<BooleanArray, ArrowError> {
// Nested datatypes cannot use the underlying not_distinct/eq function and must use a special
// implementation
// <https://github.com/apache/datafusion/issues/10749>
if left.data_type().is_nested() {
let op = match null_equality {
NullEquality::NullEqualsNothing => Operator::Eq,
NullEquality::NullEqualsNull => Operator::IsNotDistinctFrom,
};
return Ok(compare_op_for_nested(op, &left, &right)?);
}
match null_equality {
NullEquality::NullEqualsNothing => eq(&left, &right),
NullEquality::NullEqualsNull => not_distinct(&left, &right),
if is_equal {
left_builder.append_value(indices_left.value(i));
right_builder.append_value(indices_right.value(i));
}
}

Ok((left_builder.finish(), right_builder.finish()))
}

/// Get comparison result of two rows of join arrays
Expand Down Expand Up @@ -2949,4 +2949,58 @@ mod tests {
let result = max_distinct_count(&num_rows, &stats);
assert_eq!(result, Exact(0));
}

#[test]
fn test_equal_rows_arr() -> Result<()> {
let left_col = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])) as ArrayRef;
let right_col = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 6])) as ArrayRef;

let indices_left = UInt64Array::from(vec![0, 1, 2, 3, 4]);
let indices_right = UInt32Array::from(vec![0, 1, 2, 3, 4]);

// Test NullEqualsNothing
let (res_left, res_right) = equal_rows_arr(
&indices_left,
&indices_right,
&[Arc::clone(&left_col)],
&[Arc::clone(&right_col)],
NullEquality::NullEqualsNothing,
)?;

assert_eq!(res_left, UInt64Array::from(vec![0, 1, 2, 3]));
assert_eq!(res_right, UInt32Array::from(vec![0, 1, 2, 3]));

// Test with NULLs
let left_col =
Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])) as ArrayRef;
let right_col =
Arc::new(Int32Array::from(vec![Some(1), None, Some(4)])) as ArrayRef;

let indices_left = UInt64Array::from(vec![0, 1, 2]);
let indices_right = UInt32Array::from(vec![0, 1, 2]);

// NullEqualsNothing: NULL != NULL
let (res_left, res_right) = equal_rows_arr(
&indices_left,
&indices_right,
&[Arc::clone(&left_col)],
&[Arc::clone(&right_col)],
NullEquality::NullEqualsNothing,
)?;
assert_eq!(res_left, UInt64Array::from(vec![0]));
assert_eq!(res_right, UInt32Array::from(vec![0]));

// NullEqualsNull: NULL == NULL
let (res_left, res_right) = equal_rows_arr(
&indices_left,
&indices_right,
&[Arc::clone(&left_col)],
&[Arc::clone(&right_col)],
NullEquality::NullEqualsNull,
)?;
assert_eq!(res_left, UInt64Array::from(vec![0, 1]));
assert_eq!(res_right, UInt32Array::from(vec![0, 1]));

Ok(())
}
}
12 changes: 12 additions & 0 deletions test_null_join.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
use datafusion::prelude::*;
use datafusion_common::Result;
use std::sync::Arc;

#[tokio::main]
async fn main() -> Result<()> {
let ctx = SessionContext::new();
let sql = "SELECT * FROM (SELECT null AS id1) t1 INNER JOIN (SELECT null AS id2) t2 ON id1 = id2";
let df = ctx.sql(sql).await?;
df.show().await?;
Ok(())
}
Loading