Skip to content

Commit d2cea00

Browse files
authored
Merge branch 'master' into feat/bytes-node
2 parents e0c9169 + de0da09 commit d2cea00

9 files changed

Lines changed: 22000 additions & 5 deletions

File tree

README.md

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
<h1><a href="https://expr-lang.org"><img src="https://expr-lang.org/img/logo.png" alt="Zx logo" height="48"align="right"></a> Expr</h1>
22

3-
> [!IMPORTANT]
4-
> The repository [github.com/antonmedv/expr](https://github.com/antonmedv/expr) moved to [github.com/**expr-lang**/expr](https://github.com/expr-lang/expr).
5-
63
[![test](https://github.com/expr-lang/expr/actions/workflows/test.yml/badge.svg)](https://github.com/expr-lang/expr/actions/workflows/test.yml)
74
[![Go Report Card](https://goreportcard.com/badge/github.com/expr-lang/expr)](https://goreportcard.com/report/github.com/expr-lang/expr)
85
[![Fuzzing Status](https://oss-fuzz-build-logs.storage.googleapis.com/badges/expr.svg)](https://bugs.chromium.org/p/oss-fuzz/issues/list?sort=-opened&can=1&q=proj:expr)

optimizer/count_any.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package optimizer
2+
3+
import (
4+
. "github.com/expr-lang/expr/ast"
5+
)
6+
7+
// countAny optimizes count comparisons to use any for early termination.
8+
// Patterns:
9+
// - count(arr, pred) > 0 → any(arr, pred)
10+
// - count(arr, pred) >= 1 → any(arr, pred)
11+
type countAny struct{}
12+
13+
func (*countAny) Visit(node *Node) {
14+
binary, ok := (*node).(*BinaryNode)
15+
if !ok {
16+
return
17+
}
18+
19+
count, ok := binary.Left.(*BuiltinNode)
20+
if !ok || count.Name != "count" || len(count.Arguments) != 2 {
21+
return
22+
}
23+
24+
integer, ok := binary.Right.(*IntegerNode)
25+
if !ok {
26+
return
27+
}
28+
29+
if (binary.Operator == ">" && integer.Value == 0) ||
30+
(binary.Operator == ">=" && integer.Value == 1) {
31+
patchCopyType(node, &BuiltinNode{
32+
Name: "any",
33+
Arguments: count.Arguments,
34+
})
35+
}
36+
}

optimizer/count_any_test.go

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
package optimizer_test
2+
3+
import (
4+
"testing"
5+
6+
"github.com/expr-lang/expr"
7+
. "github.com/expr-lang/expr/ast"
8+
"github.com/expr-lang/expr/internal/testify/assert"
9+
"github.com/expr-lang/expr/internal/testify/require"
10+
"github.com/expr-lang/expr/optimizer"
11+
"github.com/expr-lang/expr/parser"
12+
"github.com/expr-lang/expr/vm"
13+
)
14+
15+
func TestOptimize_count_any(t *testing.T) {
16+
tree, err := parser.Parse(`count(items, .active) > 0`)
17+
require.NoError(t, err)
18+
19+
err = optimizer.Optimize(&tree.Node, nil)
20+
require.NoError(t, err)
21+
22+
expected := &BuiltinNode{
23+
Name: "any",
24+
Arguments: []Node{
25+
&IdentifierNode{Value: "items"},
26+
&PredicateNode{
27+
Node: &MemberNode{
28+
Node: &PointerNode{},
29+
Property: &StringNode{Value: "active"},
30+
},
31+
},
32+
},
33+
}
34+
35+
assert.Equal(t, Dump(expected), Dump(tree.Node))
36+
}
37+
38+
func TestOptimize_count_any_gte_one(t *testing.T) {
39+
tree, err := parser.Parse(`count(items, .valid) >= 1`)
40+
require.NoError(t, err)
41+
42+
err = optimizer.Optimize(&tree.Node, nil)
43+
require.NoError(t, err)
44+
45+
expected := &BuiltinNode{
46+
Name: "any",
47+
Arguments: []Node{
48+
&IdentifierNode{Value: "items"},
49+
&PredicateNode{
50+
Node: &MemberNode{
51+
Node: &PointerNode{},
52+
Property: &StringNode{Value: "valid"},
53+
},
54+
},
55+
},
56+
}
57+
58+
assert.Equal(t, Dump(expected), Dump(tree.Node))
59+
}
60+
61+
func TestOptimize_count_any_correctness(t *testing.T) {
62+
tests := []struct {
63+
expr string
64+
want bool
65+
}{
66+
// count > 0 → any
67+
{`count(1..100, # == 1) > 0`, true},
68+
{`count(1..100, # == 50) > 0`, true},
69+
{`count(1..100, # == 100) > 0`, true},
70+
{`count(1..100, # == 0) > 0`, false},
71+
72+
// count >= 1 → any
73+
{`count(1..100, # % 10 == 0) >= 1`, true},
74+
{`count(1..100, # > 100) >= 1`, false},
75+
}
76+
77+
for _, tt := range tests {
78+
t.Run(tt.expr, func(t *testing.T) {
79+
program, err := expr.Compile(tt.expr)
80+
require.NoError(t, err)
81+
82+
output, err := expr.Run(program, nil)
83+
require.NoError(t, err)
84+
assert.Equal(t, tt.want, output)
85+
})
86+
}
87+
}
88+
89+
func TestOptimize_count_no_optimization(t *testing.T) {
90+
// These should NOT be optimized
91+
tests := []string{
92+
`count(items, .active) > 1`, // not > 0
93+
`count(items, .active) >= 2`, // not >= 1
94+
`count(items, .active) == 0`, // not optimized (none has overhead)
95+
`count(items, .active) == 1`, // not == 0
96+
`count(items, .active) < 1`, // not optimized (none has overhead)
97+
`count(items, .active) <= 0`, // not optimized (none has overhead)
98+
`count(items, .active) != 0`, // different operator
99+
}
100+
101+
for _, code := range tests {
102+
t.Run(code, func(t *testing.T) {
103+
tree, err := parser.Parse(code)
104+
require.NoError(t, err)
105+
106+
err = optimizer.Optimize(&tree.Node, nil)
107+
require.NoError(t, err)
108+
109+
// Should still be a BinaryNode (not optimized to any)
110+
_, ok := tree.Node.(*BinaryNode)
111+
assert.True(t, ok, "expected BinaryNode, got %T", tree.Node)
112+
})
113+
}
114+
}
115+
116+
// Benchmarks for count > 0 → any
117+
func BenchmarkCountGtZero(b *testing.B) {
118+
program, _ := expr.Compile(`count(1..1000, # == 1) > 0`)
119+
var out any
120+
b.ResetTimer()
121+
for n := 0; n < b.N; n++ {
122+
out, _ = vm.Run(program, nil)
123+
}
124+
_ = out
125+
}
126+
127+
func BenchmarkCountGtZeroLargeEarlyMatch(b *testing.B) {
128+
program, _ := expr.Compile(`count(1..10000, # == 1) > 0`)
129+
var out any
130+
b.ResetTimer()
131+
for n := 0; n < b.N; n++ {
132+
out, _ = vm.Run(program, nil)
133+
}
134+
_ = out
135+
}
136+
137+
func BenchmarkCountGtZeroNoMatch(b *testing.B) {
138+
program, _ := expr.Compile(`count(1..1000, # == 0) > 0`)
139+
var out any
140+
b.ResetTimer()
141+
for n := 0; n < b.N; n++ {
142+
out, _ = vm.Run(program, nil)
143+
}
144+
_ = out
145+
}
146+
147+
// Benchmarks for count >= 1 → any
148+
func BenchmarkCountGteOneEarlyMatch(b *testing.B) {
149+
program, _ := expr.Compile(`count(1..1000, # == 1) >= 1`)
150+
var out any
151+
b.ResetTimer()
152+
for n := 0; n < b.N; n++ {
153+
out, _ = vm.Run(program, nil)
154+
}
155+
_ = out
156+
}
157+
158+
func BenchmarkCountGteOneNoMatch(b *testing.B) {
159+
program, _ := expr.Compile(`count(1..1000, # == 0) >= 1`)
160+
var out any
161+
b.ResetTimer()
162+
for n := 0; n < b.N; n++ {
163+
out, _ = vm.Run(program, nil)
164+
}
165+
_ = out
166+
}

optimizer/optimizer.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,10 @@ func Optimize(node *Node, config *conf.Config) error {
4040
Walk(node, &filterLast{})
4141
Walk(node, &filterFirst{})
4242
Walk(node, &predicateCombination{})
43+
Walk(node, &sumRange{})
4344
Walk(node, &sumArray{})
4445
Walk(node, &sumMap{})
46+
Walk(node, &countAny{})
4547
return nil
4648
}
4749

optimizer/sum_range.go

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
package optimizer
2+
3+
import (
4+
. "github.com/expr-lang/expr/ast"
5+
)
6+
7+
type sumRange struct{}
8+
9+
func (*sumRange) Visit(node *Node) {
10+
// Pattern 1: sum(m..n) or sum(m..n, predicate) where m and n are constant integers
11+
if sumBuiltin, ok := (*node).(*BuiltinNode); ok &&
12+
sumBuiltin.Name == "sum" &&
13+
(len(sumBuiltin.Arguments) == 1 || len(sumBuiltin.Arguments) == 2) {
14+
if rangeOp, ok := sumBuiltin.Arguments[0].(*BinaryNode); ok && rangeOp.Operator == ".." {
15+
if from, ok := rangeOp.Left.(*IntegerNode); ok {
16+
if to, ok := rangeOp.Right.(*IntegerNode); ok {
17+
m := from.Value
18+
n := to.Value
19+
if n >= m {
20+
count := n - m + 1
21+
// Use the arithmetic series formula: (n - m + 1) * (m + n) / 2
22+
sum := count * (m + n) / 2
23+
24+
if len(sumBuiltin.Arguments) == 1 {
25+
// sum(m..n)
26+
patchWithType(node, &IntegerNode{Value: sum})
27+
} else if len(sumBuiltin.Arguments) == 2 {
28+
// sum(m..n, predicate)
29+
if result, ok := applySumPredicate(sum, count, sumBuiltin.Arguments[1]); ok {
30+
patchWithType(node, &IntegerNode{Value: result})
31+
}
32+
}
33+
}
34+
}
35+
}
36+
}
37+
}
38+
39+
// Pattern 2: reduce(m..n, # + #acc) where m and n are constant integers
40+
if reduceBuiltin, ok := (*node).(*BuiltinNode); ok &&
41+
reduceBuiltin.Name == "reduce" &&
42+
(len(reduceBuiltin.Arguments) == 2 || len(reduceBuiltin.Arguments) == 3) {
43+
if rangeOp, ok := reduceBuiltin.Arguments[0].(*BinaryNode); ok && rangeOp.Operator == ".." {
44+
if from, ok := rangeOp.Left.(*IntegerNode); ok {
45+
if to, ok := rangeOp.Right.(*IntegerNode); ok {
46+
if isPointerPlusAcc(reduceBuiltin.Arguments[1]) {
47+
m := from.Value
48+
n := to.Value
49+
if n >= m {
50+
// Use the arithmetic series formula: (n - m + 1) * (m + n) / 2
51+
sum := (n - m + 1) * (m + n) / 2
52+
53+
// Check for optional initialValue (3rd argument)
54+
if len(reduceBuiltin.Arguments) == 3 {
55+
if initialValue, ok := reduceBuiltin.Arguments[2].(*IntegerNode); ok {
56+
result := initialValue.Value + sum
57+
patchWithType(node, &IntegerNode{Value: result})
58+
}
59+
} else {
60+
patchWithType(node, &IntegerNode{Value: sum})
61+
}
62+
}
63+
}
64+
}
65+
}
66+
}
67+
}
68+
}
69+
70+
// isPointerPlusAcc checks if the node represents `# + #acc` pattern
71+
func isPointerPlusAcc(node Node) bool {
72+
predicate, ok := node.(*PredicateNode)
73+
if !ok {
74+
return false
75+
}
76+
77+
binary, ok := predicate.Node.(*BinaryNode)
78+
if !ok {
79+
return false
80+
}
81+
82+
if binary.Operator != "+" {
83+
return false
84+
}
85+
86+
// Check for # + #acc (pointer + accumulator)
87+
leftPointer, leftIsPointer := binary.Left.(*PointerNode)
88+
rightPointer, rightIsPointer := binary.Right.(*PointerNode)
89+
90+
if leftIsPointer && rightIsPointer {
91+
// # + #acc: Left is pointer (Name=""), Right is acc (Name="acc")
92+
if leftPointer.Name == "" && rightPointer.Name == "acc" {
93+
return true
94+
}
95+
// #acc + #: Left is acc (Name="acc"), Right is pointer (Name="")
96+
if leftPointer.Name == "acc" && rightPointer.Name == "" {
97+
return true
98+
}
99+
}
100+
101+
return false
102+
}
103+
104+
// applySumPredicate tries to compute the result of sum(m..n, predicate) at compile time.
105+
// Returns (result, true) if optimization is possible, (0, false) otherwise.
106+
// Supported predicates:
107+
// - # (identity): result = sum
108+
// - # * k (multiply by constant): result = k * sum
109+
// - k * # (multiply by constant): result = k * sum
110+
// - # + k (add constant): result = sum + count * k
111+
// - k + # (add constant): result = sum + count * k
112+
// - # - k (subtract constant): result = sum - count * k
113+
func applySumPredicate(sum, count int, predicateArg Node) (int, bool) {
114+
predicate, ok := predicateArg.(*PredicateNode)
115+
if !ok {
116+
return 0, false
117+
}
118+
119+
// Case 1: # (identity) - just return the sum
120+
if pointer, ok := predicate.Node.(*PointerNode); ok && pointer.Name == "" {
121+
return sum, true
122+
}
123+
124+
// Case 2: Binary operations with pointer and constant
125+
binary, ok := predicate.Node.(*BinaryNode)
126+
if !ok {
127+
return 0, false
128+
}
129+
130+
pointer, constant, pointerOnLeft := extractPointerAndConstantWithPosition(binary)
131+
if pointer == nil || constant == nil {
132+
return 0, false
133+
}
134+
135+
switch binary.Operator {
136+
case "*":
137+
// # * k or k * # => k * sum
138+
return constant.Value * sum, true
139+
case "+":
140+
// # + k or k + # => sum + count * k
141+
return sum + count*constant.Value, true
142+
case "-":
143+
if pointerOnLeft {
144+
// # - k => sum - count * k
145+
return sum - count*constant.Value, true
146+
}
147+
// k - # => count * k - sum
148+
return count*constant.Value - sum, true
149+
}
150+
151+
return 0, false
152+
}
153+
154+
// extractPointerAndConstantWithPosition extracts pointer (#) and integer constant from a binary node.
155+
// Returns (pointer, constant, pointerOnLeft) or (nil, nil, false) if not matching the expected pattern.
156+
func extractPointerAndConstantWithPosition(binary *BinaryNode) (*PointerNode, *IntegerNode, bool) {
157+
// Try left=pointer, right=constant
158+
if pointer, ok := binary.Left.(*PointerNode); ok && pointer.Name == "" {
159+
if constant, ok := binary.Right.(*IntegerNode); ok {
160+
return pointer, constant, true
161+
}
162+
}
163+
164+
// Try left=constant, right=pointer
165+
if constant, ok := binary.Left.(*IntegerNode); ok {
166+
if pointer, ok := binary.Right.(*PointerNode); ok && pointer.Name == "" {
167+
return pointer, constant, false
168+
}
169+
}
170+
171+
return nil, nil, false
172+
}

0 commit comments

Comments
 (0)