Skip to content

Commit 09a6038

Browse files
committed
fix: preserve optimized DuckDB unparse alias scope
1 parent ddc157d commit 09a6038

6 files changed

Lines changed: 334 additions & 28 deletions

File tree

datafusion/core/tests/sql/unparser.rs

Lines changed: 126 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,19 @@
3737
3838
use std::fs::ReadDir;
3939
use std::future::Future;
40+
use std::sync::Arc;
4041

4142
use arrow::array::RecordBatch;
43+
use arrow::datatypes::{DataType, Field, Schema};
4244
use datafusion::common::Result;
45+
use datafusion::datasource::empty::EmptyTable;
4346
use datafusion::prelude::{ParquetReadOptions, SessionContext};
47+
use datafusion_catalog::memory::MemorySchemaProvider;
48+
use datafusion_catalog::{CatalogProvider, MemoryCatalogProvider, SchemaProvider};
4449
use datafusion_common::Column;
4550
use datafusion_expr::Expr;
4651
use datafusion_sql::unparser::Unparser;
47-
use datafusion_sql::unparser::dialect::DefaultDialect;
52+
use datafusion_sql::unparser::dialect::{DefaultDialect, DuckDBDialect};
4853
use itertools::Itertools;
4954
use recursive::{set_minimum_stack_size, set_stack_allocation_size};
5055

@@ -218,6 +223,126 @@ async fn sort_batches(
218223
df.collect().await
219224
}
220225

