1515// specific language governing permissions and limitations
1616// under the License.
1717
18- use std:: sync:: Arc ;
18+ use std:: sync:: { Arc , OnceLock } ;
1919
2020use arrow:: array:: {
2121 Array , BooleanArray , Capacities , MutableArrayData , Scalar , cast:: AsArray , make_array,
@@ -37,6 +37,9 @@ use datafusion_expr::{
3737} ;
3838use datafusion_macros:: user_doc;
3939
40+ use super :: named_struct:: NamedStructFunc ;
41+ use super :: r#struct:: StructFunc ;
42+
4043#[ user_doc(
4144 doc_section( label = "Other Functions" ) ,
4245 description = r#"Returns a field within a map or a struct with the given key.
@@ -249,6 +252,120 @@ fn extract_single_field(base: ColumnarValue, name: ScalarValue) -> Result<Column
249252 }
250253}
251254
255+ /// The shared `get_field` UDF, reused whenever simplification needs to build a
256+ /// fresh `get_field` node (e.g. re-wrapping the remaining access path).
257+ fn get_field_udf ( ) -> Arc < ScalarUDF > {
258+ static GET_FIELD_UDF : OnceLock < Arc < ScalarUDF > > = OnceLock :: new ( ) ;
259+ Arc :: clone (
260+ GET_FIELD_UDF
261+ . get_or_init ( || Arc :: new ( ScalarUDF :: new_from_impl ( GetFieldFunc :: new ( ) ) ) ) ,
262+ )
263+ }
264+
265+ /// Try to simplify a `get_field` call whose base is an inline struct
266+ /// constructor by resolving the field access at plan time.
267+ ///
268+ /// Handles both struct constructors:
269+ /// * `named_struct('a', x, 'b', y)` — fields are looked up by name.
270+ /// * `struct(x, y)` — fields are positional and named `c0`, `c1`, ...
271+ ///
272+ /// For example:
273+ /// * `get_field(named_struct('min', a, 'max', b), 'max')` => `b`
274+ /// * `get_field(struct(a, b), 'c1')` => `b`
275+ ///
276+ /// `args` is the (already flattened) argument list of the `get_field` call:
277+ /// `[base, field_name, rest_of_path...]`. When extra path elements remain
278+ /// after resolving the first one (`get_field(named_struct('s', inner), 's', 'k')`),
279+ /// the resolved value is re-wrapped in a `get_field` call for the remaining
280+ /// path so the simplifier can recurse into it on the next pass.
281+ ///
282+ /// Returns `None` — leaving the expression untouched — whenever the rewrite
283+ /// cannot be proven safe, e.g. a non-literal field name, a `named_struct`
284+ /// with a non-literal field name (which might shadow the requested field at
285+ /// runtime), or a field the constructor does not produce.
286+ ///
287+ /// Replacing the access with the selected field expression drops the
288+ /// expressions for the other (unaccessed) fields, so they are no longer
289+ /// evaluated — e.g. `get_field(named_struct('a', 1/0, 'b', x), 'b')` becomes
290+ /// `x` and the `1/0` is never evaluated. This is intentional and matches the
291+ /// optimizer's contract for immutable expressions: a simplification may drop
292+ /// sub-expressions whose value is not observed.
293+ fn simplify_get_field_over_struct_constructor ( args : & [ Expr ] ) -> Option < Expr > {
294+ let [ base, field_name, rest @ ..] = args else {
295+ return None ;
296+ } ;
297+
298+ // The accessed field name must be a non-empty string literal.
299+ let Expr :: Literal ( field_name, _) = field_name else {
300+ return None ;
301+ } ;
302+ let field_name = field_name
303+ . try_as_str ( )
304+ . flatten ( )
305+ . filter ( |s| !s. is_empty ( ) ) ?;
306+
307+ let Expr :: ScalarFunction ( ScalarFunction {
308+ func,
309+ args : ctor_args,
310+ } ) = base
311+ else {
312+ return None ;
313+ } ;
314+
315+ let value = if func. inner ( ) . is :: < NamedStructFunc > ( ) {
316+ // named_struct(name1, value1, name2, value2, ...)
317+ if !ctor_args. len ( ) . is_multiple_of ( 2 ) {
318+ return None ;
319+ }
320+ let mut matched = None ;
321+ for pair in ctor_args. chunks_exact ( 2 ) {
322+ // Every name must be a literal string: a non-literal name appearing
323+ // *before* the first match could evaluate to `field_name` at runtime
324+ // and become the real first match (Arrow's `column_by_name` returns
325+ // the first match), so we cannot resolve the access.
326+ //
327+ // We conservatively bail on *any* non-literal name. Once a literal
328+ // match has been found, a later non-literal name is in fact harmless
329+ // — it can never precede the first match — so bailing there is a
330+ // deliberate approximation we accept to keep this check simple, not a
331+ // correctness requirement.
332+ let Expr :: Literal ( name, _) = & pair[ 0 ] else {
333+ return None ;
334+ } ;
335+ let name = name. try_as_str ( ) . flatten ( ) ?;
336+ // `column_by_name` resolves to the first match, so do the same.
337+ if matched. is_none ( ) && name == field_name {
338+ matched = Some ( & pair[ 1 ] ) ;
339+ }
340+ }
341+ matched?. clone ( )
342+ } else if func. inner ( ) . is :: < StructFunc > ( ) {
343+ // struct(value0, value1, ...) produces fields named c0, c1, ...
344+ let index: usize = field_name. strip_prefix ( 'c' ) ?. parse ( ) . ok ( ) ?;
345+ // Reject non-canonical spellings (e.g. "c01") that name no real field.
346+ if format ! ( "c{index}" ) != field_name {
347+ return None ;
348+ }
349+ ctor_args. get ( index) ?. clone ( )
350+ } else {
351+ return None ;
352+ } ;
353+
354+ if rest. is_empty ( ) {
355+ return Some ( value) ;
356+ }
357+
358+ // Remaining path elements: re-wrap as get_field(value, rest...) and let
359+ // the simplifier resolve the rest on a subsequent pass.
360+ let mut new_args = Vec :: with_capacity ( rest. len ( ) + 1 ) ;
361+ new_args. push ( value) ;
362+ new_args. extend_from_slice ( rest) ;
363+ Some ( Expr :: ScalarFunction ( ScalarFunction :: new_udf (
364+ get_field_udf ( ) ,
365+ new_args,
366+ ) ) )
367+ }
368+
252369impl GetFieldFunc {
253370 pub fn new ( ) -> Self {
254371 Self {
@@ -479,14 +596,12 @@ impl ScalarUDFImpl for GetFieldFunc {
479596
480597 // Flatten all nested get_field calls in a single pass
481598 // Pattern: get_field(get_field(get_field(base, a), b), c) => get_field(base, a, b, c)
482-
483- // Collect path arguments from all nested levels
484- let mut path_args_stack = Vec :: new ( ) ;
599+ //
600+ // `path_args_stack` collects each level's field-name arguments,
601+ // outermost first; it is reversed below to restore access order.
602+ let mut path_args_stack = vec ! [ & args[ 1 ..] ] ;
485603 let mut current_expr = & args[ 0 ] ;
486604
487- // Push the outermost path arguments first
488- path_args_stack. push ( & args[ 1 ..] ) ;
489-
490605 // Walk down the chain of nested get_field calls
491606 let base_expr = loop {
492607 if let Expr :: ScalarFunction ( ScalarFunction {
@@ -506,28 +621,30 @@ impl ScalarUDFImpl for GetFieldFunc {
506621 break current_expr;
507622 } ;
508623
509- // If no nested get_field calls were found, return original
510- if path_args_stack. len ( ) == args. len ( ) - 1 {
511- return Ok ( ExprSimplifyResult :: Original ( args) ) ;
512- }
624+ // Whether any nested get_field calls were collapsed above.
625+ let did_flatten = path_args_stack. len ( ) > 1 ;
513626
514- // If we found any nested get_field calls, flatten them
515- // Build merged args: [base, ...all_path_args_in_correct_order]
627+ // Build merged args: [base, ...all path args in access order].
628+ // The stack holds path slices outermost-first, so iterate in reverse.
516629 let mut merged_args = vec ! [ base_expr. clone( ) ] ;
517-
518- // Add path args in reverse order (innermost to outermost)
519- // Stack is: [outermost_paths, ..., innermost_paths]
520- // We want: [base, innermost_paths, ..., outermost_paths]
521630 for path_slice in path_args_stack. iter ( ) . rev ( ) {
522631 merged_args. extend_from_slice ( path_slice) ;
523632 }
524633
525- Ok ( ExprSimplifyResult :: Simplified ( Expr :: ScalarFunction (
526- ScalarFunction :: new_udf (
527- Arc :: new ( ScalarUDF :: new_from_impl ( GetFieldFunc :: new ( ) ) ) ,
528- merged_args,
529- ) ,
530- ) ) )
634+ // Resolve field accesses against an inline struct constructor:
635+ // get_field(named_struct('min', a, 'max', b), 'max') => b
636+ if let Some ( simplified) = simplify_get_field_over_struct_constructor ( & merged_args)
637+ {
638+ return Ok ( ExprSimplifyResult :: Simplified ( simplified) ) ;
639+ }
640+
641+ if did_flatten {
642+ return Ok ( ExprSimplifyResult :: Simplified ( Expr :: ScalarFunction (
643+ ScalarFunction :: new_udf ( get_field_udf ( ) , merged_args) ,
644+ ) ) ) ;
645+ }
646+
647+ Ok ( ExprSimplifyResult :: Original ( args) )
531648 }
532649
533650 fn coerce_types ( & self , arg_types : & [ DataType ] ) -> Result < Vec < DataType > > {
@@ -828,4 +945,187 @@ mod tests {
828945 let args = vec ! [ ExpressionPlacement :: Literal , ExpressionPlacement :: Literal ] ;
829946 assert_eq ! ( func. placement( & args) , ExpressionPlacement :: KeepInPlace ) ;
830947 }
948+
949+ // --- get_field over struct constructor simplification --------------------
950+
951+ use datafusion_common:: Column ;
952+ use datafusion_expr:: simplify:: SimplifyContext ;
953+
954+ /// A non-empty string literal expression.
955+ fn lit_str ( s : & str ) -> Expr {
956+ Expr :: Literal ( ScalarValue :: Utf8 ( Some ( s. to_string ( ) ) ) , None )
957+ }
958+
959+ /// A column reference expression.
960+ fn col ( name : & str ) -> Expr {
961+ Expr :: Column ( Column :: from_name ( name) )
962+ }
963+
964+ fn scalar_fn ( udf : ScalarUDF , args : Vec < Expr > ) -> Expr {
965+ Expr :: ScalarFunction ( ScalarFunction :: new_udf ( Arc :: new ( udf) , args) )
966+ }
967+
968+ /// `named_struct(name1, value1, name2, value2, ...)`.
969+ fn named_struct ( pairs : Vec < ( & str , Expr ) > ) -> Expr {
970+ let args = pairs
971+ . into_iter ( )
972+ . flat_map ( |( name, value) | [ lit_str ( name) , value] )
973+ . collect ( ) ;
974+ scalar_fn ( ScalarUDF :: new_from_impl ( NamedStructFunc :: new ( ) ) , args)
975+ }
976+
977+ /// `struct(value0, value1, ...)`.
978+ fn struct_fn ( values : Vec < Expr > ) -> Expr {
979+ scalar_fn ( ScalarUDF :: new_from_impl ( StructFunc :: new ( ) ) , values)
980+ }
981+
982+ /// `get_field(args...)`.
983+ fn get_field ( args : Vec < Expr > ) -> Expr {
984+ scalar_fn ( ScalarUDF :: new_from_impl ( GetFieldFunc :: new ( ) ) , args)
985+ }
986+
987+ /// Run `GetFieldFunc::simplify` once and return the rewritten expression,
988+ /// panicking if the input was left unchanged.
989+ fn simplified ( args : Vec < Expr > ) -> Expr {
990+ match GetFieldFunc :: new ( )
991+ . simplify ( args, & SimplifyContext :: default ( ) )
992+ . unwrap ( )
993+ {
994+ ExprSimplifyResult :: Simplified ( expr) => expr,
995+ ExprSimplifyResult :: Original ( args) => {
996+ panic ! ( "expected the expression to be simplified, got {args:?}" )
997+ }
998+ }
999+ }
1000+
1001+ /// Assert that `GetFieldFunc::simplify` leaves the arguments unchanged.
1002+ fn assert_not_simplified ( args : Vec < Expr > ) {
1003+ match GetFieldFunc :: new ( )
1004+ . simplify ( args. clone ( ) , & SimplifyContext :: default ( ) )
1005+ . unwrap ( )
1006+ {
1007+ ExprSimplifyResult :: Original ( unchanged) => assert_eq ! ( unchanged, args) ,
1008+ ExprSimplifyResult :: Simplified ( expr) => {
1009+ panic ! ( "expected no simplification, got {expr:?}" )
1010+ }
1011+ }
1012+ }
1013+
1014+ #[ test]
1015+ fn simplify_get_field_named_struct_returns_matching_value ( ) {
1016+ // get_field(named_struct('min', a, 'max', b), 'max') => b
1017+ let args = vec ! [
1018+ named_struct( vec![ ( "min" , col( "a" ) ) , ( "max" , col( "b" ) ) ] ) ,
1019+ lit_str( "max" ) ,
1020+ ] ;
1021+ assert_eq ! ( simplified( args) , col( "b" ) ) ;
1022+ }
1023+
1024+ #[ test]
1025+ fn simplify_get_field_named_struct_first_field ( ) {
1026+ // get_field(named_struct('min', a, 'max', b), 'min') => a
1027+ let args = vec ! [
1028+ named_struct( vec![ ( "min" , col( "a" ) ) , ( "max" , col( "b" ) ) ] ) ,
1029+ lit_str( "min" ) ,
1030+ ] ;
1031+ assert_eq ! ( simplified( args) , col( "a" ) ) ;
1032+ }
1033+
1034+ #[ test]
1035+ fn simplify_get_field_named_struct_duplicate_names_picks_first ( ) {
1036+ // Arrow's `column_by_name` resolves to the first match; mirror that.
1037+ let args = vec ! [
1038+ named_struct( vec![ ( "k" , col( "a" ) ) , ( "k" , col( "b" ) ) ] ) ,
1039+ lit_str( "k" ) ,
1040+ ] ;
1041+ assert_eq ! ( simplified( args) , col( "a" ) ) ;
1042+ }
1043+
1044+ #[ test]
1045+ fn simplify_get_field_struct_positional ( ) {
1046+ // get_field(struct(a, b), 'c1') => b
1047+ let args = vec ! [ struct_fn( vec![ col( "a" ) , col( "b" ) ] ) , lit_str( "c1" ) ] ;
1048+ assert_eq ! ( simplified( args) , col( "b" ) ) ;
1049+ }
1050+
1051+ #[ test]
1052+ fn simplify_get_field_nested_named_struct ( ) {
1053+ // get_field(named_struct('s', named_struct('k', x)), 's', 'k')
1054+ // => get_field(named_struct('k', x), 'k') (first pass)
1055+ // => x (second pass)
1056+ let args = vec ! [
1057+ named_struct( vec![ ( "s" , named_struct( vec![ ( "k" , col( "x" ) ) ] ) ) ] ) ,
1058+ lit_str( "s" ) ,
1059+ lit_str( "k" ) ,
1060+ ] ;
1061+ let first_pass = simplified ( args) ;
1062+ let Expr :: ScalarFunction ( ScalarFunction { args, .. } ) = first_pass else {
1063+ panic ! ( "expected a get_field call after the first pass" )
1064+ } ;
1065+ assert_eq ! ( simplified( args) , col( "x" ) ) ;
1066+ }
1067+
1068+ #[ test]
1069+ fn simplify_get_field_flattens_then_resolves_named_struct ( ) {
1070+ // get_field(get_field(named_struct('s', named_struct('k', x)), 's'), 'k')
1071+ // flattens to get_field(named_struct(...), 's', 'k') and resolves 's'.
1072+ let args = vec ! [
1073+ get_field( vec![
1074+ named_struct( vec![ ( "s" , named_struct( vec![ ( "k" , col( "x" ) ) ] ) ) ] ) ,
1075+ lit_str( "s" ) ,
1076+ ] ) ,
1077+ lit_str( "k" ) ,
1078+ ] ;
1079+ let expected = get_field ( vec ! [ named_struct( vec![ ( "k" , col( "x" ) ) ] ) , lit_str( "k" ) ] ) ;
1080+ assert_eq ! ( simplified( args) , expected) ;
1081+ }
1082+
1083+ #[ test]
1084+ fn simplify_get_field_dynamic_field_name_left_alone ( ) {
1085+ // A non-literal field name cannot be resolved at plan time.
1086+ let args = vec ! [ named_struct( vec![ ( "a" , col( "x" ) ) ] ) , col( "field_name" ) ] ;
1087+ assert_not_simplified ( args) ;
1088+ }
1089+
1090+ #[ test]
1091+ fn simplify_get_field_null_field_name_left_alone ( ) {
1092+ // A NULL string literal field name resolves to no field, so the
1093+ // `try_as_str().flatten()` guard must leave the expression untouched.
1094+ let null_field_name = Expr :: Literal ( ScalarValue :: Utf8 ( None ) , None ) ;
1095+ let args = vec ! [ named_struct( vec![ ( "a" , col( "x" ) ) ] ) , null_field_name] ;
1096+ assert_not_simplified ( args) ;
1097+ }
1098+
1099+ #[ test]
1100+ fn simplify_get_field_dynamic_struct_name_left_alone ( ) {
1101+ // A non-literal name inside named_struct could shadow the requested
1102+ // field at runtime, so the rewrite must bail out entirely.
1103+ let named_struct_with_dynamic_name = scalar_fn (
1104+ ScalarUDF :: new_from_impl ( NamedStructFunc :: new ( ) ) ,
1105+ vec ! [ col( "dynamic_name" ) , col( "x" ) ] ,
1106+ ) ;
1107+ let args = vec ! [ named_struct_with_dynamic_name, lit_str( "a" ) ] ;
1108+ assert_not_simplified ( args) ;
1109+ }
1110+
1111+ #[ test]
1112+ fn simplify_get_field_missing_field_left_alone ( ) {
1113+ // The named_struct does not produce field 'missing'.
1114+ let args = vec ! [ named_struct( vec![ ( "a" , col( "x" ) ) ] ) , lit_str( "missing" ) ] ;
1115+ assert_not_simplified ( args) ;
1116+ }
1117+
1118+ #[ test]
1119+ fn simplify_get_field_non_canonical_struct_field_left_alone ( ) {
1120+ // 'c01' is not a real field name produced by `struct(...)`.
1121+ let args = vec ! [ struct_fn( vec![ col( "a" ) , col( "b" ) ] ) , lit_str( "c01" ) ] ;
1122+ assert_not_simplified ( args) ;
1123+ }
1124+
1125+ #[ test]
1126+ fn simplify_get_field_column_base_left_alone ( ) {
1127+ // A plain column base is not a struct constructor.
1128+ let args = vec ! [ col( "s" ) , lit_str( "a" ) ] ;
1129+ assert_not_simplified ( args) ;
1130+ }
8311131}
0 commit comments