Skip to content

Commit c50a2a0

Browse files
committed
Add field accessors to projection mappings for struct-producing expressions
1 parent 4f20b1e commit c50a2a0

2 files changed

Lines changed: 182 additions & 6 deletions

File tree

datafusion/physical-expr/src/equivalence/properties/dependency.rs

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1528,4 +1528,105 @@ mod tests {
15281528

15291529
Ok(())
15301530
}
1531+
1532+
/// Test that orderings propagate through struct-producing projections.
1533+
///
1534+
/// When a projection creates a struct via `named_struct('a', col_a, ...)`,
1535+
/// the output should preserve the ordering of `col_a` as an ordering on
1536+
/// `get_field(col("s"), "a")`. This enables sort elimination when the
1537+
/// framework sorts by a struct field that corresponds to an already-sorted
1538+
/// input column.
1539+
#[test]
1540+
fn test_ordering_propagation_through_named_struct() -> Result<()> {
1541+
use crate::expressions::Literal;
1542+
use datafusion_common::ScalarValue;
1543+
use datafusion_functions::core::{get_field, named_struct};
1544+
1545+
let input_schema = Arc::new(Schema::new(vec![
1546+
Field::new("a", DataType::Int32, true),
1547+
Field::new("b", DataType::Int32, true),
1548+
]));
1549+
1550+
let col_a = col("a", &input_schema)?;
1551+
let col_b = col("b", &input_schema)?;
1552+
let config = Arc::new(ConfigOptions::new());
1553+
1554+
// Build: named_struct('a', col_a, 'b', col_b) AS s
1555+
let named_struct_udf = named_struct();
1556+
let named_struct_expr = Arc::new(ScalarFunctionExpr::new(
1557+
"named_struct",
1558+
named_struct_udf,
1559+
vec![
1560+
Arc::new(Literal::new(ScalarValue::Utf8(Some("a".to_string())))),
1561+
Arc::clone(&col_a),
1562+
Arc::new(Literal::new(ScalarValue::Utf8(Some("b".to_string())))),
1563+
Arc::clone(&col_b),
1564+
],
1565+
Arc::new(Field::new(
1566+
"named_struct",
1567+
DataType::Struct(
1568+
vec![
1569+
Field::new("a", DataType::Int32, true),
1570+
Field::new("b", DataType::Int32, true),
1571+
]
1572+
.into(),
1573+
),
1574+
true,
1575+
)),
1576+
Arc::clone(&config),
1577+
)) as Arc<dyn PhysicalExpr>;
1578+
1579+
// Projection: named_struct(...) AS s
1580+
let proj_exprs = vec![(named_struct_expr, "s".to_string())];
1581+
let projection_mapping =
1582+
ProjectionMapping::try_new(proj_exprs, &input_schema)?;
1583+
1584+
// Input is ordered by [a ASC]
1585+
let mut input_properties =
1586+
EquivalenceProperties::new(Arc::clone(&input_schema));
1587+
let sort_a = PhysicalSortExpr::new(
1588+
Arc::clone(&col_a),
1589+
SortOptions {
1590+
descending: false,
1591+
nulls_first: false,
1592+
},
1593+
);
1594+
input_properties.add_orderings([vec![sort_a]]);
1595+
1596+
// Project through the named_struct
1597+
let out_schema = output_schema(&projection_mapping, &input_schema)?;
1598+
let out_properties =
1599+
input_properties.project(&projection_mapping, out_schema);
1600+
1601+
// Build the sort expression: get_field(col("s"), "a")
1602+
// This is what the framework would generate for ORDER BY s.a
1603+
let get_field_udf = get_field();
1604+
let col_s = Arc::new(Column::new("s", 0)) as Arc<dyn PhysicalExpr>;
1605+
let get_field_expr = Arc::new(ScalarFunctionExpr::new(
1606+
"get_field",
1607+
get_field_udf,
1608+
vec![
1609+
Arc::clone(&col_s),
1610+
Arc::new(Literal::new(ScalarValue::Utf8(Some("a".to_string())))),
1611+
],
1612+
Arc::new(Field::new("a", DataType::Int32, true)),
1613+
Arc::clone(&config),
1614+
)) as Arc<dyn PhysicalExpr>;
1615+
1616+
let sort_get_field_a = PhysicalSortExpr::new(
1617+
get_field_expr,
1618+
SortOptions {
1619+
descending: false,
1620+
nulls_first: false,
1621+
},
1622+
);
1623+
1624+
// The output should satisfy ordering by get_field(s, "a")
1625+
assert!(
1626+
out_properties.ordering_satisfy(vec![sort_get_field_a])?,
1627+
"Output should be ordered by get_field(s, 'a') since input is ordered by col_a"
1628+
);
1629+
1630+
Ok(())
1631+
}
15311632
}

