Skip to content

Commit 2ae49e5

Browse files
logiclaplan: distribute aggregations that preserve partition labels (#699)
Signed-off-by: Michael Hoffmann <mhoffmann@cloudflare.com>
1 parent 3542487 commit 2ae49e5

3 files changed

Lines changed: 205 additions & 103 deletions

File tree

logicalplan/distribute.go

Lines changed: 127 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -143,20 +143,6 @@ func (r Noop) ReturnType() parser.ValueType { return parser.ValueTypeVector }
143143

144144
func (r Noop) Type() NodeType { return NoopNode }
145145

146-
// distributiveAggregations are all PromQL aggregations which support
147-
// distributed execution.
148-
var distributiveAggregations = map[parser.ItemType]struct{}{
149-
parser.SUM: {},
150-
parser.MIN: {},
151-
parser.MAX: {},
152-
parser.GROUP: {},
153-
parser.COUNT: {},
154-
parser.BOTTOMK: {},
155-
parser.TOPK: {},
156-
parser.LIMITK: {},
157-
parser.LIMIT_RATIO: {},
158-
}
159-
160146
// DistributedExecutionOptimizer produces a logical plan suitable for
161147
// distributed Query execution.
162148
type DistributedExecutionOptimizer struct {
@@ -185,81 +171,138 @@ func (m DistributedExecutionOptimizer) Optimize(plan Node, opts *query.Options)
185171
}
186172
}
187173

188-
var warns = annotations.New()
174+
warns := annotations.New()
175+
176+
parents := computeParents(&plan)
177+
distributionPoints := m.computeDistributionPoints(&plan, parents, engineLabels, warns)
189178

