Skip to content

Commit 37311e4

Browse files
Alysson Ribeirosonalys
authored andcommitted
feat: Add nil-safety to runtime.Fetch:
- Return nil instead of panic on runtime.Fetch - Ensure OpBegin works with type nil for reflect.Value.Len - Improve perf on early return for nil from in builtin.lib.go.get - Add expr.NilSafe() configuration
1 parent 3a46b19 commit 37311e4

9 files changed

Lines changed: 75 additions & 24 deletions

File tree

builtin/builtin.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,9 @@ var Builtins = []*Function{
601601
return
602602
}
603603
}()
604-
return runtime.Fetch(args[0], 0), nil
604+
605+
value, _ := runtime.Fetch(args[0], 0)
606+
return value, nil
605607
},
606608
Validate: func(args []reflect.Type) (reflect.Type, error) {
607609
if len(args) != 1 {
@@ -624,7 +626,9 @@ var Builtins = []*Function{
624626
return
625627
}
626628
}()
627-
return runtime.Fetch(args[0], -1), nil
629+
630+
value, _ := runtime.Fetch(args[0], -1)
631+
return value, nil
628632
},
629633
Validate: func(args []reflect.Type) (reflect.Type, error) {
630634
if len(args) != 1 {

builtin/lib.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -548,13 +548,13 @@ func get(params ...any) (out any, err error) {
548548
return nil, fmt.Errorf("invalid number of arguments (expected 2, got %d)", len(params))
549549
}
550550
from := params[0]
551-
i := params[1]
552-
v := reflect.ValueOf(from)
553-
554551
if from == nil {
555552
return nil, nil
556553
}
557554

555+
i := params[1]
556+
v := reflect.ValueOf(from)
557+
558558
if v.Kind() == reflect.Invalid {
559559
panic(fmt.Sprintf("cannot fetch %v from %T", i, from))
560560
}

compiler/compiler.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,11 @@ func Compile(tree *parser.Tree, config *conf.Config) (program *Program, err erro
4545

4646
c.compile(tree.Node)
4747

48+
nilSafe := false
49+
4850
if c.config != nil {
51+
nilSafe = c.config.NilSafe
52+
4953
switch c.config.Expect {
5054
case reflect.Int:
5155
c.emit(OpCast, 0)
@@ -77,6 +81,7 @@ func Compile(tree *parser.Tree, config *conf.Config) (program *Program, err erro
7781
c.functions,
7882
c.debugInfo,
7983
span,
84+
nilSafe,
8085
)
8186
return
8287
}

conf/config.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ type Config struct {
4141
// When enabled, the lexer treats `if`/`else` as identifiers and the parser
4242
// will not parse `if` statements.
4343
DisableIfOperator bool
44+
// NilSafe enables nil-safe navigation for all expressions,
45+
// allowing access to fields and methods on nil values without panicking.
46+
NilSafe bool
4447
}
4548

4649
// CreateNew creates new config with default values.
@@ -77,7 +80,14 @@ func (c *Config) ConstExpr(name string) {
7780
if c.EnvObject == nil {
7881
panic("no environment is specified for ConstExpr()")
7982
}
80-
fn := reflect.ValueOf(runtime.Fetch(c.EnvObject, name))
83+
84+
field, ok := runtime.Fetch(c.EnvObject, name)
85+
if !ok {
86+
panic(fmt.Errorf("cannot fetch %q in the environment", name))
87+
}
88+
89+
fn := reflect.ValueOf(field)
90+
8191
if fn.Kind() != reflect.Func {
8292
panic(fmt.Errorf("const expression %q must be a function", name))
8393
}

expr.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,12 @@ func MaxNodes(n uint) Option {
225225
}
226226
}
227227

228+
func NilSafe() Option {
229+
return func(c *conf.Config) {
230+
c.NilSafe = true
231+
}
232+
}
233+
228234
// Compile parses and compiles given input expression to bytecode program.
229235
func Compile(input string, ops ...Option) (*vm.Program, error) {
230236
config := conf.CreateNew()

vm/program.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ type Program struct {
2828
functions []Function
2929
debugInfo map[string]string
3030
span *Span
31+
32+
nilSafe bool
3133
}
3234

3335
// NewProgram returns a new Program. It's used by the compiler.
@@ -42,6 +44,7 @@ func NewProgram(
4244
functions []Function,
4345
debugInfo map[string]string,
4446
span *Span,
47+
nilSafe bool,
4548
) *Program {
4649
return &Program{
4750
source: source,
@@ -54,6 +57,7 @@ func NewProgram(
5457
functions: functions,
5558
debugInfo: debugInfo,
5659
span: span,
60+
nilSafe: nilSafe,
5761
}
5862
}
5963

vm/runtime/runtime.go

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@ type fieldCacheKey struct {
1818
f string
1919
}
2020

21-
func Fetch(from, i any) any {
21+
func Fetch(from, i any) (any, bool) {
22+
if from == nil {
23+
return nil, false
24+
}
25+
2226
v := reflect.ValueOf(from)
2327
if v.Kind() == reflect.Invalid {
2428
panic(fmt.Sprintf("cannot fetch %v from %T", i, from))
@@ -29,7 +33,7 @@ func Fetch(from, i any) any {
2933
if methodName, ok := i.(string); ok {
3034
method := v.MethodByName(methodName)
3135
if method.IsValid() {
32-
return method.Interface()
36+
return method.Interface(), true
3337
}
3438
}
3539
}
@@ -52,7 +56,7 @@ func Fetch(from, i any) any {
5256
}
5357
value := v.Index(index)
5458
if value.IsValid() {
55-
return value.Interface()
59+
return value.Interface(), true
5660
}
5761

5862
case reflect.Map:
@@ -63,10 +67,10 @@ func Fetch(from, i any) any {
6367
value = v.MapIndex(reflect.ValueOf(i))
6468
}
6569
if value.IsValid() {
66-
return value.Interface()
70+
return value.Interface(), true
6771
} else {
6872
elem := reflect.TypeOf(from).Elem()
69-
return reflect.Zero(elem).Interface()
73+
return reflect.Zero(elem).Interface(), true
7074
}
7175

7276
case reflect.Struct:
@@ -77,7 +81,7 @@ func Fetch(from, i any) any {
7781
f: fieldName,
7882
}
7983
if cv, ok := fieldCache.Load(key); ok {
80-
return v.FieldByIndex(cv.([]int)).Interface()
84+
return v.FieldByIndex(cv.([]int)).Interface(), true
8185
}
8286
field, ok := t.FieldByNameFunc(func(name string) bool {
8387
field, _ := t.FieldByName(name)
@@ -94,11 +98,12 @@ func Fetch(from, i any) any {
9498
value := v.FieldByIndex(field.Index)
9599
if value.IsValid() {
96100
fieldCache.Store(key, field.Index)
97-
return value.Interface()
101+
return value.Interface(), true
98102
}
99103
}
100104
}
101-
panic(fmt.Sprintf("cannot fetch %v from %T", i, from))
105+
106+
return nil, false
102107
}
103108

104109
type Field struct {

vm/vm.go

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,12 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) {
122122
vm.push(vm.Variables[arg])
123123

124124
case OpLoadConst:
125-
vm.push(runtime.Fetch(env, program.Constants[arg]))
125+
value, ok := runtime.Fetch(env, program.Constants[arg])
126+
if !ok && !program.nilSafe {
127+
panic(fmt.Sprintf("cannot fetch %v in the environment", program.Constants[arg]))
128+
}
129+
130+
vm.push(value)
126131

127132
case OpLoadField:
128133
vm.push(runtime.FetchField(env, program.Constants[arg].(*runtime.Field)))
@@ -139,7 +144,12 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) {
139144
case OpFetch:
140145
b := vm.pop()
141146
a := vm.pop()
142-
vm.push(runtime.Fetch(a, b))
147+
148+
value, ok := runtime.Fetch(a, b)
149+
if !ok && !program.nilSafe {
150+
panic(fmt.Sprintf("cannot fetch %v from %T", b, a))
151+
}
152+
vm.push(value)
143153

144154
case OpFetchField:
145155
a := vm.pop()
@@ -609,6 +619,9 @@ func (vm *VM) Run(program *Program, env any) (_ any, err error) {
609619
a := vm.pop()
610620
s := vm.allocScope()
611621
switch v := a.(type) {
622+
case nil:
623+
s.Len = 0
624+
s.Anys = nil
612625
case []int:
613626
s.Ints = v
614627
s.Len = len(v)

vm/vm_test.go

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -694,8 +694,9 @@ func TestVM_DirectCallOpcodes(t *testing.T) {
694694
tt.bytecode,
695695
tt.args,
696696
tt.funcs,
697-
nil, // debugInfo
698-
nil, // span
697+
nil, // debugInfo
698+
nil, // span
699+
false, // nilSafe
699700
)
700701
vm := &vm.VM{}
701702
got, err := vm.Run(program, nil)
@@ -819,9 +820,10 @@ func TestVM_IndexAndCountOperations(t *testing.T) {
819820
tt.consts,
820821
tt.bytecode,
821822
tt.args,
822-
nil, // functions
823-
nil, // debugInfo
824-
nil, // span
823+
nil, // functions
824+
nil, // debugInfo
825+
nil, // span
826+
false, // nilSafe
825827
)
826828
vm := &vm.VM{}
827829
got, err := vm.Run(program, nil)
@@ -1288,9 +1290,10 @@ func TestVM_DirectBasicOpcodes(t *testing.T) {
12881290
tt.consts,
12891291
tt.bytecode,
12901292
tt.args,
1291-
nil, // functions
1292-
nil, // debugInfo
1293-
nil, // span
1293+
nil, // functions
1294+
nil, // debugInfo
1295+
nil, // span
1296+
false, // nilSafe
12941297
)
12951298
vm := &vm.VM{}
12961299
got, err := vm.Run(program, tt.env)
@@ -1460,6 +1463,7 @@ func TestVM_OpJump_NegativeOffset(t *testing.T) {
14601463
nil,
14611464
nil,
14621465
nil,
1466+
false, // nilSafe
14631467
)
14641468

14651469
_, err := vm.Run(program, nil)

0 commit comments

Comments
 (0)