Skip to content

Commit a2d0363

Browse files
committed
vcftools: address remaining review comments
1 parent b306928 commit a2d0363

3 files changed

Lines changed: 102 additions & 32 deletions

File tree

datafusion/bio-function-vcftools/src/logical/optimizer_rule.rs

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,13 @@ struct TraversalResult<'a> {
216216
///
217217
/// This handles transformation CTEs by traversing multiple Projection layers and
218218
/// composing expressions.
219-
fn traverse_to_unnest(plan: &LogicalPlan) -> Option<TraversalResult<'_>> {
219+
///
220+
/// # Returns
221+
///
222+
/// - `Ok(Some(result))` if the pattern matches and traversal succeeded
223+
/// - `Ok(None)` if the pattern doesn't match (no Unnest found)
224+
/// - `Err(e)` if traversal failed due to expression resolution errors
225+
fn traverse_to_unnest(plan: &LogicalPlan) -> Result<Option<TraversalResult<'_>>> {
220226
// Skip SubqueryAlias wrappers
221227
let plan = skip_wrappers(plan);
222228

@@ -266,14 +272,16 @@ fn traverse_to_unnest(plan: &LogicalPlan) -> Option<TraversalResult<'_>> {
266272
});
267273
}
268274

269-
Some(TraversalResult {
275+
Ok(Some(TraversalResult {
270276
unnest,
271277
column_definitions,
272-
})
278+
}))
273279
}
274280
LogicalPlan::Projection(projection) => {
275281
// Recurse into child
276-
let child_result = traverse_to_unnest(projection.input.as_ref())?;
282+
let Some(child_result) = traverse_to_unnest(projection.input.as_ref())? else {
283+
return Ok(None);
284+
};
277285

278286
// Build new column definitions by resolving projection expressions
279287
let mut new_definitions = HashMap::new();
@@ -289,28 +297,33 @@ fn traverse_to_unnest(plan: &LogicalPlan) -> Option<TraversalResult<'_>> {
289297
};
290298

291299
// Resolve the expression against child definitions
292-
let resolved = resolve_expr(&inner_expr, &child_result.column_definitions);
300+
let resolved = resolve_expr(&inner_expr, &child_result.column_definitions)?;
293301
new_definitions.insert(alias, resolved);
294302
}
295303

296-
Some(TraversalResult {
304+
Ok(Some(TraversalResult {
297305
unnest: child_result.unnest,
298306
column_definitions: new_definitions,
299-
})
307+
}))
300308
}
301309
_ => {
302310
trace!(
303311
"traverse_to_unnest: expected Projection or Unnest, got {}",
304312
plan_type_name(plan)
305313
);
306-
None
314+
Ok(None)
307315
}
308316
}
309317
}
310318

