@@ -4,6 +4,7 @@ package ast
44import (
55 "errors"
66 "fmt"
7+ "go/token"
78 "io/fs"
89 "os"
910 "path/filepath"
@@ -23,9 +24,11 @@ import (
2324const (
2425 builtinPkg = "builtin"
2526
26- genTypeSuffix = "_genType"
27- starGenTypeSuffix = "_starGenType"
28- testPkgSuffix = "_test"
27+ genTypeSuffix = "_genType"
28+ starGenTypeSuffix = "_starGenType"
29+ indexGenTypeSuffix = "_indexGenType"
30+ indexListGenTypeSuffix = "_indexListGenType"
31+ testPkgSuffix = "_test"
2932)
3033
3134//go:generate moqueries LoadFn
@@ -197,7 +200,7 @@ func (c *Cache) Type(id dst.Ident, contextPkg string, testImport bool) (TypeInfo
197200// IsComparable determines if an expression is comparable. The optional
198201// parentType can be used to supply type parameters.
199202func (c * Cache ) IsComparable (expr dst.Expr , parentType TypeInfo ) (bool , error ) {
200- return c .isDefaultComparable (expr , & parentType , true )
203+ return c .isDefaultComparable (expr , & parentType , true , false )
201204}
202205
203206// IsDefaultComparable determines if an expression is comparable. Returns the
@@ -206,7 +209,7 @@ func (c *Cache) IsComparable(expr dst.Expr, parentType TypeInfo) (bool, error) {
206209// map key will panic at runtime and by default pointers use a deep hash to be
207210// comparable).
208211func (c * Cache ) IsDefaultComparable (expr dst.Expr , parentType TypeInfo ) (bool , error ) {
209- return c .isDefaultComparable (expr , & parentType , false )
212+ return c .isDefaultComparable (expr , & parentType , false , false )
210213}
211214
212215// FindPackage finds the package for a given directory
@@ -368,32 +371,75 @@ func (c *Cache) isDefaultComparable(
368371 expr dst.Expr ,
369372 parentType * TypeInfo ,
370373 interfacePointerDefault bool ,
374+ genericType bool ,
371375) (bool , error ) {
376+ subInterfaceDefault := interfacePointerDefault
377+ if genericType {
378+ subInterfaceDefault = false
379+ }
372380 switch e := expr .(type ) {
373381 case * dst.ArrayType :
374382 if e .Len == nil {
375383 return false , nil
376384 }
377- return c .isDefaultComparable (e .Elt , parentType , interfacePointerDefault )
385+
386+ return c .isDefaultComparable (e .Elt , parentType , interfacePointerDefault , genericType )
387+ case * dst.BinaryExpr :
388+ comp , err := c .isDefaultComparable (e .X , parentType , interfacePointerDefault , genericType )
389+ if err != nil || ! comp {
390+ return comp , err
391+ }
392+
393+ return c .isDefaultComparable (e .Y , parentType , interfacePointerDefault , genericType )
378394 case * dst.Ellipsis :
379395 return false , nil
380396 case * dst.FuncType :
381397 return false , nil
382398 case * dst.InterfaceType :
383- return interfacePointerDefault , nil
384- case * dst.Ident :
385- if e .Obj != nil {
386- typ , ok := e .Obj .Decl .(* dst.TypeSpec )
387- if ! ok {
388- return false , fmt .Errorf ("%q: %w" , e .String (), ErrInvalidType )
399+ if e .Methods == nil || len (e .Methods .List ) == 0 {
400+ // Basically an "any" interface
401+ return subInterfaceDefault , nil
402+ }
403+ hasTypeConstraints := false
404+ for _ , m := range e .Methods .List {
405+ if _ , ok := m .Type .(* dst.FuncType ); ok {
406+ // Skip methods as they don't change whether something is
407+ // comparable
408+ continue
389409 }
390410
391- if typ .Name .Name == "string" && typ .Name .Path == "" {
392- return true , nil
411+ hasTypeConstraints = true
412+
413+ comp , err := c .isDefaultComparable (m .Type , parentType , subInterfaceDefault , genericType )
414+ if err != nil || ! comp {
415+ return comp , err
393416 }
417+ }
394418
395- return c .isDefaultComparable (typ .Type , parentType , interfacePointerDefault )
419+ if hasTypeConstraints {
420+ // If an interface has type constraints and none of them were not
421+ // comparable (none were because we would have returned early
422+ // above), then it is always comparable
423+ return true , nil
396424 }
425+
426+ return subInterfaceDefault , nil
427+ case * dst.Ident :
428+ // if e.Obj != nil {
429+ // var tExpr dst.Expr
430+ // switch typ := e.Obj.Decl.(type) {
431+ // case *dst.TypeSpec:
432+ // tExpr = typ.Type
433+ // case *dst.Field:
434+ // tExpr = typ.Type
435+ // default:
436+ // return false, fmt.Errorf("identity expression %q: %w", e.String(), ErrInvalidType)
437+ // }
438+ //
439+ // return c.isDefaultComparable(tExpr, parentType, "", interfacePointerDefault, false)
440+ // }
441+ // TODO: Generic type parameters should trump types in the cache (call
442+ // findGenericType first)
397443 pkgPath := e .Path
398444 typ , ok := c .typesByIdent [e .String ()]
399445 if ! ok && e .Path == "" && parentType != nil {
@@ -407,15 +453,27 @@ func (c *Cache) isDefaultComparable(
407453 Exported : isExported (e .Name , pkgPath ),
408454 Fabricated : false ,
409455 }
410- return c .isDefaultComparable (typ .typ .Type , tInfo , interfacePointerDefault )
456+ return c .isDefaultComparable (
457+ typ .typ .Type , tInfo , interfacePointerDefault , genericType )
411458 }
412459
413- // Builtin type?
414- if e .Path == "" {
415- // error is the one builtin type that may not be comparable (it's
460+ // Builtin or generic type?
461+ if e .Path == "" || (parentType != nil && parentType .Type != nil && e .Path == parentType .Type .Name .Path ) {
462+ // Precedence is given to a generic type
463+ gType := c .findGenericType (parentType , e .Name )
464+ if gType != nil {
465+ return c .isDefaultComparable (gType , parentType , interfacePointerDefault , true )
466+ }
467+
468+ // error is a builtin type that may not be comparable (it's
416469 // an interface so return the same result as an interface)
417470 if e .Name == "error" {
418- return interfacePointerDefault , nil
471+ return subInterfaceDefault , nil
472+ }
473+
474+ // any is an alias for interface{}, so again the default
475+ if e .Name == "any" {
476+ return subInterfaceDefault , nil
419477 }
420478
421479 return true , nil
@@ -434,7 +492,7 @@ func (c *Cache) isDefaultComparable(
434492 Exported : isExported (e .Name , e .Path ),
435493 Fabricated : false ,
436494 }
437- return c .isDefaultComparable (typ .typ .Type , tInfo , interfacePointerDefault )
495+ return c .isDefaultComparable (typ .typ .Type , tInfo , interfacePointerDefault , genericType )
438496 }
439497
440498 return true , nil
@@ -443,7 +501,7 @@ func (c *Cache) isDefaultComparable(
443501 case * dst.SelectorExpr :
444502 ex , ok := e .X .(* dst.Ident )
445503 if ! ok {
446- return false , fmt .Errorf ("%q: %w" , e .X , ErrInvalidType )
504+ return false , fmt .Errorf ("selector expression %q: %w" , e .X , ErrInvalidType )
447505 }
448506 path := ex .Name
449507 _ , err := c .loadPackage (path , false )
@@ -453,7 +511,7 @@ func (c *Cache) isDefaultComparable(
453511
454512 typ , ok := c .typesByIdent [IdPath (e .Sel .Name , path ).String ()]
455513 if ok {
456- return c .isDefaultComparable (typ .typ .Type , parentType , interfacePointerDefault )
514+ return c .isDefaultComparable (typ .typ .Type , nil , interfacePointerDefault , genericType )
457515 }
458516
459517 // Builtin type?
@@ -462,16 +520,96 @@ func (c *Cache) isDefaultComparable(
462520 return interfacePointerDefault , nil
463521 case * dst.StructType :
464522 for _ , f := range e .Fields .List {
465- comp , err := c .isDefaultComparable (f .Type , parentType , interfacePointerDefault )
523+ comp , err := c .isDefaultComparable (f .Type , parentType , interfacePointerDefault , genericType )
466524 if err != nil || ! comp {
467525 return false , err
468526 }
469527 }
528+ case * dst.UnaryExpr :
529+ if e .Op != token .TILDE {
530+ return false , fmt .Errorf (
531+ "unexpected unary operator %s: %w" , e .Op .String (), ErrInvalidType )
532+ }
533+ // This is a type constraint and for determining comparability, we
534+ // don't care if the constraint is for a type or underlying types
535+ return c .isDefaultComparable (e .X , parentType , interfacePointerDefault , genericType )
470536 }
471537
472538 return true , nil
473539}
474540
541+ func (c * Cache ) findGenericType (parentType * TypeInfo , paramTypeName string ) dst.Expr {
542+ if parentType == nil || parentType .Type == nil || parentType .Type .TypeParams == nil {
543+ return nil
544+ }
545+
546+ for _ , p := range parentType .Type .TypeParams .List {
547+ for _ , n := range p .Names {
548+ if n .Name == paramTypeName {
549+ return p .Type
550+ }
551+ }
552+ }
553+
554+ return nil
555+ }
556+
557+ // func (c *Cache) findMethodGenericType(fn *dst.FuncDecl, paramTypeName string) (dst.Expr, error) {
558+ // // Only handle methods here. Functions and structs have their Obj's intact
559+ // // and don't need to be looked up in another declaration
560+ // for _, r := range fn.Recv.List {
561+ // switch idxType := r.Type.(type) {
562+ // case *dst.IndexListExpr:
563+ // for n, iExpr := range idxType.Indices {
564+ // xId, ok := idxType.X.(*dst.Ident)
565+ // if !ok {
566+ // return nil, fmt.Errorf(
567+ // "expecting *dst.Ident in IndexListExpr.X: %w", ErrInvalidType)
568+ // }
569+ // gType, err := c.findIndexedGenericType(iExpr, paramTypeName, xId, n)
570+ // if err != nil || gType != nil {
571+ // return gType, err
572+ // }
573+ // }
574+ // case *dst.IndexExpr:
575+ // xId, ok := idxType.X.(*dst.Ident)
576+ // if !ok {
577+ // return nil, fmt.Errorf(
578+ // "expecting *dst.Ident in IndexExpr.X: %w", ErrInvalidType)
579+ // }
580+ // return c.findIndexedGenericType(idxType.Index, paramTypeName, xId, 0)
581+ // default:
582+ // return nil, fmt.Errorf(
583+ // "unexpected index type %#v: %w", idxType, ErrInvalidType)
584+ // }
585+ // }
586+ //
587+ // return nil, nil
588+ // }
589+
590+ // func (c *Cache) findIndexedGenericType(
591+ // iExpr dst.Expr, paramTypeName string, xId *dst.Ident, idx int,
592+ // ) (dst.Expr, error) {
593+ // if id, ok := iExpr.(*dst.Ident); ok && id.Name != paramTypeName {
594+ // return nil, nil
595+ // }
596+ //
597+ // if xId.Obj == nil {
598+ // return nil, fmt.Errorf(
599+ // "expecting Obj: %w", ErrInvalidType)
600+ // }
601+ // tSpec, ok := xId.Obj.Decl.(*dst.TypeSpec)
602+ // if !ok {
603+ // return nil, fmt.Errorf(
604+ // "expecting *dst.TypeSpec: %w", ErrInvalidType)
605+ // }
606+ // if tSpec.TypeParams == nil || len(tSpec.TypeParams.List) <= idx {
607+ // return nil, fmt.Errorf(
608+ // "base type to method type param mismatch: %w", ErrInvalidType)
609+ // }
610+ // return tSpec.TypeParams.List[idx].Type, nil
611+ // }
612+
475613func (c * Cache ) loadPackage (path string , testImport bool ) (string , error ) {
476614 indexPath := path
477615 if strings .HasPrefix (path , "." ) {
@@ -706,6 +844,14 @@ func (c *Cache) storeFuncDecl(decl *dst.FuncDecl, pkg *pkgInfo) {
706844 suffix = starGenTypeSuffix
707845 expr = sExpr .X
708846 }
847+ if iExpr , ok := expr .(* dst.IndexExpr ); ok {
848+ suffix = indexGenTypeSuffix
849+ expr = iExpr .X
850+ }
851+ if ilExpr , ok := expr .(* dst.IndexListExpr ); ok {
852+ suffix = indexListGenTypeSuffix
853+ expr = ilExpr .X
854+ }
709855 exprId , ok := expr .(* dst.Ident )
710856 if ! ok {
711857 logs .Panicf ("%s has a non-Ident (or StarExpr/Ident) receiver: %#v" ,
0 commit comments