190-
// TODO(fpetkovski): Consider changing TraverseBottomUp to pass in a list of parents in the transform function.
191-
parents := make(map[*Node]*Node)
192-
TraverseBottomUp(nil, &plan, func(parent, current *Node) (stop bool) {
193-
parents[current] = parent
194-
return false
195-
})
196179
TraverseBottomUp(nil, &plan, func(parent, current *Node) (stop bool) {
197-
// Handle avg() specially - it's not distributive but can be distributed as sum/count.
198-
if isAvgAggregation(current) {
199-
*current = m.distributeAvg(*current, engines, m.subqueryOpts(parents, current, opts), labelRanges)
200-
return true
180+
if _, distributeNow := distributionPoints[current]; !distributeNow {
181+
return false
201182
}
202183

203-
// If the current operation is not distributive, stop the traversal.
204-
if !isDistributive(current, m.SkipBinaryPushdown, engineLabels, warns) {
184+
if isAvgAggregation(current) && !preservesPartitionLabels(*current, engineLabels) {
185+
// avg without partition labels: rewrite as sum/count.
186+
*current = m.distributeAvg(*current, engines, m.subqueryOpts(parents, current, opts), labelRanges)
205187
return true
206188
}
207189

208-
// Handle absent functions specially
209190
if isAbsent(current) {
210191
*current = m.distributeAbsent(*current, engines, calculateStartOffset(current, opts.LookbackDelta), m.subqueryOpts(parents, current, opts))
211192
return true
212193
}
213194

214-
// If the current node is an aggregation, check if we should distribute here
215-
// or continue traversing up.
216-
if aggr, ok := (*current).(*Aggregation); ok {
217-
// If this aggregation preserves partition labels and there's a
218-
// distributive aggregation ancestor, continue up to let it handle distribution.
219-
// This enables patterns like:
220-
// - topk(10, sum by (P, instance) (X))
221-
// - sum(metric_a * group by (P) (metric_b))
222-
// - max(sum by (P, instance) (X))
223-
// where P is a partition label - we can push the entire expression
224-
// to remote engines.
225-
//
226-
// We need to check ancestors (not just immediate parent) because the
227-
// aggregation might be nested inside a binary expression that is itself
228-
// inside another aggregation: sum(A * group by (P) (B))
195+
if isAggregation(current) {
229196
if preservesPartitionLabels(*current, engineLabels) {
230-
if hasDistributiveAncestor(parents, current, m.SkipBinaryPushdown, engineLabels, warns) {
231-
return false
232-
}
233-
}
234-
localAggregation := aggr.Op
235-
if aggr.Op == parser.COUNT {
236-
localAggregation = parser.SUM
237-
}
238-
239-
remoteAggregation := newRemoteAggregation(aggr, engines)
240-
subQueries := m.distributeQuery(&remoteAggregation, engines, m.subqueryOpts(parents, current, opts), labelRanges)
241-
*current = &Aggregation{
242-
Op: localAggregation,
243-
Expr: subQueries,
244-
Param: aggr.Param,
245-
Grouping: aggr.Grouping,
246-
Without: aggr.Without,
197+
// Partition-preserving aggregation: push as-is since each engine
198+
// computes over disjoint partition values.
199+
*current = m.distributeQuery(current, engines, m.subqueryOpts(parents, current, opts), labelRanges)
200+
} else {
201+
// Distributive aggregation that drops partition labels: use a
202+
// two-level split with local_agg(remote_agg(X)).
203+
*current = m.distributeAggregation((*current).(*Aggregation), engines, m.subqueryOpts(parents, current, opts), labelRanges)
247204
}
248205
return true
249206
}
250207

251-
// If the parent operation is distributive or is an avg (which we handle specially),
252-
// continue the traversal.
253-
if isDistributive(parent, m.SkipBinaryPushdown, engineLabels, warns) || isAvgAggregation(parent) {
254-
return false
255-
}
256-
257208
*current = m.distributeQuery(current, engines, m.subqueryOpts(parents, current, opts), labelRanges)
258209
return true
259210
})
260211
return plan, *warns
261212
}
262213

214+
func (m DistributedExecutionOptimizer) distributeAggregation(aggr *Aggregation, engines []api.RemoteEngine, opts *query.Options, labelRanges labelSetRanges) Node {
215+
localAggregation := aggr.Op
216+
if aggr.Op == parser.COUNT {
217+
localAggregation = parser.SUM
218+
}
219+
remoteAggregation := newRemoteAggregation(aggr, engines)
220+
subQueries := m.distributeQuery(&remoteAggregation, engines, opts, labelRanges)
221+
return &Aggregation{
222+
Op: localAggregation,
223+
Expr: subQueries,
224+
Param: aggr.Param,
225+
Grouping: aggr.Grouping,
226+
Without: aggr.Without,
227+
}
228+
}
229+
230+
func computeParents(plan *Node) map[*Node]*Node {
231+
parents := make(map[*Node]*Node)
232+
TraverseBottomUp(nil, plan, func(parent, current *Node) (stop bool) {
233+
parents[current] = parent
234+
return false
235+
})
236+
return parents
237+
}
238+
239+
func (m DistributedExecutionOptimizer) computeDistributionPoints(plan *Node, parents map[*Node]*Node, engineLabels map[string]struct{}, warns *annotations.Annotations) map[*Node]struct{} {
240+
marks := make(map[*Node]struct{})
241+
242+
// First pass: mark distribution points (aggregations, absent functions).
243+
Traverse(plan, func(current *Node) {
244+
if isAbsent(current) {
245+
if m.isDistributive(current, engineLabels, warns) {
246+
marks[current] = struct{}{}
247+
}
248+
return
249+
}
250+
if isAggregation(current) {
251+
// Non-distributive aggregations that don't preserve partition labels
252+
// cannot be distributed, except for avg which gets rewritten as sum/count.
253+
if !m.isDistributive(current, engineLabels, warns) {
254+
if isAvgAggregation(current) {
255+
marks[current] = struct{}{}
256+
}
257+
return
258+
}
259+
// Distributive aggregations (standard or partition-preserving):
260+
// defer to ancestor if possible.
261+
if preservesPartitionLabels(*current, engineLabels) {
262+
if m.hasDistributiveAncestor(parents, current, engineLabels, warns) {
263+
return
264+
}
265+
}
266+
marks[current] = struct{}{}
267+
}
268+
})
269+
270+
// Second pass: for nodes whose siblings have marks, mark them too so both
271+
// sides of a binary expression get distributed.
272+
Traverse(plan, func(current *Node) {
273+
if _, ok := marks[current]; ok {
274+
return
275+
}
276+
if subtreeHasMark(current, marks) {
277+
return
278+
}
279+
if !m.isDistributive(current, engineLabels, warns) {
280+
return
281+
}
282+
parent := parents[current]
283+
if parent != nil && (m.isDistributive(parent, engineLabels, warns) || isAvgAggregation(parent)) {
284+
if !subtreeHasMark(parent, marks) {
285+
return
286+
}
287+
}
288+
marks[current] = struct{}{}
289+
})
290+
291+
return marks
292+
}
293+
294+
func subtreeHasMark(node *Node, marks map[*Node]struct{}) bool {
295+
for _, child := range (*node).Children() {
296+
if _, ok := marks[child]; ok {
297+
return true
298+
}
299+
if subtreeHasMark(child, marks) {
300+
return true
301+
}
302+
}
303+
return false
304+
}
305+
263306
func (m DistributedExecutionOptimizer) subqueryOpts(parents map[*Node]*Node, current *Node, opts *query.Options) *query.Options {
264307
subqueryParents := make([]*Subquery, 0, len(parents))
265308
for p := parents[current]; p != nil; p = parents[p] {
@@ -641,7 +684,7 @@ func preservesPartitionLabels(expr Node, partitionLabels map[string]struct{}) bo
641684
}
642685
}
643686

644-
func isDistributive(expr *Node, skipBinaryPushdown bool, engineLabels map[string]struct{}, warns *annotations.Annotations) bool {
687+
func (m DistributedExecutionOptimizer) isDistributive(expr *Node, engineLabels map[string]struct{}, warns *annotations.Annotations) bool {
645688
if expr == nil {
646689
return false
647690
}
@@ -653,13 +696,23 @@ func isDistributive(expr *Node, skipBinaryPushdown bool, engineLabels map[string
653696
if isBinaryExpressionWithOneScalarSide(e) {
654697
return true
655698
}
656-
return !skipBinaryPushdown &&
699+
return !m.SkipBinaryPushdown &&
657700
isBinaryExpressionWithDistributableMatching(e, engineLabels) &&
658-
isDistributive(&e.LHS, skipBinaryPushdown, engineLabels, warns) &&
659-
isDistributive(&e.RHS, skipBinaryPushdown, engineLabels, warns)
701+
m.isDistributive(&e.LHS, engineLabels, warns) &&
702+
m.isDistributive(&e.RHS, engineLabels, warns)
660703
case *Aggregation:
661-
// Certain aggregations are currently not supported.
662-
if _, ok := distributiveAggregations[e.Op]; !ok {
704+
switch e.Op {
705+
// Mathematically distributive: can be split into local_agg(remote_agg(X))
706+
// regardless of partition labels.
707+
case parser.SUM, parser.MIN, parser.MAX, parser.GROUP, parser.COUNT,
708+
parser.TOPK, parser.BOTTOMK, parser.LIMITK, parser.LIMIT_RATIO:
709+
// Non-distributive: can only be pushed as-is when they preserve
710+
// partition labels (each engine computes over disjoint data).
711+
case parser.AVG, parser.QUANTILE, parser.STDDEV, parser.STDVAR, parser.COUNT_VALUES:
712+
if !preservesPartitionLabels(e, engineLabels) {
713+
return false
714+
}
715+
default:
663716
return false
664717
}
665718
case *FunctionCall:
@@ -840,9 +893,9 @@ func matchesExternalLabels(ms []*labels.Matcher, externalLabels labels.Labels) b
840893
// parent chain from the current node that can handle distribution.
841894
// We must have an unbroken chain of distributive nodes to the ancestor for it to
842895
// be able to handle distribution on our behalf.
843-
func hasDistributiveAncestor(parents map[*Node]*Node, current *Node, skipBinaryPushdown bool, engineLabels map[string]struct{}, warns *annotations.Annotations) bool {
896+
func (m DistributedExecutionOptimizer) hasDistributiveAncestor(parents map[*Node]*Node, current *Node, engineLabels map[string]struct{}, warns *annotations.Annotations) bool {
844897
for p := parents[current]; p != nil; p = parents[p] {
845-
if !isDistributive(p, skipBinaryPushdown, engineLabels, warns) {
898+
if !m.isDistributive(p, engineLabels, warns) {
846899
// We hit a non-distributive node, so we can't push through it.
847900
// No ancestor can help us distribute.
848901
return false

logicalplan/distribute_test.go

Lines changed: 68 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -83,19 +83,13 @@ sum(
8383
remote(count by (pod, region) (metric_a)))))`,
8484
},
8585
{
86-
name: "avg in binary expression with outer sum distributes avg independently",
86+
name: "avg in binary expression with outer sum pushes entire expression",
8787
expr: `sum(metric_a * avg by (region) (metric_b))`,
8888
expected: `
8989
sum(
90-
dedup(remote(metric_a), remote(metric_a))
91-
*
92-
sum by (region) (dedup(
93-
remote(sum by (region) (metric_b)),
94-
remote(sum by (region) (metric_b))))
95-
/ on (region)
96-
sum by (region) (dedup(
97-
remote(count by (region) (metric_b)),
98-
remote(count by (region) (metric_b)))))`,
90+
dedup(
91+
remote(sum by (region) (metric_a * avg by (region) (metric_b))),
92+
remote(sum by (region) (metric_a * avg by (region) (metric_b)))))`,
9993
},
10094
{
10195
name: "selector",
@@ -171,21 +165,12 @@ sum by (pod) (
171165
)`,
172166
},
173167
{
174-
name: "avg with without-grouping",
168+
name: "avg with without-grouping preserving partition labels",
175169
expr: `avg without (pod) (http_requests_total)`,
176170
expected: `
177-
sum without (pod) (
178-
dedup(
179-
remote(sum without (pod) (http_requests_total)),
180-
remote(sum without (pod) (http_requests_total))
181-
)
182-
) / ignoring (pod)
183-
sum without (pod) (
184-
dedup(
185-
remote(count without (pod) (http_requests_total)),
186-
remote(count without (pod) (http_requests_total))
187-
)
188-
)`,
171+
dedup(
172+
remote(avg without (pod) (http_requests_total)),
173+
remote(avg without (pod) (http_requests_total)))`,
189174
},
190175
{
191176
name: "avg with prior aggregation",
@@ -218,6 +203,57 @@ sum by (pod) (
218203
)
219204
)`,
220205
},
206+
{
207+
name: "avg by partition label pushes as-is",
208+
expr: `avg by (region) (http_requests_total)`,
209+
expected: `
210+
dedup(
211+
remote(avg by (region) (http_requests_total)),
212+
remote(avg by (region) (http_requests_total)))`,
213+
},
214+
{
215+
name: "avg by partition label defers to distributive ancestor",
216+
expr: `max(avg by (region) (http_requests_total))`,
217+
expected: `
218+
max(
219+
dedup(
220+
remote(max by (region) (avg by (region) (http_requests_total))),
221+
remote(max by (region) (avg by (region) (http_requests_total)))))`,
222+
},
223+
{
224+
name: "avg over subquery with inner aggregations pushes entire expression",
225+
expr: `avg by (region) (quantile_over_time(0.9, (sum by (region) (rate(metric_a[2m])) / sum by (region) (metric_b))[1h:1m]))`,
226+
expected: `
227+
dedup(
228+
remote(avg by (region) (quantile_over_time(0.9, (sum by (region) (rate(metric_a[2m])) / sum by (region) (metric_b))[1h:1m]))),
229+
remote(avg by (region) (quantile_over_time(0.9, (sum by (region) (rate(metric_a[2m])) / sum by (region) (metric_b))[1h:1m]))))`,
230+
},
231+
{
232+
name: "quantile by partition label pushes as-is",
233+
expr: `quantile by (region) (0.9, http_requests_total)`,
234+
expected: `
235+
dedup(
236+
remote(quantile by (region) (0.9, http_requests_total)),
237+
remote(quantile by (region) (0.9, http_requests_total)))`,
238+
},
239+
{
240+
name: "quantile by non-partition label is not distributed",
241+
expr: `quantile by (pod) (0.9, http_requests_total)`,
242+
expected: `quantile by (pod) (0.9, dedup(remote(http_requests_total), remote(http_requests_total)))`,
243+
},
244+
{
245+
name: "stddev by partition label pushes as-is",
246+
expr: `stddev by (region) (http_requests_total)`,
247+
expected: `
248+
dedup(
249+
remote(stddev by (region) (http_requests_total)),
250+
remote(stddev by (region) (http_requests_total)))`,
251+
},
252+
{
253+
name: "stddev by non-partition label is not distributed",
254+
expr: `stddev by (pod) (http_requests_total)`,
255+
expected: `stddev by (pod) (dedup(remote(http_requests_total), remote(http_requests_total)))`,
256+
},
221257
{
222258
name: "two-level aggregation",
223259
expr: `max by (pod) (sum by (pod) (http_requests_total))`,
@@ -518,15 +554,22 @@ count by (cluster) (
518554
{
519555
name: "skip binary pushdown with nested aggregation",
520556
expr: `sum(metric_a * group by (region) (metric_b))`,
521-
expected: `sum(dedup(remote(metric_a), remote(metric_a)) * group by (region) (dedup(remote(group by (region) (metric_b)), remote(group by (region) (metric_b)))))`,
557+
expected: `sum(dedup(remote(metric_a), remote(metric_a)) * dedup(remote(group by (region) (metric_b)), remote(group by (region) (metric_b))))`,
522558
skipBinopPushdown: true,
523559
},
524560
{
525561
name: "skip binary pushdown with outer aggregation",
526562
expr: `max(metric_a + sum by (region, pod) (metric_b))`,
527-
expected: `max(dedup(remote(metric_a), remote(metric_a)) + sum by (region, pod) (dedup(remote(sum by (pod, region) (metric_b)), remote(sum by (pod, region) (metric_b)))))`,
563+
expected: `max(dedup(remote(metric_a), remote(metric_a)) + dedup(remote(sum by (region, pod) (metric_b)), remote(sum by (region, pod) (metric_b))))`,
528564
skipBinopPushdown: true,
529565
},
566+
{
567+
// When the RHS of unless has an aggregation that drops the partition label,
568+
// both sides should still be distributed independently.
569+
name: "unless with aggregation that drops partition label distributes both sides",
570+
expr: `group by (region, instance) (metric_a unless on (region, instance) max by (instance) (metric_b))`,
571+
expected: `group by (region, instance) (dedup(remote(metric_a), remote(metric_a)) unless on (region, instance) max by (instance) (dedup(remote(max by (instance, region) (metric_b)), remote(max by (instance, region) (metric_b)))))`,
572+
},
530573
{
531574
// group_left/group_right with partition label cannot be distributed because
532575
// match cardinality changes when each partition only sees one value for that label.

0 commit comments

Comments
 (0)