@@ -55,6 +55,7 @@ pub struct HigherOrderFunctionExpr {
5555 fun : Arc < dyn HigherOrderUDF > ,
5656 name : String ,
5757 args : Vec < Arc < dyn PhysicalExpr > > ,
58+ lambda_positions : Vec < usize > ,
5859 return_field : FieldRef ,
5960 config_options : Arc < ConfigOptions > ,
6061}
@@ -65,30 +66,43 @@ impl Debug for HigherOrderFunctionExpr {
6566 . field ( "fun" , & "<FUNC>" )
6667 . field ( "name" , & self . name )
6768 . field ( "args" , & self . args )
69+ . field ( "lambda_positions" , & self . lambda_positions )
6870 . field ( "return_field" , & self . return_field )
6971 . finish ( )
7072 }
7173}
7274
7375impl HigherOrderFunctionExpr {
7476 /// Create a new Higher Order function
77+ ///
78+ /// `lambda_positions` should contain the positions at `args` where
79+ /// lambda arguments can be found, wrapped or not. Note that any lambda wrapper
80+ /// [PhysicalExpr::evaluate] will not be called. The lambda *body* should be wrapped instead
81+ /// If any arg referenced by `lambda_positions` does not contain a lambda or contains a wrapper
82+ /// with multiple children before finding the lambda, the function evaluation will error
7583 pub fn new (
7684 name : impl Into < String > ,
7785 fun : Arc < dyn HigherOrderUDF > ,
7886 args : Vec < Arc < dyn PhysicalExpr > > ,
87+ lambda_positions : Vec < usize > ,
7988 return_field : FieldRef ,
8089 config_options : Arc < ConfigOptions > ,
8190 ) -> Self {
8291 Self {
8392 fun,
8493 name : name. into ( ) ,
8594 args,
95+ lambda_positions,
8696 return_field,
8797 config_options,
8898 }
8999 }
90100
91101 /// Create a new Higher Order function
102+ ///
103+ /// Note that lambda arguments must be present directly in args as [LambdaExpr],
104+ /// and not as a wrapped child of any arg. Use [HigherOrderFunctionExpr::new] to provide
105+ /// wrapped lambdas
92106 pub fn try_new (
93107 fun : Arc < dyn HigherOrderUDF > ,
94108 args : Vec < Arc < dyn PhysicalExpr > > ,
@@ -98,12 +112,11 @@ impl HigherOrderFunctionExpr {
98112 let name = fun. name ( ) . to_string ( ) ;
99113 let arg_fields = args
100114 . iter ( )
101- . map ( |e| {
102- let field = e. return_field ( schema) ?;
103- match e. as_any ( ) . downcast_ref :: < LambdaExpr > ( ) {
104- Some ( _lambda) => Ok ( ValueOrLambda :: Lambda ( field) ) ,
105- None => Ok ( ValueOrLambda :: Value ( field) ) ,
115+ . map ( |e| match e. as_any ( ) . downcast_ref :: < LambdaExpr > ( ) {
116+ Some ( lambda) => {
117+ Ok ( ValueOrLambda :: Lambda ( lambda. body ( ) . return_field ( schema) ?) )
106118 }
119+ None => Ok ( ValueOrLambda :: Value ( e. return_field ( schema) ?) ) ,
107120 } )
108121 . collect :: < Result < Vec < _ > > > ( ) ?;
109122
@@ -125,11 +138,23 @@ impl HigherOrderFunctionExpr {
125138 } ;
126139
127140 let return_field = fun. return_field_from_args ( ret_args) ?;
141+ let lambda_positions = args
142+ . iter ( )
143+ . enumerate ( )
144+ . filter_map ( |( i, arg) | {
145+ if arg. as_any ( ) . is :: < LambdaExpr > ( ) {
146+ Some ( i)
147+ } else {
148+ None
149+ }
150+ } )
151+ . collect ( ) ;
128152
129153 Ok ( Self {
130154 fun,
131155 name,
132156 args,
157+ lambda_positions,
133158 return_field,
134159 config_options,
135160 } )
@@ -169,6 +194,10 @@ impl HigherOrderFunctionExpr {
169194 pub fn config_options ( & self ) -> & ConfigOptions {
170195 & self . config_options
171196 }
197+
198+ pub fn lambda_positions ( & self ) -> & [ usize ] {
199+ & self . lambda_positions
200+ }
172201}
173202
174203impl fmt:: Display for HigherOrderFunctionExpr {
@@ -187,12 +216,14 @@ impl PartialEq for HigherOrderFunctionExpr {
187216 fun,
188217 name,
189218 args,
219+ lambda_positions,
190220 return_field,
191221 config_options,
192222 } = self ;
193223 fun. eq ( & o. fun )
194224 && name. eq ( & o. name )
195225 && args. eq ( & o. args )
226+ && lambda_positions. eq ( & o. lambda_positions )
196227 && return_field. eq ( & o. return_field )
197228 && ( Arc :: ptr_eq ( config_options, & o. config_options )
198229 || sorted_config_entries ( config_options)
@@ -206,12 +237,14 @@ impl Hash for HigherOrderFunctionExpr {
206237 fun,
207238 name,
208239 args,
240+ lambda_positions,
209241 return_field,
210242 config_options : _, // expensive to hash, and often equal
211243 } = self ;
212244 fun. hash ( state) ;
213245 name. hash ( state) ;
214246 args. hash ( state) ;
247+ lambda_positions. hash ( state) ;
215248 return_field. hash ( state) ;
216249 }
217250}
@@ -239,12 +272,16 @@ impl PhysicalExpr for HigherOrderFunctionExpr {
239272 let arg_fields = self
240273 . args
241274 . iter ( )
242- . map ( |e| {
243- let field = e. return_field ( batch. schema_ref ( ) ) ?;
244-
245- match e. as_any ( ) . downcast_ref :: < LambdaExpr > ( ) {
246- Some ( _lambda) => Ok ( ValueOrLambda :: Lambda ( field) ) ,
247- None => Ok ( ValueOrLambda :: Value ( field) ) ,
275+ . enumerate ( )
276+ . map ( |( i, e) | {
277+ if self . lambda_positions . contains ( & i) {
278+ let lambda = wrapped_lambda ( e) ?;
279+
280+ Ok ( ValueOrLambda :: Lambda (
281+ lambda. body ( ) . return_field ( batch. schema_ref ( ) ) ?,
282+ ) )
283+ } else {
284+ Ok ( ValueOrLambda :: Value ( e. return_field ( batch. schema_ref ( ) ) ?) )
248285 }
249286 } )
250287 . collect :: < Result < Vec < _ > > > ( ) ?;
@@ -282,8 +319,11 @@ impl PhysicalExpr for HigherOrderFunctionExpr {
282319 let args = self
283320 . args
284321 . iter ( )
285- . map ( |arg| match arg. as_any ( ) . downcast_ref :: < LambdaExpr > ( ) {
286- Some ( lambda) => {
322+ . enumerate ( )
323+ . map ( |( i, arg) | {
324+ if self . lambda_positions . contains ( & i) {
325+ let lambda = wrapped_lambda ( arg) ?;
326+
287327 let lambda_params = lambda_parameters. next ( ) . ok_or_else ( || {
288328 internal_datafusion_err ! (
289329 "params len should have been checked above"
@@ -292,7 +332,7 @@ impl PhysicalExpr for HigherOrderFunctionExpr {
292332
293333 if lambda. params ( ) . len ( ) > lambda_params. len ( ) {
294334 return exec_err ! (
295- "lambda defined {} params but UDF support only {}" ,
335+ "lambda defined {} params but UDHOF support only {}" ,
296336 lambda. params( ) . len( ) ,
297337 lambda_params. len( )
298338 ) ;
@@ -306,8 +346,7 @@ impl PhysicalExpr for HigherOrderFunctionExpr {
306346 params,
307347 Arc :: clone ( lambda. body ( ) ) ,
308348 ) ) )
309- }
310- None => {
349+ } else {
311350 let value = arg. evaluate ( batch) ?;
312351
313352 let value =
@@ -374,6 +413,7 @@ impl PhysicalExpr for HigherOrderFunctionExpr {
374413 & self . name ,
375414 Arc :: clone ( & self . fun ) ,
376415 children,
416+ self . lambda_positions . clone ( ) ,
377417 Arc :: clone ( & self . return_field ) ,
378418 Arc :: clone ( & self . config_options ) ,
379419 ) ) )
@@ -395,15 +435,35 @@ impl PhysicalExpr for HigherOrderFunctionExpr {
395435 }
396436}
397437
438+ fn wrapped_lambda ( expr : & Arc < dyn PhysicalExpr > ) -> Result < & LambdaExpr > {
439+ let mut current = expr;
440+
441+ loop {
442+ if let Some ( lambda) = current. as_any ( ) . downcast_ref :: < LambdaExpr > ( ) {
443+ return Ok ( lambda) ;
444+ }
445+
446+ match current. children ( ) . as_slice ( ) {
447+ [ single_child] => current = * single_child,
448+ _ => return exec_err ! ( "unable to unwrap lambda from {expr}" ) ,
449+ }
450+ }
451+ }
452+
398453#[ cfg( test) ]
399454mod tests {
400455 use std:: sync:: Arc ;
401456
402457 use super :: * ;
403458 use crate :: HigherOrderFunctionExpr ;
404459 use crate :: expressions:: Column ;
460+ use crate :: expressions:: lambda;
461+ use crate :: expressions:: not;
462+ use arrow:: array:: NullArray ;
463+ use arrow:: array:: RecordBatchOptions ;
405464 use arrow:: datatypes:: { DataType , Field , Schema } ;
406465 use datafusion_common:: Result ;
466+ use datafusion_common:: assert_contains;
407467 use datafusion_expr:: {
408468 HigherOrderFunctionArgs , HigherOrderSignature , HigherOrderUDF ,
409469 } ;
@@ -430,21 +490,30 @@ mod tests {
430490 & self ,
431491 _value_fields : & [ FieldRef ] ,
432492 ) -> Result < Vec < Vec < Field > > > {
433- unimplemented ! ( )
493+ Ok ( vec ! [ vec! [ Field :: new ( "" , DataType :: Null , true ) ] ] )
434494 }
435495
436496 fn return_field_from_args (
437497 & self ,
438- _args : HigherOrderReturnFieldArgs ,
498+ args : HigherOrderReturnFieldArgs ,
439499 ) -> Result < FieldRef > {
440- Ok ( Arc :: new ( Field :: new ( "" , DataType :: Int32 , false ) ) )
500+ match & args. arg_fields [ 0 ] {
501+ ValueOrLambda :: Lambda ( field) | ValueOrLambda :: Value ( field) => {
502+ Ok ( Arc :: clone ( field) )
503+ }
504+ }
441505 }
442506
443507 fn invoke_with_args (
444508 & self ,
445- _args : HigherOrderFunctionArgs ,
509+ args : HigherOrderFunctionArgs ,
446510 ) -> Result < ColumnarValue > {
447- Ok ( ColumnarValue :: Scalar ( ScalarValue :: Int32 ( Some ( 42 ) ) ) )
511+ match & args. args [ 0 ] {
512+ ValueOrLambda :: Lambda ( lambda) => {
513+ lambda. evaluate ( & [ & || Ok ( Arc :: new ( NullArray :: new ( args. number_rows ) ) ) ] )
514+ }
515+ ValueOrLambda :: Value ( value) => Ok ( value. clone ( ) ) ,
516+ }
448517 }
449518 }
450519
@@ -486,4 +555,84 @@ mod tests {
486555 let stable_arc: Arc < dyn PhysicalExpr > = Arc :: new ( stable_expr) ;
487556 assert ! ( !is_volatile( & stable_arc) ) ;
488557 }
558+
559+ #[ test]
560+ fn test_higher_order_function_wrapped_lambda ( ) {
561+ let fun = Arc :: new ( MockHigherOrderUDF {
562+ signature : HigherOrderSignature :: variadic_any ( Volatility :: Stable ) ,
563+ } ) ;
564+
565+ let expected = ScalarValue :: Int32 ( Some ( 42 ) ) ;
566+
567+ let hof = HigherOrderFunctionExpr :: try_new (
568+ fun,
569+ vec ! [ lambda( [ "a" ] , Arc :: new( Literal :: new( expected. clone( ) ) ) ) . unwrap( ) ] ,
570+ & Schema :: empty ( ) ,
571+ Arc :: new ( ConfigOptions :: new ( ) ) ,
572+ )
573+ . unwrap ( ) ;
574+
575+ let wrapped = HigherOrderFunctionExpr :: new (
576+ hof. name ,
577+ hof. fun ,
578+ vec ! [ not( Arc :: clone( & hof. args[ 0 ] ) ) . unwrap( ) ] ,
579+ hof. lambda_positions ,
580+ hof. return_field ,
581+ hof. config_options ,
582+ ) ;
583+
584+ let result = wrapped
585+ . evaluate (
586+ & RecordBatch :: try_new_with_options (
587+ Arc :: new ( Schema :: empty ( ) ) ,
588+ vec ! [ ] ,
589+ & RecordBatchOptions :: new ( ) . with_row_count ( Some ( 0 ) ) ,
590+ )
591+ . unwrap ( ) ,
592+ )
593+ . unwrap ( ) ;
594+
595+ let ColumnarValue :: Scalar ( result) = result else {
596+ unreachable ! ( )
597+ } ;
598+
599+ assert_eq ! ( result, expected) ;
600+ }
601+
602+ #[ test]
603+ fn test_higher_order_function_badly_wrapped_lambda ( ) {
604+ let fun = Arc :: new ( MockHigherOrderUDF {
605+ signature : HigherOrderSignature :: variadic_any ( Volatility :: Stable ) ,
606+ } ) ;
607+
608+ let hof = HigherOrderFunctionExpr :: try_new (
609+ fun,
610+ vec ! [
611+ not(
612+ lambda( [ "a" ] , Arc :: new( Literal :: new( ScalarValue :: Int32 ( Some ( 42 ) ) ) ) )
613+ . unwrap( ) ,
614+ )
615+ . unwrap( ) ,
616+ ] ,
617+ & Schema :: empty ( ) ,
618+ Arc :: new ( ConfigOptions :: new ( ) ) ,
619+ )
620+ . unwrap ( ) ;
621+
622+ let result = hof
623+ . evaluate (
624+ & RecordBatch :: try_new_with_options (
625+ Arc :: new ( Schema :: empty ( ) ) ,
626+ vec ! [ ] ,
627+ & RecordBatchOptions :: new ( ) . with_row_count ( Some ( 0 ) ) ,
628+ )
629+ . unwrap ( ) ,
630+ )
631+ . unwrap_err ( ) ;
632+
633+ assert_contains ! (
634+ result. to_string( ) ,
635+ "LambdaExpr::evaluate() should not be called"
636+ ) ;
637+ }
489638}
0 commit comments