Skip to content

Commit 0fcde75

Browse files
authored
perf(optimizer): add sum range optimization (#896)
Optimize sum(m..n) and reduce(m..n, # + #acc) with constant integer bounds to compile-time constants using the arithmetic series formula. Signed-off-by: Ville Vesilehto <ville@vesilehto.fi>
1 parent 3111bbe commit 0fcde75

File tree

3 files changed

+462
-0
lines changed

3 files changed

+462
-0
lines changed

optimizer/optimizer.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ func Optimize(node *Node, config *conf.Config) error {
4040
Walk(node, &filterLast{})
4141
Walk(node, &filterFirst{})
4242
Walk(node, &predicateCombination{})
43+
Walk(node, &sumRange{})
4344
Walk(node, &sumArray{})
4445
Walk(node, &sumMap{})
4546
return nil

optimizer/sum_range.go

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
package optimizer
2+
3+
import (
4+
. "github.com/expr-lang/expr/ast"
5+
)
6+
7+
type sumRange struct{}
8+
9+
func (*sumRange) Visit(node *Node) {
10+
// Pattern 1: sum(m..n) or sum(m..n, predicate) where m and n are constant integers
11+
if sumBuiltin, ok := (*node).(*BuiltinNode); ok &&
12+
sumBuiltin.Name == "sum" &&
13+
(len(sumBuiltin.Arguments) == 1 || len(sumBuiltin.Arguments) == 2) {
14+
if rangeOp, ok := sumBuiltin.Arguments[0].(*BinaryNode); ok && rangeOp.Operator == ".." {
15+
if from, ok := rangeOp.Left.(*IntegerNode); ok {
16+
if to, ok := rangeOp.Right.(*IntegerNode); ok {
17+
m := from.Value
18+
n := to.Value
19+
if n >= m {
20+
count := n - m + 1
21+
// Use the arithmetic series formula: (n - m + 1) * (m + n) / 2
22+
sum := count * (m + n) / 2
23+
24+
if len(sumBuiltin.Arguments) == 1 {
25+
// sum(m..n)
26+
patchWithType(node, &IntegerNode{Value: sum})
27+
} else if len(sumBuiltin.Arguments) == 2 {
28+
// sum(m..n, predicate)
29+
if result, ok := applySumPredicate(sum, count, sumBuiltin.Arguments[1]); ok {
30+
patchWithType(node, &IntegerNode{Value: result})
31+
}
32+
}
33+
}
34+
}
35+
}
36+
}
37+
}
38+
39+
// Pattern 2: reduce(m..n, # + #acc) where m and n are constant integers
40+
if reduceBuiltin, ok := (*node).(*BuiltinNode); ok &&
41+
reduceBuiltin.Name == "reduce" &&
42+
(len(reduceBuiltin.Arguments) == 2 || len(reduceBuiltin.Arguments) == 3) {
43+
if rangeOp, ok := reduceBuiltin.Arguments[0].(*BinaryNode); ok && rangeOp.Operator == ".." {
44+
if from, ok := rangeOp.Left.(*IntegerNode); ok {
45+
if to, ok := rangeOp.Right.(*IntegerNode); ok {
46+
if isPointerPlusAcc(reduceBuiltin.Arguments[1]) {
47+
m := from.Value
48+
n := to.Value
49+
if n >= m {
50+
// Use the arithmetic series formula: (n - m + 1) * (m + n) / 2
51+
sum := (n - m + 1) * (m + n) / 2
52+
53+
// Check for optional initialValue (3rd argument)
54+
if len(reduceBuiltin.Arguments) == 3 {
55+
if initialValue, ok := reduceBuiltin.Arguments[2].(*IntegerNode); ok {
56+
result := initialValue.Value + sum
57+
patchWithType(node, &IntegerNode{Value: result})
58+
}
59+
} else {
60+
patchWithType(node, &IntegerNode{Value: sum})
61+
}
62+
}
63+
}
64+
}
65+
}
66+
}
67+
}
68+
}
69+
70+
// isPointerPlusAcc checks if the node represents `# + #acc` pattern
71+
func isPointerPlusAcc(node Node) bool {
72+
predicate, ok := node.(*PredicateNode)
73+
if !ok {
74+
return false
75+
}
76+
77+
binary, ok := predicate.Node.(*BinaryNode)
78+
if !ok {
79+
return false
80+
}
81+
82+
if binary.Operator != "+" {
83+
return false
84+
}
85+
86+
// Check for # + #acc (pointer + accumulator)
87+
leftPointer, leftIsPointer := binary.Left.(*PointerNode)
88+
rightPointer, rightIsPointer := binary.Right.(*PointerNode)
89+
90+
if leftIsPointer && rightIsPointer {
91+
// # + #acc: Left is pointer (Name=""), Right is acc (Name="acc")
92+
if leftPointer.Name == "" && rightPointer.Name == "acc" {
93+
return true
94+
}
95+
// #acc + #: Left is acc (Name="acc"), Right is pointer (Name="")
96+
if leftPointer.Name == "acc" && rightPointer.Name == "" {
97+
return true
98+
}
99+
}
100+
101+
return false
102+
}
103+
104+
// applySumPredicate tries to compute the result of sum(m..n, predicate) at compile time.
105+
// Returns (result, true) if optimization is possible, (0, false) otherwise.
106+
// Supported predicates:
107+
// - # (identity): result = sum
108+
// - # * k (multiply by constant): result = k * sum
109+
// - k * # (multiply by constant): result = k * sum
110+
// - # + k (add constant): result = sum + count * k
111+
// - k + # (add constant): result = sum + count * k
112+
// - # - k (subtract constant): result = sum - count * k
113+
func applySumPredicate(sum, count int, predicateArg Node) (int, bool) {
114+
predicate, ok := predicateArg.(*PredicateNode)
115+
if !ok {
116+
return 0, false
117+
}
118+
119+
// Case 1: # (identity) - just return the sum
120+
if pointer, ok := predicate.Node.(*PointerNode); ok && pointer.Name == "" {
121+
return sum, true
122+
}
123+
124+
// Case 2: Binary operations with pointer and constant
125+
binary, ok := predicate.Node.(*BinaryNode)
126+
if !ok {
127+
return 0, false
128+
}
129+
130+
pointer, constant, pointerOnLeft := extractPointerAndConstantWithPosition(binary)
131+
if pointer == nil || constant == nil {
132+
return 0, false
133+
}
134+
135+
switch binary.Operator {
136+
case "*":
137+
// # * k or k * # => k * sum
138+
return constant.Value * sum, true
139+
case "+":
140+
// # + k or k + # => sum + count * k
141+
return sum + count*constant.Value, true
142+
case "-":
143+
if pointerOnLeft {
144+
// # - k => sum - count * k
145+
return sum - count*constant.Value, true
146+
}
147+
// k - # => count * k - sum
148+
return count*constant.Value - sum, true
149+
}
150+
151+
return 0, false
152+
}
153+
154+
// extractPointerAndConstantWithPosition extracts pointer (#) and integer constant from a binary node.
155+
// Returns (pointer, constant, pointerOnLeft) or (nil, nil, false) if not matching the expected pattern.
156+
func extractPointerAndConstantWithPosition(binary *BinaryNode) (*PointerNode, *IntegerNode, bool) {
157+
// Try left=pointer, right=constant
158+
if pointer, ok := binary.Left.(*PointerNode); ok && pointer.Name == "" {
159+
if constant, ok := binary.Right.(*IntegerNode); ok {
160+
return pointer, constant, true
161+
}
162+
}
163+
164+
// Try left=constant, right=pointer
165+
if constant, ok := binary.Left.(*IntegerNode); ok {
166+
if pointer, ok := binary.Right.(*PointerNode); ok && pointer.Name == "" {
167+
return pointer, constant, false
168+
}
169+
}
170+
171+
return nil, nil, false
172+
}

0 commit comments

Comments
 (0)