@@ -4,6 +4,7 @@ package ast
44import (
55 "errors"
66 "fmt"
7+ "go/token"
78 "io/fs"
89 "os"
910 "path/filepath"
@@ -23,9 +24,11 @@ import (
2324const (
2425 builtinPkg = "builtin"
2526
26- genTypeSuffix = "_genType"
27- starGenTypeSuffix = "_starGenType"
28- testPkgSuffix = "_test"
27+ genTypeSuffix = "_genType"
28+ starGenTypeSuffix = "_starGenType"
29+ indexGenTypeSuffix = "_indexGenType"
30+ indexListGenTypeSuffix = "_indexListGenType"
31+ testPkgSuffix = "_test"
2932)
3033
3134//go:generate moqueries LoadFn
@@ -194,18 +197,19 @@ func (c *Cache) Type(id dst.Ident, contextPkg string, testImport bool) (TypeInfo
194197 }, nil
195198}
196199
197- // IsComparable determines if an expression is comparable
198- func (c * Cache ) IsComparable (expr dst.Expr ) (bool , error ) {
199- return c .isDefaultComparable (expr , true )
200+ // IsComparable determines if an expression is comparable. The optional
201+ // parentType can be used to supply type parameters.
202+ func (c * Cache ) IsComparable (expr dst.Expr , parentType TypeInfo ) (bool , error ) {
203+ return c .isDefaultComparable (expr , & parentType , true , false )
200204}
201205
202206// IsDefaultComparable determines if an expression is comparable. Returns the
203207// same results as IsComparable but pointers and interfaces are not comparable
204208// by default (interface implementations that are not comparable and put into a
205209// map key will panic at runtime and by default pointers use a deep hash to be
206210// comparable).
207- func (c * Cache ) IsDefaultComparable (expr dst.Expr ) (bool , error ) {
208- return c .isDefaultComparable (expr , false )
211+ func (c * Cache ) IsDefaultComparable (expr dst.Expr , parentType TypeInfo ) (bool , error ) {
212+ return c .isDefaultComparable (expr , & parentType , false , false )
209213}
210214
211215// FindPackage finds the package for a given directory
@@ -363,43 +367,113 @@ func isExported(name, pkgPath string) bool {
363367 return false
364368}
365369
366- func (c * Cache ) isDefaultComparable (expr dst.Expr , interfacePointerDefault bool ) (bool , error ) {
370+ func (c * Cache ) isDefaultComparable (
371+ expr dst.Expr ,
372+ parentType * TypeInfo ,
373+ interfacePointerDefault bool ,
374+ genericType bool ,
375+ ) (bool , error ) {
376+ subInterfaceDefault := interfacePointerDefault
377+ if genericType {
378+ subInterfaceDefault = false
379+ }
367380 switch e := expr .(type ) {
368381 case * dst.ArrayType :
369382 if e .Len == nil {
370383 return false , nil
371384 }
372- return c .isDefaultComparable (e .Elt , interfacePointerDefault )
373- case * dst.MapType , * dst.Ellipsis , * dst.FuncType :
385+
386+ return c .isDefaultComparable (e .Elt , parentType , interfacePointerDefault , genericType )
387+ case * dst.BinaryExpr :
388+ comp , err := c .isDefaultComparable (e .X , parentType , interfacePointerDefault , genericType )
389+ if err != nil || ! comp {
390+ return comp , err
391+ }
392+
393+ return c .isDefaultComparable (e .Y , parentType , interfacePointerDefault , genericType )
394+ case * dst.Ellipsis :
395+ return false , nil
396+ case * dst.FuncType :
374397 return false , nil
375- case * dst.StarExpr :
376- return interfacePointerDefault , nil
377398 case * dst.InterfaceType :
378- return interfacePointerDefault , nil
379- case * dst.Ident :
380- if e .Obj != nil {
381- typ , ok := e .Obj .Decl .(* dst.TypeSpec )
382- if ! ok {
383- return false , fmt .Errorf ("%q: %w" , e .String (), ErrInvalidType )
399+ if e .Methods == nil || len (e .Methods .List ) == 0 {
400+ // Basically an "any" interface
401+ return subInterfaceDefault , nil
402+ }
403+ hasTypeConstraints := false
404+ for _ , m := range e .Methods .List {
405+ if _ , ok := m .Type .(* dst.FuncType ); ok {
406+ // Skip methods as the don't change whether something is
407+ // comparable
408+ continue
384409 }
385410
386- if typ .Name .Name == "string" && typ .Name .Path == "" {
387- return true , nil
411+ hasTypeConstraints = true
412+
413+ comp , err := c .isDefaultComparable (m .Type , parentType , subInterfaceDefault , genericType )
414+ if err != nil || ! comp {
415+ return comp , err
388416 }
417+ }
389418
390- return c .isDefaultComparable (typ .Type , interfacePointerDefault )
419+ if hasTypeConstraints {
420+ // If an interface has type constraints and none of them were not
421+ // comparable (none were because we would have returned early
422+ // above), then it is always comparable
423+ return true , nil
391424 }
425+
426+ return subInterfaceDefault , nil
427+ case * dst.Ident :
428+ // if e.Obj != nil {
429+ // var tExpr dst.Expr
430+ // switch typ := e.Obj.Decl.(type) {
431+ // case *dst.TypeSpec:
432+ // tExpr = typ.Type
433+ // case *dst.Field:
434+ // tExpr = typ.Type
435+ // default:
436+ // return false, fmt.Errorf("identity expression %q: %w", e.String(), ErrInvalidType)
437+ // }
438+ //
439+ // return c.isDefaultComparable(tExpr, parentType, "", interfacePointerDefault, false)
440+ // }
441+ // TODO: Generic type parameters should trump types in the cache (call
442+ // findGenericType first)
443+ pkgPath := e .Path
392444 typ , ok := c .typesByIdent [e .String ()]
445+ if ! ok && e .Path == "" && parentType != nil {
446+ pkgPath = parentType .PkgPath
447+ typ , ok = c .typesByIdent [IdPath (e .Name , parentType .PkgPath ).String ()]
448+ }
393449 if ok {
394- return c .isDefaultComparable (typ .typ .Type , interfacePointerDefault )
450+ tInfo := & TypeInfo {
451+ Type : typ .typ ,
452+ PkgPath : pkgPath ,
453+ Exported : isExported (e .Name , pkgPath ),
454+ Fabricated : false ,
455+ }
456+ return c .isDefaultComparable (
457+ typ .typ .Type , tInfo , interfacePointerDefault , genericType )
395458 }
396459
397- // Builtin type?
398- if e .Path == "" {
399- // error is the one builtin type that may not be comparable (it's
460+ // Builtin or generic type?
461+ if e .Path == "" || (parentType != nil && parentType .Type != nil && e .Path == parentType .Type .Name .Path ) {
462+ // Precedence is given to a generic type
463+ gType := c .findGenericType (parentType , e .Name )
464+ if gType != nil {
465+ return c .isDefaultComparable (gType , parentType , interfacePointerDefault , true )
466+ }
467+
468+ // error is a builtin type that may not be comparable (it's
400469 // an interface so return the same result as an interface)
401470 if e .Name == "error" {
402- return interfacePointerDefault , nil
471+ return subInterfaceDefault , nil
472+ }
473+
474+ // any is an alias for interface{}, so again the default
475+ if e .Name == "any" {
476+ return subInterfaceDefault , nil
403477 }
404478
405479 return true , nil
@@ -412,14 +486,22 @@ func (c *Cache) isDefaultComparable(expr dst.Expr, interfacePointerDefault bool)
412486
413487 typ , ok = c .typesByIdent [e .String ()]
414488 if ok {
415- return c .isDefaultComparable (typ .typ .Type , interfacePointerDefault )
489+ tInfo := & TypeInfo {
490+ Type : typ .typ ,
491+ PkgPath : e .Path ,
492+ Exported : isExported (e .Name , e .Path ),
493+ Fabricated : false ,
494+ }
495+ return c .isDefaultComparable (typ .typ .Type , tInfo , interfacePointerDefault , genericType )
416496 }
417497
418498 return true , nil
499+ case * dst.MapType :
500+ return false , nil
419501 case * dst.SelectorExpr :
420502 ex , ok := e .X .(* dst.Ident )
421503 if ! ok {
422- return false , fmt .Errorf ("%q: %w" , e .X , ErrInvalidType )
504+ return false , fmt .Errorf ("selector expression %q: %w" , e .X , ErrInvalidType )
423505 }
424506 path := ex .Name
425507 _ , err := c .loadPackage (path , false )
@@ -429,23 +511,105 @@ func (c *Cache) isDefaultComparable(expr dst.Expr, interfacePointerDefault bool)
429511
430512 typ , ok := c .typesByIdent [IdPath (e .Sel .Name , path ).String ()]
431513 if ok {
432- return c .isDefaultComparable (typ .typ .Type , interfacePointerDefault )
514+ return c .isDefaultComparable (typ .typ .Type , nil , interfacePointerDefault , genericType )
433515 }
434516
435517 // Builtin type?
436518 return true , nil
519+ case * dst.StarExpr :
520+ return interfacePointerDefault , nil
437521 case * dst.StructType :
438522 for _ , f := range e .Fields .List {
439- comp , err := c .isDefaultComparable (f .Type , interfacePointerDefault )
523+ comp , err := c .isDefaultComparable (f .Type , parentType , interfacePointerDefault , genericType )
440524 if err != nil || ! comp {
441525 return false , err
442526 }
443527 }
528+ case * dst.UnaryExpr :
529+ if e .Op != token .TILDE {
530+ return false , fmt .Errorf (
531+ "unexpected unary operator %s: %w" , e .Op .String (), ErrInvalidType )
532+ }
533+ // This is a type constraint and for determining comparability, we
534+ // don't care if the constraint is for a type or underlying types
535+ return c .isDefaultComparable (e .X , parentType , interfacePointerDefault , genericType )
444536 }
445537
446538 return true , nil
447539}
448540
541+ func (c * Cache ) findGenericType (parentType * TypeInfo , paramTypeName string ) dst.Expr {
542+ if parentType == nil || parentType .Type == nil || parentType .Type .TypeParams == nil {
543+ return nil
544+ }
545+
546+ for _ , p := range parentType .Type .TypeParams .List {
547+ for _ , n := range p .Names {
548+ if n .Name == paramTypeName {
549+ return p .Type
550+ }
551+ }
552+ }
553+
554+ return nil
555+ }
556+
557+ // func (c *Cache) findMethodGenericType(fn *dst.FuncDecl, paramTypeName string) (dst.Expr, error) {
558+ // // Only handle methods here. Functions and structs have their Obj's intact
559+ // // and don't need to be looked up in another declaration
560+ // for _, r := range fn.Recv.List {
561+ // switch idxType := r.Type.(type) {
562+ // case *dst.IndexListExpr:
563+ // for n, iExpr := range idxType.Indices {
564+ // xId, ok := idxType.X.(*dst.Ident)
565+ // if !ok {
566+ // return nil, fmt.Errorf(
567+ // "expecting *dst.Ident in IndexListExpr.X: %w", ErrInvalidType)
568+ // }
569+ // gType, err := c.findIndexedGenericType(iExpr, paramTypeName, xId, n)
570+ // if err != nil || gType != nil {
571+ // return gType, err
572+ // }
573+ // }
574+ // case *dst.IndexExpr:
575+ // xId, ok := idxType.X.(*dst.Ident)
576+ // if !ok {
577+ // return nil, fmt.Errorf(
578+ // "expecting *dst.Ident in IndexExpr.X: %w", ErrInvalidType)
579+ // }
580+ // return c.findIndexedGenericType(idxType.Index, paramTypeName, xId, 0)
581+ // default:
582+ // return nil, fmt.Errorf(
583+ // "unexpected index type %#v: %w", idxType, ErrInvalidType)
584+ // }
585+ // }
586+ //
587+ // return nil, nil
588+ // }
589+ //
590+ func (c * Cache ) findIndexedGenericType (
591+ iExpr dst.Expr , paramTypeName string , xId * dst.Ident , idx int ,
592+ ) (dst.Expr , error ) {
593+ if id , ok := iExpr .(* dst.Ident ); ok && id .Name != paramTypeName {
594+ return nil , nil
595+ }
596+
597+ if xId .Obj == nil {
598+ return nil , fmt .Errorf (
599+ "expecting Obj: %w" , ErrInvalidType )
600+ }
601+ tSpec , ok := xId .Obj .Decl .(* dst.TypeSpec )
602+ if ! ok {
603+ return nil , fmt .Errorf (
604+ "expecting *dst.TypeSpec: %w" , ErrInvalidType )
605+ }
606+ if tSpec .TypeParams == nil || len (tSpec .TypeParams .List ) <= idx {
607+ return nil , fmt .Errorf (
608+ "base type to method type param mismatch: %w" , ErrInvalidType )
609+ }
610+ return tSpec .TypeParams .List [idx ].Type , nil
611+ }
612+
449613func (c * Cache ) loadPackage (path string , testImport bool ) (string , error ) {
450614 indexPath := path
451615 if strings .HasPrefix (path , "." ) {
@@ -501,19 +665,6 @@ func (c *Cache) loadTypes(loadPkg string, testImport bool) (string, error) {
501665}
502666
503667func (c * Cache ) loadAST (loadPkg string , testImport bool ) ([]* pkgInfo , error ) {
504- if dp , ok := c .loadedPkgs [loadPkg ]; ok {
505- // If we already loaded the test types or if the test types aren't
506- // requested, we're done
507- if dp .loadTestPkgs || ! testImport {
508- // If we direct loaded, we're done
509- if dp .directLoaded {
510- c .metrics .ASTTypeCacheHitsInc ()
511- return []* pkgInfo {dp }, nil
512- }
513- }
514- }
515- c .metrics .ASTTypeCacheMissesInc ()
516-
517668 start := time .Now ()
518669 pkgs , err := c .load (& packages.Config {
519670 Mode : packages .NeedName |
@@ -597,7 +748,7 @@ func (c *Cache) convert(pkg *packages.Package, testImport, directLoaded bool) (*
597748
598749 start := time .Now ()
599750 p .pkg .Decorator = decorator .NewDecoratorFromPackage (pkg )
600- p .pkg .Decorator .ResolveLocalPath = true
751+ // p.pkg.Decorator.ResolveLocalPath = true
601752 for _ , f := range pkg .Syntax {
602753 fpath := pkg .Fset .File (f .Pos ()).Name ()
603754 if ! goFiles [fpath ] {
@@ -694,6 +845,14 @@ func (c *Cache) storeFuncDecl(decl *dst.FuncDecl, pkg *pkgInfo) {
694845 suffix = starGenTypeSuffix
695846 expr = sExpr .X
696847 }
848+ if iExpr , ok := expr .(* dst.IndexExpr ); ok {
849+ suffix = indexGenTypeSuffix
850+ expr = iExpr .X
851+ }
852+ if ilExpr , ok := expr .(* dst.IndexListExpr ); ok {
853+ suffix = indexListGenTypeSuffix
854+ expr = ilExpr .X
855+ }
697856 exprId , ok := expr .(* dst.Ident )
698857 if ! ok {
699858 logs .Panicf ("%s has a non-Ident (or StarExpr/Ident) receiver: %#v" ,
0 commit comments