Skip to content

Commit 3a4cc5f

Browse files
committed
Kitchen sink
1 parent 93d426d commit 3a4cc5f

29 files changed

Lines changed: 21083 additions & 7289 deletions

ast/cache.go

Lines changed: 204 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ package ast
44
import (
55
"errors"
66
"fmt"
7+
"go/token"
78
"io/fs"
89
"os"
910
"path/filepath"
@@ -23,9 +24,11 @@ import (
2324
const (
2425
builtinPkg = "builtin"
2526

26-
genTypeSuffix = "_genType"
27-
starGenTypeSuffix = "_starGenType"
28-
testPkgSuffix = "_test"
27+
genTypeSuffix = "_genType"
28+
starGenTypeSuffix = "_starGenType"
29+
indexGenTypeSuffix = "_indexGenType"
30+
indexListGenTypeSuffix = "_indexListGenType"
31+
testPkgSuffix = "_test"
2932
)
3033

3134
//go:generate moqueries LoadFn
@@ -194,18 +197,19 @@ func (c *Cache) Type(id dst.Ident, contextPkg string, testImport bool) (TypeInfo
194197
}, nil
195198
}
196199

197-
// IsComparable determines if an expression is comparable
198-
func (c *Cache) IsComparable(expr dst.Expr) (bool, error) {
199-
return c.isDefaultComparable(expr, true)
200+
// IsComparable determines if an expression is comparable. The optional
201+
// parentType can be used to supply type parameters.
202+
func (c *Cache) IsComparable(expr dst.Expr, parentType TypeInfo) (bool, error) {
203+
return c.isDefaultComparable(expr, &parentType, true, false)
200204
}
201205

202206
// IsDefaultComparable determines if an expression is comparable. Returns the
203207
// same results as IsComparable but pointers and interfaces are not comparable
204208
// by default (interface implementations that are not comparable and put into a
205209
// map key will panic at runtime and by default pointers use a deep hash to be
206210
// comparable).
207-
func (c *Cache) IsDefaultComparable(expr dst.Expr) (bool, error) {
208-
return c.isDefaultComparable(expr, false)
211+
func (c *Cache) IsDefaultComparable(expr dst.Expr, parentType TypeInfo) (bool, error) {
212+
return c.isDefaultComparable(expr, &parentType, false, false)
209213
}
210214

211215
// FindPackage finds the package for a given directory
@@ -363,43 +367,113 @@ func isExported(name, pkgPath string) bool {
363367
return false
364368
}
365369

366-
func (c *Cache) isDefaultComparable(expr dst.Expr, interfacePointerDefault bool) (bool, error) {
370+
func (c *Cache) isDefaultComparable(
371+
expr dst.Expr,
372+
parentType *TypeInfo,
373+
interfacePointerDefault bool,
374+
genericType bool,
375+
) (bool, error) {
376+
subInterfaceDefault := interfacePointerDefault
377+
if genericType {
378+
subInterfaceDefault = false
379+
}
367380
switch e := expr.(type) {
368381
case *dst.ArrayType:
369382
if e.Len == nil {
370383
return false, nil
371384
}
372-
return c.isDefaultComparable(e.Elt, interfacePointerDefault)
373-
case *dst.MapType, *dst.Ellipsis, *dst.FuncType:
385+
386+
return c.isDefaultComparable(e.Elt, parentType, interfacePointerDefault, genericType)
387+
case *dst.BinaryExpr:
388+
comp, err := c.isDefaultComparable(e.X, parentType, interfacePointerDefault, genericType)
389+
if err != nil || !comp {
390+
return comp, err
391+
}
392+
393+
return c.isDefaultComparable(e.Y, parentType, interfacePointerDefault, genericType)
394+
case *dst.Ellipsis:
395+
return false, nil
396+
case *dst.FuncType:
374397
return false, nil
375-
case *dst.StarExpr:
376-
return interfacePointerDefault, nil
377398
case *dst.InterfaceType:
378-
return interfacePointerDefault, nil
379-
case *dst.Ident:
380-
if e.Obj != nil {
381-
typ, ok := e.Obj.Decl.(*dst.TypeSpec)
382-
if !ok {
383-
return false, fmt.Errorf("%q: %w", e.String(), ErrInvalidType)
399+
if e.Methods == nil || len(e.Methods.List) == 0 {
400+
// Basically an "any" interface
401+
return subInterfaceDefault, nil
402+
}
403+
hasTypeConstraints := false
404+
for _, m := range e.Methods.List {
405+
if _, ok := m.Type.(*dst.FuncType); ok {
406+
// Skip methods as the don't change whether something is
407+
// comparable
408+
continue
384409
}
385410

386-
if typ.Name.Name == "string" && typ.Name.Path == "" {
387-
return true, nil
411+
hasTypeConstraints = true
412+
413+
comp, err := c.isDefaultComparable(m.Type, parentType, subInterfaceDefault, genericType)
414+
if err != nil || !comp {
415+
return comp, err
388416
}
417+
}
389418

390-
return c.isDefaultComparable(typ.Type, interfacePointerDefault)
419+
if hasTypeConstraints {
420+
// If an interface has type constraints and none of them were not
421+
// comparable (none were because we would have returned early
422+
// above), then it is always comparable
423+
return true, nil
391424
}
425+
426+
return subInterfaceDefault, nil
427+
case *dst.Ident:
428+
// if e.Obj != nil {
429+
// var tExpr dst.Expr
430+
// switch typ := e.Obj.Decl.(type) {
431+
// case *dst.TypeSpec:
432+
// tExpr = typ.Type
433+
// case *dst.Field:
434+
// tExpr = typ.Type
435+
// default:
436+
// return false, fmt.Errorf("identity expression %q: %w", e.String(), ErrInvalidType)
437+
// }
438+
//
439+
// return c.isDefaultComparable(tExpr, parentType, "", interfacePointerDefault, false)
440+
// }
441+
// TODO: Generic type parameters should trump types in the cache (call
442+
// findGenericType first)
443+
pkgPath := e.Path
392444
typ, ok := c.typesByIdent[e.String()]
445+
if !ok && e.Path == "" && parentType != nil {
446+
pkgPath = parentType.PkgPath
447+
typ, ok = c.typesByIdent[IdPath(e.Name, parentType.PkgPath).String()]
448+
}
393449
if ok {
394-
return c.isDefaultComparable(typ.typ.Type, interfacePointerDefault)
450+
tInfo := &TypeInfo{
451+
Type: typ.typ,
452+
PkgPath: pkgPath,
453+
Exported: isExported(e.Name, pkgPath),
454+
Fabricated: false,
455+
}
456+
return c.isDefaultComparable(
457+
typ.typ.Type, tInfo, interfacePointerDefault, genericType)
395458
}
396459

397-
// Builtin type?
398-
if e.Path == "" {
399-
// error is the one builtin type that may not be comparable (it's
460+
// Builtin or generic type?
461+
if e.Path == "" || (parentType != nil && parentType.Type != nil && e.Path == parentType.Type.Name.Path) {
462+
// Precedence is given to a generic type
463+
gType := c.findGenericType(parentType, e.Name)
464+
if gType != nil {
465+
return c.isDefaultComparable(gType, parentType, interfacePointerDefault, true)
466+
}
467+
468+
// error is a builtin type that may not be comparable (it's
400469
// an interface so return the same result as an interface)
401470
if e.Name == "error" {
402-
return interfacePointerDefault, nil
471+
return subInterfaceDefault, nil
472+
}
473+
474+
// any is an alias for interface{}, so again the default
475+
if e.Name == "any" {
476+
return subInterfaceDefault, nil
403477
}
404478

405479
return true, nil
@@ -412,14 +486,22 @@ func (c *Cache) isDefaultComparable(expr dst.Expr, interfacePointerDefault bool)
412486

413487
typ, ok = c.typesByIdent[e.String()]
414488
if ok {
415-
return c.isDefaultComparable(typ.typ.Type, interfacePointerDefault)
489+
tInfo := &TypeInfo{
490+
Type: typ.typ,
491+
PkgPath: e.Path,
492+
Exported: isExported(e.Name, e.Path),
493+
Fabricated: false,
494+
}
495+
return c.isDefaultComparable(typ.typ.Type, tInfo, interfacePointerDefault, genericType)
416496
}
417497

418498
return true, nil
499+
case *dst.MapType:
500+
return false, nil
419501
case *dst.SelectorExpr:
420502
ex, ok := e.X.(*dst.Ident)
421503
if !ok {
422-
return false, fmt.Errorf("%q: %w", e.X, ErrInvalidType)
504+
return false, fmt.Errorf("selector expression %q: %w", e.X, ErrInvalidType)
423505
}
424506
path := ex.Name
425507
_, err := c.loadPackage(path, false)
@@ -429,23 +511,105 @@ func (c *Cache) isDefaultComparable(expr dst.Expr, interfacePointerDefault bool)
429511

430512
typ, ok := c.typesByIdent[IdPath(e.Sel.Name, path).String()]
431513
if ok {
432-
return c.isDefaultComparable(typ.typ.Type, interfacePointerDefault)
514+
return c.isDefaultComparable(typ.typ.Type, nil, interfacePointerDefault, genericType)
433515
}
434516

435517
// Builtin type?
436518
return true, nil
519+
case *dst.StarExpr:
520+
return interfacePointerDefault, nil
437521
case *dst.StructType:
438522
for _, f := range e.Fields.List {
439-
comp, err := c.isDefaultComparable(f.Type, interfacePointerDefault)
523+
comp, err := c.isDefaultComparable(f.Type, parentType, interfacePointerDefault, genericType)
440524
if err != nil || !comp {
441525
return false, err
442526
}
443527
}
528+
case *dst.UnaryExpr:
529+
if e.Op != token.TILDE {
530+
return false, fmt.Errorf(
531+
"unexpected unary operator %s: %w", e.Op.String(), ErrInvalidType)
532+
}
533+
// This is a type constraint and for determining comparability, we
534+
// don't care if the constraint is for a type or underlying types
535+
return c.isDefaultComparable(e.X, parentType, interfacePointerDefault, genericType)
444536
}
445537

446538
return true, nil
447539
}
448540

541+
func (c *Cache) findGenericType(parentType *TypeInfo, paramTypeName string) dst.Expr {
542+
if parentType == nil || parentType.Type == nil || parentType.Type.TypeParams == nil {
543+
return nil
544+
}
545+
546+
for _, p := range parentType.Type.TypeParams.List {
547+
for _, n := range p.Names {
548+
if n.Name == paramTypeName {
549+
return p.Type
550+
}
551+
}
552+
}
553+
554+
return nil
555+
}
556+
557+
// func (c *Cache) findMethodGenericType(fn *dst.FuncDecl, paramTypeName string) (dst.Expr, error) {
558+
// // Only handle methods here. Functions and structs have their Obj's intact
559+
// // and don't need to be looked up in another declaration
560+
// for _, r := range fn.Recv.List {
561+
// switch idxType := r.Type.(type) {
562+
// case *dst.IndexListExpr:
563+
// for n, iExpr := range idxType.Indices {
564+
// xId, ok := idxType.X.(*dst.Ident)
565+
// if !ok {
566+
// return nil, fmt.Errorf(
567+
// "expecting *dst.Ident in IndexListExpr.X: %w", ErrInvalidType)
568+
// }
569+
// gType, err := c.findIndexedGenericType(iExpr, paramTypeName, xId, n)
570+
// if err != nil || gType != nil {
571+
// return gType, err
572+
// }
573+
// }
574+
// case *dst.IndexExpr:
575+
// xId, ok := idxType.X.(*dst.Ident)
576+
// if !ok {
577+
// return nil, fmt.Errorf(
578+
// "expecting *dst.Ident in IndexExpr.X: %w", ErrInvalidType)
579+
// }
580+
// return c.findIndexedGenericType(idxType.Index, paramTypeName, xId, 0)
581+
// default:
582+
// return nil, fmt.Errorf(
583+
// "unexpected index type %#v: %w", idxType, ErrInvalidType)
584+
// }
585+
// }
586+
//
587+
// return nil, nil
588+
// }
589+
//
590+
func (c *Cache) findIndexedGenericType(
591+
iExpr dst.Expr, paramTypeName string, xId *dst.Ident, idx int,
592+
) (dst.Expr, error) {
593+
if id, ok := iExpr.(*dst.Ident); ok && id.Name != paramTypeName {
594+
return nil, nil
595+
}
596+
597+
if xId.Obj == nil {
598+
return nil, fmt.Errorf(
599+
"expecting Obj: %w", ErrInvalidType)
600+
}
601+
tSpec, ok := xId.Obj.Decl.(*dst.TypeSpec)
602+
if !ok {
603+
return nil, fmt.Errorf(
604+
"expecting *dst.TypeSpec: %w", ErrInvalidType)
605+
}
606+
if tSpec.TypeParams == nil || len(tSpec.TypeParams.List) <= idx {
607+
return nil, fmt.Errorf(
608+
"base type to method type param mismatch: %w", ErrInvalidType)
609+
}
610+
return tSpec.TypeParams.List[idx].Type, nil
611+
}
612+
449613
func (c *Cache) loadPackage(path string, testImport bool) (string, error) {
450614
indexPath := path
451615
if strings.HasPrefix(path, ".") {
@@ -501,19 +665,6 @@ func (c *Cache) loadTypes(loadPkg string, testImport bool) (string, error) {
501665
}
502666

503667
func (c *Cache) loadAST(loadPkg string, testImport bool) ([]*pkgInfo, error) {
504-
if dp, ok := c.loadedPkgs[loadPkg]; ok {
505-
// If we already loaded the test types or if the test types aren't
506-
// requested, we're done
507-
if dp.loadTestPkgs || !testImport {
508-
// If we direct loaded, we're done
509-
if dp.directLoaded {
510-
c.metrics.ASTTypeCacheHitsInc()
511-
return []*pkgInfo{dp}, nil
512-
}
513-
}
514-
}
515-
c.metrics.ASTTypeCacheMissesInc()
516-
517668
start := time.Now()
518669
pkgs, err := c.load(&packages.Config{
519670
Mode: packages.NeedName |
@@ -597,7 +748,7 @@ func (c *Cache) convert(pkg *packages.Package, testImport, directLoaded bool) (*
597748

598749
start := time.Now()
599750
p.pkg.Decorator = decorator.NewDecoratorFromPackage(pkg)
600-
p.pkg.Decorator.ResolveLocalPath = true
751+
// p.pkg.Decorator.ResolveLocalPath = true
601752
for _, f := range pkg.Syntax {
602753
fpath := pkg.Fset.File(f.Pos()).Name()
603754
if !goFiles[fpath] {
@@ -694,6 +845,14 @@ func (c *Cache) storeFuncDecl(decl *dst.FuncDecl, pkg *pkgInfo) {
694845
suffix = starGenTypeSuffix
695846
expr = sExpr.X
696847
}
848+
if iExpr, ok := expr.(*dst.IndexExpr); ok {
849+
suffix = indexGenTypeSuffix
850+
expr = iExpr.X
851+
}
852+
if ilExpr, ok := expr.(*dst.IndexListExpr); ok {
853+
suffix = indexListGenTypeSuffix
854+
expr = ilExpr.X
855+
}
697856
exprId, ok := expr.(*dst.Ident)
698857
if !ok {
699858
logs.Panicf("%s has a non-Ident (or StarExpr/Ident) receiver: %#v",

0 commit comments

Comments
 (0)