Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions hook/hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"regexp"

"go.uber.org/nilaway/util/analysishelper"
"go.uber.org/nilaway/util/asthelper"
"go.uber.org/nilaway/util/typeshelper"
)

Expand Down Expand Up @@ -54,18 +55,24 @@ type trustedSig struct {
// - _func: match enclosing "<pkg path>". E.g., for `assert.Error(err)`, path = github.com/stretchr/testify/assert
// - _method: match "<pkg path>.<struct name>". E.g., for `u.Require().Error(err)`, path = github.com/stretchr/testify/require.Assertions
//
// Functions of dot-imported packages are also matched: for `Error(err)` with
// `import . "github.com/stretchr/testify/assert"`, the called identifier still resolves to the
// function object in the imported package, so the path matching is identical. Function values
// (e.g., `f := assert.Error; f(t, err)`) resolve to variables rather than functions and are
// hence never matched.
//
// Trusted package-level variables (kind _var) are matched separately via matchSel, since they are
// read as bare selectors rather than calls.
func (t *trustedSig) matchCall(pass *analysishelper.EnhancedPass, call *ast.CallExpr) bool {
if t.kind != _func && t.kind != _method {
return false
}
sel, ok := call.Fun.(*ast.SelectorExpr)
if !ok || !t.nameRegex.MatchString(sel.Sel.Name) {
ident := asthelper.FuncIdentFromCallExpr(call)
if ident == nil || !token.IsExported(ident.Name) || !t.nameRegex.MatchString(ident.Name) {
return false
}

funcObj, ok := pass.TypesInfo.ObjectOf(sel.Sel).(*types.Func)
funcObj, ok := pass.TypesInfo.ObjectOf(ident).(*types.Func)
if !ok || funcObj.Pkg() == nil {
return false
}
Expand Down Expand Up @@ -99,6 +106,10 @@ func (t *trustedSig) matchCall(pass *analysishelper.EnhancedPass, call *ast.Call
// This is intentionally independent of how the variable is later used: a read like `os.Stdout`,
// `os.Stdout.Write(...)` (as a method receiver), or `os.Args[0]` (as an index operand) all parse the
// bare selector as a producer, so all are covered here without involving matchCall.
//
// Unlike matchCall, dot-imported variable reads (a bare `Stdout` under `import . "os"`) are
// intentionally out of scope: they parse as bare identifiers, which are routed elsewhere by the
// assertion tree and never reach this matcher.
func (t *trustedSig) matchSel(pass *analysishelper.EnhancedPass, sel *ast.SelectorExpr) bool {
if t.kind != _var || !t.nameRegex.MatchString(sel.Sel.Name) {
return false
Expand Down
35 changes: 35 additions & 0 deletions hook/replace_conditional.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,28 @@ var _errorAsAction replaceConditionalAction = func(_ *analysishelper.EnhancedPas
}
}

// _assertConditionalAction replaces bool-returning testify assertions used as conditionals, e.g.,
// `if assert.NoError(t, err) {...}`, with `<call> && <implied expr>` (here:
// `assert.NoError(t, err) && err == nil`). The implied expression is the same one `_splitBlockOn`
// derives for the call in statement position, so the per-assertion semantics (including the
// asserted argument's position) are defined only there. The assertion returns true iff it passed,
// so the implied expression holds in the then-branch; the else-branch gains no information from
// the conjunction, which is conservative. Note that, unlike the statement-position modeling in
// `_splitBlockOn`, no assumption about test termination is involved, so this is sound for the
// non-fatal `assert` package as well.
var _assertConditionalAction replaceConditionalAction = func(pass *analysishelper.EnhancedPass, call *ast.CallExpr) ast.Expr {
implied := SplitBlockOn(pass, call)
if implied == nil {
return nil
}
return &ast.BinaryExpr{
X: call,
Op: token.LAND,
OpPos: call.Pos(),
Y: implied,
}
}

var _replaceConditionals = map[trustedSig]replaceConditionalAction{
{
kind: _func,
Expand All @@ -80,4 +102,17 @@ var _replaceConditionals = map[trustedSig]replaceConditionalAction{
enclosingRegex: regexp.MustCompile(`^(stubs/)?github\.com/cockroachdb/errors$`),
nameRegex: regexp.MustCompile(`^As$`),
}: _errorAsAction,

// Bool-returning testify assertions used as conditionals. `require` is absent since its
// functions do not return values and hence cannot appear in a conditional.
{
kind: _func,
enclosingRegex: regexp.MustCompile(`^(stubs/)?github\.com/stretchr/testify/assert$`),
nameRegex: regexp.MustCompile(`^(Nil(f)?|NotNil(f)?|NoError(f)?|Error(f)?|ErrorContains(f)?|EqualError(f)?|True(f)?|False(f)?)$`),
}: _assertConditionalAction,
{
kind: _method,
enclosingRegex: regexp.MustCompile(`^(stubs/)?github\.com/stretchr/testify/(suite\.Suite|assert\.Assertions)$`),
nameRegex: regexp.MustCompile(`^(Nil(f)?|NotNil(f)?|NoError(f)?|Error(f)?|ErrorContains(f)?|EqualError(f)?|True(f)?|False(f)?)$`),
}: _assertConditionalAction,
}
127 changes: 122 additions & 5 deletions hook/split_blocks_on.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"regexp"

"go.uber.org/nilaway/util/analysishelper"
"go.uber.org/nilaway/util/asthelper"
"go.uber.org/nilaway/util/typeshelper"
)

Expand Down Expand Up @@ -80,6 +81,93 @@ var negatedSelfExpr splitBlockOnAction = func(_ *analysishelper.EnhancedPass, ca
}
}

// boolOrErrorExpr handles assertion arguments declared as `interface{}`, e.g., gotest.tools'
// `assert.Assert(t, comparison BoolOrComparison)`, whose argument may be:
// - a boolean expression, e.g., `assert.Assert(t, x != nil)`: behaves like selfExpr;
// - an error value, e.g., `assert.Assert(t, err)`: behaves like nilBinaryExpr, since a nil error
// means success while a non-nil error fails the assertion. Only interface types (`error`
// itself or interfaces embedding it) qualify: a nil value of such a type stays nil when passed
// as `interface{}`, whereas a concrete error type would be wrapped in a non-nil interface and
// always fail the assertion (even a typed-nil pointer);
// - anything else (e.g., a `cmp.Comparison` closure): no narrowing is applied.
var boolOrErrorExpr splitBlockOnAction = func(pass *analysishelper.EnhancedPass, call *ast.CallExpr, argIndex int) ast.Expr {
if argIndex < 0 || argIndex >= len(call.Args) {
return nil
}
if isBoolExpr(pass, call.Args[argIndex]) {
return selfExpr(pass, call, argIndex)
}
t := pass.TypesInfo.TypeOf(call.Args[argIndex])
if t == nil {
return nil
}
if _, ok := t.Underlying().(*types.Interface); ok && typeshelper.ImplementsError(t) {
return nilBinaryExpr(pass, call, argIndex)
}
return nil
}

// _goconveyAssertions matches the package paths where goconvey's `Should*` assertions are
// defined: the `convey` package itself (which re-exports them as package-level variables, e.g.,
// `var ShouldBeNil = assertions.ShouldBeNil`) and the underlying assertions package (both its
// current `smarty` and historical `smartystreets` homes, for users importing it directly).
var _goconveyAssertions = regexp.MustCompile(`^(stubs/)?(github\.com/smartystreets/goconvey/convey|github\.com/smarty(streets)?/assertions)$`)

// goconveySoExpr handles goconvey's `So(actual, assertion, expected...)`, where the narrowing fact is
// determined by the assertion argument rather than the called function: e.g.,
// `So(err, ShouldBeNil)` implies `err == nil` afterwards. The assertion argument is resolved to
// its package-level object (a var re-exported by `convey`, or a function of the assertions
// package), and only the nilability-relevant assertions are modeled; any other assertion (or a
// locally-defined custom one) yields no narrowing.
var goconveySoExpr splitBlockOnAction = func(pass *analysishelper.EnhancedPass, call *ast.CallExpr, argIndex int) ast.Expr {
// The assertion argument sits right after the actual expression.
if argIndex+1 >= len(call.Args) {
return nil
}
var ident *ast.Ident
switch assert := call.Args[argIndex+1].(type) {
case *ast.Ident:
ident = assert
case *ast.SelectorExpr:
ident = assert.Sel
default:
return nil
}
obj := pass.TypesInfo.ObjectOf(ident)
if obj == nil || obj.Pkg() == nil || !_goconveyAssertions.MatchString(obj.Pkg().Path()) {
return nil
}

// For the boolean assertions, the actual argument is declared `interface{}`; only narrow
// when it is statically a boolean expression (anything else fails the assertion at runtime
// anyway).
switch obj.Name() {
case "ShouldBeNil":
return nilBinaryExpr(pass, call, argIndex)
case "ShouldNotBeNil", "ShouldBeError":
return nonnilBinaryExpr(pass, call, argIndex)
case "ShouldBeTrue":
if isBoolExpr(pass, call.Args[argIndex]) {
return selfExpr(pass, call, argIndex)
}
case "ShouldBeFalse":
if isBoolExpr(pass, call.Args[argIndex]) {
return negatedSelfExpr(pass, call, argIndex)
}
}
return nil
}

// isBoolExpr reports whether the expression is statically of boolean type.
func isBoolExpr(pass *analysishelper.EnhancedPass, expr ast.Expr) bool {
t := pass.TypesInfo.TypeOf(expr)
if t == nil {
return false
}
basic, ok := t.Underlying().(*types.Basic)
return ok && basic.Kind() == types.Bool
}

// The constant (enum) values below represent the possible values of an expected expression in a comparison
// E.g., `Equal(1, len(s))`, where `1` is the expected expression and is assigned the value `_greaterThanZero`.
// E.g., `Equal(nil, err)`, where `nil` is the expected expression and is assigned the value `_nil`.
Expand Down Expand Up @@ -196,11 +284,11 @@ var requireZeroComparators splitBlockOnAction = func(pass *analysishelper.Enhanc

// generateComparators generates comparators based on the semantics of the function.
func generateComparators(call *ast.CallExpr, actualExpr ast.Expr, actualExprIndex int, expectedVal expectedValue) ast.Expr {
sel, ok := call.Fun.(*ast.SelectorExpr)
if !ok {
ident := asthelper.FuncIdentFromCallExpr(call)
if ident == nil {
return nil
}
funcName := sel.Sel.Name
funcName := ident.Name

// Now, based on the semantics of the function, we can create artificial nonnil checks for
// the following cases.
Expand Down Expand Up @@ -299,7 +387,7 @@ var _splitBlockOn = map[trustedSig]struct {
{
kind: _method,
enclosingRegex: regexp.MustCompile(`^(stubs/)?github\.com/stretchr/testify/(suite\.Suite|assert\.Assertions|require\.Assertions)$`),
nameRegex: regexp.MustCompile(`^(NotNil(f)?|Error(f)?)$`),
nameRegex: regexp.MustCompile(`^(NotNil(f)?|Error(f)?|ErrorContains(f)?|EqualError(f)?)$`),
}: {action: nonnilBinaryExpr, argIndex: 0},
{
kind: _method,
Expand Down Expand Up @@ -331,7 +419,7 @@ var _splitBlockOn = map[trustedSig]struct {
{
kind: _func,
enclosingRegex: regexp.MustCompile(`^(stubs/)?github\.com/stretchr/testify/(assert|require)$`),
nameRegex: regexp.MustCompile(`^(NotNil(f)?|Error(f)?)$`),
nameRegex: regexp.MustCompile(`^(NotNil(f)?|Error(f)?|ErrorContains(f)?|EqualError(f)?)$`),
}: {action: nonnilBinaryExpr, argIndex: 1},
{
kind: _func,
Expand Down Expand Up @@ -363,4 +451,33 @@ var _splitBlockOn = map[trustedSig]struct {
enclosingRegex: regexp.MustCompile(`^(stubs/)?github\.com/stretchr/testify/(suite\.Suite|assert\.Assertions|require\.Assertions)$`),
nameRegex: regexp.MustCompile(`^(Empty(f)?|NotEmpty(f)?)$`),
}: {action: requireZeroComparators, argIndex: 0},

// `gotest.tools/v3/assert`, as well as its legacy v1/v2 form `gotest.tools/assert` with
// identical semantics. Note that `ErrorIs` is deliberately NOT modeled with nonnil narrowing:
// `errors.Is(nil, nil)` is true, so `assert.ErrorIs(t, err, nil)` can pass with a nil error.
{
kind: _func,
enclosingRegex: regexp.MustCompile(`^(stubs/)?gotest\.tools(/v3)?/assert$`),
nameRegex: regexp.MustCompile(`^NilError$`),
}: {action: nilBinaryExpr, argIndex: 1},
{
kind: _func,
enclosingRegex: regexp.MustCompile(`^(stubs/)?gotest\.tools(/v3)?/assert$`),
nameRegex: regexp.MustCompile(`^(Error|ErrorContains)$`),
}: {action: nonnilBinaryExpr, argIndex: 1},
{
kind: _func,
enclosingRegex: regexp.MustCompile(`^(stubs/)?gotest\.tools(/v3)?/assert$`),
nameRegex: regexp.MustCompile(`^Assert$`),
}: {action: boolOrErrorExpr, argIndex: 1},

// `github.com/smartystreets/goconvey/convey`, which is typically dot-imported. Under the
// default `FailureHalts` mode, a failed `So` panics and is recovered by the `Convey` runner,
// halting the enclosing scope; the opt-in `FailureContinues` mode is over-approximated the
// same way as testify's non-fatal `assert` (see the comment on this table).
{
kind: _func,
enclosingRegex: regexp.MustCompile(`^(stubs/)?github\.com/smartystreets/goconvey/convey$`),
nameRegex: regexp.MustCompile(`^So$`),
}: {action: goconveySoExpr, argIndex: 0},
}
81 changes: 81 additions & 0 deletions testdata/src/go.uber.org/trustedfunc/conditional.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// Copyright (c) 2026 Uber Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package trustedfunc

import (
"testing"

"stubs/github.com/stretchr/testify/assert"
)

// testAssertConditional tests bool-returning testify assertions used as conditionals, e.g.,
// `if assert.NoError(t, err) {...}`: the assertion returns true iff it passed, so the implied
// fact holds inside the then-branch (with no assumption about test termination), while the
// else-branch and the code after the conditional gain no information.
//
// nilable(x)
func testAssertConditional(t *testing.T, x any, a *assert.Assertions) any {
switch 0 {
case 1:
y, err := errs()
if assert.NoError(t, err) {
return y
}
case 2:
// No narrowing survives past the conditional.
y, err := errs()
if assert.NoError(t, err) {
consume(y)
}
return y //want "returned"
case 3:
_, err := errs()
if assert.Error(t, err) {
takesNonnil(err)
}
case 4:
if assert.NotNil(t, x) {
takesNonnil(x)
}
takesNonnil(x) //want "passed"
case 5:
// The narrowing direction for `Nil` is nil, so x must not be treated as nonnil.
if assert.Nil(t, x) {
takesNonnil(x) //want "passed"
}
case 6:
if assert.True(t, x != nil) {
return x
}
case 7:
if assert.False(t, x == nil) {
return x
}
case 8:
// Method form on `assert.Assertions`.
y, err := errs()
if a.NoError(err) {
return y
}
case 9:
// The `ok := ...; if ok` form is recognized as well.
y, err := errs()
ok := assert.NoError(t, err)
if ok {
return y
}
}
return 0
}
Loading