Skip to content

Commit 089d888

Browse files
committed
feat: Propagate orderings through struct-producing projections (#39)
1 parent a6459ec commit 089d888

6 files changed

Lines changed: 318 additions & 7 deletions

File tree

datafusion/expr/src/lib.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,9 @@ pub use udaf::{
125125
udaf_default_schema_name, udaf_default_window_function_display_name,
126126
udaf_default_window_function_schema_name,
127127
};
128-
pub use udf::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl};
128+
pub use udf::{
129+
ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, StructFieldMapping,
130+
};
129131
pub use udwf::{LimitEffect, ReversedUDWF, WindowUDF, WindowUDFImpl};
130132
pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};
131133

datafusion/expr/src/udf.rs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,25 @@ use std::fmt::Debug;
3636
use std::hash::{Hash, Hasher};
3737
use std::sync::Arc;
3838

39+
/// Describes how a struct-producing UDF's output fields correspond to its
40+
/// input arguments. This enables the optimizer to propagate orderings
41+
/// through struct projections (e.g., so that sorting by a struct field
42+
/// can be recognized as equivalent to sorting by the source column).
43+
///
44+
/// See [`ScalarUDFImpl::struct_field_mapping`] for details.
45+
pub struct StructFieldMapping {
46+
/// The UDF used to construct field access expressions on the output.
47+
/// For example, the `get_field` UDF for accessing struct fields.
48+
pub field_accessor: Arc<ScalarUDF>,
49+
/// For each output field: the literal arguments to pass to the
50+
/// `field_accessor` UDF (after the base expression), and the index
51+
/// of the corresponding input argument that produces the field's value.
52+
///
53+
/// For `named_struct('a', col1, 'b', col2)`, this would be:
54+
/// `[(["a"], 1), (["b"], 3)]` — field `"a"` comes from arg index 1.
55+
pub fields: Vec<(Vec<ScalarValue>, usize)>,
56+
}
57+
3958
/// Logical representation of a Scalar User Defined Function.
4059
///
4160
/// A scalar function produces a single row output for each row of input. This
@@ -344,6 +363,14 @@ impl ScalarUDF {
344363
self.inner.documentation()
345364
}
346365

