Skip to content

Commit ae2c8ce

Browse files
jakebaileyaykevl
andcommitted
reflect: implement method-set based AssignableTo and Implements
Based on the design from tinygo-org#4376 by aykevl. Fixes tinygo-org#4277, fixes tinygo-org#3580. Co-authored-by: Ayke van Laethem <aykevanlaethem@gmail.com>
1 parent c58d9c9 commit ae2c8ce

16 files changed

Lines changed: 484 additions & 451 deletions

File tree

compiler/interface.go

Lines changed: 98 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"fmt"
1111
"go/token"
1212
"go/types"
13+
"sort"
1314
"strconv"
1415
"strings"
1516

@@ -183,6 +184,16 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
183184
typeFieldTypes := []*types.Var{
184185
types.NewVar(token.NoPos, nil, "kind", types.Typ[types.Int8]),
185186
}
187+
// Compute the method set value for types that support methods.
188+
var methods []*types.Func
189+
for i := 0; i < ms.Len(); i++ {
190+
methods = append(methods, ms.At(i).Obj().(*types.Func))
191+
}
192+
methodSetType := types.NewStruct([]*types.Var{
193+
types.NewVar(token.NoPos, nil, "length", types.Typ[types.Uintptr]),
194+
types.NewVar(token.NoPos, nil, "methods", types.NewArray(types.Typ[types.UnsafePointer], int64(len(methods)))),
195+
}, nil)
196+
methodSetValue := c.getMethodSetValue(methods)
186197
switch typ := typ.(type) {
187198
case *types.Basic:
188199
typeFieldTypes = append(typeFieldTypes,
@@ -199,6 +210,7 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
199210
types.NewVar(token.NoPos, nil, "ptrTo", types.Typ[types.UnsafePointer]),
200211
types.NewVar(token.NoPos, nil, "underlying", types.Typ[types.UnsafePointer]),
201212
types.NewVar(token.NoPos, nil, "pkgpath", types.Typ[types.UnsafePointer]),
213+
types.NewVar(token.NoPos, nil, "methods", methodSetType),
202214
types.NewVar(token.NoPos, nil, "name", types.NewArray(types.Typ[types.Int8], int64(len(pkgname)+1+len(name)+1))),
203215
)
204216
case *types.Chan:
@@ -217,6 +229,7 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
217229
typeFieldTypes = append(typeFieldTypes,
218230
types.NewVar(token.NoPos, nil, "numMethods", types.Typ[types.Uint16]),
219231
types.NewVar(token.NoPos, nil, "elementType", types.Typ[types.UnsafePointer]),
232+
types.NewVar(token.NoPos, nil, "methods", methodSetType),
220233
)
221234
case *types.Array:
222235
typeFieldTypes = append(typeFieldTypes,
@@ -241,12 +254,13 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
241254
types.NewVar(token.NoPos, nil, "size", types.Typ[types.Uint32]),
242255
types.NewVar(token.NoPos, nil, "numFields", types.Typ[types.Uint16]),
243256
types.NewVar(token.NoPos, nil, "fields", types.NewArray(c.getRuntimeType("structField"), int64(typ.NumFields()))),
257+
types.NewVar(token.NoPos, nil, "methods", methodSetType),
244258
)
245259
case *types.Interface:
246260
typeFieldTypes = append(typeFieldTypes,
247261
types.NewVar(token.NoPos, nil, "ptrTo", types.Typ[types.UnsafePointer]),
262+
types.NewVar(token.NoPos, nil, "methods", methodSetType),
248263
)
249-
// TODO: methods
250264
case *types.Signature:
251265
typeFieldTypes = append(typeFieldTypes,
252266
types.NewVar(token.NoPos, nil, "ptrTo", types.Typ[types.UnsafePointer]),
@@ -297,6 +311,7 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
297311
c.getTypeCode(types.NewPointer(typ)), // ptrTo
298312
c.getTypeCode(typ.Underlying()), // underlying
299313
pkgPathPtr, // pkgpath pointer
314+
methodSetValue, // methods
300315
c.ctx.ConstString(pkgname+"."+name+"\x00", false), // name
301316
}
302317
metabyte |= 1 << 5 // "named" flag
@@ -326,6 +341,7 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
326341
typeFields = []llvm.Value{
327342
llvm.ConstInt(c.ctx.Int16Type(), uint64(numMethods), false), // numMethods
328343
c.getTypeCode(typ.Elem()),
344+
methodSetValue, // methods
329345
}
330346
case *types.Array:
331347
typeFields = []llvm.Value{
@@ -407,9 +423,12 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
407423
}))
408424
}
409425
typeFields = append(typeFields, llvm.ConstArray(structFieldType, fields))
426+
typeFields = append(typeFields, methodSetValue)
410427
case *types.Interface:
411-
typeFields = []llvm.Value{c.getTypeCode(types.NewPointer(typ))}
412-
// TODO: methods
428+
typeFields = []llvm.Value{
429+
c.getTypeCode(types.NewPointer(typ)),
430+
methodSetValue,
431+
}
413432
case *types.Signature:
414433
typeFields = []llvm.Value{c.getTypeCode(types.NewPointer(typ))}
415434
// TODO: params, return values, etc
@@ -696,17 +715,16 @@ func (b *builder) createTypeAssert(expr *ssa.TypeAssert) llvm.Value {
696715
// This type assertion always succeeds, so we can just set commaOk to true.
697716
commaOk = llvm.ConstInt(b.ctx.Int1Type(), 1, true)
698717
} else {
699-
// Type assert on interface type with methods.
700-
// This is a call to an interface type assert function.
701-
// The interface lowering pass will define this function by filling it
702-
// with a type switch over all concrete types that implement this
703-
// interface, and returning whether it's one of the matched types.
704-
// This is very different from how interface asserts are implemented in
705-
// the main Go compiler, where the runtime checks whether the type
706-
// implements each method of the interface. See:
718+
// Type assert using interface type with methods.
719+
// This is implemented using a runtime call, which checks that the
720+
// type implements each method of the interface.
721+
// For comparison, here is how the Go compiler does this (which is
722+
// very similar):
707723
// https://research.swtch.com/interfaces
708-
fn := b.getInterfaceImplementsFunc(expr.AssertedType)
709-
commaOk = b.CreateCall(fn.GlobalValueType(), fn, []llvm.Value{actualTypeNum}, "")
724+
commaOk = b.createRuntimeCall("typeImplementsMethodSet", []llvm.Value{
725+
actualTypeNum,
726+
b.getInterfaceMethodSet(intf),
727+
}, "")
710728
}
711729
} else {
712730
name, _ := getTypeCodeName(expr.AssertedType)
@@ -783,20 +801,74 @@ func (c *compilerContext) getMethodsString(itf *types.Interface) string {
783801
return strings.Join(methods, "; ")
784802
}
785803

786-
// getInterfaceImplementsFunc returns a declared function that works as a type
787-
// switch. The interface lowering pass will define this function.
788-
func (c *compilerContext) getInterfaceImplementsFunc(assertedType types.Type) llvm.Value {
789-
s, _ := getTypeCodeName(assertedType.Underlying())
790-
fnName := s + ".$typeassert"
791-
llvmFn := c.mod.NamedFunction(fnName)
792-
if llvmFn.IsNil() {
793-
llvmFnType := llvm.FunctionType(c.ctx.Int1Type(), []llvm.Type{c.dataPtrType}, false)
794-
llvmFn = llvm.AddFunction(c.mod, fnName, llvmFnType)
795-
c.addStandardDeclaredAttributes(llvmFn)
796-
methods := c.getMethodsString(assertedType.Underlying().(*types.Interface))
797-
llvmFn.AddFunctionAttr(c.ctx.CreateStringAttribute("tinygo-methods", methods))
804+
// getInterfaceMethodSet returns a global that contains the method set for an
805+
// interface type, creating it if needed.
806+
func (c *compilerContext) getInterfaceMethodSet(t *types.Interface) llvm.Value {
807+
s, _ := getTypeCodeName(t)
808+
methodSetName := s + "$itfmethods"
809+
methodSet := c.mod.NamedGlobal(methodSetName)
810+
if !methodSet.IsNil() {
811+
return methodSet
798812
}
799-
return llvmFn
813+
814+
var methods []*types.Func
815+
for i := 0; i < t.NumMethods(); i++ {
816+
methods = append(methods, t.Method(i))
817+
}
818+
if len(methods) == 0 {
819+
panic("unreachable: getInterfaceMethodSet called on empty interface")
820+
}
821+
822+
methodSetValue := c.getMethodSetValue(methods)
823+
methodSet = llvm.AddGlobal(c.mod, methodSetValue.Type(), methodSetName)
824+
methodSet.SetInitializer(methodSetValue)
825+
methodSet.SetGlobalConstant(true)
826+
methodSet.SetLinkage(llvm.LinkOnceODRLinkage)
827+
methodSet.SetAlignment(c.targetData.ABITypeAlignment(methodSetValue.Type()))
828+
methodSet.SetUnnamedAddr(true)
829+
830+
return methodSet
831+
}
832+
833+
// getMethodSetValue creates the method set struct value for a list of methods.
834+
// The struct contains a length and a sorted array of method signature pointers.
835+
func (c *compilerContext) getMethodSetValue(methods []*types.Func) llvm.Value {
836+
// Create a sorted list of method signature global names.
837+
type methodRef struct {
838+
name string
839+
value llvm.Value
840+
}
841+
var refs []methodRef
842+
for _, method := range methods {
843+
name := method.Name()
844+
if !token.IsExported(name) {
845+
name = method.Pkg().Path() + "." + name
846+
}
847+
s, _ := getTypeCodeName(method.Type())
848+
globalName := "reflect/types.signature:" + name + ":" + s
849+
value := c.mod.NamedGlobal(globalName)
850+
if value.IsNil() {
851+
value = llvm.AddGlobal(c.mod, c.ctx.Int8Type(), globalName)
852+
value.SetInitializer(llvm.ConstNull(c.ctx.Int8Type()))
853+
value.SetGlobalConstant(true)
854+
value.SetLinkage(llvm.LinkOnceODRLinkage)
855+
value.SetAlignment(1)
856+
}
857+
refs = append(refs, methodRef{globalName, value})
858+
}
859+
sort.Slice(refs, func(i, j int) bool {
860+
return refs[i].name < refs[j].name
861+
})
862+
863+
var values []llvm.Value
864+
for _, ref := range refs {
865+
values = append(values, ref.value)
866+
}
867+
868+
return c.ctx.ConstStruct([]llvm.Value{
869+
llvm.ConstInt(c.uintptrType, uint64(len(values)), false),
870+
llvm.ConstArray(c.dataPtrType, values),
871+
}, false)
800872
}
801873

802874
// getInvokeFunction returns the thunk to call the given interface method. The

compiler/testdata/gc.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ target triple = "wasm32-unknown-wasi"
2323
@"runtime/gc.layout:62-0100000000000020" = linkonce_odr unnamed_addr constant { i32, [8 x i8] } { i32 62, [8 x i8] c"\01\00\00\00\00\00\00 " }
2424
@"runtime/gc.layout:62-0100000000000000" = linkonce_odr unnamed_addr constant { i32, [8 x i8] } { i32 62, [8 x i8] c"\01\00\00\00\00\00\00\00" }
2525
@"reflect/types.type:basic:complex128" = linkonce_odr constant { i8, ptr } { i8 80, ptr @"reflect/types.type:pointer:basic:complex128" }, align 4
26-
@"reflect/types.type:pointer:basic:complex128" = linkonce_odr constant { i8, i16, ptr } { i8 -43, i16 0, ptr @"reflect/types.type:basic:complex128" }, align 4
26+
@"reflect/types.type:pointer:basic:complex128" = linkonce_odr constant { i8, i16, ptr, { i32, [0 x ptr] } } { i8 -43, i16 0, ptr @"reflect/types.type:basic:complex128", { i32, [0 x ptr] } zeroinitializer }, align 4
2727

2828
; Function Attrs: allockind("alloc,zeroed") allocsize(0)
2929
declare noalias nonnull ptr @runtime.alloc(i32, ptr, ptr) #0

0 commit comments

Comments
 (0)