@@ -28,20 +28,24 @@ pub fn contains_unsupported_functions(
2828 plan : & LogicalPlan ,
2929 sup : & FunctionSupport ,
3030) -> Result < bool , DataFusionError > {
31- plan . exists ( |plan| {
32- Ok ( plan. expressions ( ) . into_iter ( ) . any ( |expr | {
33- let mut found_unsupported = false ;
34- let _ = expr. apply ( |expr| {
31+ let mut found_unsupported = false ;
32+ plan. apply_with_subqueries ( |plan | {
33+ for expr in plan . expressions ( ) {
34+ expr. apply ( |expr| {
3535 if sup. supports ( expr) {
3636 Ok ( TreeNodeRecursion :: Continue )
3737 } else {
3838 found_unsupported = true ;
3939 Ok ( TreeNodeRecursion :: Stop )
4040 }
41- } ) ;
42- found_unsupported
43- } ) )
44- } )
41+ } ) ?;
42+ if found_unsupported {
43+ return Ok ( TreeNodeRecursion :: Stop ) ;
44+ }
45+ }
46+ Ok ( TreeNodeRecursion :: Continue )
47+ } ) ?;
48+ Ok ( found_unsupported)
4549}
4650
4751#[ derive( Clone , Debug ) ]
@@ -163,3 +167,193 @@ impl FunctionRestriction {
163167 }
164168 }
165169}
170+
171+ #[ cfg( test) ]
172+ mod tests {
173+ use super :: * ;
174+ use datafusion:: arrow:: datatypes:: { DataType , Field , Schema } ;
175+ use datafusion:: logical_expr:: builder:: LogicalTableSource ;
176+ use datafusion:: logical_expr:: expr:: ScalarFunction ;
177+ use datafusion:: logical_expr:: { create_udf, ColumnarValue , LogicalPlanBuilder , Subquery } ;
178+ use datafusion:: prelude:: col;
179+ use std:: sync:: Arc ;
180+
181+ fn stub_udf ( name : & str ) -> Arc < ScalarUDF > {
182+ Arc :: new ( create_udf (
183+ name,
184+ vec ! [ DataType :: Utf8 ] ,
185+ DataType :: Utf8 ,
186+ datafusion:: logical_expr:: Volatility :: Immutable ,
187+ Arc :: new ( |args : & [ ColumnarValue ] | Ok ( args[ 0 ] . clone ( ) ) ) ,
188+ ) )
189+ }
190+
191+ fn deny_support ( names : & [ & str ] ) -> FunctionSupport {
192+ FunctionSupport :: new (
193+ Some ( FunctionRestriction :: Deny (
194+ names. iter ( ) . map ( |s| s. to_string ( ) ) . collect ( ) ,
195+ ) ) ,
196+ None ,
197+ None ,
198+ )
199+ }
200+
201+ fn scan_plan ( table : & str ) -> LogicalPlan {
202+ let schema = Arc :: new ( Schema :: new ( vec ! [
203+ Field :: new( "id" , DataType :: Int32 , false ) ,
204+ Field :: new( "val" , DataType :: Utf8 , true ) ,
205+ ] ) ) ;
206+ let source = Arc :: new ( LogicalTableSource :: new ( schema) )
207+ as Arc < dyn datafusion:: logical_expr:: TableSource > ;
208+ LogicalPlanBuilder :: scan ( table, source, None )
209+ . expect ( "scan" )
210+ . build ( )
211+ . expect ( "build" )
212+ }
213+
214+ #[ test]
215+ fn detects_denied_function_in_top_level_projection ( ) {
216+ let udf = stub_udf ( "denied_fn" ) ;
217+ let plan = LogicalPlanBuilder :: from ( scan_plan ( "t" ) )
218+ . project ( vec ! [ Expr :: ScalarFunction ( ScalarFunction :: new_udf(
219+ udf,
220+ vec![ col( "val" ) ] ,
221+ ) ) ] )
222+ . expect ( "project" )
223+ . build ( )
224+ . expect ( "build" ) ;
225+
226+ let sup = deny_support ( & [ "denied_fn" ] ) ;
227+ assert ! (
228+ contains_unsupported_functions( & plan, & sup) . expect( "check" ) ,
229+ "should detect denied function in top-level projection"
230+ ) ;
231+ }
232+
233+ #[ test]
234+ fn allows_plan_without_denied_functions ( ) {
235+ let udf = stub_udf ( "allowed_fn" ) ;
236+ let plan = LogicalPlanBuilder :: from ( scan_plan ( "t" ) )
237+ . project ( vec ! [ Expr :: ScalarFunction ( ScalarFunction :: new_udf(
238+ udf,
239+ vec![ col( "val" ) ] ,
240+ ) ) ] )
241+ . expect ( "project" )
242+ . build ( )
243+ . expect ( "build" ) ;
244+
245+ let sup = deny_support ( & [ "denied_fn" ] ) ;
246+ assert ! (
247+ !contains_unsupported_functions( & plan, & sup) . expect( "check" ) ,
248+ "should allow plan with only non-denied functions"
249+ ) ;
250+ }
251+
252+ #[ test]
253+ fn detects_denied_function_inside_in_subquery ( ) {
254+ let udf = stub_udf ( "denied_fn" ) ;
255+
256+ // Build subquery: SELECT denied_fn(val) FROM inner_t
257+ let subquery_plan = LogicalPlanBuilder :: from ( scan_plan ( "inner_t" ) )
258+ . project ( vec ! [ Expr :: ScalarFunction ( ScalarFunction :: new_udf(
259+ udf,
260+ vec![ col( "val" ) ] ,
261+ ) )
262+ . alias( "result" ) ] )
263+ . expect ( "project" )
264+ . build ( )
265+ . expect ( "build" ) ;
266+
267+ // Build outer: SELECT id FROM t WHERE id IN (subquery)
268+ let outer = LogicalPlanBuilder :: from ( scan_plan ( "t" ) )
269+ . filter ( Expr :: InSubquery (
270+ datafusion:: logical_expr:: expr:: InSubquery :: new (
271+ Box :: new ( col ( "id" ) ) ,
272+ Subquery {
273+ subquery : Arc :: new ( subquery_plan) ,
274+ outer_ref_columns : vec ! [ ] ,
275+ spans : Default :: default ( ) ,
276+ } ,
277+ false ,
278+ ) ,
279+ ) )
280+ . expect ( "filter" )
281+ . build ( )
282+ . expect ( "build" ) ;
283+
284+ let sup = deny_support ( & [ "denied_fn" ] ) ;
285+ assert ! (
286+ contains_unsupported_functions( & outer, & sup) . expect( "check" ) ,
287+ "should detect denied function inside IN subquery"
288+ ) ;
289+ }
290+
291+ #[ test]
292+ fn detects_denied_function_inside_scalar_subquery ( ) {
293+ let udf = stub_udf ( "denied_fn" ) ;
294+
295+ // Build scalar subquery: SELECT denied_fn(val) FROM inner_t
296+ let subquery_plan = LogicalPlanBuilder :: from ( scan_plan ( "inner_t" ) )
297+ . project ( vec ! [ Expr :: ScalarFunction ( ScalarFunction :: new_udf(
298+ udf,
299+ vec![ col( "val" ) ] ,
300+ ) )
301+ . alias( "result" ) ] )
302+ . expect ( "project" )
303+ . build ( )
304+ . expect ( "build" ) ;
305+
306+ // Build outer: SELECT id FROM t WHERE id = (scalar subquery)
307+ let outer = LogicalPlanBuilder :: from ( scan_plan ( "t" ) )
308+ . filter ( col ( "id" ) . eq ( Expr :: ScalarSubquery ( Subquery {
309+ subquery : Arc :: new ( subquery_plan) ,
310+ outer_ref_columns : vec ! [ ] ,
311+ spans : Default :: default ( ) ,
312+ } ) ) )
313+ . expect ( "filter" )
314+ . build ( )
315+ . expect ( "build" ) ;
316+
317+ let sup = deny_support ( & [ "denied_fn" ] ) ;
318+ assert ! (
319+ contains_unsupported_functions( & outer, & sup) . expect( "check" ) ,
320+ "should detect denied function inside scalar subquery"
321+ ) ;
322+ }
323+
324+ #[ test]
325+ fn detects_denied_function_inside_exists_subquery ( ) {
326+ let udf = stub_udf ( "denied_fn" ) ;
327+
328+ // Build subquery: SELECT denied_fn(val) FROM inner_t
329+ let subquery_plan = LogicalPlanBuilder :: from ( scan_plan ( "inner_t" ) )
330+ . project ( vec ! [ Expr :: ScalarFunction ( ScalarFunction :: new_udf(
331+ udf,
332+ vec![ col( "val" ) ] ,
333+ ) )
334+ . alias( "result" ) ] )
335+ . expect ( "project" )
336+ . build ( )
337+ . expect ( "build" ) ;
338+
339+ // Build outer: SELECT id FROM t WHERE EXISTS (subquery)
340+ let outer = LogicalPlanBuilder :: from ( scan_plan ( "t" ) )
341+ . filter ( Expr :: Exists ( datafusion:: logical_expr:: expr:: Exists :: new (
342+ Subquery {
343+ subquery : Arc :: new ( subquery_plan) ,
344+ outer_ref_columns : vec ! [ ] ,
345+ spans : Default :: default ( ) ,
346+ } ,
347+ false ,
348+ ) ) )
349+ . expect ( "filter" )
350+ . build ( )
351+ . expect ( "build" ) ;
352+
353+ let sup = deny_support ( & [ "denied_fn" ] ) ;
354+ assert ! (
355+ contains_unsupported_functions( & outer, & sup) . expect( "check" ) ,
356+ "should detect denied function inside EXISTS subquery"
357+ ) ;
358+ }
359+ }
0 commit comments