diff --git a/datafusion/physical-optimizer/src/topk_aggregation.rs b/datafusion/physical-optimizer/src/topk_aggregation.rs index 581edd86cd0aa..9e909805f8e73 100644 --- a/datafusion/physical-optimizer/src/topk_aggregation.rs +++ b/datafusion/physical-optimizer/src/topk_aggregation.rs @@ -26,7 +26,9 @@ use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_plan::ExecutionPlan; use datafusion_physical_plan::aggregates::LimitOptions; -use datafusion_physical_plan::aggregates::{AggregateExec, topk_types_supported}; +use datafusion_physical_plan::aggregates::{ + AggregateExec, AggregateMode, topk_types_supported, +}; use datafusion_physical_plan::execution_plan::CardinalityEffect; use datafusion_physical_plan::projection::ProjectionExec; use datafusion_physical_plan::sorts::sort::SortExec; @@ -48,15 +50,37 @@ impl TopKAggregation { order_desc: bool, limit: usize, ) -> Option> { + // Only apply TopK optimization to Single/SinglePartitioned/Partial aggregates + // that CAN use the GroupedTopKAggregateStream. + // FinalPartitioned with RepartitionExec input is skipped, as the stream + // can't currently handle two-stage aggregation. + match aggr.mode() { + AggregateMode::Single + | AggregateMode::SinglePartitioned + | AggregateMode::Partial => {} + _ => return None, + } + // Current only support single group key let (group_key, group_key_alias) = aggr.group_expr().expr().iter().exactly_one().ok()?; let kt = group_key.data_type(&aggr.input().schema()).ok()?; - let vt = if let Some((field, _)) = aggr.get_minmax_desc() { + + // Try to find a MIN/MAX aggregate that matches the ORDER BY clause by field name. + // The sort direction will be handled by the GroupedTopKAggregateStream using limit_options, + // so we don't require the aggregate's natural order to match the REQUEST order. + let minmax_result = aggr.aggr_expr().iter().find_map(|agg_expr| { + agg_expr + .get_minmax_desc() + .filter(|(field, _desc)| order_by == field.name()) + }); + + let vt = if let Some((field, _)) = minmax_result.as_ref() { field.data_type().clone() } else { kt.clone() }; + if !topk_types_supported(&kt, &vt) { return None; } @@ -64,23 +88,17 @@ impl TopKAggregation { return None; } - // Check if this is ordering by an aggregate function (MIN/MAX) - if let Some((field, desc)) = aggr.get_minmax_desc() { - // ensure the sort direction matches aggregate function - if desc != order_desc { - return None; - } - // ensure the sort is on the same field as the aggregate output - if order_by != field.name() { - return None; - } + // Check if this is ordering by an aggregate function (MIN/MAX) or the group key + if minmax_result.is_some() { + // Found a matching MIN/MAX aggregate for the ORDER BY clause by field name. + // The GroupedTopKAggregateStream will handle the sort direction via limit_options } else if aggr.aggr_expr().is_empty() { // This is a GROUP BY without aggregates, check if ordering is on the group key itself if order_by != group_key_alias { return None; } } else { - // Has aggregates but not MIN/MAX, or doesn't DISTINCT + // Has aggregates but none of them are MIN/MAX matching the ORDER BY return None; }