From 2ed5bac364469abd87379ef9f81a5dfeb4dd88ff Mon Sep 17 00:00:00 2001 From: Tom van der Woerdt Date: Fri, 12 Jun 2026 14:09:30 -0400 Subject: [PATCH] Re-insert type switch case conditions into the CFG The CFG builder (golang.org/x/tools/go/cfg) represents each case type of a type switch as an empty two-successor branching block, with no expression for the implicit per-case type assertion. NilAway re-inserts the lost conditions for value switches (markSwitchStatements) but had nothing equivalent for type switches, so `case nil` arms were not recognized as nil guards: ptr, err := f() switch err.(type) { case nil: _ = *ptr // FP: err is provably nil here, so ptr is non-nil } Real-world hits include gotest.tools/assert.NilError (reported inside gotest.tools, where it cannot be suppressed) and protoreflect.ValueOf(nil). The new markTypeSwitchStatements pass synthesizes `operand == nil` / `operand != nil` conditions for nil and interface-typed case arms, and replicates the binding assignment (`x := v.(type)`) into each case body so that assertions on `x` transfer to `v` while per-branch knowledge is available; see its doc comment for the full semantics and soundness reasoning. Arms are classified via the type checker (EnhancedPass.IsNil now consults TypesInfo before the name-based fallback needed for synthesized identifiers), which also hardens source-level nil checks against a user-shadowed `nil`. PR #366 previously attempted this by matching `(x.(type)) == nil`, but no such conditions are ever synthesized, so its testdata passed vacuously. That dead special case is removed and the testdata made genuinely fallible, then extended with soundness tests: default arms, multi-type clauses, concrete and type-parameter arms, parenthesized `case (nil)`, and a shadowed `nil` type (package shadownil). On the stdlib golden test this removes 18 errors (database/sql, go/printer, go/types, html/template, text/template) and adds 7, for a net 2269 -> 2258. The additions are the same go/printer nil sources re-surfacing at a different dereference point: the interface-typed arm (ast.Expr) of printNode's type switch is now discharged, so the flow through the concrete-typed *ast.File arm is reported instead. --- assertion/function/assertiontree/backprop.go | 5 + .../assertiontree/rich_check_effect.go | 15 +- assertion/function/preprocess/cfg.go | 195 +++++++++++++++++- nilaway_test.go | 2 +- .../typeswitch/shadownil/shadownil.go | 26 +++ .../errorreturn/typeswitch/typeswitch.go | 140 +++++++++++++ util/analysishelper/pass.go | 17 +- 7 files changed, 369 insertions(+), 31 deletions(-) create mode 100644 testdata/src/go.uber.org/errorreturn/typeswitch/shadownil/shadownil.go diff --git a/assertion/function/assertiontree/backprop.go b/assertion/function/assertiontree/backprop.go index 07f3c2a0..5cc36bc8 100644 --- a/assertion/function/assertiontree/backprop.go +++ b/assertion/function/assertiontree/backprop.go @@ -532,6 +532,11 @@ func backpropAcrossRange(rootNode *RootAssertionNode, lhs []ast.Expr, rhs ast.Ex // TypesInfo. Uses will give a fresh `types.Var` at every usage site. This is why we have to // inspect the assertion tree for any variables that match the symbolic type switch variable // without being able to compare the identity of `types.Var` instances as we usually do. +// +// NOTE: the preprocessor replicates the type-switch Assign node into every case body (see +// preprocess.markTypeSwitchStatements), so this function can process the same assignment +// multiple times along a single path. That replication relies on this function being +// idempotent: occurrences that find no assertions on the bound variable must remain no-ops. // nonnil(lhs, rhs) func backpropAcrossTypeSwitch(rootNode *RootAssertionNode, lhs *ast.Ident, rhs ast.Expr) error { // First, make a copy of the children array to iterate over, as we will mutate it. diff --git a/assertion/function/assertiontree/rich_check_effect.go b/assertion/function/assertiontree/rich_check_effect.go index 6b74ed5d..afe75eb2 100644 --- a/assertion/function/assertiontree/rich_check_effect.go +++ b/assertion/function/assertiontree/rich_check_effect.go @@ -542,17 +542,10 @@ func nodeAssignsAny(rootNode *RootAssertionNode, node ast.Node, exprs ...Trackab // variable `checksVar`. Note that because of preprocessing done in `restructureBlock` from // `preprocess_blocks.go`, this suffices to handle cases such as `nil != checksVar` as well. func exprIsPositiveNilCheck(rootNode *RootAssertionNode, expr ast.Expr, checksExpr TrackableExpr) bool { - if binExpr, ok := expr.(*ast.BinaryExpr); ok && binExpr.Op == token.EQL && asthelper.IsLiteral(binExpr.Y, "nil") { - // Standard case: X == nil - if exprMatchesTrackableExpr(rootNode, binExpr.X, checksExpr) { - return true - } - // Special case: type-switch guard rewritten as "(x.(type)) == nil". - // In such cases, the BinaryExpr.X will be a *ast.TypeAssertExpr whose Type is nil, - // and we should treat it as if we are checking "x == nil". - if ta, ok := binExpr.X.(*ast.TypeAssertExpr); ok && ta.Type == nil { - return exprMatchesTrackableExpr(rootNode, ta.X, checksExpr) - } + if binExpr, ok := expr.(*ast.BinaryExpr); ok && binExpr.Op == token.EQL && rootNode.Pass().IsNil(binExpr.Y) { + // Note that this also covers `case nil` arms of type switches: the preprocessor + // synthesizes a canonical `x == nil` condition for them (see markTypeSwitchStatements). + return exprMatchesTrackableExpr(rootNode, binExpr.X, checksExpr) } return false } diff --git a/assertion/function/preprocess/cfg.go b/assertion/function/preprocess/cfg.go index bde0598a..e3e4ea9a 100644 --- a/assertion/function/preprocess/cfg.go +++ b/assertion/function/preprocess/cfg.go @@ -18,6 +18,7 @@ import ( "fmt" "go/ast" "go/token" + "go/types" "go.uber.org/nilaway/hook" "go.uber.org/nilaway/util/asthelper" @@ -87,13 +88,15 @@ func (p *Preprocessor) CFG(graph *cfg.CFG, funcDecl *ast.FuncDecl) *cfg.CFG { } } - // Next, we need to re-insert information that is lost during CFG build for *ast.RangeStmt - // and *ast.SwitchStmt by iterating through all blocks. This requires knowing the links between - // the nodes contained within a block to their parents (*ast.RangeStmt or *ast.SwitchStmt nodes). + // Next, we need to re-insert information that is lost during CFG build for *ast.RangeStmt, + // *ast.SwitchStmt, and *ast.TypeSwitchStmt by iterating through all blocks. This requires + // knowing the links between the nodes contained within a block to their parents + // (*ast.RangeStmt, *ast.SwitchStmt, or *ast.TypeSwitchStmt nodes). // So, here establish the link and then do the work. - rangeChildren, switchChildren := collectChildren(funcDecl) + rangeChildren, switchChildren, typeSwitchChildren := collectChildren(funcDecl) markRangeStatements(graph, rangeChildren) markSwitchStatements(graph, switchChildren) + p.markTypeSwitchStatements(graph, typeSwitchChildren) // Please check the docstring of the following call to see why this is needed. // TODO: remove this once anonymous function support handles it naturally. @@ -414,12 +417,13 @@ func (p *Preprocessor) canonicalizeConditional(graph *cfg.CFG, thisBlock *cfg.Bl } } -// collectChildren establishes the links between the range / switch statement nodes and their child -// nodes. This is specifically designed for our preprocess function: when we rewrite the CFG to -// re-insert the lost information, we need to know if a block in CFG belongs to a certain range -// statement or switch statement AST node for retrieving lost information. -func collectChildren(funcDecl *ast.FuncDecl) (map[ast.Node]*ast.RangeStmt, map[ast.Node]*ast.SwitchStmt) { +// collectChildren establishes the links between the range / switch / type switch statement nodes +// and their child nodes. This is specifically designed for our preprocess function: when we +// rewrite the CFG to re-insert the lost information, we need to know if a block in CFG belongs to +// a certain range statement or (type) switch statement AST node for retrieving lost information. +func collectChildren(funcDecl *ast.FuncDecl) (map[ast.Node]*ast.RangeStmt, map[ast.Node]*ast.SwitchStmt, map[ast.Node]*ast.TypeSwitchStmt) { rangeChildren, switchChildren := make(map[ast.Node]*ast.RangeStmt), make(map[ast.Node]*ast.SwitchStmt) + typeSwitchChildren := make(map[ast.Node]*ast.TypeSwitchStmt) ast.Inspect(funcDecl, func(node ast.Node) bool { switch n := node.(type) { @@ -440,11 +444,16 @@ func collectChildren(funcDecl *ast.FuncDecl) (map[ast.Node]*ast.RangeStmt, map[a switchChildren[n.Tag] = n } switchChildren[n.Body] = n + case *ast.TypeSwitchStmt: + // The Assign field (either `x := v.(type)` or `v.(type)`) is the node that the CFG + // builder places in the branching block of the type switch, so it is the key we use to + // rediscover type switches in the CFG. + typeSwitchChildren[n.Assign] = n } return true }) - return rangeChildren, switchChildren + return rangeChildren, switchChildren, typeSwitchChildren } // markRangeStatements rewrites a cfg to reflect ranging loops - the assignments in a `for... range y {}` @@ -597,3 +606,169 @@ func markSwitchStatements(graph *cfg.CFG, switchChildren map[ast.Node]*ast.Switc } } } + +// markTypeSwitchStatements rewrites a CFG to re-insert the branch conditions of type switches. +// +// The CFG builder (golang.org/x/tools/go/cfg) represents a type switch +// +// switch x := v.(type) { +// case nil: ... +// case T1, T2: ... +// default: ... +// } +// +// as a chain of two-successor branching blocks -- one per case *type* -- where the first +// branching block ends with the Assign node (`x := v.(type)` or `v.(type)`) and the subsequent +// ones are empty, since the builder has no expression to represent the implicit type assertion +// performed by each case. This loses the nilability information carried by the case conditions. +// We re-insert that information by appending a synthesized condition to each branching block: +// +// - `case nil` arms get `v == nil`: a nil interface matches exactly this arm, and, conversely, +// any value reaching the false branch is non-nil. +// - case arms listing an *interface* type (excluding type parameters, see typeSwitchCaseCond) +// get `v != nil`: a nil interface value has no dynamic type, so it can match only `case nil` +// (or `default`); matching an interface type therefore guarantees the operand is non-nil. +// The false branch carries no information. Note that this would NOT be sound for the *bound +// variable* of concrete-typed arms (e.g., `case *int`), where the dynamic value can still be +// a typed nil pointer; since assertions on the bound variable are transferred to the operand +// `v` (see `backpropAcrossTypeSwitch`), we synthesize no condition for non-interface arms. +// +// Additionally, for type switches with a binding (`x := v.(type)`), the Assign node is +// replicated at the top of every case body. The CFG builder only places it once, *before* the +// branching: during backpropagation, assertions on the bound variable `x` would then only be +// transferred to the operand `v` (see `backpropAcrossTypeSwitch`) after the case branches have +// already merged, which is too late for the synthesized branch conditions above to act on them. +// Replicating the assignment into each case body mirrors the actual semantics -- each body +// begins with an implicit `x := v.(T)` -- and ensures the transfer happens while the per-branch +// knowledge is still available (this relies on `backpropAcrossTypeSwitch` being idempotent; see +// the note there). +// +// This allows the existing conditional-processing logic (`AddNilCheck` and the rich check +// effects, e.g., for `ptr, err := f(); switch err.(type) { case nil: ... }`) to handle type +// switches just like their `if err == nil` equivalents. +func (p *Preprocessor) markTypeSwitchStatements(graph *cfg.CFG, typeSwitchChildren map[ast.Node]*ast.TypeSwitchStmt) { + if len(typeSwitchChildren) == 0 { + return + } + for _, block := range graph.Blocks { + n := len(block.Nodes) + if n < 1 || len(block.Succs) != 2 { + continue + } + if stmt := typeSwitchChildren[block.Nodes[n-1]]; stmt != nil { + p.markTypeSwitch(block, stmt) + } + } +} + +// markTypeSwitch re-inserts the synthesized conditions (and replicated binding assignments) for +// a single type switch whose first branching block is `block`. See markTypeSwitchStatements for +// the semantics. +func (p *Preprocessor) markTypeSwitch(block *cfg.Block, stmt *ast.TypeSwitchStmt) { + operand := typeSwitchOperand(stmt) + if operand == nil { + return + } + // Only type switches with a binding need the Assign replicated into the case bodies. + bindingAssign, _ := stmt.Assign.(*ast.AssignStmt) + + // Walk the chain of branching blocks -- one per case type, in source order (mirroring the + // CFG builder) -- and append the synthesized conditions. + hasDefault := false + caseBlock := block + for _, clause := range stmt.Body.List { + cc := clause.(*ast.CaseClause) + if cc.List == nil { + // `default` clauses do not get their own branching block. + hasDefault = true + continue + } + for i, caseType := range cc.List { + if caseBlock != block && (len(caseBlock.Nodes) != 0 || len(caseBlock.Succs) != 2) { + // The CFG does not have the shape we expect (e.g., due to dead code + // elimination). Conservatively stop processing this type switch: this loses + // precision but never soundness. + return + } + if cond := p.typeSwitchCaseCond(operand, caseType); cond != nil { + caseBlock.Nodes = append(caseBlock.Nodes, cond) + } + // The true branch is the case body, which is shared by all case types of this + // clause (and only by them), so the binding only needs to be replicated once. + if bindingAssign != nil && i == 0 { + body := caseBlock.Succs[0] + body.Nodes = append([]ast.Node{bindingAssign}, body.Nodes...) + } + caseBlock = caseBlock.Succs[1] + } + } + // The false branch of the last case type is the default body, if one exists (the CFG + // builder always emits the default body there, regardless of its position in the source). + if bindingAssign != nil && hasDefault { + caseBlock.Nodes = append([]ast.Node{bindingAssign}, caseBlock.Nodes...) + } +} + +// typeSwitchOperand extracts the switched-on expression `v` from the Assign node of a type +// switch (`x := v.(type)` or a bare `v.(type)`), returning nil if the shape is unexpected. +func typeSwitchOperand(stmt *ast.TypeSwitchStmt) ast.Expr { + var expr ast.Expr + switch assign := stmt.Assign.(type) { + case *ast.AssignStmt: + if len(assign.Rhs) != 1 { + return nil + } + expr = assign.Rhs[0] + case *ast.ExprStmt: + expr = assign.X + default: + return nil + } + if typeAssert, ok := ast.Unparen(expr).(*ast.TypeAssertExpr); ok && typeAssert.Type == nil { + return typeAssert.X + } + return nil +} + +// typeSwitchCaseCond synthesizes the branch condition for a single case type of a type switch +// (see markTypeSwitchStatements for the semantics), or nil if the case type carries no nilability +// information about the operand. +// +// Note that the `operand != nil` conditions deliberately keep the non-canonical NEQ form, even +// though the preprocessor otherwise canonicalizes conditions to `x == nil` with swapped +// successors (see canonicalizeConditional): consumers of positive nil checks (e.g., the rich +// check effects via exprIsPositiveNilCheck) attach "operand is provably nil" semantics to the +// true branch of an `x == nil` condition, which would be unsound here -- *failing* to match an +// interface case does not imply the operand is nil. AddNilCheck handles the NEQ form directly, +// with a no-op false branch. +func (p *Preprocessor) typeSwitchCaseCond(operand, caseType ast.Expr) *ast.BinaryExpr { + // Note that pass.IsNil identifies `case nil` arms via the type checker, so a parenthesized + // `case (nil)` is handled, and a user-shadowed type named "nil" is correctly classified as a + // concrete arm instead. + if p.pass.IsNil(caseType) { + // We synthesize a bare `nil` identifier (rather than reusing caseType, which may be + // parenthesized) so that downstream consumers matching the canonical `x == nil` form + // recognize the condition. + return &ast.BinaryExpr{ + X: operand, + OpPos: caseType.Pos(), + Op: token.EQL, + Y: &ast.Ident{NamePos: caseType.Pos(), Name: "nil"}, + } + } + if t := p.pass.TypesInfo.TypeOf(caseType); t != nil && types.IsInterface(t) { + // A type parameter is an interface to the type system, but a value matching it can + // still be a typed nil pointer (e.g., T instantiated with *int) -- the same hazard as + // concrete-typed arms -- so it must not produce a non-nil condition. + if _, ok := types.Unalias(t).(*types.TypeParam); ok { + return nil + } + return &ast.BinaryExpr{ + X: operand, + OpPos: caseType.Pos(), + Op: token.NEQ, + Y: &ast.Ident{NamePos: caseType.Pos(), Name: "nil"}, + } + } + return nil +} diff --git a/nilaway_test.go b/nilaway_test.go index 56c83bfe..57b3c509 100644 --- a/nilaway_test.go +++ b/nilaway_test.go @@ -47,7 +47,7 @@ func TestNilAway(t *testing.T) { {name: "Inference", patterns: []string{"go.uber.org/inference"}}, {name: "Contracts", patterns: []string{"go.uber.org/contracts", "go.uber.org/contracts/namedtypes", "go.uber.org/contracts/inference"}}, {name: "TrustedFunc", patterns: []string{"go.uber.org/trustedfunc", "go.uber.org/trustedfunc/inference"}}, - {name: "ErrorReturn", patterns: []string{"go.uber.org/errorreturn", "go.uber.org/errorreturn/inference"}}, + {name: "ErrorReturn", patterns: []string{"go.uber.org/errorreturn", "go.uber.org/errorreturn/inference", "go.uber.org/errorreturn/typeswitch", "go.uber.org/errorreturn/typeswitch/shadownil"}}, {name: "Maps", patterns: []string{"go.uber.org/maps"}}, {name: "Slices", patterns: []string{"go.uber.org/slices", "go.uber.org/slices/inference"}}, {name: "Arrays", patterns: []string{"go.uber.org/arrays"}}, diff --git a/testdata/src/go.uber.org/errorreturn/typeswitch/shadownil/shadownil.go b/testdata/src/go.uber.org/errorreturn/typeswitch/shadownil/shadownil.go new file mode 100644 index 00000000..f5c59dc8 --- /dev/null +++ b/testdata/src/go.uber.org/errorreturn/typeswitch/shadownil/shadownil.go @@ -0,0 +1,26 @@ +// Package shadownil tests that the `case nil` arm detection in type switches is not fooled by a +// user-declared type named `nil` -- the predeclared identifier is shadowable. Here `case nil` +// matches the concrete struct type, so the operand is non-nil in that arm and, conversely, can +// still be nil in the default arm. +package shadownil + +type stringer interface{ String() string } + +type nil struct{} + +func (nil) String() string { return "type named nil" } + +func deref(v stringer) { + switch v.(type) { + case nil: + // v's dynamic type is the struct `nil`, i.e., v is non-nil here. This must NOT be + // confused with a check for the nil value: the other arms remain unguarded. + default: + println(v.String()) //want "called `String" + } +} + +func useDeref() { + var s stringer + deref(s) +} diff --git a/testdata/src/go.uber.org/errorreturn/typeswitch/typeswitch.go b/testdata/src/go.uber.org/errorreturn/typeswitch/typeswitch.go index fe9caac9..a0a36cc3 100644 --- a/testdata/src/go.uber.org/errorreturn/typeswitch/typeswitch.go +++ b/testdata/src/go.uber.org/errorreturn/typeswitch/typeswitch.go @@ -1,8 +1,13 @@ package typeswitch +import "errors" + var dummy int func aa() (*int, error) { + if dummy == 1 { + return nil, errors.New("some error") + } return new(int), nil } @@ -14,3 +19,138 @@ func bb() { } } +// Same as bb, but the dereference happens in the default arm, where err can be non-nil and hence +// ptr can be nil: the diagnostic must be preserved. +func bbDefault() { + ptr, err := aa() + switch err.(type) { + case nil: + default: + _ = *ptr //want "dereferenced" + } +} + +// A `case nil` arm must still be recognized as a guard when an ok-form type assertion on an +// *unrelated* variable precedes the switch (the gotest.tools/assert.NilError shape). +func check(t interface{}, v interface{}) { + if _, ok := t.(interface{ Helper() }); ok { + println("helper") + } + switch x := v.(type) { + case nil: + return + case error: + // Unreachable with v == nil: a nil interface value has no dynamic type, so it can only + // match `case nil` above. + println(x.Error()) + } +} + +func returnsNilError() error { + return nil +} + +func useCheck() { + check(0, returnsNilError()) +} + +// A literal nil argument flowing into a type switch with a `case nil` arm must not be reported. +func checkLitNil(v interface{}) { + switch x := v.(type) { + case nil: + return + case error: + println(x.Error()) + } +} + +func useCheckLitNil() { + checkLitNil(nil) +} + +// Matching any *interface* case arm (not just `case nil`) guarantees the operand is non-nil: a +// nil interface value has no dynamic type, so it can only match `case nil` or `default`. Hence no +// diagnostic here even without a `case nil` arm. +func checkNoNilArm(v interface{}) { + switch x := v.(type) { + case error: + println(x.Error()) + } +} + +func useCheckNoNilArm() { + checkNoNilArm(nil) +} + +// In the default arm, an ok-form type assertion on the bound variable keeps its use safe: no +// diagnostic. +func checkDefaultArm(v interface{}) { + switch x := v.(type) { + case error: + println(x.Error()) + default: + if y, ok := x.(interface{ String() string }); ok { + println(y.String()) + } + } +} + +type stringer interface{ String() string } + +// In the default arm, the operand can still be nil, so calling a method on the bound variable +// must be reported. +func checkDefaultArmDeref(v stringer) { + switch x := v.(type) { + case error: + println(x.Error()) + default: + println(x.String()) //want "called `String" + } +} + +func useCheckDefaultArmDeref() { + checkDefaultArmDeref(nil) +} + +// A multi-type clause containing nil must not mark the bound variable (which keeps the operand's +// static type) non-nil inside the body, while a subsequent interface arm is still recognized as +// non-nil. This also pins the chain walk over clauses listing multiple case types. +func checkMultiTypeClause(v stringer) { + switch x := v.(type) { + case nil, error: + println(x.(error).Error()) + case interface{ Len() int }: + println(x.Len()) // safe: nil can only match the `case nil, error` clause above + } +} + +// A concrete (non-interface) case arm must not mark the *bound variable* non-nil: the dynamic +// value can be a typed nil pointer even though the interface itself is non-nil. +func checkConcreteArm(v interface{}) *int { + switch x := v.(type) { + case *int: + return x + } + return new(int) +} + +// A parenthesized `case (nil):` must be treated identically to `case nil:`. +func bbParenNil() { + ptr, err := aa() + switch err.(type) { + case (nil): + _ = *ptr // safe: err is nil in this case, so ptr must be non-nil + } +} + +// A type-parameter case arm must not be treated as an interface arm: a value matching `case T` +// can still be a typed nil pointer (e.g., T instantiated with *int), just like a concrete-typed +// arm. This currently produces no diagnostic either way (NilAway's generics support does not yet +// track typed-nil flows through instantiations); the test pins that no condition is synthesized +// and nothing crashes. +func checkTypeParamArm[T stringer](v interface{}) { + switch x := v.(type) { + case T: + println(x.String()) + } +} diff --git a/util/analysishelper/pass.go b/util/analysishelper/pass.go index 4798fd41..5d1a0515 100644 --- a/util/analysishelper/pass.go +++ b/util/analysishelper/pass.go @@ -67,17 +67,16 @@ func (p *EnhancedPass) ConstInt(expr ast.Expr) (int64, bool) { return intValue, true } -// IsNil checks if the given expression evaluates to untyped nil at compile time. It also treats -// the identifier `nil` as nil too to support cases where we have inserted a fake identifier. +// IsNil checks if the given expression evaluates to untyped nil at compile time. The type +// checker is consulted first, so that expressions involving a user-shadowed `nil` (e.g., a type +// or variable named "nil") or extra parentheses are classified correctly; identifiers named +// `nil` that are absent from the type information (the fake identifiers we synthesize during +// preprocessing) are treated as nil as well. func (p *EnhancedPass) IsNil(expr ast.Expr) bool { - if asthelper.IsLiteral(expr, "nil") { - return true + if tv, ok := p.TypesInfo.Types[expr]; ok { + return tv.IsNil() } - tv, ok := p.TypesInfo.Types[expr] - if !ok { - return false - } - return tv.IsNil() + return asthelper.IsLiteral(expr, "nil") } // HumanReadablePosition modifies the Position's filename to be more human-friendly (truncated or relative to cwd).