Skip to content

Commit 83c0f92

Browse files
authored
Merge branch 'master' into copilot/fix-bugs-in-code
2 parents 3798c08 + 3461fbb commit 83c0f92

5 files changed

Lines changed: 172 additions & 39 deletions

File tree

expr.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,14 @@ func EnableBuiltin(name string) Option {
195195

196196
// WithContext passes context to all functions calls with a context.Context argument.
197197
func WithContext(name string) Option {
198-
return Patch(patcher.WithContext{
199-
Name: name,
200-
})
198+
return func(c *conf.Config) {
199+
c.Visitors = append(c.Visitors, patcher.WithContext{
200+
Name: name,
201+
Functions: c.Functions,
202+
Env: &c.Env,
203+
NtCache: &c.NtCache,
204+
})
205+
}
201206
}
202207

203208
// Timezone sets default timezone for date() and now() builtin functions.

patcher/with_context.go

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,16 @@ import (
44
"reflect"
55

66
"github.com/expr-lang/expr/ast"
7+
"github.com/expr-lang/expr/checker/nature"
8+
"github.com/expr-lang/expr/conf"
79
)
810

911
// WithContext adds WithContext.Name argument to all functions calls with a context.Context argument.
1012
type WithContext struct {
11-
Name string
13+
Name string
14+
Functions conf.FunctionsTable // Optional: used to look up function types when callee type is unknown.
15+
Env *nature.Nature // Optional: used to look up method types when callee type is unknown.
16+
NtCache *nature.Cache // Optional: cache for nature lookups.
1217
}
1318

1419
// Visit adds WithContext.Name argument to all functions calls with a context.Context argument.
@@ -19,6 +24,24 @@ func (w WithContext) Visit(node *ast.Node) {
1924
if fn == nil {
2025
return
2126
}
27+
// If callee type is interface{} (unknown), look up the function type from
28+
// the Functions table or Env. This handles cases where the checker returns early
29+
// without visiting nested call arguments (e.g., Date2() in Now2().After(Date2()))
30+
// because the outer call's type is unknown due to missing context arguments.
31+
if fn.Kind() == reflect.Interface {
32+
if ident, ok := call.Callee.(*ast.IdentifierNode); ok {
33+
if w.Functions != nil {
34+
if f, ok := w.Functions[ident.Value]; ok {
35+
fn = f.Type()
36+
}
37+
}
38+
if fn.Kind() == reflect.Interface && w.Env != nil {
39+
if m, ok := w.Env.MethodByName(w.NtCache, ident.Value); ok {
40+
fn = m.Type
41+
}
42+
}
43+
}
44+
}
2245
if fn.Kind() != reflect.Func {
2346
return
2447
}

test/issues/823/issue_test.go

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package issue_test
22

33
import (
44
"context"
5-
"fmt"
65
"testing"
76
"time"
87

@@ -14,26 +13,72 @@ type env struct {
1413
Ctx context.Context `expr:"ctx"`
1514
}
1615

16+
// TestIssue823 verifies that WithContext injects context into nested custom
17+
// function calls. The bug was that date2() nested as an argument to After()
18+
// didn't receive the context because its callee type was unknown.
1719
func TestIssue823(t *testing.T) {
20+
now2Called := false
21+
date2Called := false
22+
1823
p, err := expr.Compile(
1924
"now2().After(date2())",
2025
expr.Env(env{}),
2126
expr.WithContext("ctx"),
2227
expr.Function(
2328
"now2",
24-
func(params ...any) (any, error) { return time.Now(), nil },
29+
func(params ...any) (any, error) {
30+
require.Len(t, params, 1, "now2 should receive context")
31+
_, ok := params[0].(context.Context)
32+
require.True(t, ok, "now2 first param should be context.Context")
33+
now2Called = true
34+
return time.Now(), nil
35+
},
2536
new(func(context.Context) time.Time),
2637
),
2738
expr.Function(
2839
"date2",
29-
func(params ...any) (any, error) { return time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), nil },
40+
func(params ...any) (any, error) {
41+
require.Len(t, params, 1, "date2 should receive context")
42+
_, ok := params[0].(context.Context)
43+
require.True(t, ok, "date2 first param should be context.Context")
44+
date2Called = true
45+
return time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), nil
46+
},
3047
new(func(context.Context) time.Time),
3148
),
3249
)
33-
fmt.Printf("Compile result err: %v\n", err)
3450
require.NoError(t, err)
3551

3652
r, err := expr.Run(p, &env{Ctx: context.Background()})
3753
require.NoError(t, err)
3854
require.True(t, r.(bool))
55+
require.True(t, now2Called, "now2 should have been called")
56+
require.True(t, date2Called, "date2 should have been called")
57+
}
58+
59+
// envWithMethods tests that Env methods with context.Context work correctly
60+
// when nested in method chains (similar to TestIssue823 but with Env methods).
61+
type envWithMethods struct {
62+
Ctx context.Context `expr:"ctx"`
63+
}
64+
65+
func (e *envWithMethods) Now2(ctx context.Context) time.Time {
66+
return time.Now()
67+
}
68+
69+
func (e *envWithMethods) Date2(ctx context.Context) time.Time {
70+
return time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)
71+
}
72+
73+
func TestIssue823_EnvMethods(t *testing.T) {
74+
p, err := expr.Compile(
75+
"Now2().After(Date2())",
76+
expr.Env(&envWithMethods{}),
77+
expr.WithContext("ctx"),
78+
)
79+
require.NoError(t, err)
80+
81+
r, err := expr.Run(p, &envWithMethods{Ctx: context.Background()})
82+
require.NoError(t, err)
83+
require.True(t, r.(bool))
3984
}

