Skip to content

Commit 97e4f7a

Browse files
Vedinrampage644
authored andcommitted
Fixes after sync to 50.0.0
1 parent 61d8483 commit 97e4f7a

18 files changed

Lines changed: 1773 additions & 57 deletions

File tree

datafusion/core/src/execution/session_state.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,9 @@ impl Session for SessionState {
292292
}
293293

294294
impl SessionState {
295-
pub(crate) fn resolve_table_ref(
295+
/// Resolve a [`TableReference`] into a [`ResolvedTableReference`] using
296+
/// the session's configured default catalog and schema.
297+
pub fn resolve_table_ref(
296298
&self,
297299
table_ref: impl Into<TableReference>,
298300
) -> ResolvedTableReference {

datafusion/core/src/physical_planner.rs

Lines changed: 230 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,12 @@ use datafusion_expr::expr_rewriter::unnormalize_cols;
8686
use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary;
8787
use datafusion_expr::utils::{expr_to_columns, split_conjunction};
8888
use datafusion_expr::{
89-
Analyze, BinaryExpr, DescribeTable, DmlStatement, Explain, ExplainFormat, Extension,
90-
FetchType, Filter, JoinType, Operator, RecursiveQuery, SkipType, StringifiedPlan,
91-
WindowFrame, WindowFrameBound, WriteOp,
89+
Analyze, BinaryExpr, Cast, DescribeTable, DmlStatement, Explain, ExplainFormat,
90+
Extension, FetchType, Filter, JoinType, LogicalPlanBuilder, Operator, RecursiveQuery,
91+
SkipType, StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp,
9292
};
9393
use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr};
94-
use datafusion_physical_expr::expressions::Literal;
94+
use datafusion_physical_expr::expressions::{Column as PhysicalColumn, Literal};
9595
use datafusion_physical_expr::{
9696
LexOrdering, PhysicalSortExpr, create_physical_sort_exprs,
9797
};
@@ -106,6 +106,7 @@ use datafusion_physical_plan::unnest::ListUnnest;
106106

107107
use async_trait::async_trait;
108108
use datafusion_physical_plan::async_func::{AsyncFuncExec, AsyncMapper};
109+
use datafusion_physical_plan::collect;
109110
use futures::{StreamExt, TryStreamExt};
110111
use itertools::{Itertools, multiunzip};
111112
use log::debug;
@@ -1256,6 +1257,56 @@ impl DefaultPhysicalPlanner {
12561257
)?)
12571258
}
12581259

