Skip to content

Commit ca260a7

Browse files
committed
handle wrapped lambdas
1 parent 5c0b41d commit ca260a7

2 files changed

Lines changed: 175 additions & 26 deletions

File tree

datafusion/physical-expr/src/expressions/lambda.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,16 +87,16 @@ impl PhysicalExpr for LambdaExpr {
8787
self
8888
}
8989

90-
fn data_type(&self, input_schema: &Schema) -> Result<DataType> {
91-
self.body.data_type(input_schema)
90+
fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
91+
Ok(DataType::Null)
9292
}
9393

94-
fn nullable(&self, input_schema: &Schema) -> Result<bool> {
95-
self.body.nullable(input_schema)
94+
fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
95+
Ok(true)
9696
}
9797

9898
fn evaluate(&self, _batch: &RecordBatch) -> Result<ColumnarValue> {
99-
internal_err!("Lambda::evaluate() should not be called")
99+
internal_err!("LambdaExpr::evaluate() should not be called")
100100
}
101101

102102
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {

datafusion/physical-expr/src/higher_order_function.rs

Lines changed: 170 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ pub struct HigherOrderFunctionExpr {
5555
fun: Arc<dyn HigherOrderUDF>,
5656
name: String,
5757
args: Vec<Arc<dyn PhysicalExpr>>,
58+
lambda_positions: Vec<usize>,
5859
return_field: FieldRef,
5960
config_options: Arc<ConfigOptions>,
6061
}
@@ -65,30 +66,43 @@ impl Debug for HigherOrderFunctionExpr {
6566
.field("fun", &"<FUNC>")
6667
.field("name", &self.name)
6768
.field("args", &self.args)
69+
.field("lambda_positions", &self.lambda_positions)
6870
.field("return_field", &self.return_field)
6971
.finish()
7072
}
7173
}
7274

