Skip to content

Commit 13c5b65

Browse files
authored
perf(optimizer): count comparisons to any (#897)
Optimize count(...) comparisons to use any/none builtins which support early exit on first matching element: - count(arr, pred) > 0 → any(arr, pred) - count(arr, pred) >= 1 → any(arr, pred) Signed-off-by: Ville Vesilehto <ville@vesilehto.fi>
1 parent 0fcde75 commit 13c5b65

3 files changed

Lines changed: 203 additions & 0 deletions

File tree

optimizer/count_any.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package optimizer
2+
3+
import (
4+
. "github.com/expr-lang/expr/ast"
5+
)
6+
7+
// countAny optimizes count comparisons to use any for early termination.
8+
// Patterns:
9+
// - count(arr, pred) > 0 → any(arr, pred)
10+
// - count(arr, pred) >= 1 → any(arr, pred)
11+
type countAny struct{}
12+
13+
func (*countAny) Visit(node *Node) {
14+
binary, ok := (*node).(*BinaryNode)
15+
if !ok {
16+
return
17+
}
18+
19+
count, ok := binary.Left.(*BuiltinNode)
20+
if !ok || count.Name != "count" || len(count.Arguments) != 2 {
21+
return
22+
}
23+
24+
integer, ok := binary.Right.(*IntegerNode)
25+
if !ok {
26+
return
27+
}
28+
29+
if (binary.Operator == ">" && integer.Value == 0) ||
30+
(binary.Operator == ">=" && integer.Value == 1) {
31+
patchCopyType(node, &BuiltinNode{
32+
Name: "any",
33+
Arguments: count.Arguments,
34+
})
35+
}
36+
}

optimizer/count_any_test.go

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
package optimizer_test
2+
3+
import (
4+
"testing"
5+
6+
"github.com/expr-lang/expr"
7+
. "github.com/expr-lang/expr/ast"
8+
"github.com/expr-lang/expr/internal/testify/assert"
9+
"github.com/expr-lang/expr/internal/testify/require"
10+
"github.com/expr-lang/expr/optimizer"
11+
"github.com/expr-lang/expr/parser"
12+
"github.com/expr-lang/expr/vm"
13+
)
14+
15+
func TestOptimize_count_any(t *testing.T) {
16+
tree, err := parser.Parse(`count(items, .active) > 0`)
17+
require.NoError(t, err)
18+
19+
err = optimizer.Optimize(&tree.Node, nil)
20+
require.NoError(t, err)
21+
22+
expected := &BuiltinNode{
23+
Name: "any",
24+
Arguments: []Node{
25+
&IdentifierNode{Value: "items"},
26+
&PredicateNode{
27+
Node: &MemberNode{
28+
Node: &PointerNode{},
29+
Property: &StringNode{Value: "active"},
30+
},
31+
},
32+
},
33+
}
34+
35+
assert.Equal(t, Dump(expected), Dump(tree.Node))
36+
}
37+
38+
func TestOptimize_count_any_gte_one(t *testing.T) {
39+
tree, err := parser.Parse(`count(items, .valid) >= 1`)
40+
require.NoError(t, err)
41+
42+
err = optimizer.Optimize(&tree.Node, nil)
43+
require.NoError(t, err)
44+
45+
expected := &BuiltinNode{
46+
Name: "any",
47+
Arguments: []Node{
48+
&IdentifierNode{Value: "items"},
49+
&PredicateNode{
50+
Node: &MemberNode{
51+
Node: &PointerNode{},
52+
Property: &StringNode{Value: "valid"},
53+
},
54+
},
55+
},
56+
}
57+
58+
assert.Equal(t, Dump(expected), Dump(tree.Node))
59+
}
60+
61+
func TestOptimize_count_any_correctness(t *testing.T) {
62+
tests := []struct {
63+
expr string
64+
want bool
65+
}{
66+
// count > 0 → any
67+
{`count(1..100, # == 1) > 0`, true},
68+
{`count(1..100, # == 50) > 0`, true},
69+
{`count(1..100, # == 100) > 0`, true},
70+
{`count(1..100, # == 0) > 0`, false},
71+
72+
// count >= 1 → any
73+
{`count(1..100, # % 10 == 0) >= 1`, true},
74+
{`count(1..100, # > 100) >= 1`, false},
75+
}
76+
77+
for _, tt := range tests {
78+
t.Run(tt.expr, func(t *testing.T) {
79+
program, err := expr.Compile(tt.expr)
80+
require.NoError(t, err)
81+
82+
output, err := expr.Run(program, nil)
83+
require.NoError(t, err)
84+
assert.Equal(t, tt.want, output)
85+
})
86+
}
87+
}
88+
89+
func TestOptimize_count_no_optimization(t *testing.T) {
90+
// These should NOT be optimized
91+
tests := []string{
92+
`count(items, .active) > 1`, // not > 0
93+
`count(items, .active) >= 2`, // not >= 1
94+
`count(items, .active) == 0`, // not optimized (none has overhead)
95+
`count(items, .active) == 1`, // not == 0
96+
`count(items, .active) < 1`, // not optimized (none has overhead)
97+
`count(items, .active) <= 0`, // not optimized (none has overhead)
98+
`count(items, .active) != 0`, // different operator
99+
}
100+
101+
for _, code := range tests {
102+
t.Run(code, func(t *testing.T) {
103+
tree, err := parser.Parse(code)
104+
require.NoError(t, err)
105+
106+
err = optimizer.Optimize(&tree.Node, nil)
107+
require.NoError(t, err)
108+
109+
// Should still be a BinaryNode (not optimized to any)
110+
_, ok := tree.Node.(*BinaryNode)
111+
assert.True(t, ok, "expected BinaryNode, got %T", tree.Node)
112+
})
113+
}
114+
}
115+
116+
// Benchmarks for count > 0 → any
117+
func BenchmarkCountGtZero(b *testing.B) {
118+
program, _ := expr.Compile(`count(1..1000, # == 1) > 0`)
119+
var out any
120+
b.ResetTimer()
121+
for n := 0; n < b.N; n++ {
122+
out, _ = vm.Run(program, nil)
123+
}
124+
_ = out
125+
}
126+
127+
func BenchmarkCountGtZeroLargeEarlyMatch(b *testing.B) {
128+
program, _ := expr.Compile(`count(1..10000, # == 1) > 0`)
129+
var out any
130+
b.ResetTimer()
131+
for n := 0; n < b.N; n++ {
132+
out, _ = vm.Run(program, nil)
133+
}
134+
_ = out
135+
}
136+
137+
func BenchmarkCountGtZeroNoMatch(b *testing.B) {
138+
program, _ := expr.Compile(`count(1..1000, # == 0) > 0`)
139+
var out any
140+
b.ResetTimer()
141+
for n := 0; n < b.N; n++ {
142+
out, _ = vm.Run(program, nil)
143+
}
144+
_ = out
145+
}
146+
147+
// Benchmarks for count >= 1 → any
148+
func BenchmarkCountGteOneEarlyMatch(b *testing.B) {
149+
program, _ := expr.Compile(`count(1..1000, # == 1) >= 1`)
150+
var out any
151+
b.ResetTimer()
152+
for n := 0; n < b.N; n++ {
153+
out, _ = vm.Run(program, nil)
154+
}
155+
_ = out
156+
}
157+
158+
func BenchmarkCountGteOneNoMatch(b *testing.B) {
159+
program, _ := expr.Compile(`count(1..1000, # == 0) >= 1`)
160+
var out any
161+
b.ResetTimer()
162+
for n := 0; n < b.N; n++ {
163+
out, _ = vm.Run(program, nil)
164+
}
165+
_ = out
166+
}

optimizer/optimizer.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ func Optimize(node *Node, config *conf.Config) error {
4343
Walk(node, &sumRange{})
4444
Walk(node, &sumArray{})
4545
Walk(node, &sumMap{})
46+
Walk(node, &countAny{})
4647
return nil
4748
}
4849

0 commit comments

Comments
 (0)