diff --git a/datafusion/physical-plan/src/joins/utils.rs b/datafusion/physical-plan/src/joins/utils.rs index d3c8ccc11bcb9..ea95c1a10bd2d 100644 --- a/datafusion/physical-plan/src/joins/utils.rs +++ b/datafusion/physical-plan/src/joins/utils.rs @@ -43,7 +43,7 @@ pub use crate::joins::{JoinOn, JoinOnRef}; use arrow::array::{ Array, ArrowPrimitiveType, BooleanBufferBuilder, NativeAdapter, PrimitiveArray, RecordBatch, RecordBatchOptions, UInt32Array, UInt32Builder, UInt64Array, - builder::UInt64Builder, downcast_array, new_null_array, + builder::UInt64Builder, downcast_array, make_comparator, new_null_array, }; use arrow::array::{ ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array, @@ -53,13 +53,11 @@ use arrow::array::{ TimestampNanosecondArray, TimestampSecondArray, UInt8Array, UInt16Array, }; use arrow::buffer::{BooleanBuffer, NullBuffer}; -use arrow::compute::kernels::cmp::eq; -use arrow::compute::{self, FilterBuilder, and, take}; +use arrow::compute::{self, take}; use arrow::datatypes::{ ArrowNativeType, Field, Schema, SchemaBuilder, UInt32Type, UInt64Type, }; -use arrow_ord::cmp::not_distinct; -use arrow_schema::{ArrowError, DataType, SortOptions, TimeUnit}; +use arrow_schema::{DataType, SortOptions, TimeUnit}; use datafusion_common::cast::as_boolean_array; use datafusion_common::hash_utils::RandomState; use datafusion_common::hash_utils::create_hashes; @@ -68,7 +66,6 @@ use datafusion_common::{ DataFusionError, JoinSide, JoinType, NullEquality, Result, SharedResult, not_impl_err, plan_err, }; -use datafusion_expr::Operator; use datafusion_expr::interval_arithmetic::Interval; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::utils::collect_columns; @@ -77,7 +74,6 @@ use datafusion_physical_expr::{ add_offset_to_physical_sort_exprs, }; -use datafusion_physical_expr_common::datum::compare_op_for_nested; use datafusion_physical_expr_common::utils::evaluate_expressions_to_arrays; use futures::future::{BoxFuture, Shared}; use futures::{FutureExt, ready}; @@ -1767,59 +1763,63 @@ pub(super) fn equal_rows_arr( right_arrays: &[ArrayRef], null_equality: NullEquality, ) -> Result<(UInt64Array, UInt32Array)> { - let mut iter = left_arrays.iter().zip(right_arrays.iter()); - - let Some((first_left, first_right)) = iter.next() else { - return Ok((Vec::::new().into(), Vec::::new().into())); - }; - - let arr_left = take(first_left.as_ref(), indices_left, None)?; - let arr_right = take(first_right.as_ref(), indices_right, None)?; - - let mut equal: BooleanArray = eq_dyn_null(&arr_left, &arr_right, null_equality)?; - - // Use map and try_fold to iterate over the remaining pairs of arrays. - // In each iteration, take is used on the pair of arrays and their equality is determined. - // The results are then folded (combined) using the and function to get a final equality result. - equal = iter - .map(|(left, right)| { - let arr_left = take(left.as_ref(), indices_left, None)?; - let arr_right = take(right.as_ref(), indices_right, None)?; - eq_dyn_null(arr_left.as_ref(), arr_right.as_ref(), null_equality) - }) - .try_fold(equal, |acc, equal2| and(&acc, &equal2?))?; - - let filter_builder = FilterBuilder::new(&equal).optimize().build(); + let num_indices = indices_left.len(); + if num_indices == 0 || left_arrays.is_empty() { + return Ok(( + UInt64Array::from(Vec::::new()), + UInt32Array::from(Vec::::new()), + )); + } - let left_filtered = filter_builder.filter(indices_left)?; - let right_filtered = filter_builder.filter(indices_right)?; + let mut comparators = Vec::with_capacity(left_arrays.len()); + for (left, right) in left_arrays.iter().zip(right_arrays.iter()) { + comparators.push(make_comparator( + left.as_ref(), + right.as_ref(), + SortOptions::default(), + )?); + } - Ok(( - downcast_array(left_filtered.as_ref()), - downcast_array(right_filtered.as_ref()), - )) -} + let mut left_builder = UInt64Builder::with_capacity(num_indices); + let mut right_builder = UInt32Builder::with_capacity(num_indices); + + for i in 0..num_indices { + let left_idx = indices_left.value(i) as usize; + let right_idx = indices_right.value(i) as usize; + + let mut is_equal = true; + for (col_idx, cmp) in comparators.iter().enumerate() { + let left_arr = left_arrays.get(col_idx).unwrap(); + let right_arr = right_arrays.get(col_idx).unwrap(); + let left_null = left_arr.data_type() == &DataType::Null || left_arr.is_null(left_idx); + let right_null = right_arr.data_type() == &DataType::Null || right_arr.is_null(right_idx); + + match (null_equality, left_null, right_null) { + (NullEquality::NullEqualsNull, true, true) => continue, // Nulls match + (NullEquality::NullEqualsNothing, true, _) | (NullEquality::NullEqualsNothing, _, true) => { + is_equal = false; // Nulls never match + break; + } + (_, true, false) | (_, false, true) => { + is_equal = false; // Different null states + break; + } + (_, false, false) => { + if cmp(left_idx, right_idx) != Ordering::Equal { + is_equal = false; + break; + } + } + } + } -// version of eq_dyn supporting equality on null arrays -fn eq_dyn_null( - left: &dyn Array, - right: &dyn Array, - null_equality: NullEquality, -) -> Result { - // Nested datatypes cannot use the underlying not_distinct/eq function and must use a special - // implementation - // - if left.data_type().is_nested() { - let op = match null_equality { - NullEquality::NullEqualsNothing => Operator::Eq, - NullEquality::NullEqualsNull => Operator::IsNotDistinctFrom, - }; - return Ok(compare_op_for_nested(op, &left, &right)?); - } - match null_equality { - NullEquality::NullEqualsNothing => eq(&left, &right), - NullEquality::NullEqualsNull => not_distinct(&left, &right), + if is_equal { + left_builder.append_value(indices_left.value(i)); + right_builder.append_value(indices_right.value(i)); + } } + + Ok((left_builder.finish(), right_builder.finish())) } /// Get comparison result of two rows of join arrays @@ -2949,4 +2949,58 @@ mod tests { let result = max_distinct_count(&num_rows, &stats); assert_eq!(result, Exact(0)); } + + #[test] + fn test_equal_rows_arr() -> Result<()> { + let left_col = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])) as ArrayRef; + let right_col = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 6])) as ArrayRef; + + let indices_left = UInt64Array::from(vec![0, 1, 2, 3, 4]); + let indices_right = UInt32Array::from(vec![0, 1, 2, 3, 4]); + + // Test NullEqualsNothing + let (res_left, res_right) = equal_rows_arr( + &indices_left, + &indices_right, + &[Arc::clone(&left_col)], + &[Arc::clone(&right_col)], + NullEquality::NullEqualsNothing, + )?; + + assert_eq!(res_left, UInt64Array::from(vec![0, 1, 2, 3])); + assert_eq!(res_right, UInt32Array::from(vec![0, 1, 2, 3])); + + // Test with NULLs + let left_col = + Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])) as ArrayRef; + let right_col = + Arc::new(Int32Array::from(vec![Some(1), None, Some(4)])) as ArrayRef; + + let indices_left = UInt64Array::from(vec![0, 1, 2]); + let indices_right = UInt32Array::from(vec![0, 1, 2]); + + // NullEqualsNothing: NULL != NULL + let (res_left, res_right) = equal_rows_arr( + &indices_left, + &indices_right, + &[Arc::clone(&left_col)], + &[Arc::clone(&right_col)], + NullEquality::NullEqualsNothing, + )?; + assert_eq!(res_left, UInt64Array::from(vec![0])); + assert_eq!(res_right, UInt32Array::from(vec![0])); + + // NullEqualsNull: NULL == NULL + let (res_left, res_right) = equal_rows_arr( + &indices_left, + &indices_right, + &[Arc::clone(&left_col)], + &[Arc::clone(&right_col)], + NullEquality::NullEqualsNull, + )?; + assert_eq!(res_left, UInt64Array::from(vec![0, 1])); + assert_eq!(res_right, UInt32Array::from(vec![0, 1])); + + Ok(()) + } } diff --git a/test_null_join.rs b/test_null_join.rs new file mode 100644 index 0000000000000..8b2bb77254bdd --- /dev/null +++ b/test_null_join.rs @@ -0,0 +1,12 @@ +use datafusion::prelude::*; +use datafusion_common::Result; +use std::sync::Arc; + +#[tokio::main] +async fn main() -> Result<()> { + let ctx = SessionContext::new(); + let sql = "SELECT * FROM (SELECT null AS id1) t1 INNER JOIN (SELECT null AS id2) t2 ON id1 = id2"; + let df = ctx.sql(sql).await?; + df.show().await?; + Ok(()) +}