Skip to content

Commit 7b01907

Browse files
committed
Add default on null. Cargo fmt
1 parent 8187ae2 commit 7b01907

11 files changed

Lines changed: 307 additions & 236 deletions

File tree

datafusion/core/src/physical_planner.rs

Lines changed: 115 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,9 @@ use datafusion_expr::expr::{
7676
use datafusion_expr::expr_rewriter::unnormalize_cols;
7777
use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary;
7878
use datafusion_expr::{
79-
Analyze, DescribeTable, DmlStatement, Explain, ExplainFormat, Extension, FetchType,
80-
Filter, JoinType, RecursiveQuery, SkipType, StringifiedPlan, WindowFrame,
81-
WindowFrameBound, WriteOp, LogicalPlanBuilder, BinaryExpr
79+
Analyze, BinaryExpr, DescribeTable, DmlStatement, Explain, ExplainFormat, Extension,
80+
FetchType, Filter, JoinType, LogicalPlanBuilder, RecursiveQuery, SkipType,
81+
StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp,
8282
};
8383
use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr};
8484
use datafusion_physical_expr::expressions::{Column, Literal};
@@ -91,12 +91,12 @@ use datafusion_physical_plan::unnest::ListUnnest;
9191
use crate::schema_equivalence::schema_satisfied_by;
9292
use async_trait::async_trait;
9393
use datafusion_datasource::file_groups::FileGroup;
94+
use datafusion_expr_common::operator::Operator;
9495
use futures::{StreamExt, TryStreamExt};
9596
use itertools::{multiunzip, Itertools};
9697
use log::{debug, trace};
9798
use sqlparser::ast::NullTreatment;
9899
use tokio::sync::Mutex;
99-
use datafusion_expr_common::operator::Operator;
100100

101101
use datafusion_physical_plan::collect;
102102

