@@ -335,6 +335,11 @@ mod tests {
335335 use arrow:: record_batch:: RecordBatch ;
336336 use datafusion_physical_expr:: scalar_subquery:: ScalarSubqueryExpr ;
337337
338+ enum ExpectedSubqueryResult {
339+ Value ( ScalarValue ) ,
340+ Error ( & ' static str ) ,
341+ }
342+
338343 #[ derive( Debug ) ]
339344 struct CountingExec {
340345 inner : Arc < dyn ExecutionPlan > ,
@@ -406,8 +411,53 @@ mod tests {
406411 TestMemoryExec :: try_new_exec ( & [ batches] , schema, None ) . unwrap ( )
407412 }
408413
409- fn make_results ( n : usize ) -> ScalarSubqueryResults {
410- ScalarSubqueryResults :: new ( n)
414+ fn int32_batch ( values : Vec < i32 > ) -> RecordBatch {
415+ let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Int32 , false ) ] ) ) ;
416+ RecordBatch :: try_new ( schema, vec ! [ Arc :: new( Int32Array :: from( values) ) ] ) . unwrap ( )
417+ }
418+
419+ fn empty_int64_batch ( ) -> RecordBatch {
420+ let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Int64 , true ) ] ) ) ;
421+ RecordBatch :: try_new ( schema, vec ! [ Arc :: new( Int64Array :: from( vec![ ] as Vec <i64 >) ) ] )
422+ . unwrap ( )
423+ }
424+
425+ fn placeholder_input ( ) -> Arc < dyn ExecutionPlan > {
426+ Arc :: new ( crate :: placeholder_row:: PlaceholderRowExec :: new (
427+ test:: aggr_test_schema ( ) ,
428+ ) )
429+ }
430+
431+ fn single_subquery_exec (
432+ input : Arc < dyn ExecutionPlan > ,
433+ subquery_plan : Arc < dyn ExecutionPlan > ,
434+ results : ScalarSubqueryResults ,
435+ ) -> ScalarSubqueryExec {
436+ ScalarSubqueryExec :: new (
437+ input,
438+ vec ! [ ScalarSubqueryLink {
439+ plan: subquery_plan,
440+ index: 0 ,
441+ } ] ,
442+ results,
443+ )
444+ }
445+
446+ fn scalar_subquery_projection_input (
447+ results : ScalarSubqueryResults ,
448+ ) -> Result < Arc < dyn ExecutionPlan > > {
449+ Ok ( Arc :: new ( ProjectionExec :: try_new (
450+ vec ! [ ProjectionExpr {
451+ expr: Arc :: new( ScalarSubqueryExpr :: new(
452+ DataType :: Int32 ,
453+ false ,
454+ 0 ,
455+ results,
456+ ) ) ,
457+ alias: "sq" . to_string( ) ,
458+ } ] ,
459+ placeholder_input ( ) ,
460+ ) ?) )
411461 }
412462
413463 fn extract_single_int32_value ( batches : & [ RecordBatch ] ) -> i32 {
@@ -422,91 +472,40 @@ mod tests {
422472 }
423473
424474 #[ tokio:: test]
425- async fn test_single_row_subquery ( ) -> Result < ( ) > {
426- let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Int32 , false ) ] ) ) ;
427- let batch = RecordBatch :: try_new (
428- Arc :: clone ( & schema) ,
429- vec ! [ Arc :: new( Int32Array :: from( vec![ 42 ] ) ) ] ,
430- ) ?;
431-
432- let results = make_results ( 1 ) ;
433- let subquery_plan = make_subquery_plan ( vec ! [ batch] ) ;
434- let sq = ScalarSubqueryLink {
435- plan : subquery_plan,
436- index : 0 ,
437- } ;
438-
439- let main_input = Arc :: new ( crate :: placeholder_row:: PlaceholderRowExec :: new (
440- test:: aggr_test_schema ( ) ,
441- ) ) ;
442- let exec = ScalarSubqueryExec :: new ( main_input, vec ! [ sq] , results. clone ( ) ) ;
443-
444- let ctx = Arc :: new ( TaskContext :: default ( ) ) ;
445- let stream = exec. execute ( 0 , ctx) ?;
446- let _batches = crate :: common:: collect ( stream) . await ?;
447-
448- assert_eq ! ( results. get( 0 ) , Some ( ScalarValue :: Int32 ( Some ( 42 ) ) ) ) ;
449- Ok ( ( ) )
450- }
451-
452- #[ tokio:: test]
453- async fn test_zero_row_subquery_returns_null ( ) -> Result < ( ) > {
454- let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Int64 , true ) ] ) ) ;
455- let batch = RecordBatch :: try_new (
456- Arc :: clone ( & schema) ,
457- vec ! [ Arc :: new( Int64Array :: from( vec![ ] as Vec <i64 >) ) ] ,
458- ) ?;
459-
460- let results = make_results ( 1 ) ;
461- let subquery_plan = make_subquery_plan ( vec ! [ batch] ) ;
462- let sq = ScalarSubqueryLink {
463- plan : subquery_plan,
464- index : 0 ,
465- } ;
466-
467- let main_input = Arc :: new ( crate :: placeholder_row:: PlaceholderRowExec :: new (
468- test:: aggr_test_schema ( ) ,
469- ) ) ;
470- let exec = ScalarSubqueryExec :: new ( main_input, vec ! [ sq] , results. clone ( ) ) ;
471-
472- let ctx = Arc :: new ( TaskContext :: default ( ) ) ;
473- let stream = exec. execute ( 0 , ctx) ?;
474- let _batches = crate :: common:: collect ( stream) . await ?;
475-
476- assert_eq ! ( results. get( 0 ) , Some ( ScalarValue :: Int64 ( None ) ) ) ;
477- Ok ( ( ) )
478- }
479-
480- #[ tokio:: test]
481- async fn test_multi_row_subquery_errors ( ) -> Result < ( ) > {
482- let schema = Arc :: new ( Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Int32 , false ) ] ) ) ;
483- let batch = RecordBatch :: try_new (
484- Arc :: clone ( & schema) ,
485- vec ! [ Arc :: new( Int32Array :: from( vec![ 1 , 2 , 3 ] ) ) ] ,
486- ) ?;
487-
488- let results = make_results ( 1 ) ;
489- let subquery_plan = make_subquery_plan ( vec ! [ batch] ) ;
490- let sq = ScalarSubqueryLink {
491- plan : subquery_plan,
492- index : 0 ,
493- } ;
494-
495- let main_input = Arc :: new ( crate :: placeholder_row:: PlaceholderRowExec :: new (
496- test:: aggr_test_schema ( ) ,
497- ) ) ;
498- let exec = ScalarSubqueryExec :: new ( main_input, vec ! [ sq] , results. clone ( ) ) ;
499-
500- let ctx = Arc :: new ( TaskContext :: default ( ) ) ;
501- let stream = exec. execute ( 0 , ctx) ?;
502- let result = crate :: common:: collect ( stream) . await ;
475+ async fn test_execute_scalar_subquery_row_count_semantics ( ) -> Result < ( ) > {
476+ for ( name, plan, expected) in [
477+ (
478+ "single_row" ,
479+ make_subquery_plan ( vec ! [ int32_batch( vec![ 42 ] ) ] ) ,
480+ ExpectedSubqueryResult :: Value ( ScalarValue :: Int32 ( Some ( 42 ) ) ) ,
481+ ) ,
482+ (
483+ "zero_rows" ,
484+ make_subquery_plan ( vec ! [ empty_int64_batch( ) ] ) ,
485+ ExpectedSubqueryResult :: Value ( ScalarValue :: Int64 ( None ) ) ,
486+ ) ,
487+ (
488+ "multiple_rows" ,
489+ make_subquery_plan ( vec ! [ int32_batch( vec![ 1 , 2 , 3 ] ) ] ) ,
490+ ExpectedSubqueryResult :: Error ( "more than one row" ) ,
491+ ) ,
492+ ] {
493+ let actual =
494+ execute_scalar_subquery ( plan, Arc :: new ( TaskContext :: default ( ) ) ) . await ;
495+ match expected {
496+ ExpectedSubqueryResult :: Value ( expected) => {
497+ assert_eq ! ( actual?, expected, "{name}" ) ;
498+ }
499+ ExpectedSubqueryResult :: Error ( expected) => {
500+ let err = actual. expect_err ( name) ;
501+ assert ! (
502+ err. to_string( ) . contains( expected) ,
503+ "{name}: expected error containing '{expected}', got {err}"
504+ ) ;
505+ }
506+ }
507+ }
503508
504- assert ! ( result. is_err( ) ) ;
505- let err_msg = result. unwrap_err ( ) . to_string ( ) ;
506- assert ! (
507- err_msg. contains( "more than one row" ) ,
508- "Expected 'more than one row' error, got: {err_msg}"
509- ) ;
510509 Ok ( ( ) )
511510 }
512511
@@ -517,16 +516,11 @@ mod tests {
517516 Arc :: new ( ErrorExec :: new ( ) ) ,
518517 Arc :: clone ( & execute_calls) ,
519518 ) ) ;
520- let results = make_results ( 1 ) ;
521- let sq = ScalarSubqueryLink {
522- plan : subquery_plan,
523- index : 0 ,
524- } ;
525-
526- let main_input = Arc :: new ( crate :: placeholder_row:: PlaceholderRowExec :: new (
527- test:: aggr_test_schema ( ) ,
528- ) ) ;
529- let exec = ScalarSubqueryExec :: new ( main_input, vec ! [ sq] , results) ;
519+ let exec = single_subquery_exec (
520+ placeholder_input ( ) ,
521+ subquery_plan,
522+ ScalarSubqueryResults :: new ( 1 ) ,
523+ ) ;
530524
531525 let ctx = Arc :: new ( TaskContext :: default ( ) ) ;
532526 let stream = exec. execute ( 0 , Arc :: clone ( & ctx) ) ?;
@@ -542,39 +536,14 @@ mod tests {
542536 #[ tokio:: test]
543537 async fn test_reset_state_clears_results_and_reexecutes_subqueries ( ) -> Result < ( ) > {
544538 let execute_calls = Arc :: new ( AtomicUsize :: new ( 0 ) ) ;
545- let results = make_results ( 1 ) ;
546- let schema =
547- Arc :: new ( Schema :: new ( vec ! [ Field :: new( "sq" , DataType :: Int32 , false ) ] ) ) ;
548- let batch = RecordBatch :: try_new (
549- Arc :: clone ( & schema) ,
550- vec ! [ Arc :: new( Int32Array :: from( vec![ 42 ] ) ) ] ,
551- ) ?;
539+ let results = ScalarSubqueryResults :: new ( 1 ) ;
552540 let subquery_plan = Arc :: new ( CountingExec :: new (
553- make_subquery_plan ( vec ! [ batch ] ) ,
541+ make_subquery_plan ( vec ! [ int32_batch ( vec! [ 42 ] ) ] ) ,
554542 Arc :: clone ( & execute_calls) ,
555543 ) ) ;
556- let sq = ScalarSubqueryLink {
557- plan : subquery_plan,
558- index : 0 ,
559- } ;
560-
561- let main_input = Arc :: new ( ProjectionExec :: try_new (
562- vec ! [ ProjectionExpr {
563- expr: Arc :: new( ScalarSubqueryExpr :: new(
564- DataType :: Int32 ,
565- false ,
566- 0 ,
567- results. clone( ) ,
568- ) ) ,
569- alias: "sq" . to_string( ) ,
570- } ] ,
571- Arc :: new ( crate :: placeholder_row:: PlaceholderRowExec :: new (
572- test:: aggr_test_schema ( ) ,
573- ) ) ,
574- ) ?) ;
575- let exec: Arc < dyn ExecutionPlan > = Arc :: new ( ScalarSubqueryExec :: new (
576- main_input,
577- vec ! [ sq] ,
544+ let exec: Arc < dyn ExecutionPlan > = Arc :: new ( single_subquery_exec (
545+ scalar_subquery_projection_input ( results. clone ( ) ) ?,
546+ subquery_plan,
578547 results. clone ( ) ,
579548 ) ) ;
580549
0 commit comments