@@ -64,8 +64,8 @@ use arrow::datatypes::{Schema, SchemaRef};
6464use datafusion_common:: display:: ToStringifiedPlan ;
6565use datafusion_common:: tree_node:: { TreeNode , TreeNodeRecursion , TreeNodeVisitor } ;
6666use datafusion_common:: {
67- exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema , DFSchemaRef ,
68- ScalarValue , Column , TableReference ,
67+ exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema ,
68+ ScalarValue , Column ,
6969} ;
7070use datafusion_datasource:: memory:: MemorySourceConfig ;
7171use datafusion_expr:: dml:: { CopyTo , InsertOp , DmlStatement , WriteOp } ;
@@ -78,7 +78,7 @@ use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessar
7878use datafusion_expr:: {
7979 Analyze , DescribeTable , DmlStatement , Explain , ExplainFormat , Extension , FetchType ,
8080 Filter , JoinType , RecursiveQuery , SkipType , SortExpr , StringifiedPlan , WindowFrame ,
81- WindowFrameBound , WriteOp , SubqueryAlias ,
81+ WindowFrameBound , WriteOp , SubqueryAlias , LogicalPlanBuilder , BinaryExpr
8282} ;
8383use datafusion_execution:: FunctionRegistry ;
8484use datafusion_physical_expr:: aggregate:: { AggregateExprBuilder , AggregateFunctionExpr } ;
@@ -97,8 +97,7 @@ use itertools::{multiunzip, Itertools};
9797use log:: { debug, trace} ;
9898use sqlparser:: ast:: NullTreatment ;
9999use tokio:: sync:: Mutex ;
100-
101- use datafusion_sql:: transform_pivot_to_aggregate;
100+ use datafusion_expr_common:: operator:: Operator ;
102101
103102use datafusion_physical_plan:: collect;
104103
@@ -923,20 +922,17 @@ impl DefaultPhysicalPlanner {
923922 pivot. pivot_values . clone ( )
924923 } ;
925924
926- if !pivot_values. is_empty ( ) {
927- // Transform Pivot into Aggregate plan with the resolved pivot values
925+ return if !pivot_values. is_empty ( ) {
928926 let agg_plan = transform_pivot_to_aggregate (
929927 Arc :: new ( pivot. input . as_ref ( ) . clone ( ) ) ,
930928 & pivot. aggregate_expr ,
931929 & pivot. pivot_column ,
932- Some ( pivot_values) ,
933- None ,
930+ pivot_values,
934931 ) ?;
935932
936- // The schema information is already preserved in the agg_plan
937- return self . create_physical_plan ( & agg_plan, session_state) . await ;
933+ self . create_physical_plan ( & agg_plan, session_state) . await
938934 } else {
939- return plan_err ! ( "PIVOT operation requires at least one value to pivot on" ) ;
935+ plan_err ! ( "PIVOT operation requires at least one value to pivot on" )
940936 }
941937 }
942938 // 2 Children
@@ -1734,6 +1730,76 @@ pub use datafusion_physical_expr::{
17341730 create_physical_sort_expr, create_physical_sort_exprs,
17351731} ;
17361732
1733+ /// Transform a PIVOT operation into a more standard Aggregate + Projection plan
1734+ /// For known pivot values, we create a projection that includes "IS NOT DISTINCT FROM" conditions
1735+ ///
1736+ /// For example, for SUM(amount) PIVOT(quarter FOR quarter in ('2023_Q1', '2023_Q2')), we create:
1737+ /// - SUM(amount) FILTER (WHERE quarter IS NOT DISTINCT FROM '2023_Q1') AS "2023_Q1"
1738+ /// - SUM(amount) FILTER (WHERE quarter IS NOT DISTINCT FROM '2023_Q2') AS "2023_Q2"
1739+ ///
1740+ pub fn transform_pivot_to_aggregate (
1741+ input : Arc < LogicalPlan > ,
1742+ aggregate_expr : & Expr ,
1743+ pivot_column : & Column ,
1744+ pivot_values : Vec < ScalarValue > ,
1745+ ) -> Result < LogicalPlan > {
1746+ let df_schema = input. schema ( ) ;
1747+
1748+ let all_columns: Vec < Column > = df_schema. columns ( ) ;
1749+
1750+ // Filter to include only columns we want for GROUP BY
1751+ // (exclude pivot column and aggregate expression columns)
1752+ let group_by_columns: Vec < Expr > = all_columns
1753+ . into_iter ( )
1754+ . filter ( |col| {
1755+ col. name != pivot_column. name
1756+ && !aggregate_expr. column_refs ( ) . iter ( ) . any ( |agg_col| agg_col. name == col. name )
1757+ } )
1758+ . map ( |col| Expr :: Column ( col) )
1759+ . collect ( ) ;
1760+
1761+ let builder = LogicalPlanBuilder :: from ( Arc :: unwrap_or_clone ( input. clone ( ) ) ) ;
1762+
1763+ let mut aggregate_exprs = Vec :: new ( ) ;
1764+
1765+ for value in & pivot_values {
1766+ let filter_condition = Expr :: BinaryExpr ( BinaryExpr :: new (
1767+ Box :: new ( Expr :: Column ( pivot_column. clone ( ) ) ) ,
1768+ Operator :: IsNotDistinctFrom ,
1769+ Box :: new ( Expr :: Literal ( value. clone ( ) ) )
1770+ ) ) ;
1771+
1772+ let filtered_agg = match aggregate_expr {
1773+ Expr :: AggregateFunction ( agg) => {
1774+ let mut new_params = agg. params . clone ( ) ;
1775+ new_params. filter = Some ( Box :: new ( filter_condition) ) ;
1776+ Expr :: AggregateFunction ( AggregateFunction {
1777+ func : agg. func . clone ( ) ,
1778+ params : new_params,
1779+ } )
1780+ } ,
1781+ _ => {
1782+ return plan_err ! ( "Unsupported aggregate expression should always be AggregateFunction" ) ;
1783+ }
1784+ } ;
1785+
1786+ // Use the pivot value as the column name
1787+ let field_name = value. to_string ( ) . trim_matches ( '\'' ) . to_string ( ) ;
1788+ let aliased_agg = Expr :: Alias ( Alias {
1789+ expr : Box :: new ( filtered_agg) ,
1790+ relation : None ,
1791+ name : field_name,
1792+ metadata : None ,
1793+ } ) ;
1794+
1795+ aggregate_exprs. push ( aliased_agg) ;
1796+ }
1797+
1798+ let aggregate_plan = builder. aggregate ( group_by_columns, aggregate_exprs) ?. build ( ) ?;
1799+
1800+ Ok ( aggregate_plan)
1801+ }
1802+
17371803impl DefaultPhysicalPlanner {
17381804 /// Handles capturing the various plans for EXPLAIN queries
17391805 ///
0 commit comments