Skip to content

Commit 0d1553f

Browse files
authored
Merge branch 'master' into fetch-perf
2 parents bc5b2ef + daf1790 commit 0d1553f

22 files changed

Lines changed: 799 additions & 39 deletions

builtin/builtin.go

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package builtin
33
import (
44
"encoding/base64"
55
"encoding/json"
6+
"errors"
67
"fmt"
78
"reflect"
89
"sort"
@@ -16,6 +17,10 @@ import (
1617
var (
1718
Index map[string]int
1819
Names []string
20+
21+
// MaxDepth limits the recursion depth for nested structures.
22+
MaxDepth = 10000
23+
ErrorMaxDepth = errors.New("recursion depth exceeded")
1924
)
2025

2126
func init() {
@@ -377,7 +382,7 @@ var Builtins = []*Function{
377382
{
378383
Name: "max",
379384
Func: func(args ...any) (any, error) {
380-
return minMax("max", runtime.Less, args...)
385+
return minMax("max", runtime.Less, 0, args...)
381386
},
382387
Validate: func(args []reflect.Type) (reflect.Type, error) {
383388
return validateAggregateFunc("max", args)
@@ -386,7 +391,7 @@ var Builtins = []*Function{
386391
{
387392
Name: "min",
388393
Func: func(args ...any) (any, error) {
389-
return minMax("min", runtime.More, args...)
394+
return minMax("min", runtime.More, 0, args...)
390395
},
391396
Validate: func(args []reflect.Type) (reflect.Type, error) {
392397
return validateAggregateFunc("min", args)
@@ -395,7 +400,7 @@ var Builtins = []*Function{
395400
{
396401
Name: "mean",
397402
Func: func(args ...any) (any, error) {
398-
count, sum, err := mean(args...)
403+
count, sum, err := mean(0, args...)
399404
if err != nil {
400405
return nil, err
401406
}
@@ -411,7 +416,7 @@ var Builtins = []*Function{
411416
{
412417
Name: "median",
413418
Func: func(args ...any) (any, error) {
414-
values, err := median(args...)
419+
values, err := median(0, args...)
415420
if err != nil {
416421
return nil, err
417422
}
@@ -940,7 +945,10 @@ var Builtins = []*Function{
940945
if v.Kind() != reflect.Array && v.Kind() != reflect.Slice {
941946
return nil, size, fmt.Errorf("cannot flatten %s", v.Kind())
942947
}
943-
ret := flatten(v)
948+
ret, err := flatten(v, 0)
949+
if err != nil {
950+
return nil, 0, err
951+
}
944952
size = uint(len(ret))
945953
return ret, size, nil
946954
},

builtin/builtin_test.go

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -722,3 +722,100 @@ func TestBuiltin_with_deref(t *testing.T) {
722722
})
723723
}
724724
}
725+
726+
func TestBuiltin_flatten_recursion(t *testing.T) {
727+
var s []any
728+
s = append(s, &s) // s contains a pointer to itself
729+
730+
env := map[string]any{
731+
"arr": s,
732+
}
733+
734+
program, err := expr.Compile("flatten(arr)", expr.Env(env))
735+
require.NoError(t, err)
736+
737+
_, err = expr.Run(program, env)
738+
require.Error(t, err)
739+
assert.Contains(t, err.Error(), builtin.ErrorMaxDepth.Error())
740+
}
741+
742+
func TestBuiltin_flatten_recursion_slice(t *testing.T) {
743+
s := make([]any, 1)
744+
s[0] = s
745+
746+
env := map[string]any{
747+
"arr": s,
748+
}
749+
750+
program, err := expr.Compile("flatten(arr)", expr.Env(env))
751+
require.NoError(t, err)
752+
753+
_, err = expr.Run(program, env)
754+
require.Error(t, err)
755+
assert.Contains(t, err.Error(), builtin.ErrorMaxDepth.Error())
756+
}
757+
758+
func TestBuiltin_numerical_recursion(t *testing.T) {
759+
s := make([]any, 1)
760+
s[0] = s
761+
762+
env := map[string]any{
763+
"arr": s,
764+
}
765+
766+
tests := []string{
767+
"max(arr)",
768+
"min(arr)",
769+
"mean(arr)",
770+
"median(arr)",
771+
}
772+
773+
for _, input := range tests {
774+
t.Run(input, func(t *testing.T) {
775+
program, err := expr.Compile(input, expr.Env(env))
776+
require.NoError(t, err)
777+
778+
_, err = expr.Run(program, env)
779+
require.Error(t, err)
780+
assert.Contains(t, err.Error(), builtin.ErrorMaxDepth.Error())
781+
})
782+
}
783+
}
784+
785+
func TestBuiltin_recursion_custom_max_depth(t *testing.T) {
786+
originalMaxDepth := builtin.MaxDepth
787+
defer func() {
788+
builtin.MaxDepth = originalMaxDepth
789+
}()
790+
791+
// Set a small depth limit
792+
builtin.MaxDepth = 2
793+
794+
// Create a deeply nested array (depth 5)
795+
// [1, [2, [3, [4, [5]]]]]
796+
arr := []any{1, []any{2, []any{3, []any{4, []any{5}}}}}
797+
798+
env := map[string]any{
799+
"arr": arr,
800+
}
801+
802+
t.Run("flatten exceeds max depth", func(t *testing.T) {
803+
program, err := expr.Compile("flatten(arr)", expr.Env(env))
804+
require.NoError(t, err)
805+
806+
_, err = expr.Run(program, env)
807+
require.Error(t, err)
808+
assert.Contains(t, err.Error(), builtin.ErrorMaxDepth.Error())
809+
})
810+
811+
t.Run("flatten within max depth", func(t *testing.T) {
812+
// Depth 2: [1, [2]]
813+
shallowArr := []any{1, []any{2}}
814+
envShallow := map[string]any{"arr": shallowArr}
815+
program, err := expr.Compile("flatten(arr)", expr.Env(envShallow))
816+
require.NoError(t, err)
817+
818+
_, err = expr.Run(program, envShallow)
819+
require.NoError(t, err)
820+
})
821+
}

builtin/lib.go

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -253,15 +253,18 @@ func String(arg any) any {
253253
return fmt.Sprintf("%v", arg)
254254
}
255255

256-
func minMax(name string, fn func(any, any) bool, args ...any) (any, error) {
256+
func minMax(name string, fn func(any, any) bool, depth int, args ...any) (any, error) {
257+
if depth > MaxDepth {
258+
return nil, ErrorMaxDepth
259+
}
257260
var val any
258261
for _, arg := range args {
259262
rv := reflect.ValueOf(arg)
260263
switch rv.Kind() {
261264
case reflect.Array, reflect.Slice:
262265
size := rv.Len()
263266
for i := 0; i < size; i++ {
264-
elemVal, err := minMax(name, fn, rv.Index(i).Interface())
267+
elemVal, err := minMax(name, fn, depth+1, rv.Index(i).Interface())
265268
if err != nil {
266269
return nil, err
267270
}
@@ -294,7 +297,10 @@ func minMax(name string, fn func(any, any) bool, args ...any) (any, error) {
294297
return val, nil
295298
}
296299

297-
func mean(args ...any) (int, float64, error) {
300+
func mean(depth int, args ...any) (int, float64, error) {
301+
if depth > MaxDepth {
302+
return 0, 0, ErrorMaxDepth
303+
}
298304
var total float64
299305
var count int
300306

@@ -304,7 +310,7 @@ func mean(args ...any) (int, float64, error) {
304310
case reflect.Array, reflect.Slice:
305311
size := rv.Len()
306312
for i := 0; i < size; i++ {
307-
elemCount, elemSum, err := mean(rv.Index(i).Interface())
313+
elemCount, elemSum, err := mean(depth+1, rv.Index(i).Interface())
308314
if err != nil {
309315
return 0, 0, err
310316
}
@@ -327,7 +333,10 @@ func mean(args ...any) (int, float64, error) {
327333
return count, total, nil
328334
}
329335

330-
func median(args ...any) ([]float64, error) {
336+
func median(depth int, args ...any) ([]float64, error) {
337+
if depth > MaxDepth {
338+
return nil, ErrorMaxDepth
339+
}
331340
var values []float64
332341

333342
for _, arg := range args {
@@ -336,7 +345,7 @@ func median(args ...any) ([]float64, error) {
336345
case reflect.Array, reflect.Slice:
337346
size := rv.Len()
338347
for i := 0; i < size; i++ {
339-
elems, err := median(rv.Index(i).Interface())
348+
elems, err := median(depth+1, rv.Index(i).Interface())
340349
if err != nil {
341350
return nil, err
342351
}
@@ -355,18 +364,24 @@ func median(args ...any) ([]float64, error) {
355364
return values, nil
356365
}
357366

358-
func flatten(arg reflect.Value) []any {
367+
func flatten(arg reflect.Value, depth int) ([]any, error) {
368+
if depth > MaxDepth {
369+
return nil, ErrorMaxDepth
370+
}
359371
ret := []any{}
360372
for i := 0; i < arg.Len(); i++ {
361373
v := deref.Value(arg.Index(i))
362374
if v.Kind() == reflect.Array || v.Kind() == reflect.Slice {
363-
x := flatten(v)
375+
x, err := flatten(v, depth+1)
376+
if err != nil {
377+
return nil, err
378+
}
364379
ret = append(ret, x...)
365380
} else {
366381
ret = append(ret, v.Interface())
367382
}
368383
}
369-
return ret
384+
return ret, nil
370385
}
371386

372387
func get(params ...any) (out any, err error) {

checker/checker.go

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ func (v *Checker) binaryNode(node *ast.BinaryNode) Nature {
462462
return v.error(node, err.Error())
463463
}
464464
}
465-
if l.IsString() && r.IsString() {
465+
if (l.IsString() || l.IsByteSlice()) && r.IsString() {
466466
return v.config.NtCache.FromType(boolType)
467467
}
468468
if l.MaybeCompatible(&v.config.NtCache, r, StringCheck) {
@@ -549,6 +549,13 @@ func (v *Checker) memberNode(node *ast.MemberNode) Nature {
549549

550550
switch base.Kind {
551551
case reflect.Map:
552+
// If the map key is a pointer, we should not dereference the property.
553+
if !prop.AssignableTo(base.Key(&v.config.NtCache)) {
554+
propDeref := prop.Deref(&v.config.NtCache)
555+
if propDeref.AssignableTo(base.Key(&v.config.NtCache)) {
556+
prop = propDeref
557+
}
558+
}
552559
if !prop.AssignableTo(base.Key(&v.config.NtCache)) && !prop.IsUnknown(&v.config.NtCache) {
553560
return v.error(node.Property, "cannot use %s to get an element from %s", prop.String(), base.String())
554561
}
@@ -562,6 +569,7 @@ func (v *Checker) memberNode(node *ast.MemberNode) Nature {
562569
return base.Elem(&v.config.NtCache)
563570

564571
case reflect.Array, reflect.Slice:
572+
prop = prop.Deref(&v.config.NtCache)
565573
if !prop.IsInteger && !prop.IsUnknown(&v.config.NtCache) {
566574
return v.error(node.Property, "array elements can only be selected using an integer (got %s)", prop.String())
567575
}
@@ -607,13 +615,15 @@ func (v *Checker) sliceNode(node *ast.SliceNode) Nature {
607615

608616
if node.From != nil {
609617
from := v.visit(node.From)
618+
from = from.Deref(&v.config.NtCache)
610619
if !from.IsInteger && !from.IsUnknown(&v.config.NtCache) {
611620
return v.error(node.From, "non-integer slice index %v", from.String())
612621
}
613622
}
614623

615624
if node.To != nil {
616625
to := v.visit(node.To)
626+
to = to.Deref(&v.config.NtCache)
617627
if !to.IsInteger && !to.IsUnknown(&v.config.NtCache) {
618628
return v.error(node.To, "non-integer slice index %v", to.String())
619629
}
@@ -942,6 +952,7 @@ func (v *Checker) checkBuiltinGet(node *ast.BuiltinNode) Nature {
942952

943953
base := v.visit(node.Arguments[0])
944954
prop := v.visit(node.Arguments[1])
955+
prop = prop.Deref(&v.config.NtCache)
945956

946957
if id, ok := node.Arguments[0].(*ast.IdentifierNode); ok && id.Value == "$env" {
947958
if s, ok := node.Arguments[1].(*ast.StringNode); ok {
@@ -1260,6 +1271,7 @@ func (v *Checker) sequenceNode(node *ast.SequenceNode) Nature {
12601271

12611272
func (v *Checker) conditionalNode(node *ast.ConditionalNode) Nature {
12621273
c := v.visit(node.Cond)
1274+
c = c.Deref(&v.config.NtCache)
12631275
if !c.IsBool() && !c.IsUnknown(&v.config.NtCache) {
12641276
return v.error(node.Cond, "non-bool expression (type %v) used as condition", c.String())
12651277
}
@@ -1277,6 +1289,13 @@ func (v *Checker) conditionalNode(node *ast.ConditionalNode) Nature {
12771289
return v.config.NtCache.NatureOf(nil)
12781290
}
12791291
if t1.AssignableTo(t2) {
1292+
if t1.IsArray() && t2.IsArray() {
1293+
e1 := t1.Elem(&v.config.NtCache)
1294+
e2 := t2.Elem(&v.config.NtCache)
1295+
if !e1.AssignableTo(e2) || !e2.AssignableTo(e1) {
1296+
return v.config.NtCache.FromType(arrayType)
1297+
}
1298+
}
12801299
return t1
12811300
}
12821301
return Nature{}

checker/checker_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ func TestCheck(t *testing.T) {
134134
{"Bool ?? Bool"},
135135
{"let foo = 1; foo == 1"},
136136
{"(Embed).EmbedPointerEmbedInt > 0"},
137+
{"(true ? [1] : [[1]])[0][0] == 1"},
137138
}
138139

139140
c := new(checker.Checker)

checker/nature/nature.go

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@ import (
1010
)
1111

1212
var (
13-
intType = reflect.TypeOf(0)
14-
floatType = reflect.TypeOf(float64(0))
15-
arrayType = reflect.TypeOf([]any{})
16-
timeType = reflect.TypeOf(time.Time{})
17-
durationType = reflect.TypeOf(time.Duration(0))
13+
intType = reflect.TypeOf(0)
14+
floatType = reflect.TypeOf(float64(0))
15+
arrayType = reflect.TypeOf([]any{})
16+
byteSliceType = reflect.TypeOf([]byte{})
17+
timeType = reflect.TypeOf(time.Time{})
18+
durationType = reflect.TypeOf(time.Duration(0))
1819

1920
builtinInt = map[reflect.Type]struct{}{
2021
reflect.TypeOf(int(0)): {},
@@ -502,6 +503,10 @@ func (n *Nature) IsString() bool {
502503
return n.Kind == reflect.String
503504
}
504505

506+
func (n *Nature) IsByteSlice() bool {
507+
return n.Type == byteSliceType
508+
}
509+
505510
func (n *Nature) IsArray() bool {
506511
return n.Kind == reflect.Slice || n.Kind == reflect.Array
507512
}

0 commit comments

Comments
 (0)