-
-
Notifications
You must be signed in to change notification settings - Fork 1k
compiler: implement method-set based AssignableTo and Implements #5304
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9dd4964
a9d5c29
1cb67a2
1b13c9f
9fb3536
d73f9b8
2767f51
a4eb3f8
bcd81c0
39f7e8e
40ac70e
60e1f18
06c884c
84579fc
9929018
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,13 +10,20 @@ import ( | |
| "fmt" | ||
| "go/token" | ||
| "go/types" | ||
| "sort" | ||
| "strconv" | ||
| "strings" | ||
|
|
||
| "golang.org/x/tools/go/ssa" | ||
| "tinygo.org/x/go-llvm" | ||
| ) | ||
|
|
||
| // numMethodHasMethodSet is a flag in bit 15 of the numMethod field (uint16) in | ||
| // Named, Pointer, and Struct type descriptors. When set, an inline method set | ||
| // is present in the type descriptor. Must match the constant in | ||
| // src/internal/reflectlite/type.go. | ||
| const numMethodHasMethodSet = 0x8000 | ||
|
|
||
| // Type kinds for basic types. | ||
| // They must match the constants for the Kind type in src/reflect/type.go. | ||
| var basicTypes = [...]uint8{ | ||
|
|
@@ -183,6 +190,16 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value { | |
| typeFieldTypes := []*types.Var{ | ||
| types.NewVar(token.NoPos, nil, "kind", types.Typ[types.Int8]), | ||
| } | ||
| // Compute the method set value for types that support methods. | ||
| var methods []*types.Func | ||
| for i := 0; i < ms.Len(); i++ { | ||
| methods = append(methods, ms.At(i).Obj().(*types.Func)) | ||
| } | ||
| methodSetType := types.NewStruct([]*types.Var{ | ||
| types.NewVar(token.NoPos, nil, "length", types.Typ[types.Uintptr]), | ||
| types.NewVar(token.NoPos, nil, "methods", types.NewArray(types.Typ[types.UnsafePointer], int64(len(methods)))), | ||
| }, nil) | ||
| methodSetValue := c.getMethodSetValue(methods) | ||
| switch typ := typ.(type) { | ||
| case *types.Basic: | ||
| typeFieldTypes = append(typeFieldTypes, | ||
|
|
@@ -199,6 +216,13 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value { | |
| types.NewVar(token.NoPos, nil, "ptrTo", types.Typ[types.UnsafePointer]), | ||
| types.NewVar(token.NoPos, nil, "underlying", types.Typ[types.UnsafePointer]), | ||
| types.NewVar(token.NoPos, nil, "pkgpath", types.Typ[types.UnsafePointer]), | ||
| ) | ||
| if len(methods) > 0 { | ||
| typeFieldTypes = append(typeFieldTypes, | ||
| types.NewVar(token.NoPos, nil, "methods", methodSetType), | ||
| ) | ||
| } | ||
| typeFieldTypes = append(typeFieldTypes, | ||
| types.NewVar(token.NoPos, nil, "name", types.NewArray(types.Typ[types.Int8], int64(len(pkgname)+1+len(name)+1))), | ||
| ) | ||
| case *types.Chan: | ||
|
|
@@ -218,6 +242,11 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value { | |
| types.NewVar(token.NoPos, nil, "numMethods", types.Typ[types.Uint16]), | ||
| types.NewVar(token.NoPos, nil, "elementType", types.Typ[types.UnsafePointer]), | ||
| ) | ||
| if len(methods) > 0 { | ||
| typeFieldTypes = append(typeFieldTypes, | ||
| types.NewVar(token.NoPos, nil, "methods", methodSetType), | ||
| ) | ||
| } | ||
| case *types.Array: | ||
| typeFieldTypes = append(typeFieldTypes, | ||
| types.NewVar(token.NoPos, nil, "numMethods", types.Typ[types.Uint16]), | ||
|
|
@@ -242,11 +271,16 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value { | |
| types.NewVar(token.NoPos, nil, "numFields", types.Typ[types.Uint16]), | ||
| types.NewVar(token.NoPos, nil, "fields", types.NewArray(c.getRuntimeType("structField"), int64(typ.NumFields()))), | ||
| ) | ||
| if len(methods) > 0 { | ||
| typeFieldTypes = append(typeFieldTypes, | ||
| types.NewVar(token.NoPos, nil, "methods", methodSetType), | ||
| ) | ||
| } | ||
| case *types.Interface: | ||
| typeFieldTypes = append(typeFieldTypes, | ||
| types.NewVar(token.NoPos, nil, "ptrTo", types.Typ[types.UnsafePointer]), | ||
| types.NewVar(token.NoPos, nil, "methods", methodSetType), | ||
| ) | ||
| // TODO: methods | ||
| case *types.Signature: | ||
| typeFieldTypes = append(typeFieldTypes, | ||
| types.NewVar(token.NoPos, nil, "ptrTo", types.Typ[types.UnsafePointer]), | ||
|
|
@@ -292,14 +326,24 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value { | |
| pkgname = pkg.Name() | ||
| } | ||
| pkgPathPtr := c.pkgPathPtr(pkgpath) | ||
| namedNumMethods := uint64(numMethods) | ||
| if namedNumMethods&numMethodHasMethodSet != 0 { | ||
| panic("numMethods overflow: too many exported methods on named type " + name) | ||
| } | ||
| if len(methods) > 0 { | ||
| namedNumMethods |= numMethodHasMethodSet | ||
| } | ||
| typeFields = []llvm.Value{ | ||
| llvm.ConstInt(c.ctx.Int16Type(), uint64(numMethods), false), // numMethods | ||
| c.getTypeCode(types.NewPointer(typ)), // ptrTo | ||
| c.getTypeCode(typ.Underlying()), // underlying | ||
| pkgPathPtr, // pkgpath pointer | ||
| c.ctx.ConstString(pkgname+"."+name+"\x00", false), // name | ||
| llvm.ConstInt(c.ctx.Int16Type(), namedNumMethods, false), // numMethods | ||
| c.getTypeCode(types.NewPointer(typ)), // ptrTo | ||
| c.getTypeCode(typ.Underlying()), // underlying | ||
| pkgPathPtr, // pkgpath pointer | ||
| } | ||
| metabyte |= 1 << 5 // "named" flag | ||
| if len(methods) > 0 { | ||
| typeFields = append(typeFields, methodSetValue) // methods | ||
| } | ||
| typeFields = append(typeFields, c.ctx.ConstString(pkgname+"."+name+"\x00", false)) // name | ||
| metabyte |= 1 << 5 // "named" flag | ||
| case *types.Chan: | ||
| var dir reflectChanDir | ||
| switch typ.Dir() { | ||
|
|
@@ -323,10 +367,20 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value { | |
| c.getTypeCode(typ.Elem()), // elementType | ||
| } | ||
| case *types.Pointer: | ||
| ptrNumMethods := uint64(numMethods) | ||
| if ptrNumMethods&numMethodHasMethodSet != 0 { | ||
| panic("numMethods overflow: too many exported methods on pointer type") | ||
| } | ||
| if len(methods) > 0 { | ||
| ptrNumMethods |= numMethodHasMethodSet | ||
| } | ||
| typeFields = []llvm.Value{ | ||
| llvm.ConstInt(c.ctx.Int16Type(), uint64(numMethods), false), // numMethods | ||
| llvm.ConstInt(c.ctx.Int16Type(), ptrNumMethods, false), // numMethods | ||
| c.getTypeCode(typ.Elem()), | ||
| } | ||
| if len(methods) > 0 { | ||
| typeFields = append(typeFields, methodSetValue) | ||
| } | ||
| case *types.Array: | ||
| typeFields = []llvm.Value{ | ||
| llvm.ConstInt(c.ctx.Int16Type(), 0, false), // numMethods | ||
|
|
@@ -353,9 +407,16 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value { | |
|
|
||
| llvmStructType := c.getLLVMType(typ) | ||
| size := c.targetData.TypeStoreSize(llvmStructType) | ||
| structNumMethods := uint64(numMethods) | ||
| if structNumMethods&numMethodHasMethodSet != 0 { | ||
| panic("numMethods overflow: too many exported methods on struct type") | ||
| } | ||
| if len(methods) > 0 { | ||
| structNumMethods |= numMethodHasMethodSet | ||
| } | ||
| typeFields = []llvm.Value{ | ||
| llvm.ConstInt(c.ctx.Int16Type(), uint64(numMethods), false), // numMethods | ||
| c.getTypeCode(types.NewPointer(typ)), // ptrTo | ||
| llvm.ConstInt(c.ctx.Int16Type(), structNumMethods, false), // numMethods | ||
| c.getTypeCode(types.NewPointer(typ)), // ptrTo | ||
| pkgPathPtr, | ||
| llvm.ConstInt(c.ctx.Int32Type(), uint64(size), false), // size | ||
| llvm.ConstInt(c.ctx.Int16Type(), uint64(typ.NumFields()), false), // numFields | ||
|
|
@@ -407,9 +468,14 @@ func (c *compilerContext) getTypeCode(typ types.Type) llvm.Value { | |
| })) | ||
| } | ||
| typeFields = append(typeFields, llvm.ConstArray(structFieldType, fields)) | ||
| if len(methods) > 0 { | ||
| typeFields = append(typeFields, methodSetValue) | ||
| } | ||
| case *types.Interface: | ||
| typeFields = []llvm.Value{c.getTypeCode(types.NewPointer(typ))} | ||
| // TODO: methods | ||
| typeFields = []llvm.Value{ | ||
| c.getTypeCode(types.NewPointer(typ)), | ||
| methodSetValue, | ||
| } | ||
| case *types.Signature: | ||
| typeFields = []llvm.Value{c.getTypeCode(types.NewPointer(typ))} | ||
| // TODO: params, return values, etc | ||
|
|
@@ -696,17 +762,11 @@ func (b *builder) createTypeAssert(expr *ssa.TypeAssert) llvm.Value { | |
| // This type assertion always succeeds, so we can just set commaOk to true. | ||
| commaOk = llvm.ConstInt(b.ctx.Int1Type(), 1, true) | ||
| } else { | ||
| // Type assert on interface type with methods. | ||
| // This is a call to an interface type assert function. | ||
| // The interface lowering pass will define this function by filling it | ||
| // with a type switch over all concrete types that implement this | ||
| // interface, and returning whether it's one of the matched types. | ||
| // This is very different from how interface asserts are implemented in | ||
| // the main Go compiler, where the runtime checks whether the type | ||
| // implements each method of the interface. See: | ||
| // https://research.swtch.com/interfaces | ||
| fn := b.getInterfaceImplementsFunc(expr.AssertedType) | ||
| commaOk = b.CreateCall(fn.GlobalValueType(), fn, []llvm.Value{actualTypeNum}, "") | ||
| // Type assert on an interface type with methods. | ||
| // Create a call to a declared-but-not-defined function that will | ||
| // be lowered by the interface lowering pass into a type-ID | ||
| // comparison chain. | ||
| commaOk = b.createInterfaceTypeAssert(intf, actualTypeNum) | ||
| } | ||
| } else { | ||
| name, _ := getTypeCodeName(expr.AssertedType) | ||
|
|
@@ -783,20 +843,58 @@ func (c *compilerContext) getMethodsString(itf *types.Interface) string { | |
| return strings.Join(methods, "; ") | ||
| } | ||
|
|
||
| // getInterfaceImplementsFunc returns a declared function that works as a type | ||
| // switch. The interface lowering pass will define this function. | ||
| func (c *compilerContext) getInterfaceImplementsFunc(assertedType types.Type) llvm.Value { | ||
| s, _ := getTypeCodeName(assertedType.Underlying()) | ||
| fnName := s + ".$typeassert" | ||
| llvmFn := c.mod.NamedFunction(fnName) | ||
| if llvmFn.IsNil() { | ||
| llvmFnType := llvm.FunctionType(c.ctx.Int1Type(), []llvm.Type{c.dataPtrType}, false) | ||
| llvmFn = llvm.AddFunction(c.mod, fnName, llvmFnType) | ||
| c.addStandardDeclaredAttributes(llvmFn) | ||
| methods := c.getMethodsString(assertedType.Underlying().(*types.Interface)) | ||
| llvmFn.AddFunctionAttr(c.ctx.CreateStringAttribute("tinygo-methods", methods)) | ||
| // getMethodSetValue creates the method set struct value for a list of methods. | ||
| // The struct contains a length and a sorted array of method signature pointers. | ||
| func (c *compilerContext) getMethodSetValue(methods []*types.Func) llvm.Value { | ||
| // Create a sorted list of method signature global names. | ||
| type methodRef struct { | ||
| name string | ||
| value llvm.Value | ||
| } | ||
| return llvmFn | ||
| var refs []methodRef | ||
| for _, method := range methods { | ||
| name := method.Name() | ||
| if !token.IsExported(name) { | ||
| name = method.Pkg().Path() + "." + name | ||
| } | ||
| s, _ := getTypeCodeName(method.Type()) | ||
| globalName := "reflect/types.signature:" + name + ":" + s | ||
| value := c.mod.NamedGlobal(globalName) | ||
| if value.IsNil() { | ||
| value = llvm.AddGlobal(c.mod, c.ctx.Int8Type(), globalName) | ||
| value.SetInitializer(llvm.ConstNull(c.ctx.Int8Type())) | ||
| value.SetGlobalConstant(true) | ||
| value.SetLinkage(llvm.LinkOnceODRLinkage) | ||
| value.SetAlignment(1) | ||
|
Comment on lines
+864
to
+868
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There might be ways to optimize this, since all it really needs is unique IDs. Anyway, just ideas for the future it looks good enough for now.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, so I did think about using IDs, but there were some challenges that made me not do that. Namely that I wasn't sure that it would be easy to DCE because with the pointers, at least I think the LLVM stack knows when something is unused, but with the IDs, not so much? For the debug info, I think I could try and do that quick, but if you don't mind it later I'm happy to wait (not sure if there's any rebasing or something needed for this PR or if it's going to just get squashed).
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, adding the debug info is acutally very easy, it's effectively just copy-paste from
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you! |
||
| if c.Debug { | ||
| file := c.getDIFile("<Go type>") | ||
| diglobal := c.dibuilder.CreateGlobalVariableExpression(file, llvm.DIGlobalVariableExpression{ | ||
| Name: globalName, | ||
| File: file, | ||
| Line: 1, | ||
| Type: c.getDIType(types.Typ[types.Uint8]), | ||
| LocalToUnit: false, | ||
| Expr: c.dibuilder.CreateExpression(nil), | ||
| AlignInBits: 8, | ||
| }) | ||
| value.AddMetadata(0, diglobal) | ||
| } | ||
| } | ||
| refs = append(refs, methodRef{globalName, value}) | ||
| } | ||
| sort.Slice(refs, func(i, j int) bool { | ||
| return refs[i].name < refs[j].name | ||
| }) | ||
|
|
||
| var values []llvm.Value | ||
| for _, ref := range refs { | ||
| values = append(values, ref.value) | ||
| } | ||
|
|
||
| return c.ctx.ConstStruct([]llvm.Value{ | ||
| llvm.ConstInt(c.uintptrType, uint64(len(values)), false), | ||
| llvm.ConstArray(c.dataPtrType, values), | ||
| }, false) | ||
| } | ||
|
|
||
| // getInvokeFunction returns the thunk to call the given interface method. The | ||
|
|
@@ -823,6 +921,24 @@ func (c *compilerContext) getInvokeFunction(instr *ssa.CallCommon) llvm.Value { | |
| return llvmFn | ||
| } | ||
|
|
||
| // createInterfaceTypeAssert creates a call to a declared-but-not-defined | ||
| // $typeassert function for the given interface. This function will be defined | ||
| // by the interface lowering pass as a type-ID comparison chain, avoiding the | ||
| // need for runtime.typeImplementsMethodSet at compile time. | ||
| func (b *builder) createInterfaceTypeAssert(intf *types.Interface, actualType llvm.Value) llvm.Value { | ||
| s, _ := getTypeCodeName(intf) | ||
| fnName := s + ".$typeassert" | ||
| llvmFn := b.mod.NamedFunction(fnName) | ||
| if llvmFn.IsNil() { | ||
| llvmFnType := llvm.FunctionType(b.ctx.Int1Type(), []llvm.Type{b.dataPtrType}, false) | ||
| llvmFn = llvm.AddFunction(b.mod, fnName, llvmFnType) | ||
| b.addStandardDeclaredAttributes(llvmFn) | ||
| methods := b.getMethodsString(intf) | ||
| llvmFn.AddFunctionAttr(b.ctx.CreateStringAttribute("tinygo-methods", methods)) | ||
| } | ||
| return b.CreateCall(llvmFn.GlobalValueType(), llvmFn, []llvm.Value{actualType}, "") | ||
| } | ||
|
|
||
| // getInterfaceInvokeWrapper returns a wrapper for the given method so it can be | ||
| // invoked from an interface. The wrapper takes in a pointer to the underlying | ||
| // value, dereferences or unpacks it if necessary, and calls the real method. | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.