Skip to content

Commit 322e5f6

Browse files
Dandandanclaude
andcommitted
fix: propagate column statistics through CAST in join key expressions
When join keys contain CAST expressions (e.g. CAST(id AS Float64)), the cardinality estimator could not extract column statistics because it only handled plain Column references. This caused unknown stats, leading to poor join ordering (e.g. putting a 1.4M-row fact table on the hash join build side instead of a 5-row dimension table). Extract the underlying column index through numeric CAST expressions, since casting can only reduce (never increase) distinct count, making the source column's stats a valid upper bound. TPC-DS Q99: 10.4s → ~60ms. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 5ba06ac commit 322e5f6

File tree

1 file changed

+28
-7
lines changed
  • datafusion/physical-plan/src/joins

1 file changed

+28
-7
lines changed

datafusion/physical-plan/src/joins/utils.rs

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,30 @@ pub(crate) fn estimate_join_statistics(
460460
})
461461
}
462462

463+
/// Extract the column index from a join key expression for statistics lookup.
464+
/// Handles plain `Column` references and `CAST(column AS numeric_type)`
465+
/// expressions. Casting can only merge values (many-to-one), never split
466+
/// them, so the source column's distinct count is always a valid upper
467+
/// bound for the cast result's distinct count.
468+
fn column_index_for_stats(expr: &Arc<dyn PhysicalExpr>) -> Option<usize> {
469+
use arrow::datatypes::DataType;
470+
471+
if let Some(col) = expr.as_any().downcast_ref::<Column>() {
472+
return Some(col.index());
473+
}
474+
if let Some(cast) = expr
475+
.as_any()
476+
.downcast_ref::<datafusion_physical_expr::expressions::CastExpr>()
477+
&& let Some(col) = cast.expr.as_any().downcast_ref::<Column>()
478+
{
479+
let target = cast.cast_type();
480+
if target.is_numeric() {
481+
return Some(col.index());
482+
}
483+
}
484+
None
485+
}
486+
463487
// Estimate the cardinality for the given join with input statistics.
464488
fn estimate_join_cardinality(
465489
join_type: &JoinType,
@@ -470,13 +494,10 @@ fn estimate_join_cardinality(
470494
let (left_col_stats, right_col_stats) = on
471495
.iter()
472496
.map(|(left, right)| {
473-
match (
474-
left.as_any().downcast_ref::<Column>(),
475-
right.as_any().downcast_ref::<Column>(),
476-
) {
477-
(Some(left), Some(right)) => (
478-
left_stats.column_statistics[left.index()].clone(),
479-
right_stats.column_statistics[right.index()].clone(),
497+
match (column_index_for_stats(left), column_index_for_stats(right)) {
498+
(Some(left_idx), Some(right_idx)) => (
499+
left_stats.column_statistics[left_idx].clone(),
500+
right_stats.column_statistics[right_idx].clone(),
480501
),
481502
_ => (
482503
ColumnStatistics::new_unknown(),

0 commit comments

Comments
 (0)