|
20 | 20 | use std::ops::Deref; |
21 | 21 | use std::sync::Arc; |
22 | 22 |
|
| 23 | +use crate::scalar_function::ScalarFunctionExpr; |
23 | 24 | use crate::PhysicalExpr; |
24 | 25 | use crate::expressions::{Column, Literal}; |
25 | 26 | use crate::utils::collect_columns; |
26 | 27 |
|
27 | 28 | use arrow::array::{RecordBatch, RecordBatchOptions}; |
28 | | -use arrow::datatypes::{Field, Schema, SchemaRef}; |
| 29 | +use arrow::datatypes::{Field, Schema, SchemaRef, DataType}; |
29 | 30 | use datafusion_common::stats::{ColumnStatistics, Precision}; |
30 | 31 | use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; |
31 | 32 | use datafusion_common::{ |
@@ -952,9 +953,73 @@ impl ProjectionMapping { |
952 | 953 | None => Ok(Transformed::no(e)), |
953 | 954 | }) |
954 | 955 | .data()?; |
955 | | - map.entry(source_expr) |
| 956 | + map.entry(Arc::clone(&source_expr)) |
956 | 957 | .or_default() |
957 | | - .push((target_expr, expr_idx)); |
| 958 | + .push((Arc::clone(&target_expr), expr_idx)); |
| 959 | + |
| 960 | + // For struct-producing functions (e.g. named_struct), decompose |
| 961 | + // into field-level mapping entries so that orderings propagate |
| 962 | + // through struct projections. For example, if the projection has |
| 963 | + // `named_struct('ticker', p.ticker, ...) AS details`, this adds: |
| 964 | + // p.ticker → get_field(col("details"), "ticker") |
| 965 | + // enabling the optimizer to know that sorting by |
| 966 | + // `details.ticker` is equivalent to sorting by `p.ticker`. |
| 967 | + if let Some(func_expr) = |
| 968 | + source_expr.as_any().downcast_ref::<ScalarFunctionExpr>() |
| 969 | + { |
| 970 | + let literal_args: Vec<Option<ScalarValue>> = func_expr |
| 971 | + .args() |
| 972 | + .iter() |
| 973 | + .map(|arg| { |
| 974 | + arg.as_any() |
| 975 | + .downcast_ref::<Literal>() |
| 976 | + .map(|l| l.value().clone()) |
| 977 | + }) |
| 978 | + .collect(); |
| 979 | + |
| 980 | + if let Some(field_mapping) = |
| 981 | + func_expr.fun().struct_field_mapping(&literal_args) |
| 982 | + && let DataType::Struct(struct_fields) = func_expr.return_type() { |
| 983 | + for (accessor_args, source_arg_idx) in &field_mapping.fields { |
| 984 | + let value_expr = |
| 985 | + Arc::clone(&func_expr.args()[*source_arg_idx]); |
| 986 | + |
| 987 | + // Build accessor args: [target_col, ...field_name_literals] |
| 988 | + let mut accessor_fn_args: Vec<Arc<dyn PhysicalExpr>> = |
| 989 | + vec![Arc::clone(&target_expr)]; |
| 990 | + accessor_fn_args.extend(accessor_args.iter().map(|sv| { |
| 991 | + Arc::new(Literal::new(sv.clone())) |
| 992 | + as Arc<dyn PhysicalExpr> |
| 993 | + })); |
| 994 | + |
| 995 | + // Look up the field's return type from the struct schema |
| 996 | + let return_field = accessor_args |
| 997 | + .first() |
| 998 | + .and_then(|sv| sv.try_as_str().flatten()) |
| 999 | + .and_then(|field_name| { |
| 1000 | + struct_fields |
| 1001 | + .iter() |
| 1002 | + .find(|f| f.name() == field_name) |
| 1003 | + .cloned() |
| 1004 | + }); |
| 1005 | + |
| 1006 | + if let Some(return_field) = return_field { |
| 1007 | + let field_access_expr = Arc::new(ScalarFunctionExpr::new( |
| 1008 | + field_mapping.field_accessor.name(), |
| 1009 | + Arc::clone(&field_mapping.field_accessor), |
| 1010 | + accessor_fn_args, |
| 1011 | + return_field, |
| 1012 | + Arc::new(func_expr.config_options().clone()), |
| 1013 | + )) |
| 1014 | + as Arc<dyn PhysicalExpr>; |
| 1015 | + |
| 1016 | + map.entry(value_expr) |
| 1017 | + .or_default() |
| 1018 | + .push((field_access_expr, expr_idx)); |
| 1019 | + } |
| 1020 | + } |
| 1021 | + } |
| 1022 | + } |
958 | 1023 | } |
959 | 1024 | Ok(Self { map }) |
960 | 1025 | } |
@@ -1110,8 +1175,10 @@ pub(crate) mod tests { |
1110 | 1175 | let data_type = source.data_type(input_schema)?; |
1111 | 1176 | let nullable = source.nullable(input_schema)?; |
1112 | 1177 | for (target, _) in targets.iter() { |
| 1178 | + // Skip non-Column targets (e.g. struct field decomposition |
| 1179 | + // entries which are ScalarFunctionExpr targets). |
1113 | 1180 | let Some(column) = target.as_any().downcast_ref::<Column>() else { |
1114 | | - return plan_err!("Expects to have column"); |
| 1181 | + continue; |
1115 | 1182 | }; |
1116 | 1183 | fields.push(Field::new(column.name(), data_type.clone(), nullable)); |
1117 | 1184 | } |
|
0 commit comments