@@ -891,42 +891,50 @@ impl DefaultPhysicalPlanner {
891891
))
892892
}
893893
LogicalPlan::Pivot(pivot) => {
894-
let pivot_values = if let Some(subquery) = &pivot.value_subquery {
894+
return if !pivot.pivot_values.is_empty() {
895+
let agg_plan = transform_pivot_to_aggregate(
896+
Arc::new(pivot.input.as_ref().clone()),
897+
&pivot.aggregate_expr,
898+
&pivot.pivot_column,
899+
pivot.pivot_values.clone(),
900+
pivot.default_on_null_expr.as_ref(),
901+
)?;
902+
903+
self.create_physical_plan(&agg_plan, session_state).await
904+
} else if let Some(subquery) = &pivot.value_subquery {
895905
let optimized_subquery = session_state.optimize(subquery.as_ref())?;
896906

897-
let subquery_physical_plan = self.create_physical_plan(
898-
&optimized_subquery,
899-
session_state
900-
).await?;
907+
let subquery_physical_plan = self
908+
.create_physical_plan(&optimized_subquery, session_state)
909+
.await?;
901910

902-
let subquery_results = collect(subquery_physical_plan.clone(), session_state.task_ctx()).await?;
911+
let subquery_results =
912+
collect(subquery_physical_plan.clone(), session_state.task_ctx())
913+
.await?;
903914

904915
let mut pivot_values = Vec::new();
905916
for batch in subquery_results.iter() {
906917
if batch.num_columns() != 1 {
907-
return plan_err!("Pivot subquery must return a single column");
918+
return plan_err!(
919+
"Pivot subquery must return a single column"
920+
);
908921
}
909922

910923
let column = batch.column(0);
911924
for row_idx in 0..batch.num_rows() {
912925
if !column.is_null(row_idx) {
913-
pivot_values.push(
914-
ScalarValue::try_from_array(column, row_idx)?
915-
);
926+
pivot_values
927+
.push(ScalarValue::try_from_array(column, row_idx)?);
916928
}
917929
}
918930
}
919-
pivot_values
920-
} else {
921-
pivot.pivot_values.clone()
922-
};
923931

924-
return if !pivot_values.is_empty() {
925932
let agg_plan = transform_pivot_to_aggregate(
926933
Arc::new(pivot.input.as_ref().clone()),
927934
&pivot.aggregate_expr,
928935
&pivot.pivot_column,
929936
pivot_values,
937+
pivot.default_on_null_expr.as_ref(),
930938
)?;
931939

932940
self.create_physical_plan(&agg_plan, session_state).await
@@ -1736,11 +1744,14 @@ pub use datafusion_physical_expr::{
17361744
/// - SUM(amount) FILTER (WHERE quarter IS NOT DISTINCT FROM '2023_Q1') AS "2023_Q1"
17371745
/// - SUM(amount) FILTER (WHERE quarter IS NOT DISTINCT FROM '2023_Q2') AS "2023_Q2"
17381746
///
1747+
/// If DEFAULT ON NULL is specified, each aggregate expression is wrapped with an outer projection that
1748+
/// applies COALESCE to the results.
17391749
pub fn transform_pivot_to_aggregate(
17401750
input: Arc<LogicalPlan>,
17411751
aggregate_expr: &Expr,
17421752
pivot_column: &datafusion_common::Column,
17431753
pivot_values: Vec<ScalarValue>,
1754+
default_on_null_expr: Option<&Expr>,
17441755
) -> Result<LogicalPlan> {
17451756
let df_schema = input.schema();
17461757

@@ -1750,22 +1761,26 @@ pub fn transform_pivot_to_aggregate(
17501761
// (exclude pivot column and aggregate expression columns)
17511762
let group_by_columns: Vec<Expr> = all_columns
17521763
.into_iter()
1753-
.filter(|col: &datafusion_common::Column | {
1764+
.filter(|col: &datafusion_common::Column| {
17541765
col.name != pivot_column.name
1755-
&& !aggregate_expr.column_refs().iter().any(|agg_col| agg_col.name == col.name)
1766+
&& !aggregate_expr
1767+
.column_refs()
1768+
.iter()
1769+
.any(|agg_col| agg_col.name == col.name)
17561770
})
1757-
.map(|col: datafusion_common::Column | Expr::Column(col))
1771+
.map(|col: datafusion_common::Column| Expr::Column(col))
17581772
.collect();
17591773

17601774
let builder = LogicalPlanBuilder::from(Arc::unwrap_or_clone(input.clone()));
17611775

1776+
// Create the aggregate plan with filtered aggregates
17621777
let mut aggregate_exprs = Vec::new();
17631778

17641779
for value in &pivot_values {
17651780
let filter_condition = Expr::BinaryExpr(BinaryExpr::new(
17661781
Box::new(Expr::Column(pivot_column.clone())),
17671782
Operator::IsNotDistinctFrom,
1768-
Box::new(Expr::Literal(value.clone()))
1783+
Box::new(Expr::Literal(value.clone())),
17691784
));
17701785

17711786
let filtered_agg = match aggregate_expr {
@@ -1776,9 +1791,11 @@ pub fn transform_pivot_to_aggregate(
17761791
func: agg.func.clone(),
17771792
params: new_params,
17781793
})
1779-
},
1794+
}
17801795
_ => {
1781-
return plan_err!("Unsupported aggregate expression should always be AggregateFunction");
1796+
return plan_err!(
1797+
"Unsupported aggregate expression should always be AggregateFunction"
1798+
);
17821799
}
17831800
};
17841801

@@ -1794,9 +1811,60 @@ pub fn transform_pivot_to_aggregate(
17941811
aggregate_exprs.push(aliased_agg);
17951812
}
17961813

1797-
let aggregate_plan = builder.aggregate(group_by_columns, aggregate_exprs)?.build()?;
1814+
// Create the plan with the aggregate
1815+
let aggregate_plan = builder
1816+
.aggregate(group_by_columns, aggregate_exprs)?
1817+
.build()?;
1818+
1819+
// If DEFAULT ON NULL is specified, add a projection to apply COALESCE
1820+
if let Some(default_expr) = default_on_null_expr {
1821+
let schema = aggregate_plan.schema();
1822+
let mut projection_exprs = Vec::new();
17981823

1799-
Ok(aggregate_plan)
1824+
for field in schema.fields() {
1825+
if !pivot_values
1826+
.iter()
1827+
.any(|v| field.name() == v.to_string().trim_matches('\''))
1828+
{
1829+
projection_exprs.push(Expr::Column(
1830+
datafusion_common::Column::from_name(field.name()),
1831+
));
1832+
}
1833+
}
1834+
1835+
// Apply COALESCE to aggregate columns
1836+
for value in &pivot_values {
1837+
let field_name = value.to_string().trim_matches('\'').to_string();
1838+
let aggregate_col =
1839+
Expr::Column(datafusion_common::Column::from_name(&field_name));
1840+
1841+
// Create COALESCE expression using CASE: CASE WHEN col IS NULL THEN default_value ELSE col END
1842+
let coalesce_expr = Expr::Case(datafusion_expr::expr::Case {
1843+
expr: None,
1844+
when_then_expr: vec![(
1845+
Box::new(Expr::IsNull(Box::new(aggregate_col.clone()))),
1846+
Box::new(default_expr.clone()),
1847+
)],
1848+
else_expr: Some(Box::new(aggregate_col)),
1849+
});
1850+
1851+
let aliased_coalesce = Expr::Alias(Alias {
1852+
expr: Box::new(coalesce_expr),
1853+
relation: None,
1854+
name: field_name,
1855+
metadata: None,
1856+
});
1857+
1858+
projection_exprs.push(aliased_coalesce);
1859+
}
1860+
1861+
// Apply the projection
1862+
LogicalPlanBuilder::from(aggregate_plan)
1863+
.project(projection_exprs)?
1864+
.build()
1865+
} else {
1866+
Ok(aggregate_plan)
1867+
}
18001868
}
18011869

18021870
impl DefaultPhysicalPlanner {
@@ -2163,31 +2231,42 @@ impl DefaultPhysicalPlanner {
21632231
// When we detect a PIVOT-derived plan with a value_subquery, ensure all generated columns are preserved
21642232
match input.as_ref() {
21652233
LogicalPlan::Pivot(pivot) => {
2166-
if pivot.value_subquery.is_some() && input_exec.as_any().downcast_ref::<AggregateExec>().is_some() {
2167-
let agg_exec = input_exec.as_any().downcast_ref::<AggregateExec>().unwrap();
2234+
if pivot.value_subquery.is_some()
2235+
&& input_exec
2236+
.as_any()
2237+
.downcast_ref::<AggregateExec>()
2238+
.is_some()
2239+
{
2240+
let agg_exec =
2241+
input_exec.as_any().downcast_ref::<AggregateExec>().unwrap();
21682242
let schema = input_exec.schema();
21692243
let group_by_len = agg_exec.group_expr().expr().len();
21702244

21712245
if group_by_len < schema.fields().len() {
21722246
let mut all_exprs = physical_exprs.clone();
21732247

2174-
for (i, field) in schema.fields().iter().enumerate().skip(group_by_len) {
2175-
if !physical_exprs.iter().any(|(_, name)| name == field.name()) {
2248+
for (i, field) in
2249+
schema.fields().iter().enumerate().skip(group_by_len)
2250+
{
2251+
if !physical_exprs
2252+
.iter()
2253+
.any(|(_, name)| name == field.name())
2254+
{
21762255
all_exprs.push((
2177-
Arc::new(Column::new(field.name(), i)) as Arc<dyn PhysicalExpr>,
2256+
Arc::new(Column::new(field.name(), i))
2257+
as Arc<dyn PhysicalExpr>,
21782258
field.name().clone(),
21792259
));
21802260
}
21812261
}
21822262

21832263
return Ok(Arc::new(ProjectionExec::try_new(
2184-
all_exprs,
2185-
input_exec,
2264+
all_exprs, input_exec,
21862265
)?));
21872266
}
21882267
}
2189-
},
2190-
_ => {}
2268+
}
2269+
_ => {}
21912270
}
21922271

21932272
Ok(Arc::new(ProjectionExec::try_new(

datafusion/expr/src/logical_plan/builder.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ use crate::expr_rewriter::{
3232
};
3333
use crate::logical_plan::{
3434
Aggregate, Analyze, Distinct, DistinctOn, EmptyRelation, Explain, Filter, Join,
35-
JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare,
35+
JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, Pivot, PlanType, Prepare,
3636
Projection, Repartition, Sort, SubqueryAlias, TableScan, Union, Unnest, Values,
37-
Window, Pivot,
37+
Window,
3838
};
3939
use crate::select_expr::SelectExpr;
4040
use crate::utils::{
@@ -1433,12 +1433,14 @@ impl LogicalPlanBuilder {
14331433
aggregate_expr: Expr,
14341434
pivot_column: Column,
14351435
pivot_values: Vec<ScalarValue>,
1436+
default_on_null: Option<Expr>,
14361437
) -> Result<Self> {
14371438
let pivot_plan = Pivot::try_new(
14381439
self.plan,
14391440
aggregate_expr,
14401441
pivot_column,
14411442
pivot_values,
1443+
default_on_null,
14421444
)?;
14431445
Ok(Self::new(LogicalPlan::Pivot(pivot_plan)))
14441446
}
@@ -2847,7 +2849,7 @@ mod tests {
28472849
Field::new("product", DataType::Utf8, false),
28482850
Field::new("sales", DataType::Int32, false),
28492851
]);
2850-
2852+
28512853
let plan = LogicalPlanBuilder::scan("sales", table_source(&schema), None)?
28522854
.pivot(
28532855
col("sales"),
@@ -2856,6 +2858,7 @@ mod tests {
28562858
ScalarValue::Utf8(Some("widget".to_string())),
28572859
ScalarValue::Utf8(Some("gadget".to_string())),
28582860
],
2861+
None,
28592862
)?
28602863
.build()?;
28612864

0 commit comments

Comments
 (0)