226+
const ISSUE_22961_QUERY: &str = r#"
227+
SELECT * FROM
228+
(
229+
SELECT
230+
item_id,
231+
order_id,
232+
product_id,
233+
quantity,
234+
unit_price,
235+
quantity * unit_price AS line_total
236+
FROM
237+
"warehouse"."main"."order_items"
238+
) oi
239+
JOIN (
240+
SELECT
241+
order_id,
242+
customer_id,
243+
order_date,
244+
lower(STATUS) AS STATUS,
245+
lower(channel) AS channel,
246+
coalesce(discount_pct, 0) AS discount_pct,
247+
coalesce(shipping_cost, 0) AS shipping_cost,
248+
STATUS IN ('completed', 'shipped') AS is_fulfilled
249+
FROM
250+
"warehouse"."main"."orders"
251+
) o USING (order_id)
252+
JOIN (
253+
SELECT
254+
p.product_id,
255+
p.category_id,
256+
p.sku,
257+
p.name AS product_name,
258+
p.price,
259+
p.cost,
260+
p.weight_kg,
261+
p.is_active,
262+
p.stock_qty,
263+
round(p.price - p.cost, 2) AS gross_margin,
264+
round((p.price - p.cost) / nullif(p.price, 0), 4) AS margin_pct,
265+
c.name AS category_name
266+
FROM
267+
"warehouse"."main"."products" p
268+
LEFT JOIN "warehouse"."main"."categories" c USING (category_id)
269+
) p USING (product_id)
270+
"#;
271+
272+
fn issue_22961_context() -> Result<SessionContext> {
273+
let ctx = SessionContext::new();
274+
275+
let schema_provider = Arc::new(MemorySchemaProvider::new());
276+
schema_provider.register_table(
277+
"order_items".to_string(),
278+
Arc::new(EmptyTable::new(Arc::new(Schema::new(vec![
279+
Field::new("item_id", DataType::Int32, false),
280+
Field::new("order_id", DataType::Int32, true),
281+
Field::new("product_id", DataType::Int32, true),
282+
Field::new("quantity", DataType::Int32, true),
283+
Field::new("unit_price", DataType::Decimal128(10, 2), true),
284+
])))),
285+
)?;
286+
schema_provider.register_table(
287+
"orders".to_string(),
288+
Arc::new(EmptyTable::new(Arc::new(Schema::new(vec![
289+
Field::new("order_id", DataType::Int32, false),
290+
Field::new("customer_id", DataType::Int32, true),
291+
Field::new("order_date", DataType::Date32, true),
292+
Field::new("status", DataType::Utf8, true),
293+
Field::new("channel", DataType::Utf8, true),
294+
Field::new("discount_pct", DataType::Decimal128(5, 2), true),
295+
Field::new("shipping_cost", DataType::Decimal128(8, 2), true),
296+
])))),
297+
)?;
298+
schema_provider.register_table(
299+
"products".to_string(),
300+
Arc::new(EmptyTable::new(Arc::new(Schema::new(vec![
301+
Field::new("product_id", DataType::Int32, false),
302+
Field::new("category_id", DataType::Int32, true),
303+
Field::new("sku", DataType::Utf8, true),
304+
Field::new("name", DataType::Utf8, true),
305+
Field::new("price", DataType::Decimal128(10, 2), true),
306+
Field::new("cost", DataType::Decimal128(10, 2), true),
307+
Field::new("weight_kg", DataType::Decimal128(6, 3), true),
308+
Field::new("is_active", DataType::Boolean, true),
309+
Field::new("stock_qty", DataType::Int32, true),
310+
])))),
311+
)?;
312+
schema_provider.register_table(
313+
"categories".to_string(),
314+
Arc::new(EmptyTable::new(Arc::new(Schema::new(vec![
315+
Field::new("category_id", DataType::Int32, false),
316+
Field::new("name", DataType::Utf8, true),
317+
Field::new("parent_id", DataType::Int32, true),
318+
Field::new("display_rank", DataType::Int32, true),
319+
])))),
320+
)?;
321+
322+
let catalog = Arc::new(MemoryCatalogProvider::new());
323+
catalog.register_schema("main", schema_provider)?;
324+
ctx.register_catalog("warehouse", catalog);
325+
326+
Ok(ctx)
327+
}
328+
329+
#[tokio::test]
330+
async fn optimized_duckdb_unparse_preserves_derived_table_scope() -> Result<()> {
331+
let ctx = issue_22961_context()?;
332+
let plan = ctx.sql(ISSUE_22961_QUERY).await?.into_optimized_plan()?;
333+
let dialect = DuckDBDialect::new();
334+
let unparser = Unparser::new(&dialect);
335+
let sql = unparser.plan_to_sql(&plan)?.to_string();
336+
337+
assert!(!sql.contains(r#""o"."__common_expr_1""#));
338+
assert!(!sql.contains(r#""o"."__common_expr_2""#));
339+
assert!(sql.contains(
340+
r#"ON "oi"."order_id" = "o"."order_id" INNER JOIN (SELECT "p"."product_id""#
341+
));
342+
343+
Ok(())
344+
}
345+
221346
/// The outcome of running a single roundtrip test.
222347
///
223348
/// A successful test produces [`TestCaseResult::Success`].

datafusion/sql/src/unparser/ast.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,9 @@ impl SelectBuilder {
264264
pub fn pop_from(&mut self) -> Option<TableWithJoinsBuilder> {
265265
self.from.pop()
266266
}
267+
pub fn has_selection(&self) -> bool {
268+
self.selection.is_some()
269+
}
267270
pub fn lateral_views(&mut self, value: Vec<ast::LateralView>) -> &mut Self {
268271
self.lateral_views = value;
269272
self
@@ -483,6 +486,7 @@ pub struct RelationBuilder {
483486
enum TableFactorBuilder {
484487
Table(TableRelationBuilder),
485488
Derived(DerivedRelationBuilder),
489+
NestedJoin(ast::TableWithJoins, Option<ast::TableAlias>),
486490
Unnest(UnnestRelationBuilder),
487491
Flatten(FlattenRelationBuilder),
488492
Empty,
@@ -501,6 +505,15 @@ impl RelationBuilder {
501505
self
502506
}
503507

508+
pub fn nested_join(
509+
&mut self,
510+
value: ast::TableWithJoins,
511+
alias: Option<ast::TableAlias>,
512+
) -> &mut Self {
513+
self.relation = Some(TableFactorBuilder::NestedJoin(value, alias));
514+
self
515+
}
516+
504517
pub fn unnest(&mut self, value: UnnestRelationBuilder) -> &mut Self {
505518
self.relation = Some(TableFactorBuilder::Unnest(value));
506519
self
@@ -524,6 +537,9 @@ impl RelationBuilder {
524537
Some(TableFactorBuilder::Derived(ref mut rel_builder)) => {
525538
rel_builder.alias = value;
526539
}
540+
Some(TableFactorBuilder::NestedJoin(_, ref mut alias)) => {
541+
*alias = value;
542+
}
527543
Some(TableFactorBuilder::Unnest(ref mut rel_builder)) => {
528544
rel_builder.alias = value;
529545
}
@@ -539,6 +555,12 @@ impl RelationBuilder {
539555
Ok(match self.relation {
540556
Some(TableFactorBuilder::Table(ref value)) => Some(value.build()?),
541557
Some(TableFactorBuilder::Derived(ref value)) => Some(value.build()?),
558+
Some(TableFactorBuilder::NestedJoin(ref table_with_joins, ref alias)) => {
559+
Some(ast::TableFactor::NestedJoin {
560+
table_with_joins: Box::new(table_with_joins.clone()),
561+
alias: alias.clone(),
562+
})
563+
}
542564
Some(TableFactorBuilder::Unnest(ref value)) => Some(value.build()?),
543565
Some(TableFactorBuilder::Flatten(ref value)) => Some(value.build()?),
544566
Some(TableFactorBuilder::Empty) => None,

datafusion/sql/src/unparser/plan.rs

Lines changed: 95 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ use crate::unparser::{
4242
};
4343
use crate::utils::UNNEST_PLACEHOLDER;
4444
use datafusion_common::{
45-
Column, DataFusionError, Result, ScalarValue, TableReference, assert_or_internal_err,
46-
internal_datafusion_err, internal_err, not_impl_err,
45+
Column, DFSchema, DataFusionError, Result, ScalarValue, TableReference,
46+
assert_or_internal_err, internal_datafusion_err, internal_err, not_impl_err,
4747
tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRecursion},
4848
utils::combine_limit,
4949
};
@@ -571,8 +571,9 @@ impl Unparser<'_> {
571571

572572
let input_schema = window.input.schema();
573573
let mut alias_rewriter = TableAliasRewriter {
574-
table_schema: input_schema.as_arrow(),
574+
table_schema: input_schema.as_ref(),
575575
alias_name: TableReference::bare(input_alias),
576+
rewrite_unqualified: true,
576577
};
577578
let window_expr = window
578579
.window_expr
@@ -1160,6 +1161,11 @@ impl Unparser<'_> {
11601161
}
11611162
None => Arc::clone(left_plan),
11621163
};
1164+
let left_plan = if already_projected {
1165+
Self::unwrap_qualified_passthrough_join_projection(left_plan)
1166+
} else {
1167+
left_plan
1168+
};
11631169

11641170
self.select_to_sql_recursively(
11651171
left_plan.as_ref(),
@@ -1185,13 +1191,22 @@ impl Unparser<'_> {
11851191
};
11861192

11871193
let mut right_relation = RelationBuilder::default();
1188-
1189-
self.select_to_sql_recursively(
1190-
right_plan.as_ref(),
1191-
query,
1192-
select,
1193-
&mut right_relation,
1194-
)?;
1194+
if already_projected
1195+
&& let Some(nested_relation) = self
1196+
.qualified_passthrough_join_projection_to_nested_relation(
1197+
right_plan.as_ref(),
1198+
query,
1199+
)?
1200+
{
1201+
right_relation = nested_relation;
1202+
} else {
1203+
self.select_to_sql_recursively(
1204+
right_plan.as_ref(),
1205+
query,
1206+
select,
1207+
&mut right_relation,
1208+
)?;
1209+
}
11951210

11961211
let (join_filters, where_filters) = Self::split_join_on_and_where_filters(
11971212
join.join_type,
@@ -1910,6 +1925,68 @@ impl Unparser<'_> {
19101925
)
19111926
}
19121927

1928+
fn is_qualified_passthrough_projection(projection: &Projection) -> bool {
1929+
projection
1930+
.expr
1931+
.iter()
1932+
.all(|expr| matches!(expr, Expr::Column(column) if column.relation.is_some()))
1933+
}
1934+
1935+
fn unwrap_qualified_passthrough_join_projection(
1936+
plan: Arc<LogicalPlan>,
1937+
) -> Arc<LogicalPlan> {
1938+
if let LogicalPlan::Projection(projection) = plan.as_ref()
1939+
&& matches!(projection.input.as_ref(), LogicalPlan::Join(_))
1940+
&& Self::is_qualified_passthrough_projection(projection)
1941+
{
1942+
Arc::clone(&projection.input)
1943+
} else {
1944+
plan
1945+
}
1946+
}
1947+
1948+
fn qualified_passthrough_join_projection_to_nested_relation(
1949+
&self,
1950+
plan: &LogicalPlan,
1951+
query: &mut Option<QueryBuilder>,
1952+
) -> Result<Option<RelationBuilder>> {
1953+
let LogicalPlan::Projection(projection) = plan else {
1954+
return Ok(None);
1955+
};
1956+
if !matches!(projection.input.as_ref(), LogicalPlan::Join(_))
1957+
|| !Self::is_qualified_passthrough_projection(projection)
1958+
{
1959+
return Ok(None);
1960+
}
1961+
1962+
let original_query = query.clone();
1963+
let mut nested_select = SelectBuilder::default();
1964+
nested_select.push_from(TableWithJoinsBuilder::default());
1965+
let mut nested_relation = RelationBuilder::default();
1966+
self.select_to_sql_recursively(
1967+
projection.input.as_ref(),
1968+
query,
1969+
&mut nested_select,
1970+
&mut nested_relation,
1971+
)?;
1972+
if nested_select.has_selection() {
1973+
*query = original_query;
1974+
return Ok(None);
1975+
}
1976+
1977+
let Some(mut nested_from) = nested_select.pop_from() else {
1978+
return internal_err!("Failed to build nested join relation");
1979+
};
1980+
nested_from.relation(nested_relation);
1981+
let Some(table_with_joins) = nested_from.build()? else {
1982+
return internal_err!("Failed to build nested join relation");
1983+
};
1984+
1985+
let mut relation = RelationBuilder::default();
1986+
relation.nested_join(table_with_joins, None);
1987+
Ok(Some(relation))
1988+
}
1989+
19131990
/// Try to unparse a table scan with pushdown operations into a new subquery plan.
19141991
/// If the table scan is without any pushdown operations, return None.
19151992
fn unparse_table_scan_pushdown(
@@ -1924,10 +2001,15 @@ impl Unparser<'_> {
19242001
return Ok(None);
19252002
}
19262003
let table_schema = table_scan.source.schema();
2004+
let filter_schema = DFSchema::try_from_qualified_schema(
2005+
table_scan.table_name.clone(),
2006+
table_schema.as_ref(),
2007+
)?;
19272008
let mut filter_alias_rewriter =
19282009
alias.as_ref().map(|alias_name| TableAliasRewriter {
1929-
table_schema: &table_schema,
2010+
table_schema: &filter_schema,
19302011
alias_name: alias_name.clone(),
2012+
rewrite_unqualified: true,
19312013
});
19322014

19332015
let mut builder = LogicalPlanBuilder::scan(
@@ -2037,8 +2119,9 @@ impl Unparser<'_> {
20372119
let exprs = if alias.is_some() {
20382120
let mut alias_rewriter =
20392121
alias.as_ref().map(|alias_name| TableAliasRewriter {
2040-
table_schema: plan.schema().as_arrow(),
2122+
table_schema: plan.schema().as_ref(),
20412123
alias_name: alias_name.clone(),
2124+
rewrite_unqualified: false,
20422125
});
20432126
projection
20442127
.expr

0 commit comments

Comments
 (0)