vm/utils.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,28 @@ type Scope struct {
2020
Len int
2121
Count int
2222
Acc any
23+
// Fast paths
24+
Ints []int
25+
Floats []float64
26+
Strings []string
27+
Anys []any
28+
}
29+
30+
// Item returns the current element from the scope using fast paths when available.
31+
func (s *Scope) Item() any {
32+
if s.Ints != nil {
33+
return s.Ints[s.Index]
34+
}
35+
if s.Floats != nil {
36+
return s.Floats[s.Index]
37+
}
38+
if s.Strings != nil {
39+
return s.Strings[s.Index]
40+
}
41+
if s.Anys != nil {
42+
return s.Anys[s.Index]
43+
}
44+
return s.Array.Index(s.Index).Interface()
2345
}
2446

2547
type groupBy = map[any][]any

vm/vm.go

Lines changed: 69 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ type VM struct {
4646
debug bool
4747
step chan struct{}
4848
curr chan int
49+
scopePool []Scope // Pre-allocated pool of Scope values; grows as needed but never shrinks
50+
scopePoolIdx int // Current index into scopePool for allocation
51+
currScope *Scope // Cached pointer to the current scope (optimization)
4952
}
5053

5154
func (vm *VM) Run(program *Program, env any) (_ any, err error) {
@@ -76,6 +79,8 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) {
7679
clearSlice(vm.Scopes)
7780
vm.Scopes = vm.Scopes[0:0]
7881
}
82+
vm.scopePoolIdx = 0 // Reset pool index for reuse
83+
vm.currScope = nil
7984
if len(vm.Variables) < program.variables {
8085
vm.Variables = make([]any, program.variables)
8186
}
@@ -221,8 +226,7 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) {
221226
if arg < 0 {
222227
panic("negative jump offset is invalid")
223228
}
224-
scope := vm.scope()
225-
if scope.Index >= scope.Len {
229+
if vm.currScope.Index >= vm.currScope.Len {
226230
vm.ip += arg
227231
}
228232

@@ -511,40 +515,34 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) {
511515
vm.push(deref.Interface(a))
512516

513517
case OpIncrementIndex:
514-
vm.scope().Index++
518+
vm.currScope.Index++
515519

516520
case OpDecrementIndex:
517-
scope := vm.scope()
518-
scope.Index--
521+
vm.currScope.Index--
519522

520523
case OpIncrementCount:
521-
scope := vm.scope()
522-
scope.Count++
524+
vm.currScope.Count++
523525

524526
case OpGetIndex:
525-
vm.push(vm.scope().Index)
527+
vm.push(vm.currScope.Index)
526528

527529
case OpGetCount:
528-
scope := vm.scope()
529-
vm.push(scope.Count)
530+
vm.push(vm.currScope.Count)
530531

531532
case OpGetLen:
532-
scope := vm.scope()
533-
vm.push(scope.Len)
533+
vm.push(vm.currScope.Len)
534534

535535
case OpGetAcc:
536-
vm.push(vm.scope().Acc)
536+
vm.push(vm.currScope.Acc)
537537

538538
case OpSetAcc:
539-
vm.scope().Acc = vm.pop()
539+
vm.currScope.Acc = vm.pop()
540540

541541
case OpSetIndex:
542-
scope := vm.scope()
543-
scope.Index = vm.pop().(int)
542+
vm.currScope.Index = vm.pop().(int)
544543

545544
case OpPointer:
546-
scope := vm.scope()
547-
vm.push(scope.Array.Index(scope.Index).Interface())
545+
vm.push(vm.currScope.Item())
548546

549547
case OpThrow:
550548
panic(vm.pop().(error))
@@ -554,7 +552,7 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) {
554552
case 1:
555553
vm.push(make(groupBy))
556554
case 2:
557-
scope := vm.scope()
555+
scope := vm.currScope
558556
var desc bool
559557
order, ok := vm.pop().(string)
560558
if !ok {
@@ -578,21 +576,19 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) {
578576
}
579577

580578
case OpGroupBy:
581-
scope := vm.scope()
579+
scope := vm.currScope
582580
key := vm.pop()
583-
item := scope.Array.Index(scope.Index).Interface()
584-
scope.Acc.(groupBy)[key] = append(scope.Acc.(groupBy)[key], item)
581+
scope.Acc.(groupBy)[key] = append(scope.Acc.(groupBy)[key], scope.Item())
585582

586583
case OpSortBy:
587-
scope := vm.scope()
584+
scope := vm.currScope
588585
value := vm.pop()
589-
item := scope.Array.Index(scope.Index).Interface()
590586
sortable := scope.Acc.(*runtime.SortBy)
591-
sortable.Array = append(sortable.Array, item)
587+
sortable.Array = append(sortable.Array, scope.Item())
592588
sortable.Values = append(sortable.Values, value)
593589

594590
case OpSort:
595-
scope := vm.scope()
591+
scope := vm.currScope
596592
sortable := scope.Acc.(*runtime.SortBy)
597593
sort.Sort(sortable)
598594
vm.memGrow(uint(scope.Len))
@@ -608,11 +604,26 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) {
608604

609605
case OpBegin:
610606
a := vm.pop()
611-
array := reflect.ValueOf(a)
612-
vm.Scopes = append(vm.Scopes, &Scope{
613-
Array: array,
614-
Len: array.Len(),
615-
})
607+
s := vm.allocScope()
608+
switch v := a.(type) {
609+
case []int:
610+
s.Ints = v
611+
s.Len = len(v)
612+
case []float64:
613+
s.Floats = v
614+
s.Len = len(v)
615+
case []string:
616+
s.Strings = v
617+
s.Len = len(v)
618+
case []any:
619+
s.Anys = v
620+
s.Len = len(v)
621+
default:
622+
s.Array = reflect.ValueOf(a)
623+
s.Len = s.Array.Len()
624+
}
625+
vm.Scopes = append(vm.Scopes, s)
626+
vm.currScope = s
616627

617628
case OpAnd:
618629
a := vm.pop()
@@ -626,6 +637,11 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) {
626637

627638
case OpEnd:
628639
vm.Scopes = vm.Scopes[:len(vm.Scopes)-1]
640+
if len(vm.Scopes) > 0 {
641+
vm.currScope = vm.Scopes[len(vm.Scopes)-1]
642+
} else {
643+
vm.currScope = nil
644+
}
629645

630646
default:
631647
panic(fmt.Sprintf("unknown bytecode %#x", op))
@@ -679,6 +695,28 @@ func (vm *VM) scope() *Scope {
679695
return vm.Scopes[len(vm.Scopes)-1]
680696
}
681697

698+
// allocScope returns a pointer to a Scope from the pool, growing the pool if needed.
699+
// Callers must set Len and exactly one of: Ints, Floats, Strings, Anys, or Array.
700+
func (vm *VM) allocScope() *Scope {
701+
if vm.scopePoolIdx >= len(vm.scopePool) {
702+
vm.scopePool = append(vm.scopePool, Scope{})
703+
}
704+
s := &vm.scopePool[vm.scopePoolIdx]
705+
vm.scopePoolIdx++
706+
// Reset iteration state
707+
s.Index = 0
708+
s.Count = 0
709+
s.Acc = nil
710+
// Clear typed slice pointers to avoid stale fast-path matches
711+
s.Ints = nil
712+
s.Floats = nil
713+
s.Strings = nil
714+
s.Anys = nil
715+
// Clear Array to release reference for GC (only matters for fallback path)
716+
s.Array = reflect.Value{}
717+
return s
718+
}
719+
682720
// getArgsForFunc lazily initializes the buffer the first time it is called for
683721
// a given program (thus, it also needs "program" to run). It will
684722
// take "needed" elements from the buffer and populate them with vm.pop() in

0 commit comments

Comments
 (0)