366+
/// See [`ScalarUDFImpl::struct_field_mapping`] for more details.
367+
pub fn struct_field_mapping(
368+
&self,
369+
literal_args: &[Option<ScalarValue>],
370+
) -> Option<StructFieldMapping> {
371+
self.inner.struct_field_mapping(literal_args)
372+
}
373+
347374
/// Return true if this function is an async function
348375
pub fn as_async(&self) -> Option<&AsyncScalarUDF> {
349376
self.inner().as_any().downcast_ref::<AsyncScalarUDF>()
@@ -846,6 +873,25 @@ pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync {
846873
fn documentation(&self) -> Option<&Documentation> {
847874
None
848875
}
876+
877+
/// For struct-producing functions, return how output fields map to input
878+
/// arguments. This enables the optimizer to propagate orderings through
879+
/// struct projections.
880+
///
881+
/// `literal_args[i]` is `Some(value)` if argument `i` is a known literal,
882+
/// allowing extraction of field names from arguments like
883+
/// `named_struct('field_name', value, ...)`.
884+
///
885+
/// For example, `named_struct('a', col1, 'b', col2)` would return a
886+
/// mapping indicating that output field `'a'` (accessed via
887+
/// `get_field(output, 'a')`) corresponds to input argument `col1` at
888+
/// index 1, and field `'b'` corresponds to `col2` at index 3.
889+
fn struct_field_mapping(
890+
&self,
891+
_literal_args: &[Option<ScalarValue>],
892+
) -> Option<StructFieldMapping> {
893+
None
894+
}
849895
}
850896

851897
/// ScalarUDF that adds an alias to the underlying function. It is better to
@@ -964,6 +1010,13 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl {
9641010
fn documentation(&self) -> Option<&Documentation> {
9651011
self.inner.documentation()
9661012
}
1013+
1014+
fn struct_field_mapping(
1015+
&self,
1016+
literal_args: &[Option<ScalarValue>],
1017+
) -> Option<StructFieldMapping> {
1018+
self.inner.struct_field_mapping(literal_args)
1019+
}
9671020
}
9681021

9691022
#[cfg(test)]

datafusion/functions/src/core/named_struct.rs

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,14 @@
1717

1818
use arrow::array::StructArray;
1919
use arrow::datatypes::{DataType, Field, FieldRef, Fields};
20-
use datafusion_common::{Result, exec_err, internal_err};
20+
use datafusion_common::{Result, ScalarValue, exec_err, internal_err};
2121
use datafusion_expr::{
22-
ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs,
22+
ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF,
23+
StructFieldMapping,
2324
};
2425
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
26+
27+
use super::getfield::GetFieldFunc;
2528
use datafusion_macros::user_doc;
2629
use std::any::Any;
2730
use std::sync::Arc;
@@ -177,4 +180,31 @@ impl ScalarUDFImpl for NamedStructFunc {
177180
fn documentation(&self) -> Option<&Documentation> {
178181
self.doc()
179182
}
183+
184+
fn struct_field_mapping(
185+
&self,
186+
literal_args: &[Option<ScalarValue>],
187+
) -> Option<StructFieldMapping> {
188+
if literal_args.is_empty() || !literal_args.len().is_multiple_of(2) {
189+
return None;
190+
}
191+
192+
let mut fields = Vec::with_capacity(literal_args.len() / 2);
193+
for (i, chunk) in literal_args.chunks(2).enumerate() {
194+
match chunk {
195+
[Some(ScalarValue::Utf8(Some(name))), _] => {
196+
fields.push((
197+
vec![ScalarValue::Utf8(Some(name.clone()))],
198+
i * 2 + 1, // index of the value argument
199+
));
200+
}
201+
_ => return None,
202+
}
203+
}
204+
205+
Some(StructFieldMapping {
206+
field_accessor: Arc::new(ScalarUDF::from(GetFieldFunc::new())),
207+
fields,
208+
})
209+
}
180210
}

datafusion/physical-expr/src/equivalence/properties/dependency.rs

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1528,4 +1528,102 @@ mod tests {
15281528

15291529
Ok(())
15301530
}
1531+
1532+
/// Test that orderings propagate through struct-producing projections.
1533+
///
1534+
/// When a projection creates a struct via `named_struct('a', col_a, ...)`,
1535+
/// the output should preserve the ordering of `col_a` as an ordering on
1536+
/// `get_field(col("s"), "a")`. This enables sort elimination when the
1537+
/// framework sorts by a struct field that corresponds to an already-sorted
1538+
/// input column.
1539+
#[test]
1540+
fn test_ordering_propagation_through_named_struct() -> Result<()> {
1541+
use crate::expressions::Literal;
1542+
use datafusion_common::ScalarValue;
1543+
use datafusion_functions::core::{get_field, named_struct};
1544+
1545+
let input_schema = Arc::new(Schema::new(vec![
1546+
Field::new("a", DataType::Int32, true),
1547+
Field::new("b", DataType::Int32, true),
1548+
]));
1549+
1550+
let col_a = col("a", &input_schema)?;
1551+
let col_b = col("b", &input_schema)?;
1552+
let config = Arc::new(ConfigOptions::new());
1553+
1554+
// Build: named_struct('a', col_a, 'b', col_b) AS s
1555+
let named_struct_udf = named_struct();
1556+
let named_struct_expr = Arc::new(ScalarFunctionExpr::new(
1557+
"named_struct",
1558+
named_struct_udf,
1559+
vec![
1560+
Arc::new(Literal::new(ScalarValue::Utf8(Some("a".to_string())))),
1561+
Arc::clone(&col_a),
1562+
Arc::new(Literal::new(ScalarValue::Utf8(Some("b".to_string())))),
1563+
Arc::clone(&col_b),
1564+
],
1565+
Arc::new(Field::new(
1566+
"named_struct",
1567+
DataType::Struct(
1568+
vec![
1569+
Field::new("a", DataType::Int32, true),
1570+
Field::new("b", DataType::Int32, true),
1571+
]
1572+
.into(),
1573+
),
1574+
true,
1575+
)),
1576+
Arc::clone(&config),
1577+
)) as Arc<dyn PhysicalExpr>;
1578+
1579+
// Projection: named_struct(...) AS s
1580+
let proj_exprs = vec![(named_struct_expr, "s".to_string())];
1581+
let projection_mapping = ProjectionMapping::try_new(proj_exprs, &input_schema)?;
1582+
1583+
// Input is ordered by [a ASC]
1584+
let mut input_properties = EquivalenceProperties::new(Arc::clone(&input_schema));
1585+
let sort_a = PhysicalSortExpr::new(
1586+
Arc::clone(&col_a),
1587+
SortOptions {
1588+
descending: false,
1589+
nulls_first: false,
1590+
},
1591+
);
1592+
input_properties.add_orderings([vec![sort_a]]);
1593+
1594+
// Project through the named_struct
1595+
let out_schema = output_schema(&projection_mapping, &input_schema)?;
1596+
let out_properties = input_properties.project(&projection_mapping, out_schema);
1597+
1598+
// Build the sort expression: get_field(col("s"), "a")
1599+
// This is what the framework would generate for ORDER BY s.a
1600+
let get_field_udf = get_field();
1601+
let col_s = Arc::new(Column::new("s", 0)) as Arc<dyn PhysicalExpr>;
1602+
let get_field_expr = Arc::new(ScalarFunctionExpr::new(
1603+
"get_field",
1604+
get_field_udf,
1605+
vec![
1606+
Arc::clone(&col_s),
1607+
Arc::new(Literal::new(ScalarValue::Utf8(Some("a".to_string())))),
1608+
],
1609+
Arc::new(Field::new("a", DataType::Int32, true)),
1610+
Arc::clone(&config),
1611+
)) as Arc<dyn PhysicalExpr>;
1612+
1613+
let sort_get_field_a = PhysicalSortExpr::new(
1614+
get_field_expr,
1615+
SortOptions {
1616+
descending: false,
1617+
nulls_first: false,
1618+
},
1619+
);
1620+
1621+
// The output should satisfy ordering by get_field(s, "a")
1622+
assert!(
1623+
out_properties.ordering_satisfy(vec![sort_get_field_a])?,
1624+
"Output should be ordered by get_field(s, 'a') since input is ordered by col_a"
1625+
);
1626+
1627+
Ok(())
1628+
}
15311629
}

