diff --git a/logicalplan/distribute.go b/logicalplan/distribute.go index 446ce5d9..cfe27191 100644 --- a/logicalplan/distribute.go +++ b/logicalplan/distribute.go @@ -177,79 +177,125 @@ func (m DistributedExecutionOptimizer) Optimize(plan Node, opts *query.Options) } } - var warns = annotations.New() + warns := annotations.New() + + parents := computeParents(&plan) + distributionPoints := m.computeDistributionPoints(&plan, parents, engineLabels, warns) - // TODO(fpetkovski): Consider changing TraverseBottomUp to pass in a list of parents in the transform function. - parents := make(map[*Node]*Node) - TraverseBottomUp(nil, &plan, func(parent, current *Node) (stop bool) { - parents[current] = parent - return false - }) TraverseBottomUp(nil, &plan, func(parent, current *Node) (stop bool) { - // Handle avg() specially - it's not distributive but can be distributed as sum/count. + if _, distributeNow := distributionPoints[current]; !distributeNow { + return false + } + if isAvgAggregation(current) { *current = m.distributeAvg(*current, engines, m.subqueryOpts(parents, current, opts), labelRanges) return true } - // If the current operation is not distributive, stop the traversal. - if !isDistributive(current, m.SkipBinaryPushdown, engineLabels, warns) { + if isAbsent(current) { + *current = m.distributeAbsent(*current, engines, calculateStartOffset(current, opts.LookbackDelta), m.subqueryOpts(parents, current, opts)) return true } - // Handle absent functions specially - if isAbsent(current) { - *current = m.distributeAbsent(*current, engines, calculateStartOffset(current, opts.LookbackDelta), m.subqueryOpts(parents, current, opts)) + if isAggregation(current) { + *current = m.distributeAggregation((*current).(*Aggregation), engines, m.subqueryOpts(parents, current, opts), labelRanges) return true } - // If the current node is an aggregation, check if we should distribute here - // or continue traversing up. - if aggr, ok := (*current).(*Aggregation); ok { - // If this aggregation preserves partition labels and there's a - // distributive aggregation ancestor, continue up to let it handle distribution. - // This enables patterns like: - // - topk(10, sum by (P, instance) (X)) - // - sum(metric_a * group by (P) (metric_b)) - // - max(sum by (P, instance) (X)) - // where P is a partition label - we can push the entire expression - // to remote engines. - // - // We need to check ancestors (not just immediate parent) because the - // aggregation might be nested inside a binary expression that is itself - // inside another aggregation: sum(A * group by (P) (B)) + *current = m.distributeQuery(current, engines, m.subqueryOpts(parents, current, opts), labelRanges) + return true + }) + return plan, *warns +} + +func (m DistributedExecutionOptimizer) distributeAggregation(aggr *Aggregation, engines []api.RemoteEngine, opts *query.Options, labelRanges labelSetRanges) Node { + localAggregation := aggr.Op + if aggr.Op == parser.COUNT { + localAggregation = parser.SUM + } + remoteAggregation := newRemoteAggregation(aggr, engines) + subQueries := m.distributeQuery(&remoteAggregation, engines, opts, labelRanges) + return &Aggregation{ + Op: localAggregation, + Expr: subQueries, + Param: aggr.Param, + Grouping: aggr.Grouping, + Without: aggr.Without, + } +} + +func computeParents(plan *Node) map[*Node]*Node { + parents := make(map[*Node]*Node) + TraverseBottomUp(nil, plan, func(parent, current *Node) (stop bool) { + parents[current] = parent + return false + }) + return parents +} + +func (m DistributedExecutionOptimizer) computeDistributionPoints(plan *Node, parents map[*Node]*Node, engineLabels map[string]struct{}, warns *annotations.Annotations) map[*Node]struct{} { + marks := make(map[*Node]struct{}) + + // First pass: mark all "forced" distribution points (aggregations that don't + // preserve partition labels, avg aggregations, absent functions). + Traverse(plan, func(current *Node) { + if isAvgAggregation(current) { + marks[current] = struct{}{} + return + } + if isAbsent(current) { + if m.isDistributive(current, engineLabels, warns) { + marks[current] = struct{}{} + } + return + } + if isAggregation(current) { + if !m.isDistributive(current, engineLabels, warns) { + return + } if preservesPartitionLabels(*current, engineLabels) { - if hasDistributiveAncestor(parents, current, m.SkipBinaryPushdown, engineLabels, warns) { - return false + if m.hasDistributiveAncestor(parents, current, engineLabels, warns) { + return } } - localAggregation := aggr.Op - if aggr.Op == parser.COUNT { - localAggregation = parser.SUM - } + marks[current] = struct{}{} + } + }) - remoteAggregation := newRemoteAggregation(aggr, engines) - subQueries := m.distributeQuery(&remoteAggregation, engines, m.subqueryOpts(parents, current, opts), labelRanges) - *current = &Aggregation{ - Op: localAggregation, - Expr: subQueries, - Param: aggr.Param, - Grouping: aggr.Grouping, - Without: aggr.Without, + // Second pass: for nodes whose siblings have marks, mark them too so both + // sides of a binary expression get distributed. + Traverse(plan, func(current *Node) { + if _, ok := marks[current]; ok { + return + } + if subtreeHasMark(current, marks) { + return + } + if !m.isDistributive(current, engineLabels, warns) { + return + } + parent := parents[current] + if parent != nil && (m.isDistributive(parent, engineLabels, warns) || isAvgAggregation(parent)) { + if !subtreeHasMark(parent, marks) { + return } - return true } + marks[current] = struct{}{} + }) - // If the parent operation is distributive or is an avg (which we handle specially), - // continue the traversal. - if isDistributive(parent, m.SkipBinaryPushdown, engineLabels, warns) || isAvgAggregation(parent) { - return false - } + return marks +} - *current = m.distributeQuery(current, engines, m.subqueryOpts(parents, current, opts), labelRanges) - return true - }) - return plan, *warns +func subtreeHasMark(node *Node, marks map[*Node]struct{}) bool { + for _, child := range (*node).Children() { + if _, ok := marks[child]; ok { + return true + } + if subtreeHasMark(child, marks) { + return true + } + } + return false } func (m DistributedExecutionOptimizer) subqueryOpts(parents map[*Node]*Node, current *Node, opts *query.Options) *query.Options { @@ -633,7 +679,7 @@ func preservesPartitionLabels(expr Node, partitionLabels map[string]struct{}) bo } } -func isDistributive(expr *Node, skipBinaryPushdown bool, engineLabels map[string]struct{}, warns *annotations.Annotations) bool { +func (m DistributedExecutionOptimizer) isDistributive(expr *Node, engineLabels map[string]struct{}, warns *annotations.Annotations) bool { if expr == nil { return false } @@ -645,10 +691,10 @@ func isDistributive(expr *Node, skipBinaryPushdown bool, engineLabels map[string if isBinaryExpressionWithOneScalarSide(e) { return true } - return !skipBinaryPushdown && + return !m.SkipBinaryPushdown && isBinaryExpressionWithDistributableMatching(e, engineLabels) && - isDistributive(&e.LHS, skipBinaryPushdown, engineLabels, warns) && - isDistributive(&e.RHS, skipBinaryPushdown, engineLabels, warns) + m.isDistributive(&e.LHS, engineLabels, warns) && + m.isDistributive(&e.RHS, engineLabels, warns) case *Aggregation: // Certain aggregations are currently not supported. if _, ok := distributiveAggregations[e.Op]; !ok { @@ -832,9 +878,9 @@ func matchesExternalLabels(ms []*labels.Matcher, externalLabels labels.Labels) b // parent chain from the current node that can handle distribution. // We must have an unbroken chain of distributive nodes to the ancestor for it to // be able to handle distribution on our behalf. -func hasDistributiveAncestor(parents map[*Node]*Node, current *Node, skipBinaryPushdown bool, engineLabels map[string]struct{}, warns *annotations.Annotations) bool { +func (m DistributedExecutionOptimizer) hasDistributiveAncestor(parents map[*Node]*Node, current *Node, engineLabels map[string]struct{}, warns *annotations.Annotations) bool { for p := parents[current]; p != nil; p = parents[p] { - if !isDistributive(p, skipBinaryPushdown, engineLabels, warns) { + if !m.isDistributive(p, engineLabels, warns) { // We hit a non-distributive node, so we can't push through it. // No ancestor can help us distribute. return false diff --git a/logicalplan/distribute_test.go b/logicalplan/distribute_test.go index f193f110..0068a63c 100644 --- a/logicalplan/distribute_test.go +++ b/logicalplan/distribute_test.go @@ -522,6 +522,13 @@ count by (cluster) ( expected: `max(dedup(remote(metric_a), remote(metric_a)) + sum by (region, pod) (dedup(remote(sum by (pod, region) (metric_b)), remote(sum by (pod, region) (metric_b)))))`, skipBinopPushdown: true, }, + { + // When the RHS of unless has an aggregation that drops the partition label, + // both sides should still be distributed independently. + name: "unless with aggregation that drops partition label distributes both sides", + expr: `group by (region, instance) (metric_a unless on (region, instance) max by (instance) (metric_b))`, + expected: `group by (region, instance) (dedup(remote(metric_a), remote(metric_a)) unless on (region, instance) max by (instance) (dedup(remote(max by (instance, region) (metric_b)), remote(max by (instance, region) (metric_b)))))`, + }, { // group_left/group_right with partition label cannot be distributed because // match cardinality changes when each partition only sees one value for that label. diff --git a/logicalplan/logical_nodes.go b/logicalplan/logical_nodes.go index 19be0e06..181d80ac 100644 --- a/logicalplan/logical_nodes.go +++ b/logicalplan/logical_nodes.go @@ -572,12 +572,18 @@ func shallowCloneSlice[T any](s []T) []T { return clone } -func isAvgAggregation(expr *Node) bool { +func isAggregation(expr *Node) bool { if expr == nil { return false } - if aggr, ok := (*expr).(*Aggregation); ok { - return aggr.Op == parser.AVG + _, ok := (*expr).(*Aggregation) + return ok +} + +func isAvgAggregation(expr *Node) bool { + if expr == nil { + return false } - return false + aggr, ok := (*expr).(*Aggregation) + return ok && aggr.Op == parser.AVG }