Skip to content

Commit 0416b0d

Browse files
committed
Refactor unit test code
1 parent dc4ca31 commit 0416b0d

1 file changed

Lines changed: 95 additions & 126 deletions

File tree

datafusion/physical-plan/src/scalar_subquery.rs

Lines changed: 95 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,11 @@ mod tests {
335335
use arrow::record_batch::RecordBatch;
336336
use datafusion_physical_expr::scalar_subquery::ScalarSubqueryExpr;
337337

338+
enum ExpectedSubqueryResult {
339+
Value(ScalarValue),
340+
Error(&'static str),
341+
}
342+
338343
#[derive(Debug)]
339344
struct CountingExec {
340345
inner: Arc<dyn ExecutionPlan>,
@@ -406,8 +411,53 @@ mod tests {
406411
TestMemoryExec::try_new_exec(&[batches], schema, None).unwrap()
407412
}
408413

409-
fn make_results(n: usize) -> ScalarSubqueryResults {
410-
ScalarSubqueryResults::new(n)
414+
fn int32_batch(values: Vec<i32>) -> RecordBatch {
415+
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
416+
RecordBatch::try_new(schema, vec![Arc::new(Int32Array::from(values))]).unwrap()
417+
}
418+
419+
fn empty_int64_batch() -> RecordBatch {
420+
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, true)]));
421+
RecordBatch::try_new(schema, vec![Arc::new(Int64Array::from(vec![] as Vec<i64>))])
422+
.unwrap()
423+
}
424+
425+
fn placeholder_input() -> Arc<dyn ExecutionPlan> {
426+
Arc::new(crate::placeholder_row::PlaceholderRowExec::new(
427+
test::aggr_test_schema(),
428+
))
429+
}
430+
431+
fn single_subquery_exec(
432+
input: Arc<dyn ExecutionPlan>,
433+
subquery_plan: Arc<dyn ExecutionPlan>,
434+
results: ScalarSubqueryResults,
435+
) -> ScalarSubqueryExec {
436+
ScalarSubqueryExec::new(
437+
input,
438+
vec![ScalarSubqueryLink {
439+
plan: subquery_plan,
440+
index: 0,
441+
}],
442+
results,
443+
)
444+
}
445+
446+
fn scalar_subquery_projection_input(
447+
results: ScalarSubqueryResults,
448+
) -> Result<Arc<dyn ExecutionPlan>> {
449+
Ok(Arc::new(ProjectionExec::try_new(
450+
vec![ProjectionExpr {
451+
expr: Arc::new(ScalarSubqueryExpr::new(
452+
DataType::Int32,
453+
false,
454+
0,
455+
results,
456+
)),
457+
alias: "sq".to_string(),
458+
}],
459+
placeholder_input(),
460+
)?))
411461
}
412462

