Skip to content

Commit 944bac2

Browse files
rkrishn7alamb
andauthored
feat: Propagate orderings through struct-producing projections (#21218)
## Which issue does this PR close? <!-- We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. For example `Closes #123` indicates that this PR will close issue #123. --> - Closes #21217 ## What changes are included in this PR? <!-- There is no need to duplicate the description in the issue here but it is sometimes worth providing a summary of the individual changes in this PR. --> - Adds `ScalarUDFImpl::struct_field_mapping` - Adds logic in `ProjectionMapping` to decompose struct-producing functions into their field-level mapping entries so that orderings propagate through struct projections - Adds unit tests/SLT ## Are these changes tested? Yes. ## Are there any user-facing changes? N/A --------- Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
1 parent c74ed91 commit 944bac2

File tree

6 files changed

+312
-8
lines changed

6 files changed

+312
-8
lines changed

datafusion/expr/src/lib.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,9 @@ pub use udaf::{
126126
udaf_default_schema_name, udaf_default_window_function_display_name,
127127
udaf_default_window_function_schema_name,
128128
};
129-
pub use udf::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl};
129+
pub use udf::{
130+
ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, StructFieldMapping,
131+
};
130132
pub use udwf::{LimitEffect, ReversedUDWF, WindowUDF, WindowUDFImpl};
131133
pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};
132134

datafusion/expr/src/udf.rs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,25 @@ use std::fmt::Debug;
3838
use std::hash::{Hash, Hasher};
3939
use std::sync::Arc;
4040

41+
/// Describes how a struct-producing UDF's output fields correspond to its
42+
/// input arguments. This enables the optimizer to propagate orderings
43+
/// through struct projections (e.g., so that sorting by a struct field
44+
/// can be recognized as equivalent to sorting by the source column).
45+
///
46+
/// See [`ScalarUDFImpl::struct_field_mapping`] for details.
47+
pub struct StructFieldMapping {
48+
/// The UDF used to construct field access expressions on the output.
49+
/// For example, the `get_field` UDF for accessing struct fields.
50+
pub field_accessor: Arc<ScalarUDF>,
51+
/// For each output field: the literal arguments to pass to the
52+
/// `field_accessor` UDF (after the base expression), and the index
53+
/// of the corresponding input argument that produces the field's value.
54+
///
55+
/// For `named_struct('a', col1, 'b', col2)`, this would be:
56+
/// `[(["a"], 1), (["b"], 3)]` — field `"a"` comes from arg index 1.
57+
pub fields: Vec<(Vec<ScalarValue>, usize)>,
58+
}
59+
4160
/// Logical representation of a Scalar User Defined Function.
4261
///
4362
/// A scalar function produces a single row output for each row of input. This
@@ -305,6 +324,14 @@ impl ScalarUDF {
305324
self.inner.evaluate_bounds(inputs)
306325
}
307326

327+
/// See [`ScalarUDFImpl::struct_field_mapping`] for more details.
328+
pub fn struct_field_mapping(
329+
&self,
330+
literal_args: &[Option<ScalarValue>],
331+
) -> Option<StructFieldMapping> {
332+
self.inner.struct_field_mapping(literal_args)
333+
}
334+
308335
/// Updates bounds for child expressions, given a known interval for this
309336
/// function. This is used to propagate constraints down through an expression
310337
/// tree.
@@ -961,6 +988,25 @@ pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync + Any {
961988
not_impl_err!("Function {} does not implement coerce_types", self.name())
962989
}
963990

