Skip to content

Commit f4f864d

Browse files
authored
Fix federation pushing denied functions inside subqueries to remote engines (#640)
1 parent 97ecd00 commit f4f864d

1 file changed

Lines changed: 202 additions & 8 deletions

File tree

core/src/util/supported_functions.rs

Lines changed: 202 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,20 +28,24 @@ pub fn contains_unsupported_functions(
2828
plan: &LogicalPlan,
2929
sup: &FunctionSupport,
3030
) -> Result<bool, DataFusionError> {
31-
plan.exists(|plan| {
32-
Ok(plan.expressions().into_iter().any(|expr| {
33-
let mut found_unsupported = false;
34-
let _ = expr.apply(|expr| {
31+
let mut found_unsupported = false;
32+
plan.apply_with_subqueries(|plan| {
33+
for expr in plan.expressions() {
34+
expr.apply(|expr| {
3535
if sup.supports(expr) {
3636
Ok(TreeNodeRecursion::Continue)
3737
} else {
3838
found_unsupported = true;
3939
Ok(TreeNodeRecursion::Stop)
4040
}
41-
});
42-
found_unsupported
43-
}))
44-
})
41+
})?;
42+
if found_unsupported {
43+
return Ok(TreeNodeRecursion::Stop);
44+
}
45+
}
46+
Ok(TreeNodeRecursion::Continue)
47+
})?;
48+
Ok(found_unsupported)
4549
}
4650

4751
#[derive(Clone, Debug)]
@@ -163,3 +167,193 @@ impl FunctionRestriction {
163167
}
164168
}
165169
}
170+
171+
#[cfg(test)]
172+
mod tests {
173+
use super::*;
174+
use datafusion::arrow::datatypes::{DataType, Field, Schema};
175+
use datafusion::logical_expr::builder::LogicalTableSource;
176+
use datafusion::logical_expr::expr::ScalarFunction;
177+
use datafusion::logical_expr::{create_udf, ColumnarValue, LogicalPlanBuilder, Subquery};
178+
use datafusion::prelude::col;
179+
use std::sync::Arc;
180+
181+
fn stub_udf(name: &str) -> Arc<ScalarUDF> {
182+
Arc::new(create_udf(
183+
name,
184+
vec![DataType::Utf8],
185+
DataType::Utf8,
186+
datafusion::logical_expr::Volatility::Immutable,
187+
Arc::new(|args: &[ColumnarValue]| Ok(args[0].clone())),
188+
))
189+
}
190+
191+
fn deny_support(names: &[&str]) -> FunctionSupport {
192+
FunctionSupport::new(
193+
Some(FunctionRestriction::Deny(
194+
names.iter().map(|s| s.to_string()).collect(),
195+
)),
196+
None,
197+
None,
198+
)
199+
}
200+
201+
fn scan_plan(table: &str) -> LogicalPlan {
202+
let schema = Arc::new(Schema::new(vec![
203+
Field::new("id", DataType::Int32, false),
204+
Field::new("val", DataType::Utf8, true),
205+
]));
206+
let source = Arc::new(LogicalTableSource::new(schema))
207+
as Arc<dyn datafusion::logical_expr::TableSource>;
208+
LogicalPlanBuilder::scan(table, source, None)
209+
.expect("scan")
210+
.build()
211+
.expect("build")
212+
}
213+
214+
#[test]
215+
fn detects_denied_function_in_top_level_projection() {
216+
let udf = stub_udf("denied_fn");
217+
let plan = LogicalPlanBuilder::from(scan_plan("t"))
218+
.project(vec![Expr::ScalarFunction(ScalarFunction::new_udf(
219+
udf,
220+
vec![col("val")],
221+
))])
222+
.expect("project")
223+
.build()
224+
.expect("build");
225+
226+
let sup = deny_support(&["denied_fn"]);
227+
assert!(
228+
contains_unsupported_functions(&plan, &sup).expect("check"),
229+
"should detect denied function in top-level projection"
230+
);
231+
}
232+
233+
#[test]
234+
fn allows_plan_without_denied_functions() {
235+
let udf = stub_udf("allowed_fn");
236+
let plan = LogicalPlanBuilder::from(scan_plan("t"))
237+
.project(vec![Expr::ScalarFunction(ScalarFunction::new_udf(
238+
udf,
239+
vec![col("val")],
240+
))])
241+
.expect("project")
242+
.build()
243+
.expect("build");
244+
245+
let sup = deny_support(&["denied_fn"]);
246+
assert!(
247+
!contains_unsupported_functions(&plan, &sup).expect("check"),
248+
"should allow plan with only non-denied functions"
249+
);
250+
}
251+
252+
#[test]
253+
fn detects_denied_function_inside_in_subquery() {
254+
let udf = stub_udf("denied_fn");
255+
256+
// Build subquery: SELECT denied_fn(val) FROM inner_t
257+
let subquery_plan = LogicalPlanBuilder::from(scan_plan("inner_t"))
258+
.project(vec![Expr::ScalarFunction(ScalarFunction::new_udf(
259+
udf,
260+
vec![col("val")],
261+
))
262+
.alias("result")])
263+
.expect("project")
264+
.build()
265+
.expect("build");
266+
267+
// Build outer: SELECT id FROM t WHERE id IN (subquery)
268+
let outer = LogicalPlanBuilder::from(scan_plan("t"))
269+
.filter(Expr::InSubquery(
270+
datafusion::logical_expr::expr::InSubquery::new(
271+
Box::new(col("id")),
272+
Subquery {
273+
subquery: Arc::new(subquery_plan),
274+
outer_ref_columns: vec![],
275+
spans: Default::default(),
276+
},
277+
false,
278+
),
279+
))
280+
.expect("filter")
281+
.build()
282+
.expect("build");
283+
284+
let sup = deny_support(&["denied_fn"]);
285+
assert!(
286+
contains_unsupported_functions(&outer, &sup).expect("check"),
287+
"should detect denied function inside IN subquery"
288+
);
289+
}
290+
291+
#[test]
292+
fn detects_denied_function_inside_scalar_subquery() {
293+
let udf = stub_udf("denied_fn");
294+
295+
// Build scalar subquery: SELECT denied_fn(val) FROM inner_t
296+
let subquery_plan = LogicalPlanBuilder::from(scan_plan("inner_t"))
297+
.project(vec![Expr::ScalarFunction(ScalarFunction::new_udf(
298+
udf,
299+
vec![col("val")],
300+
))
301+
.alias("result")])
302+
.expect("project")
303+
.build()
304+
.expect("build");
305+
306+
// Build outer: SELECT id FROM t WHERE id = (scalar subquery)
307+
let outer = LogicalPlanBuilder::from(scan_plan("t"))
308+
.filter(col("id").eq(Expr::ScalarSubquery(Subquery {
309+
subquery: Arc::new(subquery_plan),
310+
outer_ref_columns: vec![],
311+
spans: Default::default(),
312+
})))
313+
.expect("filter")
314+
.build()
315+
.expect("build");
316+
317+
let sup = deny_support(&["denied_fn"]);
318+
assert!(
319+
contains_unsupported_functions(&outer, &sup).expect("check"),
320+
"should detect denied function inside scalar subquery"
321+
);
322+
}
323+
324+
#[test]
325+
fn detects_denied_function_inside_exists_subquery() {
326+
let udf = stub_udf("denied_fn");
327+
328+
// Build subquery: SELECT denied_fn(val) FROM inner_t
329+
let subquery_plan = LogicalPlanBuilder::from(scan_plan("inner_t"))
330+
.project(vec![Expr::ScalarFunction(ScalarFunction::new_udf(
331+
udf,
332+
vec![col("val")],
333+
))
334+
.alias("result")])
335+
.expect("project")
336+
.build()
337+
.expect("build");
338+
339+
// Build outer: SELECT id FROM t WHERE EXISTS (subquery)
340+
let outer = LogicalPlanBuilder::from(scan_plan("t"))
341+
.filter(Expr::Exists(datafusion::logical_expr::expr::Exists::new(
342+
Subquery {
343+
subquery: Arc::new(subquery_plan),
344+
outer_ref_columns: vec![],
345+
spans: Default::default(),
346+
},
347+
false,
348+
)))
349+
.expect("filter")
350+
.build()
351+
.expect("build");
352+
353+
let sup = deny_support(&["denied_fn"]);
354+
assert!(
355+
contains_unsupported_functions(&outer, &sup).expect("check"),
356+
"should detect denied function inside EXISTS subquery"
357+
);
358+
}
359+
}

0 commit comments

Comments
 (0)