Skip to content

Commit 560cc5b

Browse files
Rebuild TopK filter state in with_fetch
1 parent 9260bac commit 560cc5b

1 file changed

Lines changed: 68 additions & 6 deletions

File tree

  • datafusion/physical-plan/src/sorts

datafusion/physical-plan/src/sorts/sort.rs

Lines changed: 68 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -940,7 +940,7 @@ impl SortExec {
940940
))
941941
}
942942

943-
/// Rebuild the shared TopK filter wrapper after output partitioning changes.
943+
/// Rebuild the shared TopK filter wrapper for the current output partitioning.
944944
///
945945
/// The dynamic filter expression is preserved, but wrapper state such as the
946946
/// shared threshold and remaining emitter count is reset for the new
@@ -984,14 +984,20 @@ impl SortExec {
984984
if fetch.is_some() && is_pipeline_friendly {
985985
cache = cache.with_boundedness(Boundedness::Bounded);
986986
}
987-
let filter = fetch.is_some().then(|| {
988-
// If we already have a filter, keep it. Otherwise, create a new one.
989-
self.filter.clone().unwrap_or_else(|| self.create_filter())
990-
});
991987
let mut new_sort = self.cloned();
992988
new_sort.fetch = fetch;
993989
new_sort.cache = cache.into();
994-
new_sort.filter = filter;
990+
if fetch.is_some() {
991+
if new_sort.filter.is_some() {
992+
// Keep the dynamic filter expression, but reset wrapper state
993+
// such as the shared threshold and expected emitter count.
994+
new_sort.rebuild_filter_for_current_partitioning();
995+
} else {
996+
new_sort.filter = Some(new_sort.create_filter());
997+
}
998+
} else {
999+
new_sort.filter = None;
1000+
}
9951001
new_sort
9961002
}
9971003

@@ -2866,6 +2872,62 @@ mod tests {
28662872
Ok(())
28672873
}
28682874

2875+
#[tokio::test]
2876+
async fn test_with_fetch_rebuilds_existing_topk_filter() -> Result<()> {
2877+
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
2878+
let partitions = vec![
2879+
vec![RecordBatch::try_new(
2880+
Arc::clone(&schema),
2881+
vec![Arc::new(Int32Array::from(vec![3, 1, 2]))],
2882+
)?],
2883+
vec![RecordBatch::try_new(
2884+
Arc::clone(&schema),
2885+
vec![Arc::new(Int32Array::from(vec![6, 4, 5]))],
2886+
)?],
2887+
];
2888+
let input = TestMemoryExec::try_new_exec(&partitions, Arc::clone(&schema), None)?;
2889+
let dynamic_filter = Arc::new(DynamicFilterPhysicalExpr::new(
2890+
vec![Arc::new(Column::new("a", 0))],
2891+
lit(true),
2892+
));
2893+
let dynamic_filter_id = dynamic_filter
2894+
.expression_id()
2895+
.expect("DynamicFilterPhysicalExpr always has an expression_id");
2896+
let sort = SortExec::new(
2897+
[PhysicalSortExpr::new_default(Arc::new(Column::new("a", 0)))].into(),
2898+
input,
2899+
)
2900+
.with_dynamic_filter_expr(dynamic_filter)?
2901+
.with_preserve_partitioning(true)
2902+
.with_fetch(Some(2));
2903+
2904+
let dynamic_filter = sort
2905+
.dynamic_filter_expr()
2906+
.expect("fetch sort should keep the dynamic filter");
2907+
assert_eq!(
2908+
dynamic_filter
2909+
.expression_id()
2910+
.expect("DynamicFilterPhysicalExpr always has an expression_id"),
2911+
dynamic_filter_id
2912+
);
2913+
2914+
let sort = Arc::new(sort);
2915+
let task_ctx = Arc::new(TaskContext::default());
2916+
2917+
emit_sort_partition(&sort, 0, Arc::clone(&task_ctx)).await?;
2918+
assert_filter_still_waiting(&dynamic_filter);
2919+
2920+
emit_sort_partition(&sort, 1, task_ctx).await?;
2921+
tokio::time::timeout(
2922+
std::time::Duration::from_secs(1),
2923+
dynamic_filter.wait_complete(),
2924+
)
2925+
.await
2926+
.expect("the final preserved SortExec partition should complete the filter");
2927+
2928+
Ok(())
2929+
}
2930+
28692931
#[test]
28702932
fn test_with_dynamic_filter_rejects_invalid_columns() -> Result<()> {
28712933
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));

0 commit comments

Comments
 (0)