991+
/// For struct-producing functions, return how output fields map to input
992+
/// arguments. This enables the optimizer to propagate orderings through
993+
/// struct projections.
994+
///
995+
/// `literal_args[i]` is `Some(value)` if argument `i` is a known literal,
996+
/// allowing extraction of field names from arguments like
997+
/// `named_struct('field_name', value, ...)`.
998+
///
999+
/// For example, `named_struct('a', col1, 'b', col2)` would return a
1000+
/// mapping indicating that output field `'a'` (accessed via
1001+
/// `get_field(output, 'a')`) corresponds to input argument `col1` at
1002+
/// index 1, and field `'b'` corresponds to `col2` at index 3.
1003+
fn struct_field_mapping(
1004+
&self,
1005+
_literal_args: &[Option<ScalarValue>],
1006+
) -> Option<StructFieldMapping> {
1007+
None
1008+
}
1009+
9641010
/// Returns the documentation for this Scalar UDF.
9651011
///
9661012
/// Documentation can be accessed programmatically as well as generating
@@ -1109,6 +1155,13 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl {
11091155
self.inner.propagate_constraints(interval, inputs)
11101156
}
11111157

1158+
fn struct_field_mapping(
1159+
&self,
1160+
literal_args: &[Option<ScalarValue>],
1161+
) -> Option<StructFieldMapping> {
1162+
self.inner.struct_field_mapping(literal_args)
1163+
}
1164+
11121165
fn output_ordering(&self, inputs: &[ExprProperties]) -> Result<SortProperties> {
11131166
self.inner.output_ordering(inputs)
11141167
}

datafusion/functions/src/core/named_struct.rs

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use super::getfield::GetFieldFunc;
1819
use arrow::array::StructArray;
1920
use arrow::datatypes::{DataType, Field, FieldRef, Fields};
20-
use datafusion_common::{Result, exec_err, internal_err};
21+
use datafusion_common::{Result, ScalarValue, exec_err, internal_err};
2122
use datafusion_expr::{
22-
ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs,
23+
ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF,
24+
StructFieldMapping,
2325
};
2426
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
2527
use datafusion_macros::user_doc;
@@ -174,4 +176,31 @@ impl ScalarUDFImpl for NamedStructFunc {
174176
fn documentation(&self) -> Option<&Documentation> {
175177
self.doc()
176178
}
179+
180+
fn struct_field_mapping(
181+
&self,
182+
literal_args: &[Option<ScalarValue>],
183+
) -> Option<StructFieldMapping> {
184+
if literal_args.is_empty() || !literal_args.len().is_multiple_of(2) {
185+
return None;
186+
}
187+
188+
let mut fields = Vec::with_capacity(literal_args.len() / 2);
189+
for (i, chunk) in literal_args.chunks(2).enumerate() {
190+
match chunk {
191+
[Some(ScalarValue::Utf8(Some(name))), _] => {
192+
fields.push((
193+
vec![ScalarValue::Utf8(Some(name.clone()))],
194+
i * 2 + 1, // index of the value argument
195+
));
196+
}
197+
_ => return None,
198+
}
199+
}
200+
201+
Some(StructFieldMapping {
202+
field_accessor: Arc::new(ScalarUDF::from(GetFieldFunc::new())),
203+
fields,
204+
})
205+
}
177206
}