datafusion/physical-expr/src/projection.rs

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@ use std::sync::Arc;
2222

2323
use crate::PhysicalExpr;
2424
use crate::expressions::{Column, Literal};
25+
use crate::scalar_function::ScalarFunctionExpr;
2526
use crate::utils::collect_columns;
2627

2728
use arrow::array::{RecordBatch, RecordBatchOptions};
28-
use arrow::datatypes::{Field, Schema, SchemaRef};
29+
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
2930
use datafusion_common::stats::{ColumnStatistics, Precision};
3031
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
3132
use datafusion_common::{
@@ -952,9 +953,72 @@ impl ProjectionMapping {
952953
None => Ok(Transformed::no(e)),
953954
})
954955
.data()?;
955-
map.entry(source_expr)
956+
map.entry(Arc::clone(&source_expr))
956957
.or_default()
957-
.push((target_expr, expr_idx));
958+
.push((Arc::clone(&target_expr), expr_idx));
959+
960+
// For struct-producing functions (e.g. named_struct), decompose
961+
// into field-level mapping entries so that orderings propagate
962+
// through struct projections. For example, if the projection has
963+
// `named_struct('ticker', p.ticker, ...) AS details`, this adds:
964+
// p.ticker → get_field(col("details"), "ticker")
965+
// enabling the optimizer to know that sorting by
966+
// `details.ticker` is equivalent to sorting by `p.ticker`.
967+
if let Some(func_expr) =
968+
source_expr.as_any().downcast_ref::<ScalarFunctionExpr>()
969+
{
970+
let literal_args: Vec<Option<ScalarValue>> = func_expr
971+
.args()
972+
.iter()
973+
.map(|arg| {
974+
arg.as_any()
975+
.downcast_ref::<Literal>()
976+
.map(|l| l.value().clone())
977+
})
978+
.collect();
979+
980+
if let Some(field_mapping) =
981+
func_expr.fun().struct_field_mapping(&literal_args)
982+
&& let DataType::Struct(struct_fields) = func_expr.return_type()
983+
{
984+
for (accessor_args, source_arg_idx) in &field_mapping.fields {
985+
let value_expr = Arc::clone(&func_expr.args()[*source_arg_idx]);
986+
987+
// Build accessor args: [target_col, ...field_name_literals]
988+
let mut accessor_fn_args: Vec<Arc<dyn PhysicalExpr>> =
989+
vec![Arc::clone(&target_expr)];
990+
accessor_fn_args.extend(accessor_args.iter().map(|sv| {
991+
Arc::new(Literal::new(sv.clone())) as Arc<dyn PhysicalExpr>
992+
}));
993+
994+
// Look up the field's return type from the struct schema
995+
let return_field = accessor_args
996+
.first()
997+
.and_then(|sv| sv.try_as_str().flatten())
998+
.and_then(|field_name| {
999+
struct_fields
1000+
.iter()
1001+
.find(|f| f.name() == field_name)
1002+
.cloned()
1003+
});
1004+
1005+
if let Some(return_field) = return_field {
1006+
let field_access_expr = Arc::new(ScalarFunctionExpr::new(
1007+
field_mapping.field_accessor.name(),
1008+
Arc::clone(&field_mapping.field_accessor),
1009+
accessor_fn_args,
1010+
return_field,
1011+
Arc::new(func_expr.config_options().clone()),
1012+
))
1013+
as Arc<dyn PhysicalExpr>;
1014+
1015+
map.entry(value_expr)
1016+
.or_default()
1017+
.push((field_access_expr, expr_idx));
1018+
}
1019+
}
1020+
}
1021+
}
9581022
}
9591023
Ok(Self { map })
9601024
}
@@ -1110,8 +1174,10 @@ pub(crate) mod tests {
11101174
let data_type = source.data_type(input_schema)?;
11111175
let nullable = source.nullable(input_schema)?;
11121176
for (target, _) in targets.iter() {
1177+
// Skip non-Column targets (e.g. struct field decomposition
1178+
// entries which are ScalarFunctionExpr targets).
11131179
let Some(column) = target.as_any().downcast_ref::<Column>() else {
1114-
return plan_err!("Expects to have column");
1180+
continue;
11151181
};
11161182
fields.push(Field::new(column.name(), data_type.clone(), nullable));
11171183
}

0 commit comments

Comments
 (0)