1260+
LogicalPlan::Pivot(pivot) => {
1261+
return if !pivot.pivot_values.is_empty() {
1262+
let agg_plan = transform_pivot_to_aggregate(
1263+
Arc::new(pivot.input.as_ref().clone()),
1264+
&pivot.aggregate_expr,
1265+
&pivot.pivot_column,
1266+
pivot.pivot_values.clone(),
1267+
pivot.default_on_null_expr.as_ref(),
1268+
)?;
1269+
self.create_physical_plan(&agg_plan, session_state).await
1270+
} else if let Some(subquery) = &pivot.value_subquery {
1271+
let optimized_subquery = session_state.optimize(subquery.as_ref())?;
1272+
let subquery_physical_plan = self
1273+
.create_physical_plan(&optimized_subquery, session_state)
1274+
.await?;
1275+
let subquery_results = collect(
1276+
Arc::clone(&subquery_physical_plan),
1277+
session_state.task_ctx(),
1278+
)
1279+
.await?;
1280+
1281+
let mut pivot_values = Vec::new();
1282+
for batch in subquery_results.iter() {
1283+
if batch.num_columns() != 1 {
1284+
return plan_err!(
1285+
"Pivot subquery must return a single column"
1286+
);
1287+
}
1288+
let column = batch.column(0);
1289+
for row_idx in 0..batch.num_rows() {
1290+
if !column.is_null(row_idx) {
1291+
pivot_values
1292+
.push(ScalarValue::try_from_array(column, row_idx)?);
1293+
}
1294+
}
1295+
}
1296+
1297+
let agg_plan = transform_pivot_to_aggregate(
1298+
Arc::new(pivot.input.as_ref().clone()),
1299+
&pivot.aggregate_expr,
1300+
&pivot.pivot_column,
1301+
pivot_values,
1302+
pivot.default_on_null_expr.as_ref(),
1303+
)?;
1304+
self.create_physical_plan(&agg_plan, session_state).await
1305+
} else {
1306+
plan_err!("PIVOT operation requires at least one value to pivot on")
1307+
};
1308+
}
1309+
12591310
// 2 Children
12601311
LogicalPlan::Join(Join {
12611312
left: original_left,
@@ -2143,7 +2194,8 @@ fn extract_dml_filters(
21432194
| LogicalPlan::Ddl(_)
21442195
| LogicalPlan::Copy(_)
21452196
| LogicalPlan::Unnest(_)
2146-
| LogicalPlan::RecursiveQuery(_) => {
2197+
| LogicalPlan::RecursiveQuery(_)
2198+
| LogicalPlan::Pivot(_) => {
21472199
// No filters to extract from leaf/meta plans
21482200
}
21492201
// Plans with inputs (may contain filters in children)
@@ -2499,6 +2551,147 @@ pub fn create_aggregate_expr_and_maybe_filter(
24992551
)
25002552
}
25012553

2554+
/// Transform a PIVOT operation into a more standard Aggregate + Projection plan
2555+
/// For known pivot values, we create a projection that includes "IS NOT DISTINCT FROM" conditions
2556+
///
2557+
/// For example, for SUM(amount) PIVOT(quarter FOR quarter in ('2023_Q1', '2023_Q2')), we create:
2558+
/// - SUM(amount) FILTER (WHERE quarter IS NOT DISTINCT FROM '2023_Q1') AS "2023_Q1"
2559+
/// - SUM(amount) FILTER (WHERE quarter IS NOT DISTINCT FROM '2023_Q2') AS "2023_Q2"
2560+
///
2561+
/// If DEFAULT ON NULL is specified, each aggregate expression is wrapped with an outer projection that
2562+
/// applies COALESCE to the results.
2563+
pub fn transform_pivot_to_aggregate(
2564+
input: Arc<LogicalPlan>,
2565+
aggregate_expr: &Expr,
2566+
pivot_column: &Column,
2567+
pivot_values: Vec<ScalarValue>,
2568+
default_on_null_expr: Option<&Expr>,
2569+
) -> Result<LogicalPlan> {
2570+
let df_schema = input.schema();
2571+
2572+
let all_columns: Vec<Column> = df_schema.columns();
2573+
2574+
// Filter to include only columns we want for GROUP BY
2575+
// (exclude pivot column and aggregate expression columns)
2576+
let group_by_columns: Vec<Expr> = all_columns
2577+
.into_iter()
2578+
.filter(|col: &Column| {
2579+
col.name != pivot_column.name
2580+
&& !aggregate_expr
2581+
.column_refs()
2582+
.iter()
2583+
.any(|agg_col| agg_col.name == col.name)
2584+
})
2585+
.map(|col: Column| Expr::Column(col))
2586+
.collect();
2587+
2588+
let builder = LogicalPlanBuilder::from(Arc::unwrap_or_clone(Arc::clone(&input)));
2589+
2590+
// Create the aggregate plan with filtered aggregates
2591+
let mut aggregate_exprs = Vec::new();
2592+
2593+
let input_schema = input.schema();
2594+
let pivot_col_idx = match input_schema.index_of_column(pivot_column) {
2595+
Ok(idx) => idx,
2596+
Err(_) => {
2597+
return plan_err!(
2598+
"Pivot column '{}' does not exist in input schema",
2599+
pivot_column
2600+
);
2601+
}
2602+
};
2603+
let pivot_col_type = input_schema.field(pivot_col_idx).data_type();
2604+
2605+
for value in &pivot_values {
2606+
let filter_condition = Expr::BinaryExpr(BinaryExpr::new(
2607+
Box::new(Expr::Column(pivot_column.clone())),
2608+
Operator::IsNotDistinctFrom,
2609+
Box::new(Expr::Cast(Cast::new(
2610+
Box::new(Expr::Literal(value.clone(), None)),
2611+
pivot_col_type.clone(),
2612+
))),
2613+
));
2614+
2615+
let filtered_agg = match aggregate_expr {
2616+
Expr::AggregateFunction(agg) => {
2617+
let mut new_params = agg.params.clone();
2618+
new_params.filter = Some(Box::new(filter_condition));
2619+
Expr::AggregateFunction(AggregateFunction {
2620+
func: Arc::clone(&agg.func),
2621+
params: new_params,
2622+
})
2623+
}
2624+
_ => {
2625+
return plan_err!(
2626+
"Unsupported aggregate expression should always be AggregateFunction"
2627+
);
2628+
}
2629+
};
2630+
2631+
// Use the pivot value as the column name
2632+
let field_name = value.to_string().trim_matches('\'').to_string();
2633+
let aliased_agg = Expr::Alias(Alias {
2634+
expr: Box::new(filtered_agg),
2635+
relation: None,
2636+
name: field_name,
2637+
metadata: None,
2638+
});
2639+
2640+
aggregate_exprs.push(aliased_agg);
2641+
}
2642+
2643+
// Create the plan with the aggregate
2644+
let aggregate_plan = builder
2645+
.aggregate(group_by_columns, aggregate_exprs)?
2646+
.build()?;
2647+
2648+
// If DEFAULT ON NULL is specified, add a projection to apply COALESCE
2649+
if let Some(default_expr) = default_on_null_expr {
2650+
let schema = aggregate_plan.schema();
2651+
let mut projection_exprs = Vec::new();
2652+
2653+
for field in schema.fields() {
2654+
if !pivot_values
2655+
.iter()
2656+
.any(|v| field.name() == v.to_string().trim_matches('\''))
2657+
{
2658+
projection_exprs.push(Expr::Column(Column::from_name(field.name())));
2659+
}
2660+
}
2661+
2662+
// Apply COALESCE to aggregate columns
2663+
for value in &pivot_values {
2664+
let field_name = value.to_string().trim_matches('\'').to_string();
2665+
let aggregate_col = Expr::Column(Column::from_name(&field_name));
2666+
2667+
// Create COALESCE expression using CASE: CASE WHEN col IS NULL THEN default_value ELSE col END
2668+
let coalesce_expr = Expr::Case(datafusion_expr::expr::Case {
2669+
expr: None,
2670+
when_then_expr: vec![(
2671+
Box::new(Expr::IsNull(Box::new(aggregate_col.clone()))),
2672+
Box::new(default_expr.clone()),
2673+
)],
2674+
else_expr: Some(Box::new(aggregate_col)),
2675+
});
2676+
2677+
let aliased_coalesce = Expr::Alias(Alias {
2678+
expr: Box::new(coalesce_expr),
2679+
relation: None,
2680+
name: field_name,
2681+
metadata: None,
2682+
});
2683+
2684+
projection_exprs.push(aliased_coalesce);
2685+
}
2686+
2687+
// Apply the projection
2688+
LogicalPlanBuilder::from(aggregate_plan)
2689+
.project(projection_exprs)?
2690+
.build()
2691+
} else {
2692+
Ok(aggregate_plan)
2693+
}
2694+
}
25022695
impl DefaultPhysicalPlanner {
25032696
/// Handles capturing the various plans for EXPLAIN queries
25042697
///
@@ -2881,6 +3074,38 @@ impl DefaultPhysicalPlanner {
28813074
.collect::<Result<Vec<_>>>()?;
28823075

28833076
let num_input_columns = input_exec.schema().fields().len();
3077+
// When we detect a PIVOT-derived plan with a value_subquery, ensure all generated columns are preserved
3078+
if let LogicalPlan::Pivot(pivot) = input.as_ref() {
3079+
if pivot.value_subquery.is_some()
3080+
&& input_exec
3081+
.as_any()
3082+
.downcast_ref::<AggregateExec>()
3083+
.is_some()
3084+
{
3085+
let agg_exec =
3086+
input_exec.as_any().downcast_ref::<AggregateExec>().unwrap();
3087+
let schema = input_exec.schema();
3088+
let group_by_len = agg_exec.group_expr().expr().len();
3089+
3090+
if group_by_len < schema.fields().len() {
3091+
let mut all_exprs = physical_exprs.clone();
3092+
3093+
for (i, field) in
3094+
schema.fields().iter().enumerate().skip(group_by_len)
3095+
{
3096+
if !physical_exprs.iter().any(|(_, name)| name == field.name()) {
3097+
all_exprs.push((
3098+
Arc::new(PhysicalColumn::new(field.name(), i))
3099+
as Arc<dyn PhysicalExpr>,
3100+
field.name().clone(),
3101+
));
3102+
}
3103+
}
3104+
3105+
return Ok(Arc::new(ProjectionExec::try_new(all_exprs, input_exec)?));
3106+
}
3107+
}
3108+
}
28843109

28853110
match self.try_plan_async_exprs(
28863111
num_input_columns,

datafusion/expr/src/logical_plan/display.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,12 @@ impl<'a, 'b> PgJsonVisitor<'a, 'b> {
313313
"Is Distinct": is_distinct,
314314
})
315315
}
316+
LogicalPlan::Pivot(pivot) => json!({
317+
"Node Type": "Pivot",
318+
"Aggregate Expr": format!("{}", pivot.aggregate_expr),
319+
"Pivot Column": format!("{}", pivot.pivot_column),
320+
"Pivot Values": pivot.pivot_values.iter().map(|v| v.to_string()).collect::<Vec<_>>(),
321+
}),
316322
LogicalPlan::Values(Values { values, .. }) => {
317323
let str_values = values
318324
.iter()
@@ -706,8 +712,11 @@ impl<'n> TreeNodeVisitor<'n> for PgJsonVisitor<'_, '_> {
706712

707713
#[cfg(test)]
708714
mod tests {
709-
use arrow::datatypes::{DataType, Field};
715+
use crate::EmptyRelation;
716+
use arrow::datatypes::{DataType, Field, Schema};
717+
use datafusion_common::{Column, DFSchema, ScalarValue};
710718
use insta::assert_snapshot;
719+
use std::sync::Arc;
711720

712721
use super::*;
713722

datafusion/expr/src/logical_plan/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ pub use dml::{DmlStatement, WriteOp};
4040
pub use plan::{
4141
Aggregate, Analyze, ColumnUnnestList, DescribeTable, Distinct, DistinctOn,
4242
EmptyRelation, Explain, ExplainOption, Extension, FetchType, Filter, Join,
43-
JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Projection,
44-
RecursiveQuery, Repartition, SkipType, Sort, StringifiedPlan, Subquery,
43+
JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, Pivot, PlanType,
44+
Projection, RecursiveQuery, Repartition, SkipType, Sort, StringifiedPlan, Subquery,
4545
SubqueryAlias, TableScan, ToStringifiedPlan, Union, Unnest, Values, Window,
4646
projection_schema,
4747
};

0 commit comments

Comments
 (0)