Skip to content

Commit 4ff281d

Browse files
authored
perf(optimizer): add count threshold comparisons (#898)
Optimize the following patterns: - count(arr, pred) > N - count(arr, pred) >= N - count(arr, pred) < N - count(arr, pred) <= N Add a threshold check inside the count loop. When the count reaches the threshold, the loop exits early instead of scanning the entire array. This is implemented via a new Threshold field on BuiltinNode that the optimizer sets when detecting these patterns. The compiler then emits bytecode that checks the count against the threshold after each increment and jumps out of the loop when reached. Signed-off-by: Ville Vesilehto <ville@vesilehto.fi>
1 parent ebbe1ba commit 4ff281d

File tree

5 files changed

+351
-0
lines changed

5 files changed

+351
-0
lines changed

ast/node.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ type BuiltinNode struct {
187187
Arguments []Node // Arguments of the builtin function.
188188
Throws bool // If true then accessing a field or array index can throw an error. Used by optimizer.
189189
Map Node // Used by optimizer to fold filter() and map() builtins.
190+
Threshold *int // Used by optimizer for count() early termination.
190191
}
191192

192193
// PredicateNode represents a predicate.

compiler/compiler.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -937,6 +937,7 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) {
937937
c.compile(node.Arguments[0])
938938
c.derefInNeeded(node.Arguments[0])
939939
c.emit(OpBegin)
940+
var loopBreak int
940941
c.emitLoop(func() {
941942
if len(node.Arguments) == 2 {
942943
c.compile(node.Arguments[1])
@@ -945,9 +946,25 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) {
945946
}
946947
c.emitCond(func() {
947948
c.emit(OpIncrementCount)
949+
// Early termination if threshold is set
950+
if node.Threshold != nil {
951+
c.emit(OpGetCount)
952+
c.emit(OpInt, *node.Threshold)
953+
c.emit(OpMoreOrEqual)
954+
loopBreak = c.emit(OpJumpIfTrue, placeholder)
955+
c.emit(OpPop)
956+
}
948957
})
949958
})
950959
c.emit(OpGetCount)
960+
if node.Threshold != nil {
961+
end := c.emit(OpJump, placeholder)
962+
c.patchJump(loopBreak)
963+
// Early exit path: pop the bool comparison result, push count
964+
c.emit(OpPop)
965+
c.emit(OpGetCount)
966+
c.patchJump(end)
967+
}
951968
c.emit(OpEnd)
952969
return
953970

optimizer/count_threshold.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package optimizer
2+
3+
import (
4+
. "github.com/expr-lang/expr/ast"
5+
)
6+
7+
// countThreshold optimizes count comparisons by setting a threshold for early termination.
8+
// The threshold allows the count loop to exit early once enough matches are found.
9+
// Patterns:
10+
// - count(arr, pred) > N → threshold = N + 1 (exit proves > N is true)
11+
// - count(arr, pred) >= N → threshold = N (exit proves >= N is true)
12+
// - count(arr, pred) < N → threshold = N (exit proves < N is false)
13+
// - count(arr, pred) <= N → threshold = N + 1 (exit proves <= N is false)
14+
type countThreshold struct{}
15+
16+
func (*countThreshold) Visit(node *Node) {
17+
binary, ok := (*node).(*BinaryNode)
18+
if !ok {
19+
return
20+
}
21+
22+
count, ok := binary.Left.(*BuiltinNode)
23+
if !ok || count.Name != "count" || len(count.Arguments) != 2 {
24+
return
25+
}
26+
27+
integer, ok := binary.Right.(*IntegerNode)
28+
if !ok || integer.Value < 0 {
29+
return
30+
}
31+
32+
var threshold int
33+
switch binary.Operator {
34+
case ">":
35+
threshold = integer.Value + 1
36+
case ">=":
37+
threshold = integer.Value
38+
case "<":
39+
threshold = integer.Value
40+
case "<=":
41+
threshold = integer.Value + 1
42+
default:
43+
return
44+
}
45+
46+
// Skip if threshold is 0 or 1 (handled by count_any optimizer)
47+
if threshold <= 1 {
48+
return
49+
}
50+
51+
// Set threshold on the count node for early termination
52+
// The original comparison remains unchanged
53+
count.Threshold = &threshold
54+
}

optimizer/count_threshold_test.go

Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
package optimizer_test
2+
3+
import (
4+
"testing"
5+
6+
"github.com/expr-lang/expr"
7+
. "github.com/expr-lang/expr/ast"
8+
"github.com/expr-lang/expr/internal/testify/assert"
9+
"github.com/expr-lang/expr/internal/testify/require"
10+
"github.com/expr-lang/expr/optimizer"
11+
"github.com/expr-lang/expr/parser"
12+
"github.com/expr-lang/expr/vm"
13+
)
14+
15+
func TestOptimize_count_threshold_gt(t *testing.T) {
16+
tree, err := parser.Parse(`count(items, .active) > 100`)
17+
require.NoError(t, err)
18+
19+
err = optimizer.Optimize(&tree.Node, nil)
20+
require.NoError(t, err)
21+
22+
// Operator should remain >, but count should have threshold set
23+
binary, ok := tree.Node.(*BinaryNode)
24+
require.True(t, ok, "expected BinaryNode, got %T", tree.Node)
25+
assert.Equal(t, ">", binary.Operator)
26+
27+
count, ok := binary.Left.(*BuiltinNode)
28+
require.True(t, ok, "expected BuiltinNode, got %T", binary.Left)
29+
assert.Equal(t, "count", count.Name)
30+
require.NotNil(t, count.Threshold)
31+
assert.Equal(t, 101, *count.Threshold) // threshold = N + 1 for > operator
32+
}
33+
34+
func TestOptimize_count_threshold_gte(t *testing.T) {
35+
tree, err := parser.Parse(`count(items, .active) >= 50`)
36+
require.NoError(t, err)
37+
38+
err = optimizer.Optimize(&tree.Node, nil)
39+
require.NoError(t, err)
40+
41+
// Operator should remain >=, but count should have threshold set
42+
binary, ok := tree.Node.(*BinaryNode)
43+
require.True(t, ok, "expected BinaryNode, got %T", tree.Node)
44+
assert.Equal(t, ">=", binary.Operator)
45+
46+
count, ok := binary.Left.(*BuiltinNode)
47+
require.True(t, ok, "expected BuiltinNode, got %T", binary.Left)
48+
assert.Equal(t, "count", count.Name)
49+
require.NotNil(t, count.Threshold)
50+
assert.Equal(t, 50, *count.Threshold) // threshold = N for >= operator
51+
}
52+
53+
func TestOptimize_count_threshold_lt(t *testing.T) {
54+
tree, err := parser.Parse(`count(items, .active) < 100`)
55+
require.NoError(t, err)
56+
57+
err = optimizer.Optimize(&tree.Node, nil)
58+
require.NoError(t, err)
59+
60+
// Operator should remain <, but count should have threshold set
61+
binary, ok := tree.Node.(*BinaryNode)
62+
require.True(t, ok, "expected BinaryNode, got %T", tree.Node)
63+
assert.Equal(t, "<", binary.Operator)
64+
65+
count, ok := binary.Left.(*BuiltinNode)
66+
require.True(t, ok, "expected BuiltinNode, got %T", binary.Left)
67+
assert.Equal(t, "count", count.Name)
68+
require.NotNil(t, count.Threshold)
69+
assert.Equal(t, 100, *count.Threshold) // threshold = N for < operator
70+
}
71+
72+
func TestOptimize_count_threshold_lte(t *testing.T) {
73+
tree, err := parser.Parse(`count(items, .active) <= 50`)
74+
require.NoError(t, err)
75+
76+
err = optimizer.Optimize(&tree.Node, nil)
77+
require.NoError(t, err)
78+
79+
// Operator should remain <=, but count should have threshold set
80+
binary, ok := tree.Node.(*BinaryNode)
81+
require.True(t, ok, "expected BinaryNode, got %T", tree.Node)
82+
assert.Equal(t, "<=", binary.Operator)
83+
84+
count, ok := binary.Left.(*BuiltinNode)
85+
require.True(t, ok, "expected BuiltinNode, got %T", binary.Left)
86+
assert.Equal(t, "count", count.Name)
87+
require.NotNil(t, count.Threshold)
88+
assert.Equal(t, 51, *count.Threshold) // threshold = N + 1 for <= operator
89+
}
90+
91+
func TestOptimize_count_threshold_correctness(t *testing.T) {
92+
tests := []struct {
93+
expr string
94+
want bool
95+
}{
96+
// count > N (threshold = N + 1)
97+
{`count(1..1000, # <= 100) > 50`, true}, // 100 matches > 50
98+
{`count(1..1000, # <= 100) > 100`, false}, // 100 matches not > 100
99+
{`count(1..1000, # <= 100) > 99`, true}, // 100 matches > 99
100+
{`count(1..100, # > 0) > 50`, true}, // 100 matches > 50
101+
{`count(1..100, # > 0) > 100`, false}, // 100 matches not > 100
102+
103+
// count >= N (threshold = N)
104+
{`count(1..1000, # <= 100) >= 100`, true}, // 100 matches >= 100
105+
{`count(1..1000, # <= 100) >= 101`, false}, // 100 matches not >= 101
106+
{`count(1..100, # > 0) >= 50`, true}, // 100 matches >= 50
107+
{`count(1..100, # > 0) >= 100`, true}, // 100 matches >= 100
108+
109+
// count < N (threshold = N)
110+
{`count(1..1000, # <= 100) < 101`, true}, // 100 matches < 101
111+
{`count(1..1000, # <= 100) < 100`, false}, // 100 matches not < 100
112+
{`count(1..1000, # <= 100) < 50`, false}, // 100 matches not < 50
113+
{`count(1..100, # > 0) < 101`, true}, // 100 matches < 101
114+
{`count(1..100, # > 0) < 100`, false}, // 100 matches not < 100
115+
116+
// count <= N (threshold = N + 1)
117+
{`count(1..1000, # <= 100) <= 100`, true}, // 100 matches <= 100
118+
{`count(1..1000, # <= 100) <= 99`, false}, // 100 matches not <= 99
119+
{`count(1..1000, # <= 100) <= 50`, false}, // 100 matches not <= 50
120+
{`count(1..100, # > 0) <= 100`, true}, // 100 matches <= 100
121+
{`count(1..100, # > 0) <= 99`, false}, // 100 matches not <= 99
122+
}
123+
124+
for _, tt := range tests {
125+
t.Run(tt.expr, func(t *testing.T) {
126+
program, err := expr.Compile(tt.expr)
127+
require.NoError(t, err)
128+
129+
output, err := expr.Run(program, nil)
130+
require.NoError(t, err)
131+
assert.Equal(t, tt.want, output)
132+
})
133+
}
134+
}
135+
136+
func TestOptimize_count_threshold_no_optimization(t *testing.T) {
137+
// These should NOT get a threshold (handled by count_any or not optimizable)
138+
tests := []struct {
139+
code string
140+
threshold bool
141+
}{
142+
{`count(items, .active) > 0`, false}, // handled by count_any
143+
{`count(items, .active) >= 1`, false}, // handled by count_any
144+
{`count(items, .active) < 1`, false}, // threshold = 1, skipped
145+
{`count(items, .active) <= 0`, false}, // threshold = 1, skipped
146+
{`count(items, .active) == 10`, false}, // not supported
147+
}
148+
149+
for _, tt := range tests {
150+
t.Run(tt.code, func(t *testing.T) {
151+
tree, err := parser.Parse(tt.code)
152+
require.NoError(t, err)
153+
154+
err = optimizer.Optimize(&tree.Node, nil)
155+
require.NoError(t, err)
156+
157+
// Check if count has threshold set
158+
var count *BuiltinNode
159+
if binary, ok := tree.Node.(*BinaryNode); ok {
160+
count, _ = binary.Left.(*BuiltinNode)
161+
} else if builtin, ok := tree.Node.(*BuiltinNode); ok {
162+
count = builtin
163+
}
164+
165+
if count != nil && count.Name == "count" {
166+
if tt.threshold {
167+
assert.NotNil(t, count.Threshold, "expected threshold to be set")
168+
} else {
169+
assert.Nil(t, count.Threshold, "expected threshold to be nil")
170+
}
171+
}
172+
})
173+
}
174+
}
175+
176+
// Benchmark: count > 100 with early match (element 101 matches early)
177+
func BenchmarkCountThresholdEarlyMatch(b *testing.B) {
178+
// Array of 10000 elements, all match predicate, threshold is 101
179+
// Should exit after ~101 iterations
180+
program, _ := expr.Compile(`count(1..10000, # > 0) > 100`)
181+
var out any
182+
b.ResetTimer()
183+
for n := 0; n < b.N; n++ {
184+
out, _ = vm.Run(program, nil)
185+
}
186+
_ = out
187+
}
188+
189+
// Benchmark: count >= 50 with early match
190+
func BenchmarkCountThresholdGteEarlyMatch(b *testing.B) {
191+
// All elements match, threshold is 50
192+
// Should exit after ~50 iterations
193+
program, _ := expr.Compile(`count(1..10000, # > 0) >= 50`)
194+
var out any
195+
b.ResetTimer()
196+
for n := 0; n < b.N; n++ {
197+
out, _ = vm.Run(program, nil)
198+
}
199+
_ = out
200+
}
201+
202+
// Benchmark: count > 100 with no early exit (not enough matches)
203+
func BenchmarkCountThresholdNoEarlyExit(b *testing.B) {
204+
// Only 100 elements match (# <= 100), threshold is 101
205+
// Must scan entire array
206+
program, _ := expr.Compile(`count(1..10000, # <= 100) > 100`)
207+
var out any
208+
b.ResetTimer()
209+
for n := 0; n < b.N; n++ {
210+
out, _ = vm.Run(program, nil)
211+
}
212+
_ = out
213+
}
214+
215+
// Benchmark: Large threshold with early match
216+
func BenchmarkCountThresholdLargeEarlyMatch(b *testing.B) {
217+
// All 10000 match, threshold is 1000
218+
// Should exit after ~1000 iterations
219+
program, _ := expr.Compile(`count(1..10000, # > 0) > 999`)
220+
var out any
221+
b.ResetTimer()
222+
for n := 0; n < b.N; n++ {
223+
out, _ = vm.Run(program, nil)
224+
}
225+
_ = out
226+
}
227+
228+
// Benchmark: count < N with early exit (result is false)
229+
func BenchmarkCountThresholdLtEarlyExit(b *testing.B) {
230+
// All 10000 match, threshold is 100
231+
// Should exit after ~100 iterations with result = false
232+
program, _ := expr.Compile(`count(1..10000, # > 0) < 100`)
233+
var out any
234+
b.ResetTimer()
235+
for n := 0; n < b.N; n++ {
236+
out, _ = vm.Run(program, nil)
237+
}
238+
_ = out
239+
}
240+
241+
// Benchmark: count <= N with early exit (result is false)
242+
func BenchmarkCountThresholdLteEarlyExit(b *testing.B) {
243+
// All 10000 match, threshold is 51
244+
// Should exit after ~51 iterations with result = false
245+
program, _ := expr.Compile(`count(1..10000, # > 0) <= 50`)
246+
var out any
247+
b.ResetTimer()
248+
for n := 0; n < b.N; n++ {
249+
out, _ = vm.Run(program, nil)
250+
}
251+
_ = out
252+
}
253+
254+
// Benchmark: count < N without early exit (result is true)
255+
func BenchmarkCountThresholdLtNoEarlyExit(b *testing.B) {
256+
// Only 100 elements match (# <= 100), threshold is 200
257+
// Must scan entire array, result = true
258+
program, _ := expr.Compile(`count(1..10000, # <= 100) < 200`)
259+
var out any
260+
b.ResetTimer()
261+
for n := 0; n < b.N; n++ {
262+
out, _ = vm.Run(program, nil)
263+
}
264+
_ = out
265+
}
266+
267+
// Benchmark: count <= N without early exit (result is true)
268+
func BenchmarkCountThresholdLteNoEarlyExit(b *testing.B) {
269+
// Only 100 elements match (# <= 100), threshold is 101
270+
// Must scan entire array, result = true
271+
program, _ := expr.Compile(`count(1..10000, # <= 100) <= 100`)
272+
var out any
273+
b.ResetTimer()
274+
for n := 0; n < b.N; n++ {
275+
out, _ = vm.Run(program, nil)
276+
}
277+
_ = out
278+
}

optimizer/optimizer.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ func Optimize(node *Node, config *conf.Config) error {
4444
Walk(node, &sumArray{})
4545
Walk(node, &sumMap{})
4646
Walk(node, &countAny{})
47+
Walk(node, &countThreshold{})
4748
return nil
4849
}
4950

0 commit comments

Comments
 (0)