@@ -76,9 +76,9 @@ use datafusion_expr::expr::{
7676use datafusion_expr:: expr_rewriter:: unnormalize_cols;
7777use datafusion_expr:: logical_plan:: builder:: wrap_projection_for_join_if_necessary;
7878use datafusion_expr:: {
79- Analyze , DescribeTable , DmlStatement , Explain , ExplainFormat , Extension , FetchType ,
80- Filter , JoinType , RecursiveQuery , SkipType , StringifiedPlan , WindowFrame ,
81- WindowFrameBound , WriteOp , LogicalPlanBuilder , BinaryExpr
79+ Analyze , BinaryExpr , DescribeTable , DmlStatement , Explain , ExplainFormat , Extension ,
80+ FetchType , Filter , JoinType , LogicalPlanBuilder , RecursiveQuery , SkipType ,
81+ StringifiedPlan , WindowFrame , WindowFrameBound , WriteOp ,
8282} ;
8383use datafusion_physical_expr:: aggregate:: { AggregateExprBuilder , AggregateFunctionExpr } ;
8484use datafusion_physical_expr:: expressions:: { Column , Literal } ;
@@ -91,12 +91,12 @@ use datafusion_physical_plan::unnest::ListUnnest;
9191use crate :: schema_equivalence:: schema_satisfied_by;
9292use async_trait:: async_trait;
9393use datafusion_datasource:: file_groups:: FileGroup ;
94+ use datafusion_expr_common:: operator:: Operator ;
9495use futures:: { StreamExt , TryStreamExt } ;
9596use itertools:: { multiunzip, Itertools } ;
9697use log:: { debug, trace} ;
9798use sqlparser:: ast:: NullTreatment ;
9899use tokio:: sync:: Mutex ;
99- use datafusion_expr_common:: operator:: Operator ;
100100
101101use datafusion_physical_plan:: collect;
102102
@@ -891,42 +891,50 @@ impl DefaultPhysicalPlanner {
891891 ) )
892892 }
893893 LogicalPlan :: Pivot ( pivot) => {
894- let pivot_values = if let Some ( subquery) = & pivot. value_subquery {
894+ return if !pivot. pivot_values . is_empty ( ) {
895+ let agg_plan = transform_pivot_to_aggregate (
896+ Arc :: new ( pivot. input . as_ref ( ) . clone ( ) ) ,
897+ & pivot. aggregate_expr ,
898+ & pivot. pivot_column ,
899+ pivot. pivot_values . clone ( ) ,
900+ pivot. default_on_null_expr . as_ref ( ) ,
901+ ) ?;
902+
903+ self . create_physical_plan ( & agg_plan, session_state) . await
904+ } else if let Some ( subquery) = & pivot. value_subquery {
895905 let optimized_subquery = session_state. optimize ( subquery. as_ref ( ) ) ?;
896906
897- let subquery_physical_plan = self . create_physical_plan (
898- & optimized_subquery,
899- session_state
900- ) . await ?;
907+ let subquery_physical_plan = self
908+ . create_physical_plan ( & optimized_subquery, session_state)
909+ . await ?;
901910
902- let subquery_results = collect ( subquery_physical_plan. clone ( ) , session_state. task_ctx ( ) ) . await ?;
911+ let subquery_results =
912+ collect ( subquery_physical_plan. clone ( ) , session_state. task_ctx ( ) )
913+ . await ?;
903914
904915 let mut pivot_values = Vec :: new ( ) ;
905916 for batch in subquery_results. iter ( ) {
906917 if batch. num_columns ( ) != 1 {
907- return plan_err ! ( "Pivot subquery must return a single column" ) ;
918+ return plan_err ! (
919+ "Pivot subquery must return a single column"
920+ ) ;
908921 }
909922
910923 let column = batch. column ( 0 ) ;
911924 for row_idx in 0 ..batch. num_rows ( ) {
912925 if !column. is_null ( row_idx) {
913- pivot_values. push (
914- ScalarValue :: try_from_array ( column, row_idx) ?
915- ) ;
926+ pivot_values
927+ . push ( ScalarValue :: try_from_array ( column, row_idx) ?) ;
916928 }
917929 }
918930 }
919- pivot_values
920- } else {
921- pivot. pivot_values . clone ( )
922- } ;
923931
924- return if !pivot_values. is_empty ( ) {
925932 let agg_plan = transform_pivot_to_aggregate (
926933 Arc :: new ( pivot. input . as_ref ( ) . clone ( ) ) ,
927934 & pivot. aggregate_expr ,
928935 & pivot. pivot_column ,
929936 pivot_values,
937+ pivot. default_on_null_expr . as_ref ( ) ,
930938 ) ?;
931939
932940 self . create_physical_plan ( & agg_plan, session_state) . await
@@ -1736,11 +1744,14 @@ pub use datafusion_physical_expr::{
17361744/// - SUM(amount) FILTER (WHERE quarter IS NOT DISTINCT FROM '2023_Q1') AS "2023_Q1"
17371745/// - SUM(amount) FILTER (WHERE quarter IS NOT DISTINCT FROM '2023_Q2') AS "2023_Q2"
17381746///
1747+ /// If DEFAULT ON NULL is specified, each aggregate expression is wrapped with an outer projection that
1748+ /// applies COALESCE to the results.
17391749pub fn transform_pivot_to_aggregate (
17401750 input : Arc < LogicalPlan > ,
17411751 aggregate_expr : & Expr ,
17421752 pivot_column : & datafusion_common:: Column ,
17431753 pivot_values : Vec < ScalarValue > ,
1754+ default_on_null_expr : Option < & Expr > ,
17441755) -> Result < LogicalPlan > {
17451756 let df_schema = input. schema ( ) ;
17461757
@@ -1750,22 +1761,26 @@ pub fn transform_pivot_to_aggregate(
17501761 // (exclude pivot column and aggregate expression columns)
17511762 let group_by_columns: Vec < Expr > = all_columns
17521763 . into_iter ( )
1753- . filter ( |col : & datafusion_common:: Column | {
1764+ . filter ( |col : & datafusion_common:: Column | {
17541765 col. name != pivot_column. name
1755- && !aggregate_expr. column_refs ( ) . iter ( ) . any ( |agg_col| agg_col. name == col. name )
1766+ && !aggregate_expr
1767+ . column_refs ( )
1768+ . iter ( )
1769+ . any ( |agg_col| agg_col. name == col. name )
17561770 } )
1757- . map ( |col : datafusion_common:: Column | Expr :: Column ( col) )
1771+ . map ( |col : datafusion_common:: Column | Expr :: Column ( col) )
17581772 . collect ( ) ;
17591773
17601774 let builder = LogicalPlanBuilder :: from ( Arc :: unwrap_or_clone ( input. clone ( ) ) ) ;
17611775
1776+ // Create the aggregate plan with filtered aggregates
17621777 let mut aggregate_exprs = Vec :: new ( ) ;
17631778
17641779 for value in & pivot_values {
17651780 let filter_condition = Expr :: BinaryExpr ( BinaryExpr :: new (
17661781 Box :: new ( Expr :: Column ( pivot_column. clone ( ) ) ) ,
17671782 Operator :: IsNotDistinctFrom ,
1768- Box :: new ( Expr :: Literal ( value. clone ( ) ) )
1783+ Box :: new ( Expr :: Literal ( value. clone ( ) ) ) ,
17691784 ) ) ;
17701785
17711786 let filtered_agg = match aggregate_expr {
@@ -1776,9 +1791,11 @@ pub fn transform_pivot_to_aggregate(
17761791 func : agg. func . clone ( ) ,
17771792 params : new_params,
17781793 } )
1779- } ,
1794+ }
17801795 _ => {
1781- return plan_err ! ( "Unsupported aggregate expression should always be AggregateFunction" ) ;
1796+ return plan_err ! (
1797+ "Unsupported aggregate expression should always be AggregateFunction"
1798+ ) ;
17821799 }
17831800 } ;
17841801
@@ -1794,9 +1811,60 @@ pub fn transform_pivot_to_aggregate(
17941811 aggregate_exprs. push ( aliased_agg) ;
17951812 }
17961813
1797- let aggregate_plan = builder. aggregate ( group_by_columns, aggregate_exprs) ?. build ( ) ?;
1814+ // Create the plan with the aggregate
1815+ let aggregate_plan = builder
1816+ . aggregate ( group_by_columns, aggregate_exprs) ?
1817+ . build ( ) ?;
1818+
1819+ // If DEFAULT ON NULL is specified, add a projection to apply COALESCE
1820+ if let Some ( default_expr) = default_on_null_expr {
1821+ let schema = aggregate_plan. schema ( ) ;
1822+ let mut projection_exprs = Vec :: new ( ) ;
17981823
1799- Ok ( aggregate_plan)
1824+ for field in schema. fields ( ) {
1825+ if !pivot_values
1826+ . iter ( )
1827+ . any ( |v| field. name ( ) == v. to_string ( ) . trim_matches ( '\'' ) )
1828+ {
1829+ projection_exprs. push ( Expr :: Column (
1830+ datafusion_common:: Column :: from_name ( field. name ( ) ) ,
1831+ ) ) ;
1832+ }
1833+ }
1834+
1835+ // Apply COALESCE to aggregate columns
1836+ for value in & pivot_values {
1837+ let field_name = value. to_string ( ) . trim_matches ( '\'' ) . to_string ( ) ;
1838+ let aggregate_col =
1839+ Expr :: Column ( datafusion_common:: Column :: from_name ( & field_name) ) ;
1840+
1841+ // Create COALESCE expression using CASE: CASE WHEN col IS NULL THEN default_value ELSE col END
1842+ let coalesce_expr = Expr :: Case ( datafusion_expr:: expr:: Case {
1843+ expr : None ,
1844+ when_then_expr : vec ! [ (
1845+ Box :: new( Expr :: IsNull ( Box :: new( aggregate_col. clone( ) ) ) ) ,
1846+ Box :: new( default_expr. clone( ) ) ,
1847+ ) ] ,
1848+ else_expr : Some ( Box :: new ( aggregate_col) ) ,
1849+ } ) ;
1850+
1851+ let aliased_coalesce = Expr :: Alias ( Alias {
1852+ expr : Box :: new ( coalesce_expr) ,
1853+ relation : None ,
1854+ name : field_name,
1855+ metadata : None ,
1856+ } ) ;
1857+
1858+ projection_exprs. push ( aliased_coalesce) ;
1859+ }
1860+
1861+ // Apply the projection
1862+ LogicalPlanBuilder :: from ( aggregate_plan)
1863+ . project ( projection_exprs) ?
1864+ . build ( )
1865+ } else {
1866+ Ok ( aggregate_plan)
1867+ }
18001868}
18011869
18021870impl DefaultPhysicalPlanner {
@@ -2163,31 +2231,42 @@ impl DefaultPhysicalPlanner {
21632231 // When we detect a PIVOT-derived plan with a value_subquery, ensure all generated columns are preserved
21642232 match input. as_ref ( ) {
21652233 LogicalPlan :: Pivot ( pivot) => {
2166- if pivot. value_subquery . is_some ( ) && input_exec. as_any ( ) . downcast_ref :: < AggregateExec > ( ) . is_some ( ) {
2167- let agg_exec = input_exec. as_any ( ) . downcast_ref :: < AggregateExec > ( ) . unwrap ( ) ;
2234+ if pivot. value_subquery . is_some ( )
2235+ && input_exec
2236+ . as_any ( )
2237+ . downcast_ref :: < AggregateExec > ( )
2238+ . is_some ( )
2239+ {
2240+ let agg_exec =
2241+ input_exec. as_any ( ) . downcast_ref :: < AggregateExec > ( ) . unwrap ( ) ;
21682242 let schema = input_exec. schema ( ) ;
21692243 let group_by_len = agg_exec. group_expr ( ) . expr ( ) . len ( ) ;
21702244
21712245 if group_by_len < schema. fields ( ) . len ( ) {
21722246 let mut all_exprs = physical_exprs. clone ( ) ;
21732247
2174- for ( i, field) in schema. fields ( ) . iter ( ) . enumerate ( ) . skip ( group_by_len) {
2175- if !physical_exprs. iter ( ) . any ( |( _, name) | name == field. name ( ) ) {
2248+ for ( i, field) in
2249+ schema. fields ( ) . iter ( ) . enumerate ( ) . skip ( group_by_len)
2250+ {
2251+ if !physical_exprs
2252+ . iter ( )
2253+ . any ( |( _, name) | name == field. name ( ) )
2254+ {
21762255 all_exprs. push ( (
2177- Arc :: new ( Column :: new ( field. name ( ) , i) ) as Arc < dyn PhysicalExpr > ,
2256+ Arc :: new ( Column :: new ( field. name ( ) , i) )
2257+ as Arc < dyn PhysicalExpr > ,
21782258 field. name ( ) . clone ( ) ,
21792259 ) ) ;
21802260 }
21812261 }
21822262
21832263 return Ok ( Arc :: new ( ProjectionExec :: try_new (
2184- all_exprs,
2185- input_exec,
2264+ all_exprs, input_exec,
21862265 ) ?) ) ;
21872266 }
21882267 }
2189- } ,
2190- _ => { }
2268+ }
2269+ _ => { }
21912270 }
21922271
21932272 Ok ( Arc :: new ( ProjectionExec :: try_new (
0 commit comments