|
18 | 18 | use std::ops::Deref; |
19 | 19 | use std::sync::Arc; |
20 | 20 |
|
21 | | -use crate::expressions::Column; |
| 21 | +use crate::expressions::{Column, Literal}; |
| 22 | +use crate::scalar_function::ScalarFunctionExpr; |
22 | 23 | use crate::utils::collect_columns; |
23 | 24 | use crate::PhysicalExpr; |
24 | 25 |
|
25 | | -use arrow::datatypes::{Field, Schema, SchemaRef}; |
| 26 | +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; |
26 | 27 | use datafusion_common::stats::{ColumnStatistics, Precision}; |
27 | 28 | use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; |
28 | | -use datafusion_common::{internal_datafusion_err, internal_err, plan_err, Result}; |
| 29 | +use datafusion_common::{ScalarValue, internal_datafusion_err, internal_err, plan_err, Result}; |
29 | 30 |
|
30 | 31 | use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; |
31 | 32 | use indexmap::IndexMap; |
@@ -638,9 +639,81 @@ impl ProjectionMapping { |
638 | 639 | None => Ok(Transformed::no(e)), |
639 | 640 | }) |
640 | 641 | .data()?; |
641 | | - map.entry(source_expr) |
| 642 | + map.entry(source_expr.clone()) |
642 | 643 | .or_default() |
643 | | - .push((target_expr, expr_idx)); |
| 644 | + .push((Arc::clone(&target_expr), expr_idx)); |
| 645 | + |
| 646 | + // For struct-producing functions (e.g. named_struct), decompose |
| 647 | + // into field-level mapping entries so that orderings propagate |
| 648 | + // through struct projections. For example, if the projection has |
| 649 | + // `named_struct('ticker', p.ticker, ...) AS details`, this adds: |
| 650 | + // p.ticker → get_field(col("details"), "ticker") |
| 651 | + // enabling the optimizer to know that sorting by |
| 652 | + // `details.ticker` is equivalent to sorting by `p.ticker`. |
| 653 | + if let Some(func_expr) = |
| 654 | + source_expr.as_any().downcast_ref::<ScalarFunctionExpr>() |
| 655 | + { |
| 656 | + let literal_args: Vec<Option<ScalarValue>> = func_expr |
| 657 | + .args() |
| 658 | + .iter() |
| 659 | + .map(|arg| { |
| 660 | + arg.as_any() |
| 661 | + .downcast_ref::<Literal>() |
| 662 | + .map(|l| l.value().clone()) |
| 663 | + }) |
| 664 | + .collect(); |
| 665 | + |
| 666 | + if let Some(field_mapping) = |
| 667 | + func_expr.fun().struct_field_mapping(&literal_args) |
| 668 | + { |
| 669 | + if let DataType::Struct(struct_fields) = |
| 670 | + func_expr.return_type() |
| 671 | + { |
| 672 | + for (accessor_args, source_arg_idx) in &field_mapping.fields |
| 673 | + { |
| 674 | + let value_expr = |
| 675 | + func_expr.args()[*source_arg_idx].clone(); |
| 676 | + |
| 677 | + // Build accessor args: [target_col, ...field_name_literals] |
| 678 | + let mut accessor_fn_args: Vec<Arc<dyn PhysicalExpr>> = |
| 679 | + vec![Arc::clone(&target_expr)]; |
| 680 | + accessor_fn_args.extend(accessor_args.iter().map( |
| 681 | + |sv| { |
| 682 | + Arc::new(Literal::new(sv.clone())) |
| 683 | + as Arc<dyn PhysicalExpr> |
| 684 | + }, |
| 685 | + )); |
| 686 | + |
| 687 | + // Look up the field's return type from the struct schema |
| 688 | + let return_field = accessor_args |
| 689 | + .first() |
| 690 | + .and_then(|sv| sv.try_as_str().flatten()) |
| 691 | + .and_then(|field_name| { |
| 692 | + struct_fields |
| 693 | + .iter() |
| 694 | + .find(|f| f.name() == field_name) |
| 695 | + .cloned() |
| 696 | + }); |
| 697 | + |
| 698 | + if let Some(return_field) = return_field { |
| 699 | + let field_access_expr = |
| 700 | + Arc::new(ScalarFunctionExpr::new( |
| 701 | + field_mapping.field_accessor.name(), |
| 702 | + Arc::clone(&field_mapping.field_accessor), |
| 703 | + accessor_fn_args, |
| 704 | + return_field, |
| 705 | + Arc::new(func_expr.config_options().clone()), |
| 706 | + )) |
| 707 | + as Arc<dyn PhysicalExpr>; |
| 708 | + |
| 709 | + map.entry(value_expr) |
| 710 | + .or_default() |
| 711 | + .push((field_access_expr, expr_idx)); |
| 712 | + } |
| 713 | + } |
| 714 | + } |
| 715 | + } |
| 716 | + } |
644 | 717 | } |
645 | 718 | Ok(Self { map }) |
646 | 719 | } |
@@ -795,8 +868,10 @@ pub(crate) mod tests { |
795 | 868 | let data_type = source.data_type(input_schema)?; |
796 | 869 | let nullable = source.nullable(input_schema)?; |
797 | 870 | for (target, _) in targets.iter() { |
| 871 | + // Skip non-Column targets (e.g. struct field decomposition |
| 872 | + // entries which are ScalarFunctionExpr targets). |
798 | 873 | let Some(column) = target.as_any().downcast_ref::<Column>() else { |
799 | | - return plan_err!("Expects to have column"); |
| 874 | + continue; |
800 | 875 | }; |
801 | 876 | fields.push(Field::new(column.name(), data_type.clone(), nullable)); |
802 | 877 | } |
|
0 commit comments