7375
impl HigherOrderFunctionExpr {
7476
/// Create a new Higher Order function
77+
///
78+
/// `lambda_positions` should contain the positions at `args` where
79+
/// lambda arguments can be found, wrapped or not. Note that any lambda wrapper
80+
/// [PhysicalExpr::evaluate] will not be called. The lambda *body* should be wrapped instead
81+
/// If any arg referenced by `lambda_positions` does not contain a lambda or contains a wrapper
82+
/// with multiple children before finding the lambda, the function evaluation will error
7583
pub fn new(
7684
name: impl Into<String>,
7785
fun: Arc<dyn HigherOrderUDF>,
7886
args: Vec<Arc<dyn PhysicalExpr>>,
87+
lambda_positions: Vec<usize>,
7988
return_field: FieldRef,
8089
config_options: Arc<ConfigOptions>,
8190
) -> Self {
8291
Self {
8392
fun,
8493
name: name.into(),
8594
args,
95+
lambda_positions,
8696
return_field,
8797
config_options,
8898
}
8999
}
90100

91101
/// Create a new Higher Order function
102+
///
103+
/// Note that lambda arguments must be present directly in args as [LambdaExpr],
104+
/// and not as a wrapped child of any arg. Use [HigherOrderFunctionExpr::new] to provide
105+
/// wrapped lambdas
92106
pub fn try_new(
93107
fun: Arc<dyn HigherOrderUDF>,
94108
args: Vec<Arc<dyn PhysicalExpr>>,
@@ -98,12 +112,11 @@ impl HigherOrderFunctionExpr {
98112
let name = fun.name().to_string();
99113
let arg_fields = args
100114
.iter()
101-
.map(|e| {
102-
let field = e.return_field(schema)?;
103-
match e.as_any().downcast_ref::<LambdaExpr>() {
104-
Some(_lambda) => Ok(ValueOrLambda::Lambda(field)),
105-
None => Ok(ValueOrLambda::Value(field)),
115+
.map(|e| match e.as_any().downcast_ref::<LambdaExpr>() {
116+
Some(lambda) => {
117+
Ok(ValueOrLambda::Lambda(lambda.body().return_field(schema)?))
106118
}
119+
None => Ok(ValueOrLambda::Value(e.return_field(schema)?)),
107120
})
108121
.collect::<Result<Vec<_>>>()?;
109122

@@ -125,11 +138,23 @@ impl HigherOrderFunctionExpr {
125138
};
126139

127140
let return_field = fun.return_field_from_args(ret_args)?;
141+
let lambda_positions = args
142+
.iter()
143+
.enumerate()
144+
.filter_map(|(i, arg)| {
145+
if arg.as_any().is::<LambdaExpr>() {
146+
Some(i)
147+
} else {
148+
None
149+
}
150+
})
151+
.collect();
128152

129153
Ok(Self {
130154
fun,
131155
name,
132156
args,
157+
lambda_positions,
133158
return_field,
134159
config_options,
135160
})
@@ -169,6 +194,10 @@ impl HigherOrderFunctionExpr {
169194
pub fn config_options(&self) -> &ConfigOptions {
170195
&self.config_options
171196
}
197+
198+
pub fn lambda_positions(&self) -> &[usize] {
199+
&self.lambda_positions
200+
}
172201
}
173202

174203
impl fmt::Display for HigherOrderFunctionExpr {
@@ -187,12 +216,14 @@ impl PartialEq for HigherOrderFunctionExpr {
187216
fun,
188217
name,
189218
args,
219+
lambda_positions,
190220
return_field,
191221
config_options,
192222
} = self;
193223
fun.eq(&o.fun)
194224
&& name.eq(&o.name)
195225
&& args.eq(&o.args)
226+
&& lambda_positions.eq(&o.lambda_positions)
196227
&& return_field.eq(&o.return_field)
197228
&& (Arc::ptr_eq(config_options, &o.config_options)
198229
|| sorted_config_entries(config_options)
@@ -206,12 +237,14 @@ impl Hash for HigherOrderFunctionExpr {
206237
fun,
207238
name,
208239
args,
240+
lambda_positions,
209241
return_field,
210242
config_options: _, // expensive to hash, and often equal
211243
} = self;
212244
fun.hash(state);
213245
name.hash(state);
214246
args.hash(state);
247+
lambda_positions.hash(state);
215248
return_field.hash(state);
216249
}
217250
}
@@ -239,12 +272,16 @@ impl PhysicalExpr for HigherOrderFunctionExpr {
239272
let arg_fields = self
240273
.args
241274
.iter()
242-
.map(|e| {
243-
let field = e.return_field(batch.schema_ref())?;
244-
245-
match e.as_any().downcast_ref::<LambdaExpr>() {
246-
Some(_lambda) => Ok(ValueOrLambda::Lambda(field)),
247-
None => Ok(ValueOrLambda::Value(field)),
275+
.enumerate()
276+
.map(|(i, e)| {
277+
if self.lambda_positions.contains(&i) {
278+
let lambda = wrapped_lambda(e)?;
279+
280+
Ok(ValueOrLambda::Lambda(
281+
lambda.body().return_field(batch.schema_ref())?,
282+
))
283+
} else {
284+
Ok(ValueOrLambda::Value(e.return_field(batch.schema_ref())?))
248285
}
249286
})
250287
.collect::<Result<Vec<_>>>()?;
@@ -282,8 +319,11 @@ impl PhysicalExpr for HigherOrderFunctionExpr {
282319
let args = self
283320
.args
284321
.iter()
285-
.map(|arg| match arg.as_any().downcast_ref::<LambdaExpr>() {
286-
Some(lambda) => {
322+
.enumerate()
323+
.map(|(i, arg)| {
324+
if self.lambda_positions.contains(&i) {
325+
let lambda = wrapped_lambda(arg)?;
326+
287327
let lambda_params = lambda_parameters.next().ok_or_else(|| {
288328
internal_datafusion_err!(
289329
"params len should have been checked above"
@@ -292,7 +332,7 @@ impl PhysicalExpr for HigherOrderFunctionExpr {
292332

293333
if lambda.params().len() > lambda_params.len() {
294334
return exec_err!(
295-
"lambda defined {} params but UDF support only {}",
335+
"lambda defined {} params but UDHOF support only {}",
296336
lambda.params().len(),
297337
lambda_params.len()
298338
);
@@ -306,8 +346,7 @@ impl PhysicalExpr for HigherOrderFunctionExpr {
306346
params,
307347
Arc::clone(lambda.body()),
308348
)))
309-
}
310-
None => {
349+
} else {
311350
let value = arg.evaluate(batch)?;
312351

313352
let value =
@@ -374,6 +413,7 @@ impl PhysicalExpr for HigherOrderFunctionExpr {
374413
&self.name,
375414
Arc::clone(&self.fun),
376415
children,
416+
self.lambda_positions.clone(),
377417
Arc::clone(&self.return_field),
378418
Arc::clone(&self.config_options),
379419
)))
@@ -395,15 +435,35 @@ impl PhysicalExpr for HigherOrderFunctionExpr {
395435
}
396436
}
397437

438+
fn wrapped_lambda(expr: &Arc<dyn PhysicalExpr>) -> Result<&LambdaExpr> {
439+
let mut current = expr;
440+
441+
loop {
442+
if let Some(lambda) = current.as_any().downcast_ref::<LambdaExpr>() {
443+
return Ok(lambda);
444+
}
445+
446+
match current.children().as_slice() {
447+
[single_child] => current = *single_child,
448+
_ => return exec_err!("unable to unwrap lambda from {expr}"),
449+
}
450+
}
451+
}
452+
398453
#[cfg(test)]
399454
mod tests {
400455
use std::sync::Arc;
401456

402457
use super::*;
403458
use crate::HigherOrderFunctionExpr;
404459
use crate::expressions::Column;
460+
use crate::expressions::lambda;
461+
use crate::expressions::not;
462+
use arrow::array::NullArray;
463+
use arrow::array::RecordBatchOptions;
405464
use arrow::datatypes::{DataType, Field, Schema};
406465
use datafusion_common::Result;
466+
use datafusion_common::assert_contains;
407467
use datafusion_expr::{
408468
HigherOrderFunctionArgs, HigherOrderSignature, HigherOrderUDF,
409469
};
@@ -430,21 +490,30 @@ mod tests {
430490
&self,
431491
_value_fields: &[FieldRef],
432492
) -> Result<Vec<Vec<Field>>> {
433-
unimplemented!()
493+
Ok(vec![vec![Field::new("", DataType::Null, true)]])
434494
}
435495

436496
fn return_field_from_args(
437497
&self,
438-
_args: HigherOrderReturnFieldArgs,
498+
args: HigherOrderReturnFieldArgs,
439499
) -> Result<FieldRef> {
440-
Ok(Arc::new(Field::new("", DataType::Int32, false)))
500+
match &args.arg_fields[0] {
501+
ValueOrLambda::Lambda(field) | ValueOrLambda::Value(field) => {
502+
Ok(Arc::clone(field))
503+
}
504+
}
441505
}
442506

443507
fn invoke_with_args(
444508
&self,
445-
_args: HigherOrderFunctionArgs,
509+
args: HigherOrderFunctionArgs,
446510
) -> Result<ColumnarValue> {
447-
Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(42))))
511+
match &args.args[0] {
512+
ValueOrLambda::Lambda(lambda) => {
513+
lambda.evaluate(&[&|| Ok(Arc::new(NullArray::new(args.number_rows)))])
514+
}
515+
ValueOrLambda::Value(value) => Ok(value.clone()),
516+
}
448517
}
449518
}
450519

@@ -486,4 +555,84 @@ mod tests {
486555
let stable_arc: Arc<dyn PhysicalExpr> = Arc::new(stable_expr);
487556
assert!(!is_volatile(&stable_arc));
488557
}
558+
559+
#[test]
560+
fn test_higher_order_function_wrapped_lambda() {
561+
let fun = Arc::new(MockHigherOrderUDF {
562+
signature: HigherOrderSignature::variadic_any(Volatility::Stable),
563+
});
564+
565+
let expected = ScalarValue::Int32(Some(42));
566+
567+
let hof = HigherOrderFunctionExpr::try_new(
568+
fun,
569+
vec![lambda(["a"], Arc::new(Literal::new(expected.clone()))).unwrap()],
570+
&Schema::empty(),
571+
Arc::new(ConfigOptions::new()),
572+
)
573+
.unwrap();
574+
575+
let wrapped = HigherOrderFunctionExpr::new(
576+
hof.name,
577+
hof.fun,
578+
vec![not(Arc::clone(&hof.args[0])).unwrap()],
579+
hof.lambda_positions,
580+
hof.return_field,
581+
hof.config_options,
582+
);
583+
584+
let result = wrapped
585+
.evaluate(
586+
&RecordBatch::try_new_with_options(
587+
Arc::new(Schema::empty()),
588+
vec![],
589+
&RecordBatchOptions::new().with_row_count(Some(0)),
590+
)
591+
.unwrap(),
592+
)
593+
.unwrap();
594+
595+
let ColumnarValue::Scalar(result) = result else {
596+
unreachable!()
597+
};
598+
599+
assert_eq!(result, expected);
600+
}
601+
602+
#[test]
603+
fn test_higher_order_function_badly_wrapped_lambda() {
604+
let fun = Arc::new(MockHigherOrderUDF {
605+
signature: HigherOrderSignature::variadic_any(Volatility::Stable),
606+
});
607+
608+
let hof = HigherOrderFunctionExpr::try_new(
609+
fun,
610+
vec![
611+
not(
612+
lambda(["a"], Arc::new(Literal::new(ScalarValue::Int32(Some(42)))))
613+
.unwrap(),
614+
)
615+
.unwrap(),
616+
],
617+
&Schema::empty(),
618+
Arc::new(ConfigOptions::new()),
619+
)
620+
.unwrap();
621+
622+
let result = hof
623+
.evaluate(
624+
&RecordBatch::try_new_with_options(
625+
Arc::new(Schema::empty()),
626+
vec![],
627+
&RecordBatchOptions::new().with_row_count(Some(0)),
628+
)
629+
.unwrap(),
630+
)
631+
.unwrap_err();
632+
633+
assert_contains!(
634+
result.to_string(),
635+
"LambdaExpr::evaluate() should not be called"
636+
);
637+
}
489638
}

0 commit comments

Comments
 (0)