Skip to content

Commit a016ed6

Browse files
adriangbLiaCastaneda
authored andcommitted
Refactor HashJoinExec to progressively accumulate dynamic filter bounds instead of computing them after data is accumulated (apache#17444)
(cherry picked from commit 5b833b9)
1 parent 013d4ad commit a016ed6

3 files changed

Lines changed: 159 additions & 34 deletions

File tree

Cargo.lock

Lines changed: 0 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/physical-plan/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ datafusion-common = { workspace = true, default-features = true }
5353
datafusion-common-runtime = { workspace = true, default-features = true }
5454
datafusion-execution = { workspace = true }
5555
datafusion-expr = { workspace = true }
56-
datafusion-functions-aggregate-common = { workspace = true }
56+
datafusion-functions-aggregate = { workspace = true }
5757
datafusion-functions-window-common = { workspace = true }
5858
datafusion-physical-expr = { workspace = true, default-features = true }
5959
datafusion-physical-expr-common = { workspace = true }

datafusion/physical-plan/src/joins/hash_join.rs

Lines changed: 158 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ use arrow::datatypes::{Schema, SchemaRef};
7272
use arrow::error::ArrowError;
7373
use arrow::record_batch::RecordBatch;
7474
use arrow::util::bit_util;
75+
use arrow_schema::DataType;
7576
use datafusion_common::config::ConfigOptions;
7677
use datafusion_common::utils::memory::estimate_memory_size;
7778
use datafusion_common::{
@@ -80,8 +81,9 @@ use datafusion_common::{
8081
};
8182
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
8283
use datafusion_execution::TaskContext;
84+
use datafusion_expr::Accumulator;
8385
use datafusion_expr::Operator;
84-
use datafusion_functions_aggregate_common::min_max::{max_batch, min_batch};
86+
use datafusion_functions_aggregate::min_max::{MaxAccumulator, MinAccumulator};
8587
use datafusion_physical_expr::equivalence::{
8688
join_equivalence_properties, ProjectionMapping,
8789
};
@@ -1430,29 +1432,123 @@ impl ExecutionPlan for HashJoinExec {
14301432
}
14311433
}
14321434

1433-
/// Compute min/max bounds for each column in the given arrays
1434-
fn compute_bounds(arrays: &[ArrayRef]) -> Result<Vec<ColumnBounds>> {
1435-
arrays
1436-
.iter()
1437-
.map(|array| {
1438-
if array.is_empty() {
1439-
// Return NULL values for empty arrays
1440-
return Ok(ColumnBounds::new(
1441-
ScalarValue::try_from(array.data_type())?,
1442-
ScalarValue::try_from(array.data_type())?,
1443-
));
1435+
/// Accumulator for collecting min/max bounds from build-side data during hash join.
1436+
///
1437+
/// This struct encapsulates the logic for progressively computing column bounds
1438+
/// (minimum and maximum values) for a specific join key expression as batches
1439+
/// are processed during the build phase of a hash join.
1440+
///
1441+
/// The bounds are used for dynamic filter pushdown optimization, where filters
1442+
/// based on the actual data ranges can be pushed down to the probe side to
1443+
/// eliminate unnecessary data early.
1444+
struct CollectLeftAccumulator {
1445+
/// The physical expression to evaluate for each batch
1446+
expr: Arc<dyn PhysicalExpr>,
1447+
/// Accumulator for tracking the minimum value across all batches
1448+
min: MinAccumulator,
1449+
/// Accumulator for tracking the maximum value across all batches
1450+
max: MaxAccumulator,
1451+
}
1452+
1453+
impl CollectLeftAccumulator {
1454+
/// Creates a new accumulator for tracking bounds of a join key expression.
1455+
///
1456+
/// # Arguments
1457+
/// * `expr` - The physical expression to track bounds for
1458+
/// * `schema` - The schema of the input data
1459+
///
1460+
/// # Returns
1461+
/// A new `CollectLeftAccumulator` instance configured for the expression's data type
1462+
fn try_new(expr: Arc<dyn PhysicalExpr>, schema: &SchemaRef) -> Result<Self> {
1463+
/// Recursively unwraps dictionary types to get the underlying value type.
1464+
fn dictionary_value_type(data_type: &DataType) -> DataType {
1465+
match data_type {
1466+
DataType::Dictionary(_, value_type) => {
1467+
dictionary_value_type(value_type.as_ref())
1468+
}
1469+
_ => data_type.clone(),
14441470
}
1471+
}
1472+
1473+
let data_type = expr
1474+
.data_type(schema)
1475+
// Min/Max can operate on dictionary data but expect to be initialized with the underlying value type
1476+
.map(|dt| dictionary_value_type(&dt))?;
1477+
Ok(Self {
1478+
expr,
1479+
min: MinAccumulator::try_new(&data_type)?,
1480+
max: MaxAccumulator::try_new(&data_type)?,
1481+
})
1482+
}
14451483

1446-
// Use Arrow kernels for efficient min/max computation
1447-
let min_val = min_batch(array)?;
1448-
let max_val = max_batch(array)?;
1484+
/// Updates the accumulators with values from a new batch.
1485+
///
1486+
/// Evaluates the expression on the batch and updates both min and max
1487+
/// accumulators with the resulting values.
1488+
///
1489+
/// # Arguments
1490+
/// * `batch` - The record batch to process
1491+
///
1492+
/// # Returns
1493+
/// Ok(()) if the update succeeds, or an error if expression evaluation fails
1494+
fn update_batch(&mut self, batch: &RecordBatch) -> Result<()> {
1495+
let array = self.expr.evaluate(batch)?.into_array(batch.num_rows())?;
1496+
self.min.update_batch(std::slice::from_ref(&array))?;
1497+
self.max.update_batch(std::slice::from_ref(&array))?;
1498+
Ok(())
1499+
}
14491500

1450-
Ok(ColumnBounds::new(min_val, max_val))
1501+
/// Finalizes the accumulation and returns the computed bounds.
1502+
///
1503+
/// Consumes self to extract the final min and max values from the accumulators.
1504+
///
1505+
/// # Returns
1506+
/// The `ColumnBounds` containing the minimum and maximum values observed
1507+
fn evaluate(mut self) -> Result<ColumnBounds> {
1508+
Ok(ColumnBounds::new(
1509+
self.min.evaluate()?,
1510+
self.max.evaluate()?,
1511+
))
1512+
}
1513+
}
1514+
1515+
/// State for collecting the build-side data during hash join
1516+
struct BuildSideState {
1517+
batches: Vec<RecordBatch>,
1518+
num_rows: usize,
1519+
metrics: BuildProbeJoinMetrics,
1520+
reservation: MemoryReservation,
1521+
bounds_accumulators: Option<Vec<CollectLeftAccumulator>>,
1522+
}
1523+
1524+
impl BuildSideState {
1525+
/// Create a new BuildSideState with optional accumulators for bounds computation
1526+
fn try_new(
1527+
metrics: BuildProbeJoinMetrics,
1528+
reservation: MemoryReservation,
1529+
on_left: Vec<Arc<dyn PhysicalExpr>>,
1530+
schema: &SchemaRef,
1531+
should_compute_bounds: bool,
1532+
) -> Result<Self> {
1533+
Ok(Self {
1534+
batches: Vec::new(),
1535+
num_rows: 0,
1536+
metrics,
1537+
reservation,
1538+
bounds_accumulators: should_compute_bounds
1539+
.then(|| {
1540+
on_left
1541+
.iter()
1542+
.map(|expr| {
1543+
CollectLeftAccumulator::try_new(Arc::clone(expr), schema)
1544+
})
1545+
.collect::<Result<Vec<_>>>()
1546+
})
1547+
.transpose()?,
14511548
})
1452-
.collect()
1549+
}
14531550
}
14541551

1455-
#[expect(clippy::too_many_arguments)]
14561552
/// Collects all batches from the left (build) side stream and creates a hash map for joining.
14571553
///
14581554
/// This function is responsible for:
@@ -1481,6 +1577,7 @@ fn compute_bounds(arrays: &[ArrayRef]) -> Result<Vec<ColumnBounds>> {
14811577
/// # Returns
14821578
/// `JoinLeftData` containing the hash map, consolidated batch, join key values,
14831579
/// visited indices bitmap, and computed bounds (if requested).
1580+
#[allow(clippy::too_many_arguments)]
14841581
async fn collect_left_input(
14851582
random_state: RandomState,
14861583
left_stream: SendableRecordBatchStream,
@@ -1496,24 +1593,48 @@ async fn collect_left_input(
14961593
// This operation performs 2 steps at once:
14971594
// 1. creates a [JoinHashMap] of all batches from the stream
14981595
// 2. stores the batches in a vector.
1499-
let initial = (Vec::new(), 0, metrics, reservation);
1500-
let (batches, num_rows, metrics, mut reservation) = left_stream
1501-
.try_fold(initial, |mut acc, batch| async {
1596+
let initial = BuildSideState::try_new(
1597+
metrics,
1598+
reservation,
1599+
on_left.clone(),
1600+
&schema,
1601+
should_compute_bounds,
1602+
)?;
1603+
1604+
let state = left_stream
1605+
.try_fold(initial, |mut state, batch| async move {
1606+
// Update accumulators if computing bounds
1607+
if let Some(ref mut accumulators) = state.bounds_accumulators {
1608+
for accumulator in accumulators {
1609+
accumulator.update_batch(&batch)?;
1610+
}
1611+
}
1612+
1613+
// Decide if we spill or not
15021614
let batch_size = get_record_batch_memory_size(&batch);
15031615
// Reserve memory for incoming batch
1504-
acc.3.try_grow(batch_size)?;
1616+
state.reservation.try_grow(batch_size)?;
15051617
// Update metrics
1506-
acc.2.build_mem_used.add(batch_size);
1507-
acc.2.build_input_batches.add(1);
1508-
acc.2.build_input_rows.add(batch.num_rows());
1618+
state.metrics.build_mem_used.add(batch_size);
1619+
state.metrics.build_input_batches.add(1);
1620+
state.metrics.build_input_rows.add(batch.num_rows());
15091621
// Update row count
1510-
acc.1 += batch.num_rows();
1622+
state.num_rows += batch.num_rows();
15111623
// Push batch to output
1512-
acc.0.push(batch);
1513-
Ok(acc)
1624+
state.batches.push(batch);
1625+
Ok(state)
15141626
})
15151627
.await?;
15161628

1629+
// Extract fields from state
1630+
let BuildSideState {
1631+
batches,
1632+
num_rows,
1633+
metrics,
1634+
mut reservation,
1635+
bounds_accumulators,
1636+
} = state;
1637+
15171638
// Estimation of memory size, required for hashtable, prior to allocation.
15181639
// Final result can be verified using `RawTable.allocation_info()`
15191640
let fixed_size_u32 = size_of::<JoinHashMapU32>();
@@ -1580,10 +1701,15 @@ async fn collect_left_input(
15801701
.collect::<Result<Vec<_>>>()?;
15811702

15821703
// Compute bounds for dynamic filter if enabled
1583-
let bounds = if should_compute_bounds && num_rows > 0 {
1584-
Some(compute_bounds(&left_values)?)
1585-
} else {
1586-
None
1704+
let bounds = match bounds_accumulators {
1705+
Some(accumulators) if num_rows > 0 => {
1706+
let bounds = accumulators
1707+
.into_iter()
1708+
.map(CollectLeftAccumulator::evaluate)
1709+
.collect::<Result<Vec<_>>>()?;
1710+
Some(bounds)
1711+
}
1712+
_ => None,
15871713
};
15881714

15891715
let data = JoinLeftData::new(

0 commit comments

Comments
 (0)