311319
/// Resolve an expression by substituting column references with their definitions.
312320
/// Uses DataFusion's `transform` to recursively traverse all expression variants.
313-
fn resolve_expr(expr: &Expr, definitions: &HashMap<String, Expr>) -> Expr {
321+
///
322+
/// # Errors
323+
///
324+
/// Returns an error if the expression tree traversal fails (e.g., due to an
325+
/// unexpected expression variant or internal DataFusion error).
326+
fn resolve_expr(expr: &Expr, definitions: &HashMap<String, Expr>) -> Result<Expr> {
314327
expr.clone()
315328
.transform(|e| {
316329
if let Expr::Column(col) = &e {
@@ -321,7 +334,6 @@ fn resolve_expr(expr: &Expr, definitions: &HashMap<String, Expr>) -> Expr {
321334
Ok(Transformed::no(e))
322335
})
323336
.map(|t| t.data)
324-
.unwrap_or_else(|_| expr.clone())
325337
}
326338

327339
/// Attempt to detect and optimize the pattern.
@@ -355,9 +367,16 @@ fn try_optimize(plan: &LogicalPlan) -> Option<Result<LogicalPlan>> {
355367
};
356368

357369
// Traverse to find Unnest while collecting column definitions
358-
let Some(traversal) = traverse_to_unnest(aggregate.input.as_ref()) else {
359-
trace!("traverse_to_unnest returned None");
360-
return None;
370+
let traversal = match traverse_to_unnest(aggregate.input.as_ref()) {
371+
Ok(Some(t)) => t,
372+
Ok(None) => {
373+
trace!("traverse_to_unnest returned None");
374+
return None;
375+
}
376+
Err(e) => {
377+
// Pattern matched but traversal failed - propagate error
378+
return Some(Err(e));
379+
}
361380
};
362381
let unnest_plan = traversal.unnest;
363382
let column_definitions = traversal.column_definitions;
@@ -410,7 +429,10 @@ fn try_optimize(plan: &LogicalPlan) -> Option<Result<LogicalPlan>> {
410429
if let Expr::AggregateFunction(AggregateFunction { params, .. }) = expr {
411430
if let Some(arg) = params.args.first() {
412431
// Resolve the argument through the column definitions
413-
let resolved = resolve_expr(arg, &column_definitions);
432+
let resolved = match resolve_expr(arg, &column_definitions) {
433+
Ok(r) => r,
434+
Err(e) => return Some(Err(e)),
435+
};
414436
trace!("Resolved array_agg argument: {arg:?} -> {resolved:?}");
415437
transform_exprs.push(resolved);
416438
} else {

datafusion/bio-function-vcftools/src/physical/fused_array_transform_exec.rs

Lines changed: 64 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,9 @@ impl FusedArrayTransformStream {
323323
for col_name in &self.array_columns {
324324
let idx = input_schema.index_of(col_name)?;
325325
let col = batch.column(idx);
326-
let output_array = self.apply_identity_transform(col, col_name)?;
326+
327+
let filtered_col = datafusion::arrow::compute::filter(col.as_ref(), &bool_mask)?;
328+
let output_array = self.apply_identity_transform(&filtered_col, col_name)?;
327329
output_columns.push(output_array);
328330
}
329331
} else {
@@ -627,28 +629,35 @@ mod tests {
627629
use datafusion::arrow::datatypes::{Field, Schema};
628630
use datafusion::physical_plan::test::TestMemoryExec;
629631

630-
fn create_test_batch() -> RecordBatch {
632+
fn create_test_batch(
633+
row0_lista: Vec<f64>,
634+
row0_listb: Vec<f64>,
635+
row1_lista: Vec<f64>,
636+
row1_listb: Vec<f64>
637+
) -> RecordBatch {
631638
let mut list_builder_a = ListBuilder::new(Float64Builder::new());
632639
let mut list_builder_b = ListBuilder::new(Float64Builder::new());
633640

634-
// Row 0: [1.0, 2.0, 3.0], [10.0, 20.0, 30.0]
635-
list_builder_a.values().append_value(1.0);
636-
list_builder_a.values().append_value(2.0);
637-
list_builder_a.values().append_value(3.0);
641+
// Row 0
642+
for val in row0_lista {
643+
list_builder_a.values().append_value(val);
644+
}
638645
list_builder_a.append(true);
639646

640-
list_builder_b.values().append_value(10.0);
641-
list_builder_b.values().append_value(20.0);
642-
list_builder_b.values().append_value(30.0);
647+
for val in row0_listb {
648+
list_builder_b.values().append_value(val);
649+
}
643650
list_builder_b.append(true);
644651

645-
// Row 1: [4.0, 5.0], [40.0, 50.0]
646-
list_builder_a.values().append_value(4.0);
647-
list_builder_a.values().append_value(5.0);
652+
// Row 1
653+
for val in row1_lista {
654+
list_builder_a.values().append_value(val);
655+
}
648656
list_builder_a.append(true);
649657

650-
list_builder_b.values().append_value(40.0);
651-
list_builder_b.values().append_value(50.0);
658+
for val in row1_listb {
659+
list_builder_b.values().append_value(val);
660+
}
652661
list_builder_b.append(true);
653662

654663
let arr_a = list_builder_a.finish();
@@ -677,9 +686,18 @@ mod tests {
677686
.unwrap()
678687
}
679688

689+
macro_rules! create_test_batch {
690+
($row0_lista: expr, $row0_listb: expr, $row1_lista: expr, $row1_listb: expr) => {
691+
create_test_batch($row0_lista, $row0_listb, $row1_lista, $row1_listb)
692+
};
693+
() => {
694+
create_test_batch(vec![1.0, 2.0, 3.0], vec![10.0, 20.0, 30.0], vec![4.0, 5.0], vec![40.0, 50.0])
695+
};
696+
}
697+
680698
#[tokio::test]
681699
async fn test_identity_transform() {
682-
let batch = create_test_batch();
700+
let batch = create_test_batch!();
683701
let schema = batch.schema();
684702

685703
let mem_exec = TestMemoryExec::try_new(&[vec![batch.clone()]], schema, None).unwrap();
@@ -695,11 +713,41 @@ mod tests {
695713

696714
// Schema should have 2 fields: metadata + values_a_out
697715
assert_eq!(fused.schema().fields().len(), 2);
716+
717+
let ctx = Arc::new(TaskContext::default());
718+
let mut stream = fused.execute(0, ctx).unwrap();
719+
let result_batch = stream.next().await.unwrap().unwrap();
720+
assert_eq!(result_batch.num_rows(), 2);
721+
}
722+
723+
#[tokio::test]
724+
async fn test_identity_transform_with_empty_array() {
725+
let batch = create_test_batch!(vec![], vec![], vec![4.0, 5.0], vec![40.0, 50.0]);
726+
let schema = batch.schema();
727+
728+
let mem_exec = TestMemoryExec::try_new(&[vec![batch.clone()]], schema, None).unwrap();
729+
730+
let fused = FusedArrayTransformExec::try_new(
731+
Arc::new(mem_exec),
732+
vec!["values_a".to_string(), "values_b".to_string()],
733+
vec!["metadata".to_string()],
734+
vec!["values_a_out".to_string(), "values_b_out".to_string()],
735+
vec![],
736+
)
737+
.unwrap();
738+
739+
// Schema should have 3 fields: metadata + values_a_out + values_b_out
740+
assert_eq!(fused.schema().fields().len(), 3);
741+
// row0 should be filtered out due to empty array, so only row1 remains
742+
let ctx = Arc::new(TaskContext::default());
743+
let mut stream = fused.execute(0, ctx).unwrap();
744+
let result_batch = stream.next().await.unwrap().unwrap();
745+
assert_eq!(result_batch.num_rows(), 1);
698746
}
699747

700748
#[tokio::test]
701749
async fn test_execution() {
702-
let batch = create_test_batch();
750+
let batch = create_test_batch!();
703751
let schema = batch.schema();
704752

705753
let mem_exec = TestMemoryExec::try_new(&[vec![batch.clone()]], schema, None).unwrap();

datafusion/bio-function-vcftools/tests/integration_test.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -959,7 +959,7 @@ async fn test_mismatched_array_lengths() {
959959
.to_string();
960960
assert!(
961961
plan_str_optimized.contains("FusedArrayTransform"),
962-
"FusedArrayTransform optimization was NOT applied for empty arrays case! Physical plan:\n{plan_str_optimized}"
962+
"FusedArrayTransform optimization was NOT applied for mismatched lengths case! Physical plan:\n{plan_str_optimized}"
963963
);
964964
let df_optimized2 = ctx_optimized.sql(sql).await.unwrap();
965965
let optimized_results = df_optimized2.collect().await.unwrap();
@@ -973,7 +973,7 @@ async fn test_mismatched_array_lengths() {
973973
.to_string();
974974
assert_eq!(
975975
baseline_str, optimized_str,
976-
"Baseline and optimized results differ for empty arrays case"
976+
"Baseline and optimized results differ for mismatched lengths case"
977977
);
978978
}
979979

0 commit comments

Comments
 (0)