datafusion/physical-expr/src/equivalence/properties/dependency.rs

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1564,4 +1564,102 @@ mod tests {
15641564

15651565
Ok(())
15661566
}
1567+
1568+
/// Test that orderings propagate through struct-producing projections.
1569+
///
1570+
/// When a projection creates a struct via `named_struct('a', col_a, ...)`,
1571+
/// the output should preserve the ordering of `col_a` as an ordering on
1572+
/// `get_field(col("s"), "a")`. This enables sort elimination when the
1573+
/// framework sorts by a struct field that corresponds to an already-sorted
1574+
/// input column.
1575+
#[test]
1576+
fn test_ordering_propagation_through_named_struct() -> Result<()> {
1577+
use crate::expressions::Literal;
1578+
use datafusion_common::ScalarValue;
1579+
use datafusion_functions::core::{get_field, named_struct};
1580+
1581+
let input_schema = Arc::new(Schema::new(vec![
1582+
Field::new("a", DataType::Int32, true),
1583+
Field::new("b", DataType::Int32, true),
1584+
]));
1585+
1586+
let col_a = col("a", &input_schema)?;
1587+
let col_b = col("b", &input_schema)?;
1588+
let config = Arc::new(ConfigOptions::new());
1589+
1590+
// Build: named_struct('a', col_a, 'b', col_b) AS s
1591+
let named_struct_udf = named_struct();
1592+
let named_struct_expr = Arc::new(ScalarFunctionExpr::new(
1593+
"named_struct",
1594+
named_struct_udf,
1595+
vec![
1596+
Arc::new(Literal::new(ScalarValue::Utf8(Some("a".to_string())))),
1597+
Arc::clone(&col_a),
1598+
Arc::new(Literal::new(ScalarValue::Utf8(Some("b".to_string())))),
1599+
Arc::clone(&col_b),
1600+
],
1601+
Arc::new(Field::new(
1602+
"named_struct",
1603+
DataType::Struct(
1604+
vec![
1605+
Field::new("a", DataType::Int32, true),
1606+
Field::new("b", DataType::Int32, true),
1607+
]
1608+
.into(),
1609+
),
1610+
true,
1611+
)),
1612+
Arc::clone(&config),
1613+
)) as Arc<dyn PhysicalExpr>;
1614+
1615+
// Projection: named_struct(...) AS s
1616+
let proj_exprs = vec![(named_struct_expr, "s".to_string())];
1617+
let projection_mapping = ProjectionMapping::try_new(proj_exprs, &input_schema)?;
1618+
1619+
// Input is ordered by [a ASC]
1620+
let mut input_properties = EquivalenceProperties::new(Arc::clone(&input_schema));
1621+
let sort_a = PhysicalSortExpr::new(
1622+
Arc::clone(&col_a),
1623+
SortOptions {
1624+
descending: false,
1625+
nulls_first: false,
1626+
},
1627+
);
1628+
input_properties.add_orderings([vec![sort_a]]);
1629+
1630+
// Project through the named_struct
1631+
let out_schema = output_schema(&projection_mapping, &input_schema)?;
1632+
let out_properties = input_properties.project(&projection_mapping, out_schema);
1633+
1634+
// Build the sort expression: get_field(col("s"), "a")
1635+
// This is what the framework would generate for ORDER BY s.a
1636+
let get_field_udf = get_field();
1637+
let col_s = Arc::new(Column::new("s", 0)) as Arc<dyn PhysicalExpr>;
1638+
let get_field_expr = Arc::new(ScalarFunctionExpr::new(
1639+
"get_field",
1640+
get_field_udf,
1641+
vec![
1642+
Arc::clone(&col_s),
1643+
Arc::new(Literal::new(ScalarValue::Utf8(Some("a".to_string())))),
1644+
],
1645+
Arc::new(Field::new("a", DataType::Int32, true)),
1646+
Arc::clone(&config),
1647+
)) as Arc<dyn PhysicalExpr>;
1648+
1649+
let sort_get_field_a = PhysicalSortExpr::new(
1650+
get_field_expr,
1651+
SortOptions {
1652+
descending: false,
1653+
nulls_first: false,
1654+
},
1655+
);
1656+
1657+
// The output should satisfy ordering by get_field(s, "a")
1658+
assert!(
1659+
out_properties.ordering_satisfy(vec![sort_get_field_a])?,
1660+
"Output should be ordered by get_field(s, 'a') since input is ordered by col_a"
1661+
);
1662+
1663+
Ok(())
1664+
}
15671665
}

datafusion/physical-expr/src/projection.rs

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@ use std::sync::Arc;
2222

2323
use crate::PhysicalExpr;
2424
use crate::expressions::{Column, Literal};
25+
use crate::scalar_function::ScalarFunctionExpr;
2526
use crate::utils::collect_columns;
2627