413463
fn extract_single_int32_value(batches: &[RecordBatch]) -> i32 {
@@ -422,91 +472,40 @@ mod tests {
422472
}
423473

424474
#[tokio::test]
425-
async fn test_single_row_subquery() -> Result<()> {
426-
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
427-
let batch = RecordBatch::try_new(
428-
Arc::clone(&schema),
429-
vec![Arc::new(Int32Array::from(vec![42]))],
430-
)?;
431-
432-
let results = make_results(1);
433-
let subquery_plan = make_subquery_plan(vec![batch]);
434-
let sq = ScalarSubqueryLink {
435-
plan: subquery_plan,
436-
index: 0,
437-
};
438-
439-
let main_input = Arc::new(crate::placeholder_row::PlaceholderRowExec::new(
440-
test::aggr_test_schema(),
441-
));
442-
let exec = ScalarSubqueryExec::new(main_input, vec![sq], results.clone());
443-
444-
let ctx = Arc::new(TaskContext::default());
445-
let stream = exec.execute(0, ctx)?;
446-
let _batches = crate::common::collect(stream).await?;
447-
448-
assert_eq!(results.get(0), Some(ScalarValue::Int32(Some(42))));
449-
Ok(())
450-
}
451-
452-
#[tokio::test]
453-
async fn test_zero_row_subquery_returns_null() -> Result<()> {
454-
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, true)]));
455-
let batch = RecordBatch::try_new(
456-
Arc::clone(&schema),
457-
vec![Arc::new(Int64Array::from(vec![] as Vec<i64>))],
458-
)?;
459-
460-
let results = make_results(1);
461-
let subquery_plan = make_subquery_plan(vec![batch]);
462-
let sq = ScalarSubqueryLink {
463-
plan: subquery_plan,
464-
index: 0,
465-
};
466-
467-
let main_input = Arc::new(crate::placeholder_row::PlaceholderRowExec::new(
468-
test::aggr_test_schema(),
469-
));
470-
let exec = ScalarSubqueryExec::new(main_input, vec![sq], results.clone());
471-
472-
let ctx = Arc::new(TaskContext::default());
473-
let stream = exec.execute(0, ctx)?;
474-
let _batches = crate::common::collect(stream).await?;
475-
476-
assert_eq!(results.get(0), Some(ScalarValue::Int64(None)));
477-
Ok(())
478-
}
479-
480-
#[tokio::test]
481-
async fn test_multi_row_subquery_errors() -> Result<()> {
482-
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
483-
let batch = RecordBatch::try_new(
484-
Arc::clone(&schema),
485-
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
486-
)?;
487-
488-
let results = make_results(1);
489-
let subquery_plan = make_subquery_plan(vec![batch]);
490-
let sq = ScalarSubqueryLink {
491-
plan: subquery_plan,
492-
index: 0,
493-
};
494-
495-
let main_input = Arc::new(crate::placeholder_row::PlaceholderRowExec::new(
496-
test::aggr_test_schema(),
497-
));
498-
let exec = ScalarSubqueryExec::new(main_input, vec![sq], results.clone());
499-
500-
let ctx = Arc::new(TaskContext::default());
501-
let stream = exec.execute(0, ctx)?;
502-
let result = crate::common::collect(stream).await;
475+
async fn test_execute_scalar_subquery_row_count_semantics() -> Result<()> {
476+
for (name, plan, expected) in [
477+
(
478+
"single_row",
479+
make_subquery_plan(vec![int32_batch(vec![42])]),
480+
ExpectedSubqueryResult::Value(ScalarValue::Int32(Some(42))),
481+
),
482+
(
483+
"zero_rows",
484+
make_subquery_plan(vec![empty_int64_batch()]),
485+
ExpectedSubqueryResult::Value(ScalarValue::Int64(None)),
486+
),
487+
(
488+
"multiple_rows",
489+
make_subquery_plan(vec![int32_batch(vec![1, 2, 3])]),
490+
ExpectedSubqueryResult::Error("more than one row"),
491+
),
492+
] {
493+
let actual =
494+
execute_scalar_subquery(plan, Arc::new(TaskContext::default())).await;
495+
match expected {
496+
ExpectedSubqueryResult::Value(expected) => {
497+
assert_eq!(actual?, expected, "{name}");
498+
}
499+
ExpectedSubqueryResult::Error(expected) => {
500+
let err = actual.expect_err(name);
501+
assert!(
502+
err.to_string().contains(expected),
503+
"{name}: expected error containing '{expected}', got {err}"
504+
);
505+
}
506+
}
507+
}
503508

504-
assert!(result.is_err());
505-
let err_msg = result.unwrap_err().to_string();
506-
assert!(
507-
err_msg.contains("more than one row"),
508-
"Expected 'more than one row' error, got: {err_msg}"
509-
);
510509
Ok(())
511510
}
512511

@@ -517,16 +516,11 @@ mod tests {
517516
Arc::new(ErrorExec::new()),
518517
Arc::clone(&execute_calls),
519518
));
520-
let results = make_results(1);
521-
let sq = ScalarSubqueryLink {
522-
plan: subquery_plan,
523-
index: 0,
524-
};
525-
526-
let main_input = Arc::new(crate::placeholder_row::PlaceholderRowExec::new(
527-
test::aggr_test_schema(),
528-
));
529-
let exec = ScalarSubqueryExec::new(main_input, vec![sq], results);
519+
let exec = single_subquery_exec(
520+
placeholder_input(),
521+
subquery_plan,
522+
ScalarSubqueryResults::new(1),
523+
);
530524

531525
let ctx = Arc::new(TaskContext::default());
532526
let stream = exec.execute(0, Arc::clone(&ctx))?;
@@ -542,39 +536,14 @@ mod tests {
542536
#[tokio::test]
543537
async fn test_reset_state_clears_results_and_reexecutes_subqueries() -> Result<()> {
544538
let execute_calls = Arc::new(AtomicUsize::new(0));
545-
let results = make_results(1);
546-
let schema =
547-
Arc::new(Schema::new(vec![Field::new("sq", DataType::Int32, false)]));
548-
let batch = RecordBatch::try_new(
549-
Arc::clone(&schema),
550-
vec![Arc::new(Int32Array::from(vec![42]))],
551-
)?;
539+
let results = ScalarSubqueryResults::new(1);
552540
let subquery_plan = Arc::new(CountingExec::new(
553-
make_subquery_plan(vec![batch]),
541+
make_subquery_plan(vec![int32_batch(vec![42])]),
554542
Arc::clone(&execute_calls),
555543
));
556-
let sq = ScalarSubqueryLink {
557-
plan: subquery_plan,
558-
index: 0,
559-
};
560-
561-
let main_input = Arc::new(ProjectionExec::try_new(
562-
vec![ProjectionExpr {
563-
expr: Arc::new(ScalarSubqueryExpr::new(
564-
DataType::Int32,
565-
false,
566-
0,
567-
results.clone(),
568-
)),
569-
alias: "sq".to_string(),
570-
}],
571-
Arc::new(crate::placeholder_row::PlaceholderRowExec::new(
572-
test::aggr_test_schema(),
573-
)),
574-
)?);
575-
let exec: Arc<dyn ExecutionPlan> = Arc::new(ScalarSubqueryExec::new(
576-
main_input,
577-
vec![sq],
544+
let exec: Arc<dyn ExecutionPlan> = Arc::new(single_subquery_exec(
545+
scalar_subquery_projection_input(results.clone())?,
546+
subquery_plan,
578547
results.clone(),
579548
));
580549

0 commit comments

Comments
 (0)