diff --git a/assertion/function/assertiontree/backprop.go b/assertion/function/assertiontree/backprop.go index 07f3c2a0..00cb7581 100644 --- a/assertion/function/assertiontree/backprop.go +++ b/assertion/function/assertiontree/backprop.go @@ -491,7 +491,7 @@ func backpropAcrossRange(rootNode *RootAssertionNode, lhs []ast.Expr, rhs ast.Ex case 1: // If we are ranging over a map slice or string with only a single lhs operand, then that // operand will be int-valued. - if typeshelper.IsDeeplyMap(rhsType) || typeshelper.IsDeeplySlice(rhsType) || typeshelper.IsDeeplyArray(rhsType) || typeIsString(rhsType) { + if typeshelper.IsDeeplyMap(rhsType) || typeshelper.IsDeeplySlice(rhsType) || typeshelper.IsDeeplyArrayOrArrayPtr(rhsType) || typeIsString(rhsType) { produceAsIndex(0) return nil } diff --git a/assertion/function/assertiontree/parse_expr_producer.go b/assertion/function/assertiontree/parse_expr_producer.go index 6801a2e4..7879c303 100644 --- a/assertion/function/assertiontree/parse_expr_producer.go +++ b/assertion/function/assertiontree/parse_expr_producer.go @@ -479,7 +479,7 @@ func (r *RootAssertionNode) ParseExprAsProducer(expr ast.Expr, doNotTrack bool) // storage. For example, `var a [4]int; _ = a[:0]` is a nonnil (empty) slice. This holds // regardless of the indices, so we must check it before the `b[_:0:_]` case below (which // would otherwise wrongly treat `a[:0]` as a nilable empty slice). - if typeshelper.IsDeeplyArray(r.Pass().TypesInfo.Types[expr.X].Type) { + if typeshelper.IsDeeplyArrayOrArrayPtr(r.Pass().TypesInfo.TypeOf(expr.X)) { // Returning nil to indicate the slice expression results in a nonnil slice. return nil, nil } diff --git a/testdata/src/go.uber.org/slices/inference/slices-with-inference.go b/testdata/src/go.uber.org/slices/inference/slices-with-inference.go index 07c1d878..c5938a8e 100644 --- a/testdata/src/go.uber.org/slices/inference/slices-with-inference.go +++ b/testdata/src/go.uber.org/slices/inference/slices-with-inference.go @@ -96,3 +96,26 @@ func helperReturnNonZeroSlicingNonNilProducerForNilableParam(b []int) []int { func helperReturnNonZeroSlicingNonNilProducerForNonNilParam(b []int) []int { return b[1:3] } + +// Aliases of slice types must behave like the aliased slice type itself: indexing a nilable +// alias-of-slice value gets the same slice-access check (aliases are materialized as +// *types.Alias since Go 1.23, so they must be explicitly resolved when classifying the operand). +type sliceAlias = []int + +var sliceAliasDummy bool + +func mkSliceAlias() sliceAlias { + if sliceAliasDummy { + return nil + } + return []int{1} +} + +func testSliceAliasIndex() int { + s := mkSliceAlias() + if s != nil { + return s[0] + } + t := mkSliceAlias() + return t[0] //want "sliced into" +} diff --git a/util/typeshelper/typeshelper.go b/util/typeshelper/typeshelper.go index 7ae5cf31..7044717e 100644 --- a/util/typeshelper/typeshelper.go +++ b/util/typeshelper/typeshelper.go @@ -75,99 +75,137 @@ func IsSlice(t types.Type) bool { } } -// IsDeeplyArray returns true if `t` is of array type, including -// transitively through Named types +// IsDeeplyArray returns true if `t` is of array type, including transitively through named +// types and aliases, as well as type parameters whose type sets contain only array types. func IsDeeplyArray(t types.Type) bool { - switch tt := UnwrapPtr(t).(type) { - case *types.Array: - return true - case *types.Named: - return IsDeeplyArray(tt.Underlying()) - } - return false + return underlyingIs[*types.Array](t) } -// IsDeeplySlice returns true if `t` is of slice type, including -// transitively through Named types -func IsDeeplySlice(t types.Type) bool { - if IsSlice(t) { +// IsDeeplyArrayOrArrayPtr is like IsDeeplyArray, but additionally accepts pointers to arrays +// (again resolving named types, aliases, and type parameters). Slice expressions and range +// statements auto-dereference pointers to arrays, so for them an operand of either type +// behaves like an array. +func IsDeeplyArrayOrArrayPtr(t types.Type) bool { + return underlyingAlwaysSatisfies(t, func(u types.Type) bool { + if ptr, ok := u.(*types.Pointer); ok { + u = ptr.Elem().Underlying() + } + _, ok := u.(*types.Array) + return ok + }) +} + +// underlyingIs reports whether the underlying type of `t` (resolved as described in +// underlyingAlwaysSatisfies) is a T. +func underlyingIs[T types.Type](t types.Type) bool { + return underlyingAlwaysSatisfies(t, func(u types.Type) bool { + _, ok := u.(T) + return ok + }) +} + +// underlyingAlwaysSatisfies reports whether the underlying type of `t` satisfies pred. Named +// types and aliases are resolved via Underlying(). For type parameters, the underlying type of +// every term in the constraint's type set must satisfy pred. Since the elements of a constraint +// interface intersect, this is conservative: the actual type set can only be smaller than the +// enumerated terms, and type sets we cannot enumerate (such as method-only constraints) yield +// false. +func underlyingAlwaysSatisfies(t types.Type, pred func(types.Type) bool) bool { + if t == nil { + return false + } + if tp, ok := types.Unalias(t).(*types.TypeParam); ok { + iface, isIface := tp.Constraint().Underlying().(*types.Interface) + if !isIface { + return false + } + terms := constraintTerms(iface) + if len(terms) == 0 { + return false + } + for _, term := range terms { + if !pred(term.Underlying()) { + return false + } + } return true } - if t, ok := t.(*types.Named); ok { - return IsDeeplySlice(t.Underlying()) + return pred(t.Underlying()) +} + +// constraintTerms returns the types of all type terms of the constraint interface `iface`, +// recursing into embedded interfaces, both standalone (`interface{ Elem }`) and as union terms +// (`interface{ Elem | ~[8]int }` -- go/types does not flatten interface-typed union terms, and +// such terms are necessarily method-less per the spec). Method-only elements contribute no terms. +func constraintTerms(iface *types.Interface) []types.Type { + var terms []types.Type + for i := 0; i < iface.NumEmbeddeds(); i++ { + switch e := types.Unalias(iface.EmbeddedType(i)).(type) { + case *types.Union: + for j := 0; j < e.Len(); j++ { + if emb, isIface := e.Term(j).Type().Underlying().(*types.Interface); isIface { + terms = append(terms, constraintTerms(emb)...) + } else { + terms = append(terms, e.Term(j).Type()) + } + } + default: + if emb, isIface := e.Underlying().(*types.Interface); isIface { + terms = append(terms, constraintTerms(emb)...) + } else { + terms = append(terms, e) + } + } } - return false + return terms +} + +// IsDeeplySlice returns true if `t` is of slice type, including transitively through named +// types and aliases, as well as type parameters whose type sets contain only slice types. +func IsDeeplySlice(t types.Type) bool { + return underlyingIs[*types.Slice](t) } -// IsDeeplyMap returns true if `t` is of map type, including -// transitively through Named types +// IsDeeplyMap returns true if `t` is of map type, including transitively through named types +// and aliases, as well as type parameters whose type sets contain only map types. func IsDeeplyMap(t types.Type) bool { - if _, ok := t.(*types.Map); ok { - return true - } - if t, ok := t.(*types.Named); ok { - return IsDeeplyMap(t.Underlying()) - } - return false + return underlyingIs[*types.Map](t) } -// IsDeeplyPtr returns true if `t` is of pointer type, including -// transitively through Named types +// IsDeeplyPtr returns true if `t` is of pointer type, including transitively through named +// types and aliases, as well as type parameters whose type sets contain only pointer types. func IsDeeplyPtr(t types.Type) bool { - if _, ok := t.(*types.Pointer); ok { - return true - } - if t, ok := t.(*types.Named); ok { - return IsDeeplyPtr(t.Underlying()) - } - return false + return underlyingIs[*types.Pointer](t) } -// IsDeeplyChan returns true if `t` is of channel type, including -// transitively through Named types +// IsDeeplyChan returns true if `t` is of channel type, including transitively through named +// types and aliases, as well as type parameters whose type sets contain only channel types. func IsDeeplyChan(t types.Type) bool { - if _, ok := t.(*types.Chan); ok { - return true - } - if t, ok := t.(*types.Named); ok { - return IsDeeplyChan(t.Underlying()) - } - return false + return underlyingIs[*types.Chan](t) } -// AsDeeplyStruct returns underlying struct type if the type is struct type or a pointer to a struct type -// returns nil otherwise +// AsDeeplyStruct returns the underlying struct type if `typ` is a struct or a pointer to a +// named struct (resolving named types and aliases). Returns nil otherwise. +// Note: pointer-to-anonymous-struct is intentionally excluded — the struct-init analyzer does +// not yet handle anonymous struct initialization. func AsDeeplyStruct(typ types.Type) *types.Struct { - if typ, ok := typ.(*types.Struct); ok { - return typ - } - - if typ, ok := typ.(*types.Named); ok { - if resType, ok := typ.Underlying().(*types.Struct); ok { - return resType - } + if s, ok := typ.Underlying().(*types.Struct); ok { + return s } - - if ptType, ok := typ.(*types.Pointer); ok { - if namedType, ok := types.Unalias(ptType.Elem()).(*types.Named); ok { - if resType, ok := namedType.Underlying().(*types.Struct); ok { - return resType + if ptr, ok := types.Unalias(typ).(*types.Pointer); ok { + if named, ok := types.Unalias(ptr.Elem()).(*types.Named); ok { + if s, ok := named.Underlying().(*types.Struct); ok { + return s } } } return nil } -// IsDeeplyInterface returns true if `t` is of struct type, including -// transitively through Named types +// IsDeeplyInterface returns true if `t` is of interface type, including transitively through +// named types and aliases, as well as type parameters whose type sets contain only interface types. func IsDeeplyInterface(t types.Type) bool { - if _, ok := t.(*types.Interface); ok { - return true - } - if t, ok := t.(*types.Named); ok { - return IsDeeplyInterface(t.Underlying()) - } - return false + return underlyingIs[*types.Interface](t) } // IsPointer checks whether the type `t` is an explicit or implicit pointer type, which could also be of deep type. diff --git a/util/typeshelper/typeshelper_test.go b/util/typeshelper/typeshelper_test.go index 31c9b06a..fe50a988 100644 --- a/util/typeshelper/typeshelper_test.go +++ b/util/typeshelper/typeshelper_test.go @@ -15,6 +15,8 @@ package typeshelper import ( + "go/ast" + "go/parser" "go/token" "go/types" "testing" @@ -22,6 +24,93 @@ import ( "github.com/stretchr/testify/require" ) +func TestIsDeeplyArray(t *testing.T) { + t.Parallel() + + const src = `package testpkg + +type NamedArray [8]int +type NamedArray2 NamedArray +type AliasArray = [8]int +type NamedArrayPtr *[8]int +type ArrayConstraint interface{ ~[8]int } +type ArrayConstraint16 interface{ ~[16]int } + +var ( + Array [8]int + Slice []int + NamedArr NamedArray + NamedArr2 NamedArray2 + AliasArr AliasArray + Ptr *[8]int + NamedPtr NamedArrayPtr + Int int + PtrToSlice *[]int +) + +func Generic[A ~[8]int, E ArrayConstraint, U ~[8]int | ~[16]int, X ~[8]int | ~[]int, S ~[]int, M any, IU ArrayConstraint | ArrayConstraint16, IX ArrayConstraint | ~[]int, P ~*[8]int](a A, e E, u U, x X, s S, m M, iu IU, ix IX, p P) {} +` + + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, "test.go", src, 0) + require.NoError(t, err) + pkg, err := (&types.Config{}).Check("testpkg", fset, []*ast.File{f}, nil) + require.NoError(t, err) + + typeOf := func(name string) types.Type { + obj := pkg.Scope().Lookup(name) + require.NotNil(t, obj, "object %q not found", name) + return obj.Type() + } + params := pkg.Scope().Lookup("Generic").(*types.Func).Signature().Params() + typeParamOf := func(name string) types.Type { + for i := 0; i < params.Len(); i++ { + if p := params.At(i); p.Name() == name { + return p.Type() + } + } + require.Failf(t, "parameter not found", "parameter %q", name) + return nil + } + + tests := []struct { + name string + typ types.Type + wantArray bool + wantOrArrayPtr bool + }{ + {"Nil", nil, false, false}, + {"Array", typeOf("Array"), true, true}, + {"Slice", typeOf("Slice"), false, false}, + {"NamedArray", typeOf("NamedArr"), true, true}, + {"NamedArrayOfNamedArray", typeOf("NamedArr2"), true, true}, + {"AliasArray", typeOf("AliasArr"), true, true}, + {"PtrToArray", typeOf("Ptr"), false, true}, + {"NamedPtrToArray", typeOf("NamedPtr"), false, true}, + {"Int", typeOf("Int"), false, false}, + {"PtrToSlice", typeOf("PtrToSlice"), false, false}, + {"TypeParamArray", typeParamOf("a"), true, true}, + {"TypeParamEmbeddedArrayConstraint", typeParamOf("e"), true, true}, + {"TypeParamArrayUnion", typeParamOf("u"), true, true}, + {"TypeParamMixedUnion", typeParamOf("x"), false, false}, + {"TypeParamSlice", typeParamOf("s"), false, false}, + {"TypeParamAny", typeParamOf("m"), false, false}, + // Unions whose terms are themselves (method-less) interfaces are not flattened by + // go/types, so constraintTerms must recurse into them. + {"TypeParamInterfaceUnionArrays", typeParamOf("iu"), true, true}, + {"TypeParamInterfaceUnionMixed", typeParamOf("ix"), false, false}, + {"TypeParamPtrToArray", typeParamOf("p"), false, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tt.wantArray, IsDeeplyArray(tt.typ), "IsDeeplyArray(%v)", tt.typ) + require.Equal(t, tt.wantOrArrayPtr, IsDeeplyArrayOrArrayPtr(tt.typ), "IsDeeplyArrayOrArrayPtr(%v)", tt.typ) + }) + } +} + func TestIsIterType(t *testing.T) { t.Parallel()