datafusion/physical-expr/src/projection.rs

Lines changed: 81 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,15 @@
1818
use std::ops::Deref;
1919
use std::sync::Arc;
2020

21-
use crate::expressions::Column;
21+
use crate::expressions::{Column, Literal};
22+
use crate::scalar_function::ScalarFunctionExpr;
2223
use crate::utils::collect_columns;
2324
use crate::PhysicalExpr;
2425

25-
use arrow::datatypes::{Field, Schema, SchemaRef};
26+
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
2627
use datafusion_common::stats::{ColumnStatistics, Precision};
2728
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
28-
use datafusion_common::{internal_datafusion_err, internal_err, plan_err, Result};
29+
use datafusion_common::{ScalarValue, internal_datafusion_err, internal_err, plan_err, Result};
2930

3031
use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr};
3132
use indexmap::IndexMap;
@@ -638,9 +639,81 @@ impl ProjectionMapping {
638639
None => Ok(Transformed::no(e)),
639640
})
640641
.data()?;
641-
map.entry(source_expr)
642+
map.entry(source_expr.clone())
642643
.or_default()
643-
.push((target_expr, expr_idx));
644+
.push((Arc::clone(&target_expr), expr_idx));
645+
646+
// For struct-producing functions (e.g. named_struct), decompose
647+
// into field-level mapping entries so that orderings propagate
648+
// through struct projections. For example, if the projection has
649+
// `named_struct('ticker', p.ticker, ...) AS details`, this adds:
650+
// p.ticker → get_field(col("details"), "ticker")
651+
// enabling the optimizer to know that sorting by
652+
// `details.ticker` is equivalent to sorting by `p.ticker`.
653+
if let Some(func_expr) =
654+
source_expr.as_any().downcast_ref::<ScalarFunctionExpr>()
655+
{
656+
let literal_args: Vec<Option<ScalarValue>> = func_expr
657+
.args()
658+
.iter()
659+
.map(|arg| {
660+
arg.as_any()
661+
.downcast_ref::<Literal>()
662+
.map(|l| l.value().clone())
663+
})
664+
.collect();
665+
666+
if let Some(field_mapping) =
667+
func_expr.fun().struct_field_mapping(&literal_args)
668+
{
669+
if let DataType::Struct(struct_fields) =
670+
func_expr.return_type()
671+
{
672+
for (accessor_args, source_arg_idx) in &field_mapping.fields
673+
{
674+
let value_expr =
675+
func_expr.args()[*source_arg_idx].clone();
676+
677+
// Build accessor args: [target_col, ...field_name_literals]
678+
let mut accessor_fn_args: Vec<Arc<dyn PhysicalExpr>> =
679+
vec![Arc::clone(&target_expr)];
680+
accessor_fn_args.extend(accessor_args.iter().map(
681+
|sv| {
682+
Arc::new(Literal::new(sv.clone()))
683+
as Arc<dyn PhysicalExpr>
684+
},
685+
));
686+
687+
// Look up the field's return type from the struct schema
688+
let return_field = accessor_args
689+
.first()
690+
.and_then(|sv| sv.try_as_str().flatten())
691+
.and_then(|field_name| {
692+
struct_fields
693+
.iter()
694+
.find(|f| f.name() == field_name)
695+
.cloned()
696+
});
697+
698+
if let Some(return_field) = return_field {
699+
let field_access_expr =
700+
Arc::new(ScalarFunctionExpr::new(
701+
field_mapping.field_accessor.name(),
702+
Arc::clone(&field_mapping.field_accessor),
703+
accessor_fn_args,
704+
return_field,
705+
Arc::new(func_expr.config_options().clone()),
706+
))
707+
as Arc<dyn PhysicalExpr>;
708+
709+
map.entry(value_expr)
710+
.or_default()
711+
.push((field_access_expr, expr_idx));
712+
}
713+
}
714+
}
715+
}
716+
}
644717
}
645718
Ok(Self { map })
646719
}
@@ -795,8 +868,10 @@ pub(crate) mod tests {
795868
let data_type = source.data_type(input_schema)?;
796869
let nullable = source.nullable(input_schema)?;
797870
for (target, _) in targets.iter() {
871+
// Skip non-Column targets (e.g. struct field decomposition
872+
// entries which are ScalarFunctionExpr targets).
798873
let Some(column) = target.as_any().downcast_ref::<Column>() else {
799-
return plan_err!("Expects to have column");
874+
continue;
800875
};
801876
fields.push(Field::new(column.name(), data_type.clone(), nullable));
802877
}

0 commit comments

Comments
 (0)