Skip to content

Commit aee81f4

Browse files
committed
Polishing
1 parent 260dd47 commit aee81f4

10 files changed

Lines changed: 87 additions & 204 deletions

File tree

datafusion/core/src/execution/session_state.rs

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -562,28 +562,6 @@ impl SessionState {
562562

563563
/// Optimizes the logical plan by applying optimizer rules.
564564
pub fn optimize(&self, plan: &LogicalPlan) -> datafusion_common::Result<LogicalPlan> {
565-
// Special case for Pivot nodes with subqueries
566-
/*if let LogicalPlan::Pivot(p) = plan {
567-
if let Some(subquery) = &p.value_subquery {
568-
// Optimize the subquery first
569-
let optimized_subquery = self.optimizer.optimize(
570-
subquery.as_ref().clone(),
571-
self,
572-
|_, _| {},
573-
)?;
574-
575-
// Create a new Pivot with the optimized subquery
576-
return Ok(LogicalPlan::Pivot(datafusion_expr::Pivot {
577-
input: p.input.clone(),
578-
aggregate_expr: p.aggregate_expr.clone(),
579-
pivot_column: p.pivot_column.clone(),
580-
pivot_values: p.pivot_values.clone(),
581-
schema: p.schema.clone(),
582-
value_subquery: Some(Arc::new(optimized_subquery)),
583-
}));
584-
}
585-
}*/
586-
587565
if let LogicalPlan::Explain(e) = plan {
588566
let mut stringified_plans = e.stringified_plans.clone();
589567

@@ -673,9 +651,7 @@ impl SessionState {
673651
&self,
674652
logical_plan: &LogicalPlan,
675653
) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
676-
println!("logical_plan_beore_optPIVOT328: {:#?}", logical_plan);
677654
let logical_plan = self.optimize(logical_plan)?;
678-
println!("logical_plan_PIVOT329: {:#?}", logical_plan);
679655
self.query_planner
680656
.create_physical_plan(&logical_plan, self)
681657
.await

datafusion/core/src/physical_planner.rs

Lines changed: 78 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ use arrow::datatypes::{Schema, SchemaRef};
6464
use datafusion_common::display::ToStringifiedPlan;
6565
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor};
6666
use datafusion_common::{
67-
exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema, DFSchemaRef,
68-
ScalarValue, Column, TableReference,
67+
exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema,
68+
ScalarValue, Column,
6969
};
7070
use datafusion_datasource::memory::MemorySourceConfig;
7171
use datafusion_expr::dml::{CopyTo, InsertOp, DmlStatement, WriteOp};
@@ -78,7 +78,7 @@ use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessar
7878
use datafusion_expr::{
7979
Analyze, DescribeTable, DmlStatement,Explain, ExplainFormat, Extension, FetchType,
8080
Filter, JoinType, RecursiveQuery, SkipType, SortExpr, StringifiedPlan, WindowFrame,
81-
WindowFrameBound, WriteOp, SubqueryAlias,
81+
WindowFrameBound, WriteOp, SubqueryAlias, LogicalPlanBuilder, BinaryExpr
8282
};
8383
use datafusion_execution::FunctionRegistry;
8484
use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr};
@@ -97,8 +97,7 @@ use itertools::{multiunzip, Itertools};
9797
use log::{debug, trace};
9898
use sqlparser::ast::NullTreatment;
9999
use tokio::sync::Mutex;
100-
101-
use datafusion_sql::transform_pivot_to_aggregate;
100+
use datafusion_expr_common::operator::Operator;
102101

103102
use datafusion_physical_plan::collect;
104103

@@ -923,20 +922,17 @@ impl DefaultPhysicalPlanner {
923922
pivot.pivot_values.clone()
924923
};
925924

926-
if !pivot_values.is_empty() {
927-
// Transform Pivot into Aggregate plan with the resolved pivot values
925+
return if !pivot_values.is_empty() {
928926
let agg_plan = transform_pivot_to_aggregate(
929927
Arc::new(pivot.input.as_ref().clone()),
930928
&pivot.aggregate_expr,
931929
&pivot.pivot_column,
932-
Some(pivot_values),
933-
None,
930+
pivot_values,
934931
)?;
935932

936-
// The schema information is already preserved in the agg_plan
937-
return self.create_physical_plan(&agg_plan, session_state).await;
933+
self.create_physical_plan(&agg_plan, session_state).await
938934
} else {
939-
return plan_err!("PIVOT operation requires at least one value to pivot on");
935+
plan_err!("PIVOT operation requires at least one value to pivot on")
940936
}
941937
}
942938
// 2 Children
@@ -1734,6 +1730,76 @@ pub use datafusion_physical_expr::{
17341730
create_physical_sort_expr, create_physical_sort_exprs,
17351731
};
17361732

1733+
/// Transform a PIVOT operation into a more standard Aggregate + Projection plan
1734+
/// For known pivot values, we create a projection that includes "IS NOT DISTINCT FROM" conditions
1735+
///
1736+
/// For example, for SUM(amount) PIVOT(quarter FOR quarter in ('2023_Q1', '2023_Q2')), we create:
1737+
/// - SUM(amount) FILTER (WHERE quarter IS NOT DISTINCT FROM '2023_Q1') AS "2023_Q1"
1738+
/// - SUM(amount) FILTER (WHERE quarter IS NOT DISTINCT FROM '2023_Q2') AS "2023_Q2"
1739+
///
1740+
pub fn transform_pivot_to_aggregate(
1741+
input: Arc<LogicalPlan>,
1742+
aggregate_expr: &Expr,
1743+
pivot_column: &Column,
1744+
pivot_values: Vec<ScalarValue>,
1745+
) -> Result<LogicalPlan> {
1746+
let df_schema = input.schema();
1747+
1748+
let all_columns: Vec<Column> = df_schema.columns();
1749+
1750+
// Filter to include only columns we want for GROUP BY
1751+
// (exclude pivot column and aggregate expression columns)
1752+
let group_by_columns: Vec<Expr> = all_columns
1753+
.into_iter()
1754+
.filter(|col| {
1755+
col.name != pivot_column.name
1756+
&& !aggregate_expr.column_refs().iter().any(|agg_col| agg_col.name == col.name)
1757+
})
1758+
.map(|col| Expr::Column(col))
1759+
.collect();
1760+
1761+
let builder = LogicalPlanBuilder::from(Arc::unwrap_or_clone(input.clone()));
1762+
1763+
let mut aggregate_exprs = Vec::new();
1764+
1765+
for value in &pivot_values {
1766+
let filter_condition = Expr::BinaryExpr(BinaryExpr::new(
1767+
Box::new(Expr::Column(pivot_column.clone())),
1768+
Operator::IsNotDistinctFrom,
1769+
Box::new(Expr::Literal(value.clone()))
1770+
));
1771+
1772+
let filtered_agg = match aggregate_expr {
1773+
Expr::AggregateFunction(agg) => {
1774+
let mut new_params = agg.params.clone();
1775+
new_params.filter = Some(Box::new(filter_condition));
1776+
Expr::AggregateFunction(AggregateFunction {
1777+
func: agg.func.clone(),
1778+
params: new_params,
1779+
})
1780+
},
1781+
_ => {
1782+
return plan_err!("Unsupported aggregate expression should always be AggregateFunction");
1783+
}
1784+
};
1785+
1786+
// Use the pivot value as the column name
1787+
let field_name = value.to_string().trim_matches('\'').to_string();
1788+
let aliased_agg = Expr::Alias(Alias {
1789+
expr: Box::new(filtered_agg),
1790+
relation: None,
1791+
name: field_name,
1792+
metadata: None,
1793+
});
1794+
1795+
aggregate_exprs.push(aliased_agg);
1796+
}
1797+
1798+
let aggregate_plan = builder.aggregate(group_by_columns, aggregate_exprs)?.build()?;
1799+
1800+
Ok(aggregate_plan)
1801+
}
1802+
17371803
impl DefaultPhysicalPlanner {
17381804
/// Handles capturing the various plans for EXPLAIN queries
17391805
///

datafusion/expr/src/logical_plan/display.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ use crate::{
2424
expr_vec_fmt, Aggregate, DescribeTable, Distinct, DistinctOn, DmlStatement, Expr,
2525
Filter, Join, Limit, LogicalPlan, Partitioning, Projection, RecursiveQuery,
2626
Repartition, Sort, Subquery, SubqueryAlias, TableProviderFilterPushDown, TableScan,
27-
Unnest, Values, Window, Pivot,
27+
Unnest, Values, Window,
2828
};
2929

3030
use crate::dml::CopyTo;

datafusion/expr/src/logical_plan/plan.rs

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
use std::cmp::Ordering;
2121
use std::collections::{HashMap, HashSet};
2222
use std::fmt::{self, Debug, Display, Formatter};
23-
use std::fs::metadata;
2423
use std::hash::{Hash, Hasher};
2524
use std::str::FromStr;
2625
use std::sync::{Arc, LazyLock};
@@ -1179,10 +1178,10 @@ impl LogicalPlan {
11791178
Ok(new_plan)
11801179
}
11811180
LogicalPlan::Pivot(Pivot {
1182-
aggregate_expr,
1181+
aggregate_expr: _,
11831182
pivot_column,
11841183
pivot_values,
1185-
schema,
1184+
schema: _,
11861185
value_subquery,
11871186
..
11881187
}) => {
@@ -2345,7 +2344,6 @@ impl Pivot {
23452344
pivot_column: Column,
23462345
value_subquery: Arc<LogicalPlan>,
23472346
) -> Result<Self> {
2348-
// Create an empty schema - will be filled in when the subquery is executed
23492347
let schema = pivot_schema_without_values(
23502348
input.schema(),
23512349
&aggregate_expr,
@@ -2356,15 +2354,13 @@ impl Pivot {
23562354
input,
23572355
aggregate_expr,
23582356
pivot_column,
2359-
pivot_values: Vec::new(), // Will be populated during physical planning
2357+
pivot_values: Vec::new(),
23602358
schema: Arc::new(schema),
23612359
value_subquery: Some(value_subquery),
23622360
})
23632361
}
23642362
}
23652363

2366-
/// Create a pivot schema without knowing the pivot values
2367-
/// This is used when we have a subquery for pivot values
23682364
fn pivot_schema_without_values(
23692365
input_schema: &DFSchemaRef,
23702366
aggregate_expr: &Expr,
@@ -2387,7 +2383,6 @@ fn pivot_schema_without_values(
23872383
DFSchema::new_with_metadata(fields_with_table_ref, input_schema.metadata().clone())
23882384
}
23892385

2390-
/// Create a pivot schema with known pivot values
23912386
fn pivot_schema(
23922387
input_schema: &DFSchemaRef,
23932388
aggregate_expr: &Expr,
@@ -2396,14 +2391,12 @@ fn pivot_schema(
23962391
) -> Result<DFSchema> {
23972392
let mut fields = vec![];
23982393

2399-
// Include all fields except pivot and value columns
24002394
for field in input_schema.fields() {
24012395
if !aggregate_expr.column_refs().iter().any(|col| col.name() == field.name()) && field.name() != pivot_column.name() {
24022396
fields.push(field.clone());
24032397
}
24042398
}
24052399

2406-
// Add new fields for each pivot value
24072400
for pivot_value in pivot_values {
24082401
let field_name = format!("{}", pivot_value);
24092402
let data_type = aggregate_expr.get_type(input_schema)?;
@@ -5558,4 +5551,3 @@ digraph {
55585551
Ok(())
55595552
}
55605553
}
5561-

datafusion/optimizer/src/optimize_projections/mod.rs

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -382,23 +382,8 @@ fn optimize_projections(
382382
dependency_indices.clone(),
383383
)]
384384
}
385-
LogicalPlan::Pivot(pivot) => {
386-
return Ok(Transformed::no(plan)); // TODO: Implement this
387-
/*
388-
// For PIVOT operations, we need columns from the aggregate expression and the pivot column
389-
let mut pivot_indices = Vec::new();
390-
391-
// Add required indices for the pivot column
392-
if let Ok(idx) = pivot.input.schema().index_of_column(&pivot.pivot_column) {
393-
pivot_indices.push(idx);
394-
}
395-
396-
// Create RequiredIndices with these indices and add dependency from aggregate expression
397-
let required = RequiredIndices::new_from_indices(pivot_indices)
398-
.with_exprs(pivot.input.schema(), std::iter::once(&pivot.aggregate_expr))
399-
.with_projection_beneficial();
400-
401-
vec![required]*/
385+
LogicalPlan::Pivot(_) => {
386+
return Ok(Transformed::no(plan));
402387
},
403388

404389
};

datafusion/proto/src/logical_plan/mod.rs

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1838,19 +1838,6 @@ impl AsLogicalPlan for LogicalPlanNode {
18381838
})
18391839
}
18401840
LogicalPlan::Pivot(_) => {
1841-
/*let input =
1842-
LogicalPlanNode::try_from_logical_plan(pivot.input.as_ref(), extension_codec)?;
1843-
Ok(LogicalPlanNode {
1844-
logical_plan_type: Some(LogicalPlanType::Pivot(Box::new(
1845-
protobuf::PivotNode {
1846-
input: Some(Box::new(input)),
1847-
aggregate_expr: pivot.aggregate_expr.clone(),
1848-
pivot_column: pivot.pivot_column.clone(),
1849-
pivot_values: pivot.pivot_values.clone(),
1850-
schema: convert_required!(*pivot.schema)?,
1851-
},
1852-
))),
1853-
})*/
18541841
Err(proto_error(
18551842
"LogicalPlan serde is not yet implemented for Statement",
18561843
))

datafusion/sql/src/lib.rs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,3 @@ mod values;
6262
)]
6363
pub use datafusion_common::{ResolvedTableReference, TableReference};
6464
pub use sqlparser;
65-
66-
// Re-export the transform_pivot_to_aggregate function
67-
pub use relation::transform_pivot_to_aggregate;

0 commit comments

Comments
 (0)