diff --git a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs index 47e3adb455117..0cb9bc78f14c0 100644 --- a/datafusion/core/tests/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/tests/physical_optimizer/enforce_sorting.rs @@ -57,6 +57,7 @@ use datafusion_physical_optimizer::enforce_sorting::replace_with_order_preservin use datafusion_physical_optimizer::enforce_sorting::sort_pushdown::{SortPushDown, assign_initial_requirements, pushdown_sorts}; use datafusion_physical_optimizer::enforce_distribution::EnforceDistribution; use datafusion_physical_optimizer::output_requirements::OutputRequirementExec; +use datafusion_physical_optimizer::sanity_checker::SanityCheckPlan; use datafusion_physical_optimizer::PhysicalOptimizerRule; use datafusion::prelude::*; use arrow::array::{Int32Array, RecordBatch}; @@ -417,6 +418,46 @@ async fn union_with_mix_of_presorted_and_explicitly_resorted_inputs_with_reparti Ok(()) } +#[tokio::test] +async fn output_requirement_adds_merge_after_partition_preserving_sort() -> Result<()> { + let schema = create_test_schema()?; + let input = union_exec(vec![memory_exec(&schema), memory_exec(&schema)]); + let requirement = [PhysicalSortRequirement::new( + col("nullable_col", &schema)?, + Some(SortOptions::new(false, true)), + )] + .into(); + let physical_plan: Arc = Arc::new(OutputRequirementExec::new( + input, + Some(OrderingRequirements::new(requirement)), + Distribution::SinglePartition, + Some(21), + )); + + let config = ConfigOptions::new(); + let optimized_plan = + EnforceSorting::new().optimize(Arc::clone(&physical_plan), &config)?; + SanityCheckPlan::new().optimize(optimized_plan, &config)?; + + let test = EnforceSortingTest::new(physical_plan); + assert_snapshot!(test.run(), @r" + Input Plan: + OutputRequirementExec: order_by=[(nullable_col@0, asc)], dist_by=SinglePartition + UnionExec + DataSourceExec: partitions=1, partition_sizes=[0] + DataSourceExec: partitions=1, partition_sizes=[0] + + Optimized Plan: + OutputRequirementExec: order_by=[(nullable_col@0, asc)], dist_by=SinglePartition + SortPreservingMergeExec: [nullable_col@0 ASC], fetch=21 + SortExec: TopK(fetch=21), expr=[nullable_col@0 ASC], preserve_partitioning=[true] + UnionExec + DataSourceExec: partitions=1, partition_sizes=[0] + DataSourceExec: partitions=1, partition_sizes=[0] + "); + Ok(()) +} + async fn union_with_mix_of_presorted_and_explicitly_resorted_inputs_impl( repartition_sorts: bool, ) -> Result { diff --git a/datafusion/physical-optimizer/src/enforce_sorting/mod.rs b/datafusion/physical-optimizer/src/enforce_sorting/mod.rs index 247ebb2785dd3..6b05e733c6da4 100644 --- a/datafusion/physical-optimizer/src/enforce_sorting/mod.rs +++ b/datafusion/physical-optimizer/src/enforce_sorting/mod.rs @@ -49,8 +49,8 @@ use crate::enforce_sorting::sort_pushdown::{ }; use crate::output_requirements::OutputRequirementExec; use crate::utils::{ - add_sort_above, add_sort_above_with_check, is_coalesce_partitions, is_limit, - is_repartition, is_sort, is_sort_preserving_merge, is_window, + add_sort_above_with_check, add_sort_above_with_distribution, is_coalesce_partitions, + is_limit, is_repartition, is_sort, is_sort_preserving_merge, is_window, }; use datafusion_common::Result; @@ -489,6 +489,7 @@ pub fn ensure_sorting( }; let plan = &requirements.plan; + let required_distributions = plan.required_input_distribution(); let mut updated_children = vec![]; for (idx, (required_ordering, mut child)) in plan .required_input_ordering() @@ -506,13 +507,14 @@ pub fn ensure_sorting( if physical_ordering.is_some() { child = update_child_to_remove_unnecessary_sort(idx, child, plan)?; } - child = add_sort_above( + child = add_sort_above_with_distribution( child, req, plan.as_any() .downcast_ref::() .map(|output| output.fetch()) .unwrap_or(None), + &required_distributions[idx], ); child = update_sort_ctx_children_data(child, true)?; } @@ -609,13 +611,17 @@ fn analyze_immediate_sort_removal( fn adjust_window_sort_removal( mut window_tree: PlanWithCorrespondingSort, ) -> Result { + let required_distribution = window_tree + .plan + .required_input_distribution() + .swap_remove(0); + let requires_single_partition = + matches!(required_distribution, Distribution::SinglePartition); + // Window operators have a single child we need to adjust: let child_node = remove_corresponding_sort_from_sub_plan( window_tree.children.swap_remove(0), - matches!( - window_tree.plan.required_input_distribution()[0], - Distribution::SinglePartition - ), + requires_single_partition, )?; window_tree.children.push(child_node); @@ -647,7 +653,12 @@ fn adjust_window_sort_removal( // Satisfy the ordering requirement so that the window can run: let mut child_node = window_tree.children.swap_remove(0); if let Some(reqs) = reqs { - child_node = add_sort_above(child_node, reqs.into_single(), None); + child_node = add_sort_above_with_distribution( + child_node, + reqs.into_single(), + None, + &required_distribution, + ); } let child_plan = Arc::clone(&child_node.plan); window_tree.children.push(child_node); diff --git a/datafusion/physical-optimizer/src/utils.rs b/datafusion/physical-optimizer/src/utils.rs index 13a1745216e83..808ccc4a710b5 100644 --- a/datafusion/physical-optimizer/src/utils.rs +++ b/datafusion/physical-optimizer/src/utils.rs @@ -18,7 +18,7 @@ use std::sync::Arc; use datafusion_common::Result; -use datafusion_physical_expr::{LexOrdering, LexRequirement}; +use datafusion_physical_expr::{Distribution, LexOrdering, LexRequirement}; use datafusion_physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion_physical_plan::limit::{GlobalLimitExec, LocalLimitExec}; use datafusion_physical_plan::repartition::RepartitionExec; @@ -39,6 +39,33 @@ pub fn add_sort_above( node: PlanContext, sort_requirements: LexRequirement, fetch: Option, +) -> PlanContext { + add_sort_above_impl(node, sort_requirements, fetch, false) +} + +/// This utility function adds a `SortExec` above an operator according to the +/// given ordering requirements. If the parent distribution requires a single +/// input partition, it adds a `SortPreservingMergeExec` above the +/// partition-preserving sort. +pub fn add_sort_above_with_distribution( + node: PlanContext, + sort_requirements: LexRequirement, + fetch: Option, + required_distribution: &Distribution, +) -> PlanContext { + add_sort_above_impl( + node, + sort_requirements, + fetch, + matches!(required_distribution, Distribution::SinglePartition), + ) +} + +fn add_sort_above_impl( + node: PlanContext, + sort_requirements: LexRequirement, + fetch: Option, + requires_single_partition: bool, ) -> PlanContext { let mut sort_reqs: Vec<_> = sort_requirements.into(); sort_reqs.retain(|sort_expr| { @@ -51,11 +78,28 @@ pub fn add_sort_above( let Some(ordering) = LexOrdering::new(sort_exprs) else { return node; }; - let mut new_sort = SortExec::new(ordering, Arc::clone(&node.plan)).with_fetch(fetch); - if node.plan.output_partitioning().partition_count() > 1 { + let input_has_multiple_partitions = + node.plan.output_partitioning().partition_count() > 1; + + let mut new_sort = + SortExec::new(ordering.clone(), Arc::clone(&node.plan)).with_fetch(fetch); + if input_has_multiple_partitions { new_sort = new_sort.with_preserve_partitioning(true); } - PlanContext::new(Arc::new(new_sort), T::default(), vec![node]) + + let sort_node = PlanContext::new(Arc::new(new_sort), T::default(), vec![node]); + if !(requires_single_partition && input_has_multiple_partitions) { + return sort_node; + } + + PlanContext::new( + Arc::new( + SortPreservingMergeExec::new(ordering, Arc::clone(&sort_node.plan)) + .with_fetch(fetch), + ), + T::default(), + vec![sort_node], + ) } /// This utility function adds a `SortExec` above an operator according to the