Skip to content

Commit 7bfb602

Browse files
JanKaulclaude
andcommitted
Embucket changes
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 7acbe03 commit 7bfb602

43 files changed

Lines changed: 3350 additions & 145 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

Cargo.lock

Lines changed: 3 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,9 @@ regex = "1.12"
188188
rstest = "0.26.1"
189189
serde_json = "1"
190190
sha2 = "^0.10.9"
191-
sqlparser = { version = "0.61.0", default-features = false, features = ["std", "visitor"] }
191+
sqlparser = { git = "https://github.com/Embucket/datafusion-sqlparser-rs.git", branch = "embucket-sync-df53.0.0-parser0.61.0", features = [
192+
"visitor",
193+
] }
192194
strum = "0.28.0"
193195
strum_macros = "0.28.0"
194196
tempfile = "3"

FETCH_HEAD

Whitespace-only changes.

datafusion-cli/src/functions.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,10 @@ pub struct ParquetMetadataFunc {}
324324
impl TableFunctionImpl for ParquetMetadataFunc {
325325
fn call_with_args(&self, args: TableFunctionArgs) -> Result<Arc<dyn TableProvider>> {
326326
let exprs = args.exprs();
327+
if exprs.is_empty() {
328+
return plan_err!("parquet_metadata requires string argument as its input");
329+
}
330+
327331
let filename = match exprs.first() {
328332
Some(Expr::Literal(ScalarValue::Utf8(Some(s)), _)) => s, // single quote: parquet_metadata('x.parquet')
329333
Some(Expr::Column(Column { name, .. })) => name, // double quote: parquet_metadata("x.parquet")

datafusion/core/src/execution/session_state.rs

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,9 @@ impl Session for SessionState {
319319
}
320320

321321
impl SessionState {
322-
pub(crate) fn resolve_table_ref(
322+
/// Resolve a [`TableReference`] into a [`ResolvedTableReference`] using
323+
/// the session's configured default catalog and schema.
324+
pub fn resolve_table_ref(
323325
&self,
324326
table_ref: impl Into<TableReference>,
325327
) -> ResolvedTableReference {
@@ -541,8 +543,9 @@ impl SessionState {
541543
query.statement_to_plan(statement)
542544
}
543545

546+
/// Get the parser options
544547
#[cfg(feature = "sql")]
545-
fn get_parser_options(&self) -> ParserOptions {
548+
pub fn get_parser_options(&self) -> ParserOptions {
546549
let sql_parser_options = &self.config.options().sql_parser;
547550

548551
ParserOptions {
@@ -1864,9 +1867,11 @@ impl From<SessionState> for SessionStateBuilder {
18641867
/// This is used so the SQL planner can access the state of the session without
18651868
/// having a direct dependency on the [`SessionState`] struct (and core crate)
18661869
#[cfg(feature = "sql")]
1867-
struct SessionContextProvider<'a> {
1868-
state: &'a SessionState,
1869-
tables: HashMap<ResolvedTableReference, Arc<dyn TableSource>>,
1870+
pub struct SessionContextProvider<'a> {
1871+
/// The session state
1872+
pub state: &'a SessionState,
1873+
/// The tables available in the session
1874+
pub tables: HashMap<ResolvedTableReference, Arc<dyn TableSource>>,
18701875
}
18711876

18721877
#[cfg(feature = "sql")]
@@ -1901,14 +1906,15 @@ impl ContextProvider for SessionContextProvider<'_> {
19011906
fn get_table_function_source(
19021907
&self,
19031908
name: &str,
1904-
args: Vec<Expr>,
1909+
args: Vec<(Expr, Option<String>)>,
19051910
) -> datafusion_common::Result<Arc<dyn TableSource>> {
19061911
use datafusion_catalog::TableFunctionArgs;
19071912

1913+
let name = name.to_ascii_lowercase();
19081914
let tbl_func = self
19091915
.state
19101916
.table_functions
1911-
.get(name)
1917+
.get(&name)
19121918
.cloned()
19131919
.ok_or_else(|| plan_datafusion_err!("table function '{name}' not found"))?;
19141920
let simplify_context = SimplifyContext::builder()
@@ -1919,11 +1925,11 @@ impl ContextProvider for SessionContextProvider<'_> {
19191925
.build();
19201926
let simplifier = ExprSimplifier::new(simplify_context);
19211927
let schema = DFSchema::empty();
1922-
let args = args
1928+
let args: Vec<Expr> = args
19231929
.into_iter()
1924-
.map(|arg| {
1930+
.map(|(expr, _named_param)| {
19251931
simplifier
1926-
.coerce(arg, &schema)
1932+
.coerce(expr, &schema)
19271933
.and_then(|e| simplifier.simplify(e))
19281934
})
19291935
.collect::<datafusion_common::Result<Vec<_>>>()?;

datafusion/core/src/physical_planner.rs

Lines changed: 230 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,9 @@ use datafusion_expr::expr_rewriter::unnormalize_cols;
8686
use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary;
8787
use datafusion_expr::utils::{expr_to_columns, split_conjunction};
8888
use datafusion_expr::{
89-
Analyze, BinaryExpr, DescribeTable, DmlStatement, Explain, ExplainFormat, Extension,
90-
FetchType, Filter, JoinType, Operator, RecursiveQuery, SkipType, StringifiedPlan,
91-
WindowFrame, WindowFrameBound, WriteOp,
89+
Analyze, BinaryExpr, Cast, DescribeTable, DmlStatement, Explain, ExplainFormat,
90+
Extension, FetchType, Filter, JoinType, LogicalPlanBuilder, Operator, RecursiveQuery,
91+
SkipType, StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp,
9292
};
9393
use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr};
9494
use datafusion_physical_expr::expressions::Literal;
@@ -110,6 +110,8 @@ use itertools::{Itertools, multiunzip};
110110
use log::debug;
111111
use tokio::sync::Mutex;
112112

113+
use datafusion_physical_plan::collect;
114+
113115
/// Physical query planner that converts a `LogicalPlan` to an
114116
/// `ExecutionPlan` suitable for execution.
115117
#[async_trait]
@@ -1253,7 +1255,62 @@ impl DefaultPhysicalPlanner {
12531255
options.clone(),
12541256
)?)
12551257
}
1258+
LogicalPlan::Pivot(pivot) => {
1259+
return if !pivot.pivot_values.is_empty() {
1260+
let input = Arc::new(pivot.input.as_ref().clone());
1261+
let agg_plan = transform_pivot_to_aggregate(
1262+
&input,
1263+
&pivot.aggregate_expr,
1264+
&pivot.pivot_column,
1265+
&pivot.pivot_values,
1266+
pivot.default_on_null_expr.as_ref(),
1267+
)?;
1268+
1269+
self.create_physical_plan(&agg_plan, session_state).await
1270+
} else if let Some(subquery) = &pivot.value_subquery {
1271+
let optimized_subquery = session_state.optimize(subquery.as_ref())?;
1272+
1273+
let subquery_physical_plan = self
1274+
.create_physical_plan(&optimized_subquery, session_state)
1275+
.await?;
12561276

1277+
let subquery_results = collect(
1278+
Arc::clone(&subquery_physical_plan),
1279+
session_state.task_ctx(),
1280+
)
1281+
.await?;
1282+
1283+
let mut pivot_values = Vec::new();
1284+
for batch in subquery_results.iter() {
1285+
if batch.num_columns() != 1 {
1286+
return plan_err!(
1287+
"Pivot subquery must return a single column"
1288+
);
1289+
}
1290+
1291+
let column = batch.column(0);
1292+
for row_idx in 0..batch.num_rows() {
1293+
if !column.is_null(row_idx) {
1294+
pivot_values
1295+
.push(ScalarValue::try_from_array(column, row_idx)?);
1296+
}
1297+
}
1298+
}
1299+
1300+
let input = Arc::new(pivot.input.as_ref().clone());
1301+
let agg_plan = transform_pivot_to_aggregate(
1302+
&input,
1303+
&pivot.aggregate_expr,
1304+
&pivot.pivot_column,
1305+
&pivot_values,
1306+
pivot.default_on_null_expr.as_ref(),
1307+
)?;
1308+
1309+
self.create_physical_plan(&agg_plan, session_state).await
1310+
} else {
1311+
plan_err!("PIVOT operation requires at least one value to pivot on")
1312+
};
1313+
}
12571314
// 2 Children
12581315
LogicalPlan::Join(Join {
12591316
left: original_left,
@@ -2142,7 +2199,8 @@ fn extract_dml_filters(
21422199
| LogicalPlan::Ddl(_)
21432200
| LogicalPlan::Copy(_)
21442201
| LogicalPlan::Unnest(_)
2145-
| LogicalPlan::RecursiveQuery(_) => {
2202+
| LogicalPlan::RecursiveQuery(_)
2203+
| LogicalPlan::Pivot(_) => {
21462204
// No filters to extract from leaf/meta plans
21472205
}
21482206
// Plans with inputs (may contain filters in children)
@@ -2498,6 +2556,148 @@ pub fn create_aggregate_expr_and_maybe_filter(
24982556
)
24992557
}
25002558

2559+
/// Transform a PIVOT operation into a more standard Aggregate + Projection plan
2560+
/// For known pivot values, we create a projection that includes "IS NOT DISTINCT FROM" conditions
2561+
///
2562+
/// For example, for SUM(amount) PIVOT(quarter FOR quarter in ('2023_Q1', '2023_Q2')), we create:
2563+
/// - SUM(amount) FILTER (WHERE quarter IS NOT DISTINCT FROM '2023_Q1') AS "2023_Q1"
2564+
/// - SUM(amount) FILTER (WHERE quarter IS NOT DISTINCT FROM '2023_Q2') AS "2023_Q2"
2565+
///
2566+
/// If DEFAULT ON NULL is specified, each aggregate expression is wrapped with an outer projection that
2567+
/// applies COALESCE to the results.
2568+
pub fn transform_pivot_to_aggregate(
2569+
input: &Arc<LogicalPlan>,
2570+
aggregate_expr: &Expr,
2571+
pivot_column: &Column,
2572+
pivot_values: &[ScalarValue],
2573+
default_on_null_expr: Option<&Expr>,
2574+
) -> Result<LogicalPlan> {
2575+
let df_schema = input.schema();
2576+
2577+
let all_columns: Vec<Column> = df_schema.columns();
2578+
2579+
// Filter to include only columns we want for GROUP BY
2580+
// (exclude pivot column and aggregate expression columns)
2581+
let group_by_columns: Vec<Expr> = all_columns
2582+
.into_iter()
2583+
.filter(|col: &Column| {
2584+
col.name != pivot_column.name
2585+
&& !aggregate_expr
2586+
.column_refs()
2587+
.iter()
2588+
.any(|agg_col| agg_col.name == col.name)
2589+
})
2590+
.map(|col: Column| Expr::Column(col))
2591+
.collect();
2592+
2593+
let builder = LogicalPlanBuilder::from(Arc::unwrap_or_clone(Arc::clone(input)));
2594+
2595+
// Create the aggregate plan with filtered aggregates
2596+
let mut aggregate_exprs = Vec::new();
2597+
2598+
let input_schema = input.schema();
2599+
let pivot_col_idx = match input_schema.index_of_column(pivot_column) {
2600+
Ok(idx) => idx,
2601+
Err(_) => {
2602+
return plan_err!(
2603+
"Pivot column '{}' does not exist in input schema",
2604+
pivot_column
2605+
);
2606+
}
2607+
};
2608+
let pivot_col_type = input_schema.field(pivot_col_idx).data_type();
2609+
2610+
for value in pivot_values {
2611+
let filter_condition = Expr::BinaryExpr(BinaryExpr::new(
2612+
Box::new(Expr::Column(pivot_column.clone())),
2613+
Operator::IsNotDistinctFrom,
2614+
Box::new(Expr::Cast(Cast::new(
2615+
Box::new(Expr::Literal(value.clone(), None)),
2616+
pivot_col_type.clone(),
2617+
))),
2618+
));
2619+
2620+
let filtered_agg = match aggregate_expr {
2621+
Expr::AggregateFunction(agg) => {
2622+
let mut new_params = agg.params.clone();
2623+
new_params.filter = Some(Box::new(filter_condition));
2624+
Expr::AggregateFunction(AggregateFunction {
2625+
func: Arc::clone(&agg.func),
2626+
params: new_params,
2627+
})
2628+
}
2629+
_ => {
2630+
return plan_err!(
2631+
"Unsupported aggregate expression should always be AggregateFunction"
2632+
);
2633+
}
2634+
};
2635+
2636+
// Use the pivot value as the column name
2637+
let field_name = value.to_string().trim_matches('\'').to_string();
2638+
let aliased_agg = Expr::Alias(Alias {
2639+
expr: Box::new(filtered_agg),
2640+
relation: None,
2641+
name: field_name,
2642+
metadata: None,
2643+
});
2644+
2645+
aggregate_exprs.push(aliased_agg);
2646+
}
2647+
2648+
// Create the plan with the aggregate
2649+
let aggregate_plan = builder
2650+
.aggregate(group_by_columns, aggregate_exprs)?
2651+
.build()?;
2652+
2653+
// If DEFAULT ON NULL is specified, add a projection to apply COALESCE
2654+
if let Some(default_expr) = default_on_null_expr {
2655+
let schema = aggregate_plan.schema();
2656+
let mut projection_exprs = Vec::new();
2657+
2658+
for field in schema.fields() {
2659+
if !pivot_values
2660+
.iter()
2661+
.any(|v| field.name() == v.to_string().trim_matches('\''))
2662+
{
2663+
projection_exprs.push(Expr::Column(Column::from_name(field.name())));
2664+
}
2665+
}
2666+
2667+
// Apply COALESCE to aggregate columns
2668+
for value in pivot_values {
2669+
let field_name = value.to_string().trim_matches('\'').to_string();
2670+
let aggregate_col = Expr::Column(Column::from_name(&field_name));
2671+
2672+
// Create COALESCE expression using CASE: CASE WHEN col IS NULL THEN default_value ELSE col END
2673+
let coalesce_expr = Expr::Case(datafusion_expr::expr::Case {
2674+
expr: None,
2675+
when_then_expr: vec![(
2676+
Box::new(Expr::IsNull(Box::new(aggregate_col.clone()))),
2677+
Box::new(default_expr.clone()),
2678+
)],
2679+
else_expr: Some(Box::new(aggregate_col)),
2680+
});
2681+
2682+
let aliased_coalesce = Expr::Alias(Alias {
2683+
expr: Box::new(coalesce_expr),
2684+
relation: None,
2685+
name: field_name,
2686+
metadata: None,
2687+
});
2688+
2689+
projection_exprs.push(aliased_coalesce);
2690+
}
2691+
2692+
// Apply the projection
2693+
LogicalPlanBuilder::from(aggregate_plan)
2694+
.project(projection_exprs)?
2695+
.build()
2696+
} else {
2697+
Ok(aggregate_plan)
2698+
}
2699+
}
2700+
25012701
impl DefaultPhysicalPlanner {
25022702
/// Handles capturing the various plans for EXPLAIN queries
25032703
///
@@ -2887,6 +3087,32 @@ impl DefaultPhysicalPlanner {
28873087
.collect::<Result<Vec<_>>>()?;
28883088

28893089
let num_input_columns = input_exec.schema().fields().len();
3090+
// When we detect a PIVOT-derived plan with a value_subquery, ensure all generated columns are preserved
3091+
if let LogicalPlan::Pivot(pivot) = input.as_ref()
3092+
&& pivot.value_subquery.is_some()
3093+
&& let Some(agg_exec) = input_exec.downcast_ref::<AggregateExec>()
3094+
{
3095+
let schema = input_exec.schema();
3096+
let group_by_len = agg_exec.group_expr().expr().len();
3097+
3098+
if group_by_len < schema.fields().len() {
3099+
let mut all_exprs = physical_exprs.clone();
3100+
3101+
for (i, field) in schema.fields().iter().enumerate().skip(group_by_len) {
3102+
if !physical_exprs.iter().any(|(_, name)| name == field.name()) {
3103+
all_exprs.push((
3104+
Arc::new(datafusion_physical_expr::expressions::Column::new(
3105+
field.name(),
3106+
i,
3107+
)) as Arc<dyn PhysicalExpr>,
3108+
field.name().to_string(),
3109+
));
3110+
}
3111+
}
3112+
3113+
return Ok(Arc::new(ProjectionExec::try_new(all_exprs, input_exec)?));
3114+
}
3115+
}
28903116

28913117
match self.try_plan_async_exprs(
28923118
num_input_columns,

0 commit comments

Comments
 (0)