@@ -10,13 +10,20 @@ import (
1010 "fmt"
1111 "go/token"
1212 "go/types"
13+ "sort"
1314 "strconv"
1415 "strings"
1516
1617 "golang.org/x/tools/go/ssa"
1718 "tinygo.org/x/go-llvm"
1819)
1920
21+ // numMethodHasMethodSet is a flag in bit 15 of the numMethod field (uint16) in
22+ // Named, Pointer, and Struct type descriptors. When set, an inline method set
23+ // is present in the type descriptor. Must match the constant in
24+ // src/internal/reflectlite/type.go.
25+ const numMethodHasMethodSet = 0x8000
26+
2027// Type kinds for basic types.
2128// They must match the constants for the Kind type in src/reflect/type.go.
2229var basicTypes = [... ]uint8 {
@@ -183,6 +190,16 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
183190 typeFieldTypes := []* types.Var {
184191 types .NewVar (token .NoPos , nil , "kind" , types .Typ [types .Int8 ]),
185192 }
193+ // Compute the method set value for types that support methods.
194+ var methods []* types.Func
195+ for i := 0 ; i < ms .Len (); i ++ {
196+ methods = append (methods , ms .At (i ).Obj ().(* types.Func ))
197+ }
198+ methodSetType := types .NewStruct ([]* types.Var {
199+ types .NewVar (token .NoPos , nil , "length" , types .Typ [types .Uintptr ]),
200+ types .NewVar (token .NoPos , nil , "methods" , types .NewArray (types .Typ [types .UnsafePointer ], int64 (len (methods )))),
201+ }, nil )
202+ methodSetValue := c .getMethodSetValue (methods )
186203 switch typ := typ .(type ) {
187204 case * types.Basic :
188205 typeFieldTypes = append (typeFieldTypes ,
@@ -199,6 +216,13 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
199216 types .NewVar (token .NoPos , nil , "ptrTo" , types .Typ [types .UnsafePointer ]),
200217 types .NewVar (token .NoPos , nil , "underlying" , types .Typ [types .UnsafePointer ]),
201218 types .NewVar (token .NoPos , nil , "pkgpath" , types .Typ [types .UnsafePointer ]),
219+ )
220+ if len (methods ) > 0 {
221+ typeFieldTypes = append (typeFieldTypes ,
222+ types .NewVar (token .NoPos , nil , "methods" , methodSetType ),
223+ )
224+ }
225+ typeFieldTypes = append (typeFieldTypes ,
202226 types .NewVar (token .NoPos , nil , "name" , types .NewArray (types .Typ [types .Int8 ], int64 (len (pkgname )+ 1 + len (name )+ 1 ))),
203227 )
204228 case * types.Chan :
@@ -218,6 +242,11 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
218242 types .NewVar (token .NoPos , nil , "numMethods" , types .Typ [types .Uint16 ]),
219243 types .NewVar (token .NoPos , nil , "elementType" , types .Typ [types .UnsafePointer ]),
220244 )
245+ if len (methods ) > 0 {
246+ typeFieldTypes = append (typeFieldTypes ,
247+ types .NewVar (token .NoPos , nil , "methods" , methodSetType ),
248+ )
249+ }
221250 case * types.Array :
222251 typeFieldTypes = append (typeFieldTypes ,
223252 types .NewVar (token .NoPos , nil , "numMethods" , types .Typ [types .Uint16 ]),
@@ -242,11 +271,16 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
242271 types .NewVar (token .NoPos , nil , "numFields" , types .Typ [types .Uint16 ]),
243272 types .NewVar (token .NoPos , nil , "fields" , types .NewArray (c .getRuntimeType ("structField" ), int64 (typ .NumFields ()))),
244273 )
274+ if len (methods ) > 0 {
275+ typeFieldTypes = append (typeFieldTypes ,
276+ types .NewVar (token .NoPos , nil , "methods" , methodSetType ),
277+ )
278+ }
245279 case * types.Interface :
246280 typeFieldTypes = append (typeFieldTypes ,
247281 types .NewVar (token .NoPos , nil , "ptrTo" , types .Typ [types .UnsafePointer ]),
282+ types .NewVar (token .NoPos , nil , "methods" , methodSetType ),
248283 )
249- // TODO: methods
250284 case * types.Signature :
251285 typeFieldTypes = append (typeFieldTypes ,
252286 types .NewVar (token .NoPos , nil , "ptrTo" , types .Typ [types .UnsafePointer ]),
@@ -292,14 +326,24 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
292326 pkgname = pkg .Name ()
293327 }
294328 pkgPathPtr := c .pkgPathPtr (pkgpath )
329+ namedNumMethods := uint64 (numMethods )
330+ if namedNumMethods & numMethodHasMethodSet != 0 {
331+ panic ("numMethods overflow: too many exported methods on named type " + name )
332+ }
333+ if len (methods ) > 0 {
334+ namedNumMethods |= numMethodHasMethodSet
335+ }
295336 typeFields = []llvm.Value {
296- llvm .ConstInt (c .ctx .Int16Type (), uint64 (numMethods ), false ), // numMethods
297- c .getTypeCode (types .NewPointer (typ )), // ptrTo
298- c .getTypeCode (typ .Underlying ()), // underlying
299- pkgPathPtr , // pkgpath pointer
300- c .ctx .ConstString (pkgname + "." + name + "\x00 " , false ), // name
337+ llvm .ConstInt (c .ctx .Int16Type (), namedNumMethods , false ), // numMethods
338+ c .getTypeCode (types .NewPointer (typ )), // ptrTo
339+ c .getTypeCode (typ .Underlying ()), // underlying
340+ pkgPathPtr , // pkgpath pointer
301341 }
302- metabyte |= 1 << 5 // "named" flag
342+ if len (methods ) > 0 {
343+ typeFields = append (typeFields , methodSetValue ) // methods
344+ }
345+ typeFields = append (typeFields , c .ctx .ConstString (pkgname + "." + name + "\x00 " , false )) // name
346+ metabyte |= 1 << 5 // "named" flag
303347 case * types.Chan :
304348 var dir reflectChanDir
305349 switch typ .Dir () {
@@ -323,10 +367,20 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
323367 c .getTypeCode (typ .Elem ()), // elementType
324368 }
325369 case * types.Pointer :
370+ ptrNumMethods := uint64 (numMethods )
371+ if ptrNumMethods & numMethodHasMethodSet != 0 {
372+ panic ("numMethods overflow: too many exported methods on pointer type" )
373+ }
374+ if len (methods ) > 0 {
375+ ptrNumMethods |= numMethodHasMethodSet
376+ }
326377 typeFields = []llvm.Value {
327- llvm .ConstInt (c .ctx .Int16Type (), uint64 ( numMethods ) , false ), // numMethods
378+ llvm .ConstInt (c .ctx .Int16Type (), ptrNumMethods , false ), // numMethods
328379 c .getTypeCode (typ .Elem ()),
329380 }
381+ if len (methods ) > 0 {
382+ typeFields = append (typeFields , methodSetValue )
383+ }
330384 case * types.Array :
331385 typeFields = []llvm.Value {
332386 llvm .ConstInt (c .ctx .Int16Type (), 0 , false ), // numMethods
@@ -353,9 +407,16 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
353407
354408 llvmStructType := c .getLLVMType (typ )
355409 size := c .targetData .TypeStoreSize (llvmStructType )
410+ structNumMethods := uint64 (numMethods )
411+ if structNumMethods & numMethodHasMethodSet != 0 {
412+ panic ("numMethods overflow: too many exported methods on struct type" )
413+ }
414+ if len (methods ) > 0 {
415+ structNumMethods |= numMethodHasMethodSet
416+ }
356417 typeFields = []llvm.Value {
357- llvm .ConstInt (c .ctx .Int16Type (), uint64 ( numMethods ) , false ), // numMethods
358- c .getTypeCode (types .NewPointer (typ )), // ptrTo
418+ llvm .ConstInt (c .ctx .Int16Type (), structNumMethods , false ), // numMethods
419+ c .getTypeCode (types .NewPointer (typ )), // ptrTo
359420 pkgPathPtr ,
360421 llvm .ConstInt (c .ctx .Int32Type (), uint64 (size ), false ), // size
361422 llvm .ConstInt (c .ctx .Int16Type (), uint64 (typ .NumFields ()), false ), // numFields
@@ -407,9 +468,14 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value {
407468 }))
408469 }
409470 typeFields = append (typeFields , llvm .ConstArray (structFieldType , fields ))
471+ if len (methods ) > 0 {
472+ typeFields = append (typeFields , methodSetValue )
473+ }
410474 case * types.Interface :
411- typeFields = []llvm.Value {c .getTypeCode (types .NewPointer (typ ))}
412- // TODO: methods
475+ typeFields = []llvm.Value {
476+ c .getTypeCode (types .NewPointer (typ )),
477+ methodSetValue ,
478+ }
413479 case * types.Signature :
414480 typeFields = []llvm.Value {c .getTypeCode (types .NewPointer (typ ))}
415481 // TODO: params, return values, etc
@@ -696,17 +762,11 @@ func (b *builder) createTypeAssert(expr *ssa.TypeAssert) llvm.Value {
696762 // This type assertion always succeeds, so we can just set commaOk to true.
697763 commaOk = llvm .ConstInt (b .ctx .Int1Type (), 1 , true )
698764 } else {
699- // Type assert on interface type with methods.
700- // This is a call to an interface type assert function.
701- // The interface lowering pass will define this function by filling it
702- // with a type switch over all concrete types that implement this
703- // interface, and returning whether it's one of the matched types.
704- // This is very different from how interface asserts are implemented in
705- // the main Go compiler, where the runtime checks whether the type
706- // implements each method of the interface. See:
707- // https://research.swtch.com/interfaces
708- fn := b .getInterfaceImplementsFunc (expr .AssertedType )
709- commaOk = b .CreateCall (fn .GlobalValueType (), fn , []llvm.Value {actualTypeNum }, "" )
765+ // Type assert on an interface type with methods.
766+ // Create a call to a declared-but-not-defined function that will
767+ // be lowered by the interface lowering pass into a type-ID
768+ // comparison chain.
769+ commaOk = b .createInterfaceTypeAssert (intf , actualTypeNum )
710770 }
711771 } else {
712772 name , _ := getTypeCodeName (expr .AssertedType )
@@ -783,20 +843,58 @@ func (c *compilerContext) getMethodsString(itf *types.Interface) string {
783843 return strings .Join (methods , "; " )
784844}
785845
786- // getInterfaceImplementsFunc returns a declared function that works as a type
787- // switch. The interface lowering pass will define this function.
788- func (c * compilerContext ) getInterfaceImplementsFunc (assertedType types.Type ) llvm.Value {
789- s , _ := getTypeCodeName (assertedType .Underlying ())
790- fnName := s + ".$typeassert"
791- llvmFn := c .mod .NamedFunction (fnName )
792- if llvmFn .IsNil () {
793- llvmFnType := llvm .FunctionType (c .ctx .Int1Type (), []llvm.Type {c .dataPtrType }, false )
794- llvmFn = llvm .AddFunction (c .mod , fnName , llvmFnType )
795- c .addStandardDeclaredAttributes (llvmFn )
796- methods := c .getMethodsString (assertedType .Underlying ().(* types.Interface ))
797- llvmFn .AddFunctionAttr (c .ctx .CreateStringAttribute ("tinygo-methods" , methods ))
846+ // getMethodSetValue creates the method set struct value for a list of methods.
847+ // The struct contains a length and a sorted array of method signature pointers.
848+ func (c * compilerContext ) getMethodSetValue (methods []* types.Func ) llvm.Value {
849+ // Create a sorted list of method signature global names.
850+ type methodRef struct {
851+ name string
852+ value llvm.Value
798853 }
799- return llvmFn
854+ var refs []methodRef
855+ for _ , method := range methods {
856+ name := method .Name ()
857+ if ! token .IsExported (name ) {
858+ name = method .Pkg ().Path () + "." + name
859+ }
860+ s , _ := getTypeCodeName (method .Type ())
861+ globalName := "reflect/types.signature:" + name + ":" + s
862+ value := c .mod .NamedGlobal (globalName )
863+ if value .IsNil () {
864+ value = llvm .AddGlobal (c .mod , c .ctx .Int8Type (), globalName )
865+ value .SetInitializer (llvm .ConstNull (c .ctx .Int8Type ()))
866+ value .SetGlobalConstant (true )
867+ value .SetLinkage (llvm .LinkOnceODRLinkage )
868+ value .SetAlignment (1 )
869+ if c .Debug {
870+ file := c .getDIFile ("<Go type>" )
871+ diglobal := c .dibuilder .CreateGlobalVariableExpression (file , llvm.DIGlobalVariableExpression {
872+ Name : globalName ,
873+ File : file ,
874+ Line : 1 ,
875+ Type : c .getDIType (types .Typ [types .Uint8 ]),
876+ LocalToUnit : false ,
877+ Expr : c .dibuilder .CreateExpression (nil ),
878+ AlignInBits : 8 ,
879+ })
880+ value .AddMetadata (0 , diglobal )
881+ }
882+ }
883+ refs = append (refs , methodRef {globalName , value })
884+ }
885+ sort .Slice (refs , func (i , j int ) bool {
886+ return refs [i ].name < refs [j ].name
887+ })
888+
889+ var values []llvm.Value
890+ for _ , ref := range refs {
891+ values = append (values , ref .value )
892+ }
893+
894+ return c .ctx .ConstStruct ([]llvm.Value {
895+ llvm .ConstInt (c .uintptrType , uint64 (len (values )), false ),
896+ llvm .ConstArray (c .dataPtrType , values ),
897+ }, false )
800898}
801899
802900// getInvokeFunction returns the thunk to call the given interface method. The
@@ -823,6 +921,24 @@ func (c *compilerContext) getInvokeFunction(instr *ssa.CallCommon) llvm.Value {
823921 return llvmFn
824922}
825923
924+ // createInterfaceTypeAssert creates a call to a declared-but-not-defined
925+ // $typeassert function for the given interface. This function will be defined
926+ // by the interface lowering pass as a type-ID comparison chain, avoiding the
927+ // need for runtime.typeImplementsMethodSet at compile time.
928+ func (b * builder ) createInterfaceTypeAssert (intf * types.Interface , actualType llvm.Value ) llvm.Value {
929+ s , _ := getTypeCodeName (intf )
930+ fnName := s + ".$typeassert"
931+ llvmFn := b .mod .NamedFunction (fnName )
932+ if llvmFn .IsNil () {
933+ llvmFnType := llvm .FunctionType (b .ctx .Int1Type (), []llvm.Type {b .dataPtrType }, false )
934+ llvmFn = llvm .AddFunction (b .mod , fnName , llvmFnType )
935+ b .addStandardDeclaredAttributes (llvmFn )
936+ methods := b .getMethodsString (intf )
937+ llvmFn .AddFunctionAttr (b .ctx .CreateStringAttribute ("tinygo-methods" , methods ))
938+ }
939+ return b .CreateCall (llvmFn .GlobalValueType (), llvmFn , []llvm.Value {actualType }, "" )
940+ }
941+
826942// getInterfaceInvokeWrapper returns a wrapper for the given method so it can be
827943// invoked from an interface. The wrapper takes in a pointer to the underlying
828944// value, dereferences or unpacks it if necessary, and calls the real method.
0 commit comments