2728
use arrow::array::{RecordBatch, RecordBatchOptions};
28-
use arrow::datatypes::{Field, Schema, SchemaRef};
29+
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
2930
use datafusion_common::stats::{ColumnStatistics, Precision};
3031
use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
3132
use datafusion_common::{
@@ -1063,9 +1064,72 @@ impl ProjectionMapping {
10631064
None => Ok(Transformed::no(e)),
10641065
})
10651066
.data()?;
1066-
map.entry(source_expr)
1067+
map.entry(Arc::clone(&source_expr))
10671068
.or_default()
1068-
.push((target_expr, expr_idx));
1069+
.push((Arc::clone(&target_expr), expr_idx));
1070+
1071+
// For struct-producing functions (e.g. named_struct), decompose
1072+
// into field-level mapping entries so that orderings propagate
1073+
// through struct projections. For example, if the projection has
1074+
// `named_struct('ticker', p.ticker, ...) AS details`, this adds:
1075+
// p.ticker → get_field(col("details"), "ticker")
1076+
// enabling the optimizer to know that sorting by
1077+
// `details.ticker` is equivalent to sorting by `p.ticker`.
1078+
if let Some(func_expr) =
1079+
source_expr.as_any().downcast_ref::<ScalarFunctionExpr>()
1080+
{
1081+
let literal_args: Vec<Option<ScalarValue>> = func_expr
1082+
.args()
1083+
.iter()
1084+
.map(|arg| {
1085+
arg.as_any()
1086+
.downcast_ref::<Literal>()
1087+
.map(|l| l.value().clone())
1088+
})
1089+
.collect();
1090+
1091+
if let Some(field_mapping) =
1092+
func_expr.fun().struct_field_mapping(&literal_args)
1093+
&& let DataType::Struct(struct_fields) = func_expr.return_type()
1094+
{
1095+
for (accessor_args, source_arg_idx) in &field_mapping.fields {
1096+
let value_expr = Arc::clone(&func_expr.args()[*source_arg_idx]);
1097+
1098+
// Build accessor args: [target_col, ...field_name_literals]
1099+
let mut accessor_fn_args: Vec<Arc<dyn PhysicalExpr>> =
1100+
vec![Arc::clone(&target_expr)];
1101+
accessor_fn_args.extend(accessor_args.iter().map(|sv| {
1102+
Arc::new(Literal::new(sv.clone())) as Arc<dyn PhysicalExpr>
1103+
}));
1104+
1105+
// Look up the field's return type from the struct schema
1106+
let return_field = accessor_args
1107+
.first()
1108+
.and_then(|sv| sv.try_as_str().flatten())
1109+
.and_then(|field_name| {
1110+
struct_fields
1111+
.iter()
1112+
.find(|f| f.name() == field_name)
1113+
.cloned()
1114+
});
1115+
1116+
if let Some(return_field) = return_field {
1117+
let field_access_expr = Arc::new(ScalarFunctionExpr::new(
1118+
field_mapping.field_accessor.name(),
1119+
Arc::clone(&field_mapping.field_accessor),
1120+
accessor_fn_args,
1121+
return_field,
1122+
Arc::new(func_expr.config_options().clone()),
1123+
))
1124+
as Arc<dyn PhysicalExpr>;
1125+
1126+
map.entry(value_expr)
1127+
.or_default()
1128+
.push((field_access_expr, expr_idx));
1129+
}
1130+
}
1131+
}
1132+
}
10691133
}
10701134
Ok(Self { map })
10711135
}
@@ -1219,8 +1283,10 @@ pub(crate) mod tests {
12191283
let data_type = source.data_type(input_schema)?;
12201284
let nullable = source.nullable(input_schema)?;
12211285
for (target, _) in targets.iter() {
1286+
// Skip non-Column targets (e.g. struct field decomposition
1287+
// entries which are ScalarFunctionExpr targets).
12221288
let Some(column) = target.as_any().downcast_ref::<Column>() else {
1223-
return plan_err!("Expects to have column");
1289+
continue;
12241290
};
12251291
fields.push(Field::new(column.name(), data_type.clone(), nullable));
12261292
}

0 commit comments

Comments
 (0)