Skip to content

Commit 10b0c98

Browse files
authored
Merge branch 'master' into fix/issue-823
2 parents 46a50a2 + cf589a4 commit 10b0c98

8 files changed

Lines changed: 216 additions & 17 deletions

File tree

builtin/builtin.go

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package builtin
33
import (
44
"encoding/base64"
55
"encoding/json"
6+
"errors"
67
"fmt"
78
"reflect"
89
"sort"
@@ -16,6 +17,10 @@ import (
1617
var (
1718
Index map[string]int
1819
Names []string
20+
21+
// MaxDepth limits the recursion depth for nested structures.
22+
MaxDepth = 10000
23+
ErrorMaxDepth = errors.New("recursion depth exceeded")
1924
)
2025

2126
func init() {
@@ -377,7 +382,7 @@ var Builtins = []*Function{
377382
{
378383
Name: "max",
379384
Func: func(args ...any) (any, error) {
380-
return minMax("max", runtime.Less, args...)
385+
return minMax("max", runtime.Less, 0, args...)
381386
},
382387
Validate: func(args []reflect.Type) (reflect.Type, error) {
383388
return validateAggregateFunc("max", args)
@@ -386,7 +391,7 @@ var Builtins = []*Function{
386391
{
387392
Name: "min",
388393
Func: func(args ...any) (any, error) {
389-
return minMax("min", runtime.More, args...)
394+
return minMax("min", runtime.More, 0, args...)
390395
},
391396
Validate: func(args []reflect.Type) (reflect.Type, error) {
392397
return validateAggregateFunc("min", args)
@@ -395,7 +400,7 @@ var Builtins = []*Function{
395400
{
396401
Name: "mean",
397402
Func: func(args ...any) (any, error) {
398-
count, sum, err := mean(args...)
403+
count, sum, err := mean(0, args...)
399404
if err != nil {
400405
return nil, err
401406
}
@@ -411,7 +416,7 @@ var Builtins = []*Function{
411416
{
412417
Name: "median",
413418
Func: func(args ...any) (any, error) {
414-
values, err := median(args...)
419+
values, err := median(0, args...)
415420
if err != nil {
416421
return nil, err
417422
}
@@ -940,7 +945,10 @@ var Builtins = []*Function{
940945
if v.Kind() != reflect.Array && v.Kind() != reflect.Slice {
941946
return nil, size, fmt.Errorf("cannot flatten %s", v.Kind())
942947
}
943-
ret := flatten(v)
948+
ret, err := flatten(v, 0)
949+
if err != nil {
950+
return nil, 0, err
951+
}
944952
size = uint(len(ret))
945953
return ret, size, nil
946954
},

builtin/builtin_test.go

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -722,3 +722,100 @@ func TestBuiltin_with_deref(t *testing.T) {
722722
})
723723
}
724724
}
725+
726+
func TestBuiltin_flatten_recursion(t *testing.T) {
727+
var s []any
728+
s = append(s, &s) // s contains a pointer to itself
729+
730+
env := map[string]any{
731+
"arr": s,
732+
}
733+
734+
program, err := expr.Compile("flatten(arr)", expr.Env(env))
735+
require.NoError(t, err)
736+
737+
_, err = expr.Run(program, env)
738+
require.Error(t, err)
739+
assert.Contains(t, err.Error(), builtin.ErrorMaxDepth.Error())
740+
}
741+
742+
func TestBuiltin_flatten_recursion_slice(t *testing.T) {
743+
s := make([]any, 1)
744+
s[0] = s
745+
746+
env := map[string]any{
747+
"arr": s,
748+
}
749+
750+
program, err := expr.Compile("flatten(arr)", expr.Env(env))
751+
require.NoError(t, err)
752+
753+
_, err = expr.Run(program, env)
754+
require.Error(t, err)
755+
assert.Contains(t, err.Error(), builtin.ErrorMaxDepth.Error())
756+
}
757+
758+
func TestBuiltin_numerical_recursion(t *testing.T) {
759+
s := make([]any, 1)
760+
s[0] = s
761+
762+
env := map[string]any{
763+
"arr": s,
764+
}
765+
766+
tests := []string{
767+
"max(arr)",
768+
"min(arr)",
769+
"mean(arr)",
770+
"median(arr)",
771+
}
772+
773+
for _, input := range tests {
774+
t.Run(input, func(t *testing.T) {
775+
program, err := expr.Compile(input, expr.Env(env))
776+
require.NoError(t, err)
777+
778+
_, err = expr.Run(program, env)
779+
require.Error(t, err)
780+
assert.Contains(t, err.Error(), builtin.ErrorMaxDepth.Error())
781+
})
782+
}
783+
}
784+
785+
func TestBuiltin_recursion_custom_max_depth(t *testing.T) {
786+
originalMaxDepth := builtin.MaxDepth
787+
defer func() {
788+
builtin.MaxDepth = originalMaxDepth
789+
}()
790+
791+
// Set a small depth limit
792+
builtin.MaxDepth = 2
793+
794+
// Create a deeply nested array (depth 5)
795+
// [1, [2, [3, [4, [5]]]]]
796+
arr := []any{1, []any{2, []any{3, []any{4, []any{5}}}}}
797+
798+
env := map[string]any{
799+
"arr": arr,
800+
}
801+
802+
t.Run("flatten exceeds max depth", func(t *testing.T) {
803+
program, err := expr.Compile("flatten(arr)", expr.Env(env))
804+
require.NoError(t, err)
805+
806+
_, err = expr.Run(program, env)
807+
require.Error(t, err)
808+
assert.Contains(t, err.Error(), builtin.ErrorMaxDepth.Error())
809+
})
810+
811+
t.Run("flatten within max depth", func(t *testing.T) {
812+
// Depth 2: [1, [2]]
813+
shallowArr := []any{1, []any{2}}
814+
envShallow := map[string]any{"arr": shallowArr}
815+
program, err := expr.Compile("flatten(arr)", expr.Env(envShallow))
816+
require.NoError(t, err)
817+
818+
_, err = expr.Run(program, envShallow)
819+
require.NoError(t, err)
820+
})
821+
}

builtin/lib.go

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -253,15 +253,18 @@ func String(arg any) any {
253253
return fmt.Sprintf("%v", arg)
254254
}
255255

256-
func minMax(name string, fn func(any, any) bool, args ...any) (any, error) {
256+
func minMax(name string, fn func(any, any) bool, depth int, args ...any) (any, error) {
257+
if depth > MaxDepth {
258+
return nil, ErrorMaxDepth
259+
}
257260
var val any
258261
for _, arg := range args {
259262
rv := reflect.ValueOf(arg)
260263
switch rv.Kind() {
261264
case reflect.Array, reflect.Slice:
262265
size := rv.Len()
263266
for i := 0; i < size; i++ {
264-
elemVal, err := minMax(name, fn, rv.Index(i).Interface())
267+
elemVal, err := minMax(name, fn, depth+1, rv.Index(i).Interface())
265268
if err != nil {
266269
return nil, err
267270
}
@@ -294,7 +297,10 @@ func minMax(name string, fn func(any, any) bool, args ...any) (any, error) {
294297
return val, nil
295298
}
296299

297-
func mean(args ...any) (int, float64, error) {
300+
func mean(depth int, args ...any) (int, float64, error) {
301+
if depth > MaxDepth {
302+
return 0, 0, ErrorMaxDepth
303+
}
298304
var total float64
299305
var count int
300306

@@ -304,7 +310,7 @@ func mean(args ...any) (int, float64, error) {
304310
case reflect.Array, reflect.Slice:
305311
size := rv.Len()
306312
for i := 0; i < size; i++ {
307-
elemCount, elemSum, err := mean(rv.Index(i).Interface())
313+
elemCount, elemSum, err := mean(depth+1, rv.Index(i).Interface())
308314
if err != nil {
309315
return 0, 0, err
310316
}
@@ -327,7 +333,10 @@ func mean(args ...any) (int, float64, error) {
327333
return count, total, nil
328334
}
329335

330-
func median(args ...any) ([]float64, error) {
336+
func median(depth int, args ...any) ([]float64, error) {
337+
if depth > MaxDepth {
338+
return nil, ErrorMaxDepth
339+
}
331340
var values []float64
332341

333342
for _, arg := range args {
@@ -336,7 +345,7 @@ func median(args ...any) ([]float64, error) {
336345
case reflect.Array, reflect.Slice:
337346
size := rv.Len()
338347
for i := 0; i < size; i++ {
339-
elems, err := median(rv.Index(i).Interface())
348+
elems, err := median(depth+1, rv.Index(i).Interface())
340349
if err != nil {
341350
return nil, err
342351
}
@@ -355,18 +364,24 @@ func median(args ...any) ([]float64, error) {
355364
return values, nil
356365
}
357366

358-
func flatten(arg reflect.Value) []any {
367+
func flatten(arg reflect.Value, depth int) ([]any, error) {
368+
if depth > MaxDepth {
369+
return nil, ErrorMaxDepth
370+
}
359371
ret := []any{}
360372
for i := 0; i < arg.Len(); i++ {
361373
v := deref.Value(arg.Index(i))
362374
if v.Kind() == reflect.Array || v.Kind() == reflect.Slice {
363-
x := flatten(v)
375+
x, err := flatten(v, depth+1)
376+
if err != nil {
377+
return nil, err
378+
}
364379
ret = append(ret, x...)
365380
} else {
366381
ret = append(ret, v.Interface())
367382
}
368383
}
369-
return ret
384+
return ret, nil
370385
}
371386

372387
func get(params ...any) (out any, err error) {

compiler/compiler.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1103,7 +1103,9 @@ func (c *compiler) BuiltinNode(node *ast.BuiltinNode) {
11031103
if f.Fast != nil {
11041104
c.emit(OpCallBuiltin1, id)
11051105
} else if f.Safe != nil {
1106-
c.emit(OpPush, c.addConstant(f.Safe))
1106+
id := c.addConstant(f.Safe)
1107+
c.emit(OpPush, id)
1108+
c.debugInfo[fmt.Sprintf("const_%d", id)] = node.Name
11071109
c.emit(OpCallSafe, len(node.Arguments))
11081110
} else if f.Func != nil {
11091111
c.emitFunction(f, len(node.Arguments))

test/issues/567/issue_test.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package expr_test
2+
3+
import (
4+
"bytes"
5+
"strings"
6+
"testing"
7+
8+
"github.com/expr-lang/expr"
9+
"github.com/expr-lang/expr/internal/testify/require"
10+
)
11+
12+
func TestIssue567(t *testing.T) {
13+
program, err := expr.Compile("concat(1..2, 3..4)")
14+
require.NoError(t, err)
15+
16+
var buf bytes.Buffer
17+
program.DisassembleWriter(&buf)
18+
output := buf.String()
19+
20+
// Check if "concat" is mentioned in the output
21+
require.True(t, strings.Contains(output, "concat"), "expected 'concat' in disassembly output")
22+
23+
// It should appear as a pushed constant
24+
require.True(t, strings.Contains(output, "OpPush\t<4>\tconcat"), "expected 'OpPush <4> concat' in disassembly output")
25+
}

test/issues/817/issue_test.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package issue_test
2+
3+
import (
4+
"fmt"
5+
"testing"
6+
7+
"github.com/expr-lang/expr"
8+
"github.com/expr-lang/expr/internal/testify/require"
9+
)
10+
11+
func TestIssue817_1(t *testing.T) {
12+
out, err := expr.Eval(
13+
`sprintf("result: %v %v", 1, nil)`,
14+
map[string]any{
15+
"sprintf": fmt.Sprintf,
16+
},
17+
)
18+
require.NoError(t, err)
19+
require.Equal(t, "result: 1 <nil>", out)
20+
}
21+
22+
func TestIssue817_2(t *testing.T) {
23+
out, err := expr.Eval(
24+
`thing(nil)`,
25+
map[string]any{
26+
"thing": func(arg ...any) string {
27+
return fmt.Sprintf("result: (%T) %v", arg[0], arg[0])
28+
},
29+
},
30+
)
31+
require.NoError(t, err)
32+
require.Equal(t, "result: (<nil>) <nil>", out)
33+
}

vm/program.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,9 @@ func (program *Program) DisassembleWriter(w io.Writer) {
112112
} else {
113113
c = "out of range"
114114
}
115+
if name, ok := program.debugInfo[fmt.Sprintf("const_%d", arg)]; ok {
116+
c = name
117+
}
115118
if r, ok := c.(*regexp.Regexp); ok {
116119
c = r.String()
117120
}

vm/vm.go

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -330,13 +330,29 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) {
330330
vm.push(runtime.Slice(node, from, to))
331331

332332
case OpCall:
333-
fn := reflect.ValueOf(vm.pop())
333+
v := vm.pop()
334+
if v == nil {
335+
panic("invalid operation: cannot call nil")
336+
}
337+
fn := reflect.ValueOf(v)
338+
if fn.Kind() != reflect.Func {
339+
panic(fmt.Sprintf("invalid operation: cannot call non-function of type %T", v))
340+
}
341+
fnType := fn.Type()
334342
size := arg
335343
in := make([]reflect.Value, size)
344+
isVariadic := fnType.IsVariadic()
345+
numIn := fnType.NumIn()
336346
for i := int(size) - 1; i >= 0; i-- {
337347
param := vm.pop()
338348
if param == nil {
339-
in[i] = reflect.Zero(fn.Type().In(i))
349+
var inType reflect.Type
350+
if isVariadic && i >= numIn-1 {
351+
inType = fnType.In(numIn - 1).Elem()
352+
} else {
353+
inType = fnType.In(i)
354+
}
355+
in[i] = reflect.Zero(inType)
340356
} else {
341357
in[i] = reflect.ValueOf(param)
342358
}

0 commit comments

Comments
 (0)