@@ -10,6 +10,7 @@ import (
1010 "fmt"
1111 "go/token"
1212 "go/types"
13+ "sort"
1314 "strconv"
1415 "strings"
1516
@@ -183,6 +184,16 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
183184 typeFieldTypes := []* types.Var {
184185 types .NewVar (token .NoPos , nil , "kind" , types .Typ [types .Int8 ]),
185186 }
187+ // Compute the method set value for types that support methods.
188+ var methods []* types.Func
189+ for i := 0 ; i < ms .Len (); i ++ {
190+ methods = append (methods , ms .At (i ).Obj ().(* types.Func ))
191+ }
192+ methodSetType := types .NewStruct ([]* types.Var {
193+ types .NewVar (token .NoPos , nil , "length" , types .Typ [types .Uintptr ]),
194+ types .NewVar (token .NoPos , nil , "methods" , types .NewArray (types .Typ [types .UnsafePointer ], int64 (len (methods )))),
195+ }, nil )
196+ methodSetValue := c .getMethodSetValue (methods )
186197 switch typ := typ .(type ) {
187198 case * types.Basic :
188199 typeFieldTypes = append (typeFieldTypes ,
@@ -199,6 +210,7 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
199210 types .NewVar (token .NoPos , nil , "ptrTo" , types .Typ [types .UnsafePointer ]),
200211 types .NewVar (token .NoPos , nil , "underlying" , types .Typ [types .UnsafePointer ]),
201212 types .NewVar (token .NoPos , nil , "pkgpath" , types .Typ [types .UnsafePointer ]),
213+ types .NewVar (token .NoPos , nil , "methods" , methodSetType ),
202214 types .NewVar (token .NoPos , nil , "name" , types .NewArray (types .Typ [types .Int8 ], int64 (len (pkgname )+ 1 + len (name )+ 1 ))),
203215 )
204216 case * types.Chan :
@@ -217,6 +229,7 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
217229 typeFieldTypes = append (typeFieldTypes ,
218230 types .NewVar (token .NoPos , nil , "numMethods" , types .Typ [types .Uint16 ]),
219231 types .NewVar (token .NoPos , nil , "elementType" , types .Typ [types .UnsafePointer ]),
232+ types .NewVar (token .NoPos , nil , "methods" , methodSetType ),
220233 )
221234 case * types.Array :
222235 typeFieldTypes = append (typeFieldTypes ,
@@ -241,12 +254,13 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
241254 types .NewVar (token .NoPos , nil , "size" , types .Typ [types .Uint32 ]),
242255 types .NewVar (token .NoPos , nil , "numFields" , types .Typ [types .Uint16 ]),
243256 types .NewVar (token .NoPos , nil , "fields" , types .NewArray (c .getRuntimeType ("structField" ), int64 (typ .NumFields ()))),
257+ types .NewVar (token .NoPos , nil , "methods" , methodSetType ),
244258 )
245259 case * types.Interface :
246260 typeFieldTypes = append (typeFieldTypes ,
247261 types .NewVar (token .NoPos , nil , "ptrTo" , types .Typ [types .UnsafePointer ]),
262+ types .NewVar (token .NoPos , nil , "methods" , methodSetType ),
248263 )
249- // TODO: methods
250264 case * types.Signature :
251265 typeFieldTypes = append (typeFieldTypes ,
252266 types .NewVar (token .NoPos , nil , "ptrTo" , types .Typ [types .UnsafePointer ]),
@@ -297,6 +311,7 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
297311 c .getTypeCode (types .NewPointer (typ )), // ptrTo
298312 c .getTypeCode (typ .Underlying ()), // underlying
299313 pkgPathPtr , // pkgpath pointer
314+ methodSetValue , // methods
300315 c .ctx .ConstString (pkgname + "." + name + "\x00 " , false ), // name
301316 }
302317 metabyte |= 1 << 5 // "named" flag
@@ -326,6 +341,7 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
326341 typeFields = []llvm.Value {
327342 llvm .ConstInt (c .ctx .Int16Type (), uint64 (numMethods ), false ), // numMethods
328343 c .getTypeCode (typ .Elem ()),
344+ methodSetValue , // methods
329345 }
330346 case * types.Array :
331347 typeFields = []llvm.Value {
@@ -407,9 +423,12 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
407423 }))
408424 }
409425 typeFields = append (typeFields , llvm .ConstArray (structFieldType , fields ))
426+ typeFields = append (typeFields , methodSetValue )
410427 case * types.Interface :
411- typeFields = []llvm.Value {c .getTypeCode (types .NewPointer (typ ))}
412- // TODO: methods
428+ typeFields = []llvm.Value {
429+ c .getTypeCode (types .NewPointer (typ )),
430+ methodSetValue ,
431+ }
413432 case * types.Signature :
414433 typeFields = []llvm.Value {c .getTypeCode (types .NewPointer (typ ))}
415434 // TODO: params, return values, etc
@@ -696,17 +715,16 @@ func (b *builder) createTypeAssert(expr *ssa.TypeAssert) llvm.Value {
696715 // This type assertion always succeeds, so we can just set commaOk to true.
697716 commaOk = llvm .ConstInt (b .ctx .Int1Type (), 1 , true )
698717 } else {
699- // Type assert on interface type with methods.
700- // This is a call to an interface type assert function.
701- // The interface lowering pass will define this function by filling it
702- // with a type switch over all concrete types that implement this
703- // interface, and returning whether it's one of the matched types.
704- // This is very different from how interface asserts are implemented in
705- // the main Go compiler, where the runtime checks whether the type
706- // implements each method of the interface. See:
718+ // Type assert using interface type with methods.
719+ // This is implemented using a runtime call, which checks that the
720+ // type implements each method of the interface.
721+ // For comparison, here is how the Go compiler does this (which is
722+ // very similar):
707723 // https://research.swtch.com/interfaces
708- fn := b .getInterfaceImplementsFunc (expr .AssertedType )
709- commaOk = b .CreateCall (fn .GlobalValueType (), fn , []llvm.Value {actualTypeNum }, "" )
724+ commaOk = b .createRuntimeCall ("typeImplementsMethodSet" , []llvm.Value {
725+ actualTypeNum ,
726+ b .getInterfaceMethodSet (intf ),
727+ }, "" )
710728 }
711729 } else {
712730 name , _ := getTypeCodeName (expr .AssertedType )
@@ -783,20 +801,74 @@ func (c *compilerContext) getMethodsString(itf *types.Interface) string {
783801 return strings .Join (methods , "; " )
784802}
785803
786- // getInterfaceImplementsFunc returns a declared function that works as a type
787- // switch. The interface lowering pass will define this function.
788- func (c * compilerContext ) getInterfaceImplementsFunc (assertedType types.Type ) llvm.Value {
789- s , _ := getTypeCodeName (assertedType .Underlying ())
790- fnName := s + ".$typeassert"
791- llvmFn := c .mod .NamedFunction (fnName )
792- if llvmFn .IsNil () {
793- llvmFnType := llvm .FunctionType (c .ctx .Int1Type (), []llvm.Type {c .dataPtrType }, false )
794- llvmFn = llvm .AddFunction (c .mod , fnName , llvmFnType )
795- c .addStandardDeclaredAttributes (llvmFn )
796- methods := c .getMethodsString (assertedType .Underlying ().(* types.Interface ))
797- llvmFn .AddFunctionAttr (c .ctx .CreateStringAttribute ("tinygo-methods" , methods ))
804+ // getInterfaceMethodSet returns a global that contains the method set for an
805+ // interface type, creating it if needed.
806+ func (c * compilerContext ) getInterfaceMethodSet (t * types.Interface ) llvm.Value {
807+ s , _ := getTypeCodeName (t )
808+ methodSetName := s + "$itfmethods"
809+ methodSet := c .mod .NamedGlobal (methodSetName )
810+ if ! methodSet .IsNil () {
811+ return methodSet
798812 }
799- return llvmFn
813+
814+ var methods []* types.Func
815+ for i := 0 ; i < t .NumMethods (); i ++ {
816+ methods = append (methods , t .Method (i ))
817+ }
818+ if len (methods ) == 0 {
819+ panic ("unreachable: getInterfaceMethodSet called on empty interface" )
820+ }
821+
822+ methodSetValue := c .getMethodSetValue (methods )
823+ methodSet = llvm .AddGlobal (c .mod , methodSetValue .Type (), methodSetName )
824+ methodSet .SetInitializer (methodSetValue )
825+ methodSet .SetGlobalConstant (true )
826+ methodSet .SetLinkage (llvm .LinkOnceODRLinkage )
827+ methodSet .SetAlignment (c .targetData .ABITypeAlignment (methodSetValue .Type ()))
828+ methodSet .SetUnnamedAddr (true )
829+
830+ return methodSet
831+ }
832+
833+ // getMethodSetValue creates the method set struct value for a list of methods.
834+ // The struct contains a length and a sorted array of method signature pointers.
835+ func (c * compilerContext ) getMethodSetValue (methods []* types.Func ) llvm.Value {
836+ // Create a sorted list of method signature global names.
837+ type methodRef struct {
838+ name string
839+ value llvm.Value
840+ }
841+ var refs []methodRef
842+ for _ , method := range methods {
843+ name := method .Name ()
844+ if ! token .IsExported (name ) {
845+ name = method .Pkg ().Path () + "." + name
846+ }
847+ s , _ := getTypeCodeName (method .Type ())
848+ globalName := "reflect/types.signature:" + name + ":" + s
849+ value := c .mod .NamedGlobal (globalName )
850+ if value .IsNil () {
851+ value = llvm .AddGlobal (c .mod , c .ctx .Int8Type (), globalName )
852+ value .SetInitializer (llvm .ConstNull (c .ctx .Int8Type ()))
853+ value .SetGlobalConstant (true )
854+ value .SetLinkage (llvm .LinkOnceODRLinkage )
855+ value .SetAlignment (1 )
856+ }
857+ refs = append (refs , methodRef {globalName , value })
858+ }
859+ sort .Slice (refs , func (i , j int ) bool {
860+ return refs [i ].name < refs [j ].name
861+ })
862+
863+ var values []llvm.Value
864+ for _ , ref := range refs {
865+ values = append (values , ref .value )
866+ }
867+
868+ return c .ctx .ConstStruct ([]llvm.Value {
869+ llvm .ConstInt (c .uintptrType , uint64 (len (values )), false ),
870+ llvm .ConstArray (c .dataPtrType , values ),
871+ }, false )
800872}
801873
802874// getInvokeFunction returns the thunk to call the given interface method. The
0 commit comments