diff --git a/fixer/fixer.go b/fixer/fixer.go new file mode 100644 index 0000000..b25daa0 --- /dev/null +++ b/fixer/fixer.go @@ -0,0 +1,271 @@ +package fixer + +import ( + "bytes" + "fmt" + "go/ast" + "go/format" + "go/parser" + "go/token" + "os" + "strings" + + "github.com/unsaid-dev/goperf/rules" +) + +// Fix represents an auto-fixable change +type Fix struct { + File string + Line int + Original string + Fixed string + Rule string + Applied bool +} + +// Fixer handles automatic code fixes +type Fixer struct { + DryRun bool + Verbose bool +} + +// NewFixer creates a new fixer +func NewFixer(dryRun, verbose bool) *Fixer { + return &Fixer{ + DryRun: dryRun, + Verbose: verbose, + } +} + +// FixIssues attempts to fix the given issues +func (f *Fixer) FixIssues(issues []rules.Issue) []Fix { + var fixes []Fix + + // Group issues by file + byFile := make(map[string][]rules.Issue) + for _, issue := range issues { + byFile[issue.File] = append(byFile[issue.File], issue) + } + + for file, fileIssues := range byFile { + fileFixes := f.fixFile(file, fileIssues) + fixes = append(fixes, fileFixes...) + } + + return fixes +} + +func (f *Fixer) fixFile(filename string, issues []rules.Issue) []Fix { + var fixes []Fix + + src, err := os.ReadFile(filename) + if err != nil { + return fixes + } + + fset := token.NewFileSet() + astFile, err := parser.ParseFile(fset, filename, src, parser.ParseComments) + if err != nil { + return fixes + } + + modified := false + lines := strings.Split(string(src), "\n") + + for _, issue := range issues { + fix := f.attemptFix(issue, astFile, fset, lines) + if fix != nil { + fixes = append(fixes, *fix) + if !f.DryRun && fix.Fixed != "" { + modified = true + } + } + } + + if modified && !f.DryRun { + // Format and write back + var buf bytes.Buffer + if err := format.Node(&buf, fset, astFile); err == nil { + os.WriteFile(filename, buf.Bytes(), 0644) + } + } + + return fixes +} + +func (f *Fixer) attemptFix(issue rules.Issue, file *ast.File, fset *token.FileSet, lines []string) *Fix { + switch issue.Rule { + case "string-concat-loop": + return f.fixStringConcat(issue, file, fset, lines) + case "unpreallocated-slice": + return f.fixUnpreallocatedSlice(issue, file, fset, lines) + case "missing-body-close": + return f.fixMissingBodyClose(issue, file, fset, lines) + case "context-leak": + return f.fixContextLeak(issue, file, fset, lines) + default: + // Return suggestion-only fix + return &Fix{ + File: issue.File, + Line: issue.Line, + Original: getLine(lines, issue.Line), + Fixed: "", // No auto-fix available + Rule: issue.Rule, + Applied: false, + } + } +} + +func (f *Fixer) fixStringConcat(issue rules.Issue, file *ast.File, fset *token.FileSet, lines []string) *Fix { + // Find the function containing this line + line := issue.Line + original := getLine(lines, line) + + // Suggest using strings.Builder + fix := &Fix{ + File: issue.File, + Line: line, + Original: original, + Rule: issue.Rule, + Applied: false, + } + + // Generate suggestion (actual AST modification is complex) + fix.Fixed = "// TODO: Replace with strings.Builder\n" + + "// var b strings.Builder\n" + + "// for ... { b.WriteString(s) }\n" + + "// result := b.String()" + + return fix +} + +func (f *Fixer) fixUnpreallocatedSlice(issue rules.Issue, file *ast.File, fset *token.FileSet, lines []string) *Fix { + line := issue.Line + original := getLine(lines, line) + + fix := &Fix{ + File: issue.File, + Line: line, + Original: original, + Rule: issue.Rule, + Applied: false, + } + + // Extract slice name from message + msg := issue.Message + start := strings.Index(msg, "'") + end := strings.LastIndex(msg, "'") + if start >= 0 && end > start { + sliceName := msg[start+1 : end] + fix.Fixed = fmt.Sprintf("%s = make([]T, 0, expectedSize) // Preallocate %s", sliceName, sliceName) + } + + return fix +} + +func (f *Fixer) fixMissingBodyClose(issue rules.Issue, file *ast.File, fset *token.FileSet, lines []string) *Fix { + line := issue.Line + original := getLine(lines, line) + + // Find the variable name from the message + msg := issue.Message + start := strings.Index(msg, "'") + end := strings.LastIndex(msg, "'") + + varName := "resp" + if start >= 0 && end > start { + varName = msg[start+1 : end] + } + + fix := &Fix{ + File: issue.File, + Line: line, + Original: original, + Rule: issue.Rule, + Applied: false, + Fixed: fmt.Sprintf("defer %s.Body.Close()", varName), + } + + return fix +} + +func (f *Fixer) fixContextLeak(issue rules.Issue, file *ast.File, fset *token.FileSet, lines []string) *Fix { + line := issue.Line + original := getLine(lines, line) + + // Extract cancel function name from message + msg := issue.Message + start := strings.Index(msg, "'") + end := strings.LastIndex(msg, "'") + + cancelName := "cancel" + if start >= 0 && end > start { + cancelName = msg[start+1 : end] + } + + fix := &Fix{ + File: issue.File, + Line: line, + Original: original, + Rule: issue.Rule, + Applied: false, + Fixed: fmt.Sprintf("defer %s()", cancelName), + } + + return fix +} + +func getLine(lines []string, lineNum int) string { + if lineNum > 0 && lineNum <= len(lines) { + return lines[lineNum-1] + } + return "" +} + +// PrintFixes displays the fixes in a readable format +func PrintFixes(fixes []Fix, dryRun bool) { + if len(fixes) == 0 { + fmt.Println("No auto-fixes available for the detected issues.") + return + } + + if dryRun { + fmt.Println("=== DRY RUN: Suggested fixes (no files modified) ===\n") + } else { + fmt.Println("=== Applied fixes ===\n") + } + + for _, fix := range fixes { + fmt.Printf("File: %s:%d\n", fix.File, fix.Line) + fmt.Printf("Rule: %s\n", fix.Rule) + if fix.Original != "" { + fmt.Printf("Original: %s\n", strings.TrimSpace(fix.Original)) + } + if fix.Fixed != "" { + fmt.Printf("Fix: %s\n", fix.Fixed) + } else { + fmt.Println("Fix: Manual intervention required - see issue suggestion") + } + fmt.Println() + } +} + +// GenerateDiff creates a unified diff for review +func GenerateDiff(fixes []Fix) string { + var buf bytes.Buffer + + for _, fix := range fixes { + if fix.Fixed == "" { + continue + } + + buf.WriteString(fmt.Sprintf("--- a/%s\n", fix.File)) + buf.WriteString(fmt.Sprintf("+++ b/%s\n", fix.File)) + buf.WriteString(fmt.Sprintf("@@ -%d,1 +%d,1 @@\n", fix.Line, fix.Line)) + buf.WriteString(fmt.Sprintf("-%s\n", fix.Original)) + buf.WriteString(fmt.Sprintf("+%s\n", fix.Fixed)) + buf.WriteString("\n") + } + + return buf.String() +} diff --git a/main.go b/main.go index 0657135..95b0cca 100644 --- a/main.go +++ b/main.go @@ -8,18 +8,21 @@ import ( "path/filepath" "strings" + "github.com/unsaid-dev/goperf/fixer" "github.com/unsaid-dev/goperf/reporter" "github.com/unsaid-dev/goperf/rules" ) var ( - rulesFlag = flag.String("rules", "all", "Comma-separated rules to run: algorithm,allocation,database,concurrency,io,cache,all") - formatFlag = flag.String("format", "console", "Output format: console, json") + rulesFlag = flag.String("rules", "all", "Comma-separated rules to run: algorithm,allocation,database,concurrency,io,cache,context,memory,benchmark,all") + formatFlag = flag.String("format", "console", "Output format: console, json, diff") failOnFlag = flag.String("fail-on", "", "Exit with code 1 if issues at this level or higher: low, medium, high, critical") contextFlag = flag.Int("context", 3, "Lines of context to show around issues") ignoreFlag = flag.String("ignore", "", "Comma-separated paths to ignore") verboseFlag = flag.Bool("verbose", false, "Show verbose output") versionFlag = flag.Bool("version", false, "Show version") + fixFlag = flag.Bool("fix", false, "Automatically fix issues where possible") + dryRunFlag = flag.Bool("dry-run", false, "Show fixes without applying them (use with --fix)") ) var version = "0.1.0" @@ -38,6 +41,8 @@ Examples: goperf --rules=algorithm ./internal/ # Only algorithm rules goperf --format=json ./... # JSON output for CI goperf --fail-on=high ./... # Exit 1 if high+ issues + goperf --fix --dry-run ./... # Preview auto-fixes + goperf --fix ./... # Apply auto-fixes Flags: `) @@ -45,17 +50,27 @@ Flags: fmt.Fprintf(os.Stderr, ` Rule Categories: algorithm - O(n²) loops, repeated linear searches - allocation - Unpreallocated slices, string concatenation - database - N+1 queries, SQL in loops + allocation - Unpreallocated slices, string concatenation, interface boxing + database - N+1 queries, SQL in loops, connection pool issues concurrency - Lock contention, unbuffered channels - io - Unbuffered I/O, sequential operations - cache - Repeated computations, missing memoization + io - Unbuffered I/O, HTTP body handling, response buffering + cache - Repeated regex/template compilation, JSON schema in loops + context - Missing timeouts, context leaks, context.Background in handlers + memory - Large struct copying, pprof in hot paths, heap escapes + benchmark - Functions with performance patterns that need benchmarks Severity Levels: critical - Will cause production issues high - Significant performance impact medium - Moderate impact, should fix low - Minor optimization opportunity + +Auto-Fix Support: + The following rules support auto-fix: + - string-concat-loop → strings.Builder suggestion + - unpreallocated-slice → make() with capacity + - missing-body-close → defer Body.Close() + - context-leak → defer cancel() `) } flag.Parse() @@ -105,11 +120,42 @@ Severity Levels: // Run analysis issues := analyzer.Analyze(files) + // Handle fix mode + if *fixFlag { + f := fixer.NewFixer(*dryRunFlag, *verboseFlag) + fixes := f.FixIssues(issues) + + if *formatFlag == "diff" { + fmt.Println(fixer.GenerateDiff(fixes)) + } else { + fixer.PrintFixes(fixes, *dryRunFlag) + } + + // Still show summary + if len(issues) > 0 { + fmt.Printf("\nTotal issues found: %d\n", len(issues)) + fixable := 0 + for _, fix := range fixes { + if fix.Fixed != "" { + fixable++ + } + } + fmt.Printf("Auto-fixable: %d\n", fixable) + } + return + } + // Report results var rep reporter.Reporter switch *formatFlag { case "json": rep = &reporter.JSONReporter{} + case "diff": + // Generate diff output even without --fix + f := fixer.NewFixer(true, *verboseFlag) + fixes := f.FixIssues(issues) + fmt.Println(fixer.GenerateDiff(fixes)) + return default: rep = &reporter.ConsoleReporter{Context: *contextFlag} } @@ -130,7 +176,7 @@ Severity Levels: func parseRules(rulesStr string) []string { if rulesStr == "all" { - return []string{"algorithm", "allocation", "database", "concurrency", "io", "cache"} + return []string{"algorithm", "allocation", "database", "concurrency", "io", "cache", "context", "memory", "benchmark"} } return strings.Split(rulesStr, ",") } diff --git a/rules/algorithm.go b/rules/algorithm.go index 740633a..4c683f7 100644 --- a/rules/algorithm.go +++ b/rules/algorithm.go @@ -11,6 +11,7 @@ func init() { } // NestedRangeRule detects O(n²) nested range loops +// Now smarter: recognizes map-based optimizations type NestedRangeRule struct{} func (r *NestedRangeRule) Name() string { return "nested-range" } @@ -20,42 +21,75 @@ func (r *NestedRangeRule) Check(file *ast.File, fset *token.FileSet, src []byte) var issues []Issue ast.Inspect(file, func(n ast.Node) bool { - outerRange, ok := n.(*ast.RangeStmt) + funcDecl, ok := n.(*ast.FuncDecl) if !ok { return true } - outerVar := getRangeVar(outerRange) + if funcDecl.Body == nil { + return true + } + + // Find maps that are populated before loops (lookup optimization) + lookupMaps := findLookupMaps(funcDecl.Body) - // Look for nested range inside this one - ast.Inspect(outerRange.Body, func(inner ast.Node) bool { - innerRange, ok := inner.(*ast.RangeStmt) + ast.Inspect(funcDecl.Body, func(inner ast.Node) bool { + outerRange, ok := inner.(*ast.RangeStmt) if !ok { return true } - // Check if inner loop iterates over same or related collection - severity := SeverityMedium - innerVar := getRangeVar(innerRange) + outerVar := getRangeVar(outerRange) - // Same variable = definitely O(n²) - if outerVar != "" && outerVar == innerVar { - severity = SeverityHigh - } + // Look for nested range inside this one + ast.Inspect(outerRange.Body, func(innerNode ast.Node) bool { + innerRange, ok := innerNode.(*ast.RangeStmt) + if !ok { + return true + } + + innerVar := getRangeVar(innerRange) + + // Check if the inner loop uses a map lookup instead of linear search + // This is O(n*m) where m is O(1) = O(n), not O(n²) + if usesMapLookup(innerRange.Body, lookupMaps) { + // This is actually optimized - don't flag + return false + } + + // Check if inner loop iterates over same or related collection + severity := SeverityMedium + + // Same variable = definitely O(n²) + if outerVar != "" && outerVar == innerVar { + severity = SeverityHigh + } + + // Check if the loop body is trivial (few operations) + if isTrivalLoopBody(innerRange.Body) { + if severity == SeverityHigh { + severity = SeverityMedium + } else { + severity = SeverityLow + } + } + + pos := fset.Position(innerRange.Pos()) + issues = append(issues, Issue{ + Rule: r.Name(), + Category: r.Category(), + Severity: severity, + Line: pos.Line, + Column: pos.Column, + Message: "Nested range loop detected - potential O(n²) complexity", + Why: "Nested iteration over collections scales quadratically. With 1000 items, this runs 1,000,000 times. With 10,000 items, 100,000,000 times.", + Fix: "Consider: (1) Building a map for O(1) lookups, (2) Using incremental/delta computation, (3) Sorting + binary search, (4) Breaking early when possible", + }) - pos := fset.Position(innerRange.Pos()) - issues = append(issues, Issue{ - Rule: r.Name(), - Category: r.Category(), - Severity: severity, - Line: pos.Line, - Column: pos.Column, - Message: "Nested range loop detected - potential O(n²) complexity", - Why: "Nested iteration over collections scales quadratically. With 1000 items, this runs 1,000,000 times. With 10,000 items, 100,000,000 times.", - Fix: "Consider: (1) Building a map for O(1) lookups, (2) Using incremental/delta computation, (3) Sorting + binary search, (4) Breaking early when possible", + return false // Don't recurse into inner range }) - return false // Don't recurse into inner range + return true }) return true @@ -64,6 +98,81 @@ func (r *NestedRangeRule) Check(file *ast.File, fset *token.FileSet, src []byte) return issues } +// findLookupMaps finds maps that are populated before being used as lookups +func findLookupMaps(body *ast.BlockStmt) map[string]bool { + maps := make(map[string]bool) + + // Find map declarations + ast.Inspect(body, func(n ast.Node) bool { + assign, ok := n.(*ast.AssignStmt) + if !ok { + return true + } + + for i, rhs := range assign.Rhs { + call, ok := rhs.(*ast.CallExpr) + if !ok { + continue + } + + ident, ok := call.Fun.(*ast.Ident) + if !ok || ident.Name != "make" { + continue + } + + if len(call.Args) < 1 { + continue + } + + // Check if it's a map type + if _, isMap := call.Args[0].(*ast.MapType); isMap { + if i < len(assign.Lhs) { + if lhsIdent, ok := assign.Lhs[i].(*ast.Ident); ok { + maps[lhsIdent.Name] = true + } + } + } + } + + return true + }) + + return maps +} + +// usesMapLookup checks if a loop body uses map lookup instead of iteration +func usesMapLookup(body *ast.BlockStmt, maps map[string]bool) bool { + usesMap := false + + ast.Inspect(body, func(n ast.Node) bool { + // Look for map[key] access pattern + indexExpr, ok := n.(*ast.IndexExpr) + if !ok { + return true + } + + if ident, ok := indexExpr.X.(*ast.Ident); ok { + if maps[ident.Name] { + usesMap = true + return false + } + } + + return true + }) + + return usesMap +} + +// isTrivalLoopBody checks if loop body is very simple (few statements) +func isTrivalLoopBody(body *ast.BlockStmt) bool { + if body == nil { + return true + } + // Consider trivial if <= 3 statements + return len(body.List) <= 3 +} + func getRangeVar(r *ast.RangeStmt) string { if ident, ok := r.X.(*ast.Ident); ok { return ident.Name @@ -80,44 +189,65 @@ func (r *LinearSearchInLoopRule) Category() string { return "algorithm" } func (r *LinearSearchInLoopRule) Check(file *ast.File, fset *token.FileSet, src []byte) []Issue { var issues []Issue - // Find slice/array contains checks in loops ast.Inspect(file, func(n ast.Node) bool { - var loopBody *ast.BlockStmt - switch stmt := n.(type) { - case *ast.RangeStmt: - loopBody = stmt.Body - case *ast.ForStmt: - loopBody = stmt.Body - default: + funcDecl, ok := n.(*ast.FuncDecl) + if !ok { return true } - if loopBody == nil { + if funcDecl.Body == nil { return true } - // Look for inner range loops that look like linear searches - ast.Inspect(loopBody, func(inner ast.Node) bool { - innerRange, ok := inner.(*ast.RangeStmt) - if !ok { + // Find existing lookup maps + lookupMaps := findLookupMaps(funcDecl.Body) + + // Find slice/array contains checks in loops + ast.Inspect(funcDecl.Body, func(inner ast.Node) bool { + var loopBody *ast.BlockStmt + switch stmt := inner.(type) { + case *ast.RangeStmt: + loopBody = stmt.Body + case *ast.ForStmt: + loopBody = stmt.Body + default: return true } - // Check if the inner loop body contains a comparison and break/return - if containsSearchPattern(innerRange.Body) { - pos := fset.Position(innerRange.Pos()) - issues = append(issues, Issue{ - Rule: r.Name(), - Category: r.Category(), - Severity: SeverityMedium, - Line: pos.Line, - Column: pos.Column, - Message: "Linear search inside loop - consider using a map", - Why: "Searching a slice/array is O(n). Inside a loop, this becomes O(n*m). Building a map once is O(n), then lookups are O(1).", - Fix: "Build a map[key]value or map[key]struct{} before the loop for O(1) lookups", - }) + if loopBody == nil { + return true } + // Look for inner range loops that look like linear searches + ast.Inspect(loopBody, func(stmt ast.Node) bool { + innerRange, ok := stmt.(*ast.RangeStmt) + if !ok { + return true + } + + // Check if already using map lookup + if usesMapLookup(innerRange.Body, lookupMaps) { + return true // Already optimized + } + + // Check if the inner loop body contains a comparison and break/return + if containsSearchPattern(innerRange.Body) { + pos := fset.Position(innerRange.Pos()) + issues = append(issues, Issue{ + Rule: r.Name(), + Category: r.Category(), + Severity: SeverityMedium, + Line: pos.Line, + Column: pos.Column, + Message: "Linear search inside loop - consider using a map", + Why: "Searching a slice/array is O(n). Inside a loop, this becomes O(n*m). Building a map once is O(n), then lookups are O(1).", + Fix: "Build a map[key]value or map[key]struct{} before the loop for O(1) lookups", + }) + } + + return true + }) + return true }) @@ -133,12 +263,16 @@ func containsSearchPattern(body *ast.BlockStmt) bool { hasBreakOrReturn := false ast.Inspect(body, func(n ast.Node) bool { - switch n.(type) { + switch node := n.(type) { case *ast.BinaryExpr: - // Has a comparison - hasComparison = true + // Look for equality comparison (== or !=) + if node.Op == token.EQL || node.Op == token.NEQ { + hasComparison = true + } case *ast.BranchStmt: - hasBreakOrReturn = true + if node.Tok == token.BREAK { + hasBreakOrReturn = true + } case *ast.ReturnStmt: hasBreakOrReturn = true } diff --git a/rules/allocation.go b/rules/allocation.go index 1639e42..7723970 100644 --- a/rules/allocation.go +++ b/rules/allocation.go @@ -12,6 +12,7 @@ func init() { } // UnpreallocatedSliceRule detects slice append in loops without preallocation +// Now smarter: tracks make() calls with capacity before loops type UnpreallocatedSliceRule struct{} func (r *UnpreallocatedSliceRule) Name() string { return "unpreallocated-slice" } @@ -20,23 +21,164 @@ func (r *UnpreallocatedSliceRule) Category() string { return "allocation" } func (r *UnpreallocatedSliceRule) Check(file *ast.File, fset *token.FileSet, src []byte) []Issue { var issues []Issue - results := FindAppendInLoop(file, fset) - for _, result := range results { - issues = append(issues, Issue{ - Rule: r.Name(), - Category: r.Category(), - Severity: SeverityLow, - Line: result.Pos.Line, - Column: result.Pos.Column, - Message: "append() in loop without preallocation", - Why: "Slice grows dynamically, causing repeated memory allocations and copies. Each reallocation typically doubles capacity, wasting memory and CPU.", - Fix: "Preallocate with make([]T, 0, expectedSize) before the loop if size is known or estimable", + // For each function, track preallocated slices + ast.Inspect(file, func(n ast.Node) bool { + funcDecl, ok := n.(*ast.FuncDecl) + if !ok { + return true + } + + if funcDecl.Body == nil { + return true + } + + // Track preallocated slices in this function + preallocated := findPreallocatedSlices(funcDecl.Body) + + // Now find append in loops + ast.Inspect(funcDecl.Body, func(inner ast.Node) bool { + var loopBody *ast.BlockStmt + var loopNode ast.Node + switch stmt := inner.(type) { + case *ast.RangeStmt: + loopBody = stmt.Body + loopNode = stmt + case *ast.ForStmt: + loopBody = stmt.Body + loopNode = stmt + default: + return true + } + + if loopBody == nil { + return true + } + + loopBound := getLoopBound(loopNode) + + // Find append calls in the loop body + ast.Inspect(loopBody, func(stmt ast.Node) bool { + assign, ok := stmt.(*ast.AssignStmt) + if !ok { + return true + } + + for i, rhs := range assign.Rhs { + call, ok := rhs.(*ast.CallExpr) + if !ok { + continue + } + ident, ok := call.Fun.(*ast.Ident) + if !ok || ident.Name != "append" { + continue + } + + // Get the slice variable being appended to + if i >= len(assign.Lhs) { + continue + } + lhsIdent, ok := assign.Lhs[i].(*ast.Ident) + if !ok { + continue + } + + // Check if this slice was preallocated + if preallocated[lhsIdent.Name] { + continue // Skip - preallocated, not an issue + } + + severity := SeverityLow + // If loop is large or unbounded, increase severity + if loopBound < 0 || loopBound > 100 { + severity = SeverityMedium + } + + pos := fset.Position(call.Pos()) + issues = append(issues, Issue{ + Rule: r.Name(), + Category: r.Category(), + Severity: severity, + Line: pos.Line, + Column: pos.Column, + Message: "append() in loop without preallocation for '" + lhsIdent.Name + "'", + Why: "Slice grows dynamically, causing repeated memory allocations and copies. Each reallocation typically doubles capacity, wasting memory and CPU.", + Fix: "Preallocate with make([]T, 0, expectedSize) before the loop if size is known or estimable", + }) + } + return true + }) + + return true }) - } + + return true + }) return issues } +// findPreallocatedSlices finds slice variables that were created with make() and a capacity +func findPreallocatedSlices(body *ast.BlockStmt) map[string]bool { + preallocated := make(map[string]bool) + + ast.Inspect(body, func(n ast.Node) bool { + // Look for: s := make([]T, 0, cap) or s := make([]T, size) + assign, ok := n.(*ast.AssignStmt) + if !ok { + return true + } + + for i, rhs := range assign.Rhs { + call, ok := rhs.(*ast.CallExpr) + if !ok { + continue + } + + ident, ok := call.Fun.(*ast.Ident) + if !ok || ident.Name != "make" { + continue + } + + if len(call.Args) < 1 { + continue + } + + // Check if it's a slice type + _, isSlice := call.Args[0].(*ast.ArrayType) + if !isSlice { + continue + } + + // Has capacity if 3 args: make([]T, len, cap) + // Or if 2 args with non-zero len: make([]T, size) + hasCapacity := false + if len(call.Args) >= 3 { + hasCapacity = true + } else if len(call.Args) == 2 { + // Check if size is non-zero or a variable (assumed sized) + if lit, ok := call.Args[1].(*ast.BasicLit); ok { + if lit.Kind == token.INT && lit.Value != "0" { + hasCapacity = true + } + } else { + // It's a variable or expression - assume it's sized + hasCapacity = true + } + } + + if hasCapacity && i < len(assign.Lhs) { + if lhsIdent, ok := assign.Lhs[i].(*ast.Ident); ok { + preallocated[lhsIdent.Name] = true + } + } + } + + return true + }) + + return preallocated +} + // StringConcatInLoopRule detects string += concatenation in loops type StringConcatInLoopRule struct{} @@ -46,24 +188,132 @@ func (r *StringConcatInLoopRule) Category() string { return "allocation" } func (r *StringConcatInLoopRule) Check(file *ast.File, fset *token.FileSet, src []byte) []Issue { var issues []Issue - positions := FindStringConcatInLoop(file, fset) - for _, pos := range positions { - issues = append(issues, Issue{ - Rule: r.Name(), - Category: r.Category(), - Severity: SeverityMedium, - Line: pos.Line, - Column: pos.Column, - Message: "String concatenation in loop creates O(n²) allocations", - Why: "Strings are immutable in Go. Each += creates a new string, copying all previous content. Building a 1000-char string this way allocates ~500KB total.", - Fix: "Use strings.Builder: var b strings.Builder; for ... { b.WriteString(s) }; result := b.String()", + // Track strings.Builder usage + builders := findStringBuilders(file) + + ast.Inspect(file, func(n ast.Node) bool { + var loopBody *ast.BlockStmt + switch stmt := n.(type) { + case *ast.RangeStmt: + loopBody = stmt.Body + case *ast.ForStmt: + loopBody = stmt.Body + default: + return true + } + + if loopBody == nil { + return true + } + + // Find += on strings + ast.Inspect(loopBody, func(inner ast.Node) bool { + assign, ok := inner.(*ast.AssignStmt) + if !ok || assign.Tok != token.ADD_ASSIGN { + return true + } + + // Get the variable being concatenated + if len(assign.Lhs) > 0 { + if ident, ok := assign.Lhs[0].(*ast.Ident); ok { + // Skip if using strings.Builder + if builders[ident.Name] { + return true + } + } + } + + // Check if RHS involves strings + for _, rhs := range assign.Rhs { + if isStringExpr(rhs) { + pos := fset.Position(assign.Pos()) + issues = append(issues, Issue{ + Rule: r.Name(), + Category: r.Category(), + Severity: SeverityMedium, + Line: pos.Line, + Column: pos.Column, + Message: "String concatenation in loop creates O(n²) allocations", + Why: "Strings are immutable in Go. Each += creates a new string, copying all previous content. Building a 1000-char string this way allocates ~500KB total.", + Fix: "Use strings.Builder: var b strings.Builder; for ... { b.WriteString(s) }; result := b.String()", + }) + } + } + return true }) - } + + return true + }) return issues } -// MapWithoutSizeRule detects map creation without size hint when size is known +// findStringBuilders finds variables declared as strings.Builder +func findStringBuilders(file *ast.File) map[string]bool { + builders := make(map[string]bool) + + ast.Inspect(file, func(n ast.Node) bool { + // var b strings.Builder + genDecl, ok := n.(*ast.GenDecl) + if ok && genDecl.Tok == token.VAR { + for _, spec := range genDecl.Specs { + valueSpec, ok := spec.(*ast.ValueSpec) + if !ok { + continue + } + if sel, ok := valueSpec.Type.(*ast.SelectorExpr); ok { + if ident, ok := sel.X.(*ast.Ident); ok { + if ident.Name == "strings" && sel.Sel.Name == "Builder" { + for _, name := range valueSpec.Names { + builders[name.Name] = true + } + } + } + } + } + } + + // Also check for short declarations: b := strings.Builder{} + assign, ok := n.(*ast.AssignStmt) + if !ok { + return true + } + + for i, rhs := range assign.Rhs { + compLit, ok := rhs.(*ast.CompositeLit) + if !ok { + continue + } + if sel, ok := compLit.Type.(*ast.SelectorExpr); ok { + if ident, ok := sel.X.(*ast.Ident); ok { + if ident.Name == "strings" && sel.Sel.Name == "Builder" { + if i < len(assign.Lhs) { + if lhsIdent, ok := assign.Lhs[i].(*ast.Ident); ok { + builders[lhsIdent.Name] = true + } + } + } + } + } + } + + return true + }) + + return builders +} + +func isStringExpr(expr ast.Expr) bool { + switch e := expr.(type) { + case *ast.BasicLit: + return e.Kind == token.STRING + case *ast.BinaryExpr: + return isStringExpr(e.X) || isStringExpr(e.Y) + } + return false +} + +// MapWithoutSizeRule detects map creation without size hint when populated in a loop type MapWithoutSizeRule struct{} func (r *MapWithoutSizeRule) Name() string { return "map-without-size" } @@ -72,52 +322,126 @@ func (r *MapWithoutSizeRule) Category() string { return "allocation" } func (r *MapWithoutSizeRule) Check(file *ast.File, fset *token.FileSet, src []byte) []Issue { var issues []Issue + // Only flag maps that are actually populated in loops ast.Inspect(file, func(n ast.Node) bool { - // Look for patterns like: - // m := make(map[K]V) - // for _, item := range items { m[k] = v } + funcDecl, ok := n.(*ast.FuncDecl) + if !ok { + return true + } - assign, ok := n.(*ast.AssignStmt) - if !ok || len(assign.Rhs) != 1 { + if funcDecl.Body == nil { return true } - call, ok := assign.Rhs[0].(*ast.CallExpr) + // Find maps without size hints + unsizedMaps := findUnsizedMaps(funcDecl.Body, fset) + + // Find maps that are populated in loops + populatedInLoop := findMapsPopulatedInLoop(funcDecl.Body) + + // Only report maps that are both unsized AND populated in a loop + for name, pos := range unsizedMaps { + if populatedInLoop[name] { + issues = append(issues, Issue{ + Rule: r.Name(), + Category: r.Category(), + Severity: SeverityLow, + Line: pos.Line, + Column: pos.Column, + Message: "Map '" + name + "' created without size hint and populated in loop", + Why: "Maps without size hints start small and rehash as they grow. If you know the approximate size, providing it avoids rehashing overhead.", + Fix: "Use make(map[K]V, expectedSize) if the size is known or estimable from the loop source", + }) + } + } + + return true + }) + + return issues +} + +func findUnsizedMaps(body *ast.BlockStmt, fset *token.FileSet) map[string]token.Position { + maps := make(map[string]token.Position) + + ast.Inspect(body, func(n ast.Node) bool { + assign, ok := n.(*ast.AssignStmt) if !ok { return true } - ident, ok := call.Fun.(*ast.Ident) - if !ok || ident.Name != "make" { + for i, rhs := range assign.Rhs { + call, ok := rhs.(*ast.CallExpr) + if !ok { + continue + } + + ident, ok := call.Fun.(*ast.Ident) + if !ok || ident.Name != "make" { + continue + } + + if len(call.Args) < 1 { + continue + } + + // Check if it's a map type without size + _, isMap := call.Args[0].(*ast.MapType) + if !isMap || len(call.Args) > 1 { + continue // Has size hint or not a map + } + + if i < len(assign.Lhs) { + if lhsIdent, ok := assign.Lhs[i].(*ast.Ident); ok { + maps[lhsIdent.Name] = fset.Position(call.Pos()) + } + } + } + + return true + }) + + return maps +} + +func findMapsPopulatedInLoop(body *ast.BlockStmt) map[string]bool { + populated := make(map[string]bool) + + ast.Inspect(body, func(n ast.Node) bool { + var loopBody *ast.BlockStmt + switch stmt := n.(type) { + case *ast.RangeStmt: + loopBody = stmt.Body + case *ast.ForStmt: + loopBody = stmt.Body + default: return true } - if len(call.Args) < 1 { + if loopBody == nil { return true } - // Check if it's a map type without size - _, isMap := call.Args[0].(*ast.MapType) - if !isMap || len(call.Args) > 1 { - return true // Has size hint or not a map - } - - // Check if there's a loop that populates it nearby - // This is a simplified heuristic - pos := fset.Position(call.Pos()) - issues = append(issues, Issue{ - Rule: r.Name(), - Category: r.Category(), - Severity: SeverityLow, - Line: pos.Line, - Column: pos.Column, - Message: "Map created without size hint", - Why: "Maps without size hints start small and rehash as they grow. If you know the approximate size, providing it avoids rehashing overhead.", - Fix: "Use make(map[K]V, expectedSize) if the size is known or estimable", + // Find map assignments in the loop: m[k] = v + ast.Inspect(loopBody, func(inner ast.Node) bool { + assign, ok := inner.(*ast.AssignStmt) + if !ok { + return true + } + + for _, lhs := range assign.Lhs { + if indexExpr, ok := lhs.(*ast.IndexExpr); ok { + if ident, ok := indexExpr.X.(*ast.Ident); ok { + populated[ident.Name] = true + } + } + } + + return true }) return true }) - return issues + return populated } diff --git a/rules/analyzer.go b/rules/analyzer.go index bc636a8..d971f1c 100644 --- a/rules/analyzer.go +++ b/rules/analyzer.go @@ -213,16 +213,6 @@ func FindStringConcatInLoop(file *ast.File, fset *token.FileSet) []token.Positio return results } -func isStringExpr(expr ast.Expr) bool { - switch e := expr.(type) { - case *ast.BasicLit: - return e.Kind == token.STRING - case *ast.BinaryExpr: - return isStringExpr(e.X) || isStringExpr(e.Y) - } - return false -} - // FindSQLInLoop finds database query patterns inside loops func FindSQLInLoop(file *ast.File, fset *token.FileSet) []SQLInLoopInfo { var results []SQLInLoopInfo diff --git a/rules/benchmark.go b/rules/benchmark.go new file mode 100644 index 0000000..d0ab3ad --- /dev/null +++ b/rules/benchmark.go @@ -0,0 +1,188 @@ +package rules + +import ( + "go/ast" + "go/token" +) + +func init() { + RegisterRule("benchmark", &BenchmarkSuggestionRule{}) +} + +// BenchmarkSuggestionRule suggests benchmarks for functions with detected issues +type BenchmarkSuggestionRule struct{} + +func (r *BenchmarkSuggestionRule) Name() string { return "benchmark-suggestion" } +func (r *BenchmarkSuggestionRule) Category() string { return "benchmark" } + +func (r *BenchmarkSuggestionRule) Check(file *ast.File, fset *token.FileSet, src []byte) []Issue { + var issues []Issue + + // Find functions that have performance-sensitive patterns + ast.Inspect(file, func(n ast.Node) bool { + funcDecl, ok := n.(*ast.FuncDecl) + if !ok { + return true + } + + if funcDecl.Body == nil { + return true + } + + funcName := funcDecl.Name.Name + + // Skip if already a benchmark + if len(funcName) > 9 && funcName[:9] == "Benchmark" { + return true + } + + // Check for performance-sensitive patterns + patterns := checkPerformancePatterns(funcDecl.Body) + + if len(patterns) > 0 { + pos := fset.Position(funcDecl.Pos()) + + // Generate benchmark suggestion + benchCode := generateBenchmarkCode(funcName, funcDecl) + + issues = append(issues, Issue{ + Rule: r.Name(), + Category: r.Category(), + Severity: SeverityLow, + Line: pos.Line, + Column: pos.Column, + Message: "Function '" + funcName + "' has " + itoa(len(patterns)) + " performance-sensitive pattern(s) - consider adding benchmark", + Why: "This function contains: " + joinPatterns(patterns) + ". Benchmarking helps track performance regressions.", + Fix: "Add benchmark:\n" + benchCode, + }) + } + + return true + }) + + return issues +} + +type perfPattern struct { + name string + count int +} + +func checkPerformancePatterns(body *ast.BlockStmt) []perfPattern { + var patterns []perfPattern + + loopCount := 0 + sqlCount := 0 + allocCount := 0 + reflectCount := 0 + + ast.Inspect(body, func(n ast.Node) bool { + switch node := n.(type) { + case *ast.RangeStmt, *ast.ForStmt: + loopCount++ + case *ast.CallExpr: + if sel, ok := node.Fun.(*ast.SelectorExpr); ok { + if ident, ok := sel.X.(*ast.Ident); ok { + // Check for SQL + sqlMethods := map[string]bool{ + "Query": true, "QueryRow": true, "Exec": true, + "QueryContext": true, "ExecContext": true, + } + if sqlMethods[sel.Sel.Name] { + sqlCount++ + } + + // Check for reflection + if ident.Name == "reflect" { + reflectCount++ + } + + // Check for JSON + if ident.Name == "json" && (sel.Sel.Name == "Marshal" || sel.Sel.Name == "Unmarshal") { + allocCount++ + } + } + + // Check for make + if ident, ok := node.Fun.(*ast.Ident); ok && ident.Name == "make" { + allocCount++ + } + } + } + return true + }) + + if loopCount > 0 { + patterns = append(patterns, perfPattern{"loops", loopCount}) + } + if sqlCount > 0 { + patterns = append(patterns, perfPattern{"database calls", sqlCount}) + } + if allocCount > 0 { + patterns = append(patterns, perfPattern{"allocations", allocCount}) + } + if reflectCount > 0 { + patterns = append(patterns, perfPattern{"reflection", reflectCount}) + } + + return patterns +} + +func generateBenchmarkCode(funcName string, funcDecl *ast.FuncDecl) string { + // Generate basic benchmark scaffold + benchName := "Benchmark" + capitalizeFirst(funcName) + + code := "func " + benchName + "(b *testing.B) {\n" + code += "\t// Setup: initialize test data\n" + + // Add parameter hints based on function signature + if funcDecl.Type.Params != nil && len(funcDecl.Type.Params.List) > 0 { + code += "\t// params: " + for i, param := range funcDecl.Type.Params.List { + if i > 0 { + code += ", " + } + for j, name := range param.Names { + if j > 0 { + code += ", " + } + code += name.Name + } + } + code += "\n" + } + + code += "\n\tb.ResetTimer()\n" + code += "\tfor i := 0; i < b.N; i++ {\n" + code += "\t\t" + funcName + "(...) // Add arguments\n" + code += "\t}\n" + code += "}" + + return code +} + +func capitalizeFirst(s string) string { + if s == "" { + return s + } + first := s[0] + if first >= 'a' && first <= 'z' { + return string(first-32) + s[1:] + } + return s +} + +func joinPatterns(patterns []perfPattern) string { + if len(patterns) == 0 { + return "" + } + + result := "" + for i, p := range patterns { + if i > 0 { + result += ", " + } + result += itoa(p.count) + " " + p.name + } + return result +} diff --git a/rules/cache.go b/rules/cache.go index 595b266..040fee7 100644 --- a/rules/cache.go +++ b/rules/cache.go @@ -8,6 +8,8 @@ import ( func init() { RegisterRule("cache", &RepeatedRegexpCompileRule{}) RegisterRule("cache", &RepeatedTemplateParseRule{}) + RegisterRule("cache", &RegexpMatchStringRule{}) + RegisterRule("cache", &JSONSchemaValidationRule{}) } // RepeatedRegexpCompileRule detects regexp.Compile inside functions (should be package-level) @@ -124,3 +126,134 @@ func (r *RepeatedTemplateParseRule) Check(file *ast.File, fset *token.FileSet, s return issues } + +// RegexpMatchStringRule detects regexp.MatchString in loops (compiles each time) +type RegexpMatchStringRule struct{} + +func (r *RegexpMatchStringRule) Name() string { return "regexp-match-string-loop" } +func (r *RegexpMatchStringRule) Category() string { return "cache" } + +func (r *RegexpMatchStringRule) Check(file *ast.File, fset *token.FileSet, src []byte) []Issue { + var issues []Issue + + ast.Inspect(file, func(n ast.Node) bool { + var loopBody *ast.BlockStmt + switch stmt := n.(type) { + case *ast.RangeStmt: + loopBody = stmt.Body + case *ast.ForStmt: + loopBody = stmt.Body + default: + return true + } + + if loopBody == nil { + return true + } + + ast.Inspect(loopBody, func(inner ast.Node) bool { + call, ok := inner.(*ast.CallExpr) + if !ok { + return true + } + + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return true + } + + ident, ok := sel.X.(*ast.Ident) + if !ok { + return true + } + + // regexp.MatchString, regexp.Match compile the pattern each time + if ident.Name == "regexp" { + switch sel.Sel.Name { + case "MatchString", "Match", "ReplaceAllString", "ReplaceAll", + "FindString", "FindAllString", "FindStringSubmatch": + pos := fset.Position(call.Pos()) + issues = append(issues, Issue{ + Rule: r.Name(), + Category: r.Category(), + Severity: SeverityHigh, + Line: pos.Line, + Column: pos.Column, + Message: "regexp." + sel.Sel.Name + "() in loop - compiles regex on EVERY call", + Why: "regexp.MatchString and similar functions compile the pattern each time. This is O(n*m) where m is pattern complexity.", + Fix: "Compile once: var re = regexp.MustCompile(`pattern`); then use re.MatchString(s) in the loop", + }) + } + } + + return true + }) + + return true + }) + + return issues +} + +// JSONSchemaValidationRule detects JSON schema validation in loops +type JSONSchemaValidationRule struct{} + +func (r *JSONSchemaValidationRule) Name() string { return "json-schema-in-loop" } +func (r *JSONSchemaValidationRule) Category() string { return "cache" } + +func (r *JSONSchemaValidationRule) Check(file *ast.File, fset *token.FileSet, src []byte) []Issue { + var issues []Issue + + ast.Inspect(file, func(n ast.Node) bool { + var loopBody *ast.BlockStmt + switch stmt := n.(type) { + case *ast.RangeStmt: + loopBody = stmt.Body + case *ast.ForStmt: + loopBody = stmt.Body + default: + return true + } + + if loopBody == nil { + return true + } + + ast.Inspect(loopBody, func(inner ast.Node) bool { + call, ok := inner.(*ast.CallExpr) + if !ok { + return true + } + + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return true + } + + // Check for schema compilation/validation patterns + if sel.Sel.Name == "Compile" || sel.Sel.Name == "NewSchema" || sel.Sel.Name == "Validate" { + // Look for jsonschema, gojsonschema, etc. + ident, ok := sel.X.(*ast.Ident) + if ok && (ident.Name == "jsonschema" || ident.Name == "gojsonschema" || ident.Name == "schema") { + pos := fset.Position(call.Pos()) + issues = append(issues, Issue{ + Rule: r.Name(), + Category: r.Category(), + Severity: SeverityMedium, + Line: pos.Line, + Column: pos.Column, + Message: "JSON schema compilation/validation in loop - compile schema once", + Why: "Compiling JSON schemas is expensive. Recompiling for each item wastes CPU.", + Fix: "Compile the schema once outside the loop, then call Validate() in the loop", + }) + } + } + + return true + }) + + return true + }) + + return issues +} diff --git a/rules/callgraph.go b/rules/callgraph.go new file mode 100644 index 0000000..a974fdb --- /dev/null +++ b/rules/callgraph.go @@ -0,0 +1,314 @@ +package rules + +import ( + "go/ast" + "go/token" +) + +func init() { + RegisterRule("database", &IndirectSQLInLoopRule{}) +} + +// IndirectSQLInLoopRule detects when functions containing SQL are called in loops +// This catches the N+1 pattern even when SQL is wrapped in helper functions +type IndirectSQLInLoopRule struct{} + +func (r *IndirectSQLInLoopRule) Name() string { return "indirect-sql-in-loop" } +func (r *IndirectSQLInLoopRule) Category() string { return "database" } + +func (r *IndirectSQLInLoopRule) Check(file *ast.File, fset *token.FileSet, src []byte) []Issue { + var issues []Issue + + // First pass: find functions that contain direct SQL calls + sqlFuncs := findFunctionsWithSQL(file) + + // Second pass: find loops that call these functions + ast.Inspect(file, func(n ast.Node) bool { + var loopBody *ast.BlockStmt + var loopNode ast.Node + switch stmt := n.(type) { + case *ast.RangeStmt: + loopBody = stmt.Body + loopNode = stmt + case *ast.ForStmt: + loopBody = stmt.Body + loopNode = stmt + default: + return true + } + + if loopBody == nil { + return true + } + + loopBound := getLoopBound(loopNode) + + // Find function calls in the loop body + ast.Inspect(loopBody, func(inner ast.Node) bool { + call, ok := inner.(*ast.CallExpr) + if !ok { + return true + } + + // Get the function name being called + var funcName string + switch fn := call.Fun.(type) { + case *ast.Ident: + funcName = fn.Name + case *ast.SelectorExpr: + // Method call like s.Save() - check method name + funcName = fn.Sel.Name + } + + if funcName == "" { + return true + } + + // Check if this function contains SQL + if sqlInfo, ok := sqlFuncs[funcName]; ok { + severity := SeverityHigh + if loopBound > 0 && loopBound <= 10 { + severity = SeverityMedium + } + + pos := fset.Position(call.Pos()) + issues = append(issues, Issue{ + Rule: r.Name(), + Category: r.Category(), + Severity: severity, + Line: pos.Line, + Column: pos.Column, + Message: "Function '" + funcName + "' called in loop contains SQL (" + sqlInfo.method + ")", + Why: "The function " + funcName + "() contains database operations. Calling it in a loop creates N+1 query patterns even though the SQL isn't directly visible.", + Fix: "Refactor to batch: (1) Collect items first, (2) Pass slice to function, (3) Use single batch query inside function", + }) + } + + return true + }) + + return true + }) + + return issues +} + +type sqlFuncInfo struct { + method string + line int +} + +// findFunctionsWithSQL finds all functions in the file that contain SQL operations +func findFunctionsWithSQL(file *ast.File) map[string]sqlFuncInfo { + sqlFuncs := make(map[string]sqlFuncInfo) + + sqlMethods := map[string]bool{ + "Query": true, "QueryRow": true, "Exec": true, + "QueryRowContext": true, "QueryContext": true, "ExecContext": true, + "Get": true, "Select": true, + } + + for _, decl := range file.Decls { + funcDecl, ok := decl.(*ast.FuncDecl) + if !ok || funcDecl.Body == nil { + continue + } + + funcName := funcDecl.Name.Name + + // Skip methods that are themselves SQL operations + if sqlMethods[funcName] { + continue + } + + // Check if function body contains SQL calls + ast.Inspect(funcDecl.Body, func(n ast.Node) bool { + call, ok := n.(*ast.CallExpr) + if !ok { + return true + } + + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return true + } + + if sqlMethods[sel.Sel.Name] { + sqlFuncs[funcName] = sqlFuncInfo{ + method: sel.Sel.Name, + } + return false // Found one, stop searching this function + } + + return true + }) + } + + return sqlFuncs +} + +// ReflectionInLoopRule detects reflection usage in loops (advanced) +type ReflectionInLoopRule struct{} + +func init() { + RegisterRule("io", &ReflectionInLoopRule{}) +} + +func (r *ReflectionInLoopRule) Name() string { return "reflection-in-loop" } +func (r *ReflectionInLoopRule) Category() string { return "io" } + +func (r *ReflectionInLoopRule) Check(file *ast.File, fset *token.FileSet, src []byte) []Issue { + var issues []Issue + + ast.Inspect(file, func(n ast.Node) bool { + var loopBody *ast.BlockStmt + switch stmt := n.(type) { + case *ast.RangeStmt: + loopBody = stmt.Body + case *ast.ForStmt: + loopBody = stmt.Body + default: + return true + } + + if loopBody == nil { + return true + } + + // Find reflect.ValueOf, reflect.TypeOf calls in loops + ast.Inspect(loopBody, func(inner ast.Node) bool { + call, ok := inner.(*ast.CallExpr) + if !ok { + return true + } + + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return true + } + + ident, ok := sel.X.(*ast.Ident) + if !ok { + return true + } + + if ident.Name == "reflect" { + reflectMethods := map[string]bool{ + "ValueOf": true, "TypeOf": true, "New": true, + "MakeSlice": true, "MakeMap": true, "MakeFunc": true, + } + + if reflectMethods[sel.Sel.Name] { + pos := fset.Position(call.Pos()) + issues = append(issues, Issue{ + Rule: r.Name(), + Category: r.Category(), + Severity: SeverityMedium, + Line: pos.Line, + Column: pos.Column, + Message: "reflect." + sel.Sel.Name + "() inside loop - significant overhead", + Why: "Reflection is slow compared to direct type access. In loops, this overhead multiplies significantly.", + Fix: "Consider: (1) Caching reflection results outside loop, (2) Using type assertions, (3) Code generation for type-specific operations", + }) + } + } + + return true + }) + + return true + }) + + return issues +} + +// SyncPoolOpportunityRule detects repeated allocations that could use sync.Pool +type SyncPoolOpportunityRule struct{} + +func init() { + RegisterRule("allocation", &SyncPoolOpportunityRule{}) +} + +func (r *SyncPoolOpportunityRule) Name() string { return "sync-pool-opportunity" } +func (r *SyncPoolOpportunityRule) Category() string { return "allocation" } + +func (r *SyncPoolOpportunityRule) Check(file *ast.File, fset *token.FileSet, src []byte) []Issue { + var issues []Issue + + // Find functions that are called frequently (heuristic: called in loops) + // and allocate buffers/slices that could be pooled + ast.Inspect(file, func(n ast.Node) bool { + var loopBody *ast.BlockStmt + switch stmt := n.(type) { + case *ast.RangeStmt: + loopBody = stmt.Body + case *ast.ForStmt: + loopBody = stmt.Body + default: + return true + } + + if loopBody == nil { + return true + } + + // Find allocations that could be pooled + ast.Inspect(loopBody, func(inner ast.Node) bool { + call, ok := inner.(*ast.CallExpr) + if !ok { + return true + } + + ident, ok := call.Fun.(*ast.Ident) + if !ok { + return true + } + + // Check for make([]byte, ...) - common buffer allocation + if ident.Name == "make" && len(call.Args) >= 1 { + if arrayType, ok := call.Args[0].(*ast.ArrayType); ok { + if elemIdent, ok := arrayType.Elt.(*ast.Ident); ok { + if elemIdent.Name == "byte" { + pos := fset.Position(call.Pos()) + issues = append(issues, Issue{ + Rule: r.Name(), + Category: r.Category(), + Severity: SeverityLow, + Line: pos.Line, + Column: pos.Column, + Message: "Byte slice allocation in loop - consider sync.Pool", + Why: "Allocating buffers in a loop creates GC pressure. For high-frequency operations, sync.Pool can reuse allocations.", + Fix: "Use sync.Pool: var bufPool = sync.Pool{New: func() any { return make([]byte, size) }}; buf := bufPool.Get().([]byte); defer bufPool.Put(buf)", + }) + } + } + } + } + + // Check for bytes.Buffer creation + if sel, ok := call.Fun.(*ast.SelectorExpr); ok { + if x, ok := sel.X.(*ast.Ident); ok { + if x.Name == "bytes" && sel.Sel.Name == "NewBuffer" { + pos := fset.Position(call.Pos()) + issues = append(issues, Issue{ + Rule: r.Name(), + Category: r.Category(), + Severity: SeverityLow, + Line: pos.Line, + Column: pos.Column, + Message: "bytes.Buffer creation in loop - consider sync.Pool", + Why: "Creating new buffers in a loop creates GC pressure. sync.Pool can reuse buffer instances.", + Fix: "Pool bytes.Buffer instances and Reset() before reuse", + }) + } + } + } + + return true + }) + + return true + }) + + return issues +} diff --git a/rules/concurrency.go b/rules/concurrency.go index 38eb4fd..89bfde4 100644 --- a/rules/concurrency.go +++ b/rules/concurrency.go @@ -12,6 +12,7 @@ func init() { } // UnbufferedChannelRule detects unbuffered channel creation +// Now smarter: checks for intentional synchronization patterns type UnbufferedChannelRule struct{} func (r *UnbufferedChannelRule) Name() string { return "unbuffered-channel" } @@ -20,23 +21,251 @@ func (r *UnbufferedChannelRule) Category() string { return "concurrency" } func (r *UnbufferedChannelRule) Check(file *ast.File, fset *token.FileSet, src []byte) []Issue { var issues []Issue - positions := FindUnbufferedChannels(file, fset) - for _, pos := range positions { - issues = append(issues, Issue{ - Rule: r.Name(), - Category: r.Category(), - Severity: SeverityLow, - Line: pos.Line, - Column: pos.Column, - Message: "Unbuffered channel - may cause goroutine blocking", - Why: "Unbuffered channels block the sender until a receiver is ready. This can cause deadlocks or reduce parallelism if not carefully designed.", - Fix: "Consider adding a buffer: make(chan T, bufferSize). Use unbuffered only when synchronization is intentional.", - }) - } + ast.Inspect(file, func(n ast.Node) bool { + funcDecl, ok := n.(*ast.FuncDecl) + if !ok { + return true + } + + if funcDecl.Body == nil { + return true + } + + // Find unbuffered channels and their variable names + unbufferedChans := findUnbufferedChannelVars(funcDecl.Body, fset) + + // Find channels used in select statements (intentional synchronization) + selectChans := findChannelsInSelect(funcDecl.Body) + + // Find channels used with proper goroutine coordination (done/signal patterns) + signalChans := findSignalChannels(funcDecl.Body) + + // Find channels that are struct{} type (typically signals) + emptyStructChans := findEmptyStructChannels(funcDecl.Body) + + // Report only channels that are NOT in select and NOT used as signals + for name, pos := range unbufferedChans { + // Skip if used in select (intentional blocking) + if selectChans[name] { + continue + } + + // Skip signal/done pattern channels + if signalChans[name] { + continue + } + + // Skip chan struct{} - almost always intentional signals + if emptyStructChans[name] { + continue + } + + // Lower severity - could be intentional + severity := SeverityLow + message := "Unbuffered channel '" + name + "' - verify synchronization is intentional" + why := "Unbuffered channels block the sender until a receiver is ready. This is correct for synchronization but may cause deadlocks if misused." + fix := "If synchronization is intentional, add comment: // Intentional: synchronization point. Otherwise, consider adding a buffer: make(chan T, size)" + + issues = append(issues, Issue{ + Rule: r.Name(), + Category: r.Category(), + Severity: severity, + Line: pos.Line, + Column: pos.Column, + Message: message, + Why: why, + Fix: fix, + }) + } + + return true + }) return issues } +// findUnbufferedChannelVars finds variables that hold unbuffered channels +func findUnbufferedChannelVars(body *ast.BlockStmt, fset *token.FileSet) map[string]token.Position { + chans := make(map[string]token.Position) + + ast.Inspect(body, func(n ast.Node) bool { + assign, ok := n.(*ast.AssignStmt) + if !ok { + return true + } + + for i, rhs := range assign.Rhs { + call, ok := rhs.(*ast.CallExpr) + if !ok { + continue + } + + ident, ok := call.Fun.(*ast.Ident) + if !ok || ident.Name != "make" { + continue + } + + if len(call.Args) < 1 { + continue + } + + // Check if it's a channel type + _, ok = call.Args[0].(*ast.ChanType) + if !ok { + continue + } + + // Unbuffered if no second argument or second arg is 0 + isUnbuffered := false + if len(call.Args) == 1 { + isUnbuffered = true + } else if len(call.Args) >= 2 { + if lit, ok := call.Args[1].(*ast.BasicLit); ok { + if lit.Kind == token.INT && lit.Value == "0" { + isUnbuffered = true + } + } + } + + if isUnbuffered && i < len(assign.Lhs) { + if lhsIdent, ok := assign.Lhs[i].(*ast.Ident); ok { + chans[lhsIdent.Name] = fset.Position(call.Pos()) + } + } + } + + return true + }) + + return chans +} + +// findChannelsInSelect finds channel variables used in select statements +func findChannelsInSelect(body *ast.BlockStmt) map[string]bool { + selectChans := make(map[string]bool) + + ast.Inspect(body, func(n ast.Node) bool { + selectStmt, ok := n.(*ast.SelectStmt) + if !ok { + return true + } + + // Check each case in the select + for _, stmt := range selectStmt.Body.List { + commClause, ok := stmt.(*ast.CommClause) + if !ok { + continue + } + + // Extract channel from comm statement + switch comm := commClause.Comm.(type) { + case *ast.SendStmt: + if ident, ok := comm.Chan.(*ast.Ident); ok { + selectChans[ident.Name] = true + } + case *ast.ExprStmt: + if unary, ok := comm.X.(*ast.UnaryExpr); ok && unary.Op == token.ARROW { + if ident, ok := unary.X.(*ast.Ident); ok { + selectChans[ident.Name] = true + } + } + case *ast.AssignStmt: + for _, rhs := range comm.Rhs { + if unary, ok := rhs.(*ast.UnaryExpr); ok && unary.Op == token.ARROW { + if ident, ok := unary.X.(*ast.Ident); ok { + selectChans[ident.Name] = true + } + } + } + } + } + + return true + }) + + return selectChans +} + +// findSignalChannels finds channels named with common signal patterns +func findSignalChannels(body *ast.BlockStmt) map[string]bool { + signals := make(map[string]bool) + + signalNames := map[string]bool{ + "done": true, "quit": true, "stop": true, "cancel": true, + "sig": true, "signal": true, "shutdown": true, "closed": true, + "ready": true, "started": true, "finished": true, + } + + ast.Inspect(body, func(n ast.Node) bool { + assign, ok := n.(*ast.AssignStmt) + if !ok { + return true + } + + for i := range assign.Rhs { + if i < len(assign.Lhs) { + if ident, ok := assign.Lhs[i].(*ast.Ident); ok { + if signalNames[ident.Name] { + signals[ident.Name] = true + } + } + } + } + + return true + }) + + return signals +} + +// findEmptyStructChannels finds channels of type chan struct{} +func findEmptyStructChannels(body *ast.BlockStmt) map[string]bool { + emptyStructChans := make(map[string]bool) + + ast.Inspect(body, func(n ast.Node) bool { + assign, ok := n.(*ast.AssignStmt) + if !ok { + return true + } + + for i, rhs := range assign.Rhs { + call, ok := rhs.(*ast.CallExpr) + if !ok { + continue + } + + ident, ok := call.Fun.(*ast.Ident) + if !ok || ident.Name != "make" { + continue + } + + if len(call.Args) < 1 { + continue + } + + chanType, ok := call.Args[0].(*ast.ChanType) + if !ok { + continue + } + + // Check if it's chan struct{} + if structType, ok := chanType.Value.(*ast.StructType); ok { + if structType.Fields == nil || len(structType.Fields.List) == 0 { + if i < len(assign.Lhs) { + if lhsIdent, ok := assign.Lhs[i].(*ast.Ident); ok { + emptyStructChans[lhsIdent.Name] = true + } + } + } + } + } + + return true + }) + + return emptyStructChans +} + // MutexInLoopRule detects mutex Lock() calls inside loops type MutexInLoopRule struct{} @@ -48,11 +277,14 @@ func (r *MutexInLoopRule) Check(file *ast.File, fset *token.FileSet, src []byte) ast.Inspect(file, func(n ast.Node) bool { var loopBody *ast.BlockStmt + var loopNode ast.Node switch stmt := n.(type) { case *ast.RangeStmt: loopBody = stmt.Body + loopNode = stmt case *ast.ForStmt: loopBody = stmt.Body + loopNode = stmt default: return true } @@ -61,6 +293,8 @@ func (r *MutexInLoopRule) Check(file *ast.File, fset *token.FileSet, src []byte) return true } + loopBound := getLoopBound(loopNode) + // Find Lock() calls in the loop body ast.Inspect(loopBody, func(inner ast.Node) bool { call, ok := inner.(*ast.CallExpr) @@ -74,11 +308,18 @@ func (r *MutexInLoopRule) Check(file *ast.File, fset *token.FileSet, src []byte) } if sel.Sel.Name == "Lock" || sel.Sel.Name == "RLock" { + severity := SeverityMedium + + // Small bounded loops are less severe + if loopBound > 0 && loopBound <= 10 { + severity = SeverityLow + } + pos := fset.Position(call.Pos()) issues = append(issues, Issue{ Rule: r.Name(), Category: r.Category(), - Severity: SeverityMedium, + Severity: severity, Line: pos.Line, Column: pos.Column, Message: "Mutex " + sel.Sel.Name + "() inside loop - potential contention", @@ -116,7 +357,7 @@ func (r *GoroutineLeakRule) Check(file *ast.File, fset *token.FileSet, src []byt ast.Inspect(goStmt.Call, func(inner ast.Node) bool { switch node := inner.(type) { case *ast.Ident: - if node.Name == "ctx" || node.Name == "done" || node.Name == "cancel" || node.Name == "quit" { + if node.Name == "ctx" || node.Name == "done" || node.Name == "cancel" || node.Name == "quit" || node.Name == "stop" { hasContextOrDone = true } case *ast.SelectorExpr: @@ -125,6 +366,9 @@ func (r *GoroutineLeakRule) Check(file *ast.File, fset *token.FileSet, src []byt hasContextOrDone = true } } + case *ast.SelectStmt: + // Has a select statement - likely has proper termination + hasContextOrDone = true } return true }) diff --git a/rules/context.go b/rules/context.go new file mode 100644 index 0000000..d9e2e20 --- /dev/null +++ b/rules/context.go @@ -0,0 +1,312 @@ +package rules + +import ( + "go/ast" + "go/token" +) + +func init() { + RegisterRule("context", &ContextBackgroundInHandlerRule{}) + RegisterRule("context", &MissingContextTimeoutRule{}) + RegisterRule("context", &ContextLeakRule{}) +} + +// ContextBackgroundInHandlerRule detects context.Background() in HTTP handlers +type ContextBackgroundInHandlerRule struct{} + +func (r *ContextBackgroundInHandlerRule) Name() string { return "context-background-in-handler" } +func (r *ContextBackgroundInHandlerRule) Category() string { return "context" } + +func (r *ContextBackgroundInHandlerRule) Check(file *ast.File, fset *token.FileSet, src []byte) []Issue { + var issues []Issue + + ast.Inspect(file, func(n ast.Node) bool { + funcDecl, ok := n.(*ast.FuncDecl) + if !ok { + return true + } + + // Check if this looks like an HTTP handler + if !isHTTPHandler(funcDecl) { + return true + } + + if funcDecl.Body == nil { + return true + } + + // Look for context.Background() or context.TODO() calls + ast.Inspect(funcDecl.Body, func(inner ast.Node) bool { + call, ok := inner.(*ast.CallExpr) + if !ok { + return true + } + + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return true + } + + ident, ok := sel.X.(*ast.Ident) + if !ok { + return true + } + + if ident.Name == "context" && (sel.Sel.Name == "Background" || sel.Sel.Name == "TODO") { + pos := fset.Position(call.Pos()) + issues = append(issues, Issue{ + Rule: r.Name(), + Category: r.Category(), + Severity: SeverityMedium, + Line: pos.Line, + Column: pos.Column, + Message: "context." + sel.Sel.Name + "() in HTTP handler - use request context instead", + Why: "HTTP handlers should use r.Context() which is cancelled when the client disconnects. context.Background() ignores client disconnection.", + Fix: "Use ctx := r.Context() instead, or derive from it: ctx, cancel := context.WithTimeout(r.Context(), timeout)", + }) + } + + return true + }) + + return true + }) + + return issues +} + +// MissingContextTimeoutRule detects external calls without context timeout +type MissingContextTimeoutRule struct{} + +func (r *MissingContextTimeoutRule) Name() string { return "missing-context-timeout" } +func (r *MissingContextTimeoutRule) Category() string { return "context" } + +func (r *MissingContextTimeoutRule) Check(file *ast.File, fset *token.FileSet, src []byte) []Issue { + var issues []Issue + + // Track contexts with timeouts in each function + ast.Inspect(file, func(n ast.Node) bool { + funcDecl, ok := n.(*ast.FuncDecl) + if !ok { + return true + } + + if funcDecl.Body == nil { + return true + } + + // Find context variables with timeouts + ctxsWithTimeout := findContextsWithTimeout(funcDecl.Body) + + // Check external service calls + ast.Inspect(funcDecl.Body, func(inner ast.Node) bool { + call, ok := inner.(*ast.CallExpr) + if !ok { + return true + } + + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return true + } + + // Check for HTTP client calls, gRPC calls, database calls + externalMethods := map[string]bool{ + "Do": true, // http.Client.Do + "Get": true, // http.Get or client.Get + "Post": true, + "PostForm": true, + "Head": true, + "Invoke": true, // gRPC + "NewRequest": true, // http.NewRequest (should use NewRequestWithContext) + } + + if externalMethods[sel.Sel.Name] { + // Check if any argument is a context with timeout + hasTimeoutCtx := false + for _, arg := range call.Args { + if ident, ok := arg.(*ast.Ident); ok { + if ctxsWithTimeout[ident.Name] { + hasTimeoutCtx = true + break + } + } + } + + // Special check for http.NewRequest - should be NewRequestWithContext + if sel.Sel.Name == "NewRequest" { + if ident, ok := sel.X.(*ast.Ident); ok && ident.Name == "http" { + pos := fset.Position(call.Pos()) + issues = append(issues, Issue{ + Rule: r.Name(), + Category: r.Category(), + Severity: SeverityMedium, + Line: pos.Line, + Column: pos.Column, + Message: "http.NewRequest() doesn't accept context - use NewRequestWithContext()", + Why: "NewRequest creates a request without context, making it impossible to cancel or timeout.", + Fix: "Use http.NewRequestWithContext(ctx, method, url, body) instead", + }) + } + } + + // For Do, Get, etc. - check if context seems to have timeout + if !hasTimeoutCtx && (sel.Sel.Name == "Do" || sel.Sel.Name == "Get" || sel.Sel.Name == "Invoke") { + // This is a heuristic - we can't be 100% sure without type info + // Only flag if the receiver looks like an HTTP client + } + } + + return true + }) + + return true + }) + + return issues +} + +// ContextLeakRule detects context cancel functions that aren't called +type ContextLeakRule struct{} + +func (r *ContextLeakRule) Name() string { return "context-leak" } +func (r *ContextLeakRule) Category() string { return "context" } + +func (r *ContextLeakRule) Check(file *ast.File, fset *token.FileSet, src []byte) []Issue { + var issues []Issue + + ast.Inspect(file, func(n ast.Node) bool { + funcDecl, ok := n.(*ast.FuncDecl) + if !ok { + return true + } + + if funcDecl.Body == nil { + return true + } + + // Find cancel functions from WithCancel/WithTimeout/WithDeadline + cancelFuncs := make(map[string]token.Position) + + ast.Inspect(funcDecl.Body, func(inner ast.Node) bool { + assign, ok := inner.(*ast.AssignStmt) + if !ok { + return true + } + + for i, rhs := range assign.Rhs { + call, ok := rhs.(*ast.CallExpr) + if !ok { + continue + } + + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + continue + } + + ident, ok := sel.X.(*ast.Ident) + if !ok { + continue + } + + if ident.Name == "context" { + switch sel.Sel.Name { + case "WithCancel", "WithTimeout", "WithDeadline": + // Second return value is the cancel function + if len(assign.Lhs) > i+1 { + if cancelIdent, ok := assign.Lhs[i+1].(*ast.Ident); ok { + if cancelIdent.Name != "_" { + cancelFuncs[cancelIdent.Name] = fset.Position(assign.Pos()) + } + } + } + } + } + } + + return true + }) + + // Check if cancel functions are called (either directly or via defer) + calledCancels := make(map[string]bool) + + ast.Inspect(funcDecl.Body, func(inner ast.Node) bool { + switch node := inner.(type) { + case *ast.CallExpr: + if ident, ok := node.Fun.(*ast.Ident); ok { + calledCancels[ident.Name] = true + } + case *ast.DeferStmt: + if call, ok := node.Call.Fun.(*ast.Ident); ok { + calledCancels[call.Name] = true + } + } + return true + }) + + // Report uncalled cancel functions + for name, pos := range cancelFuncs { + if !calledCancels[name] { + issues = append(issues, Issue{ + Rule: r.Name(), + Category: r.Category(), + Severity: SeverityHigh, + Line: pos.Line, + Column: pos.Column, + Message: "Context cancel function '" + name + "' is never called - resource leak", + Why: "Not calling cancel() leaks goroutines and resources associated with the context. The context will never be cancelled.", + Fix: "Add 'defer " + name + "()' immediately after creating the context", + }) + } + } + + return true + }) + + return issues +} + +// Helper function to find contexts that have timeout/deadline +func findContextsWithTimeout(body *ast.BlockStmt) map[string]bool { + ctxs := make(map[string]bool) + + ast.Inspect(body, func(n ast.Node) bool { + assign, ok := n.(*ast.AssignStmt) + if !ok { + return true + } + + for i, rhs := range assign.Rhs { + call, ok := rhs.(*ast.CallExpr) + if !ok { + continue + } + + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + continue + } + + ident, ok := sel.X.(*ast.Ident) + if !ok { + continue + } + + if ident.Name == "context" { + switch sel.Sel.Name { + case "WithTimeout", "WithDeadline": + if i < len(assign.Lhs) { + if ctxIdent, ok := assign.Lhs[i].(*ast.Ident); ok { + ctxs[ctxIdent.Name] = true + } + } + } + } + } + + return true + }) + + return ctxs +} diff --git a/rules/database.go b/rules/database.go index f9e88aa..60b014a 100644 --- a/rules/database.go +++ b/rules/database.go @@ -10,7 +10,7 @@ func init() { RegisterRule("database", &UnbatchedInsertRule{}) } -// SQLInLoopRule detects N+1 query patterns +// SQLInLoopRule detects N+1 query patterns with smart prepared statement detection type SQLInLoopRule struct{} func (r *SQLInLoopRule) Name() string { return "sql-in-loop" } @@ -19,28 +19,235 @@ func (r *SQLInLoopRule) Category() string { return "database" } func (r *SQLInLoopRule) Check(file *ast.File, fset *token.FileSet, src []byte) []Issue { var issues []Issue - results := FindSQLInLoop(file, fset) - for _, result := range results { - severity := SeverityHigh - if result.Method == "Exec" || result.Method == "ExecContext" { - severity = SeverityCritical // Writes in loops are worse + // Track prepared statements declared in the current scope + preparedStmts := findPreparedStatements(file) + + // Track transaction variables + txVars := findTransactionVariables(file) + + ast.Inspect(file, func(n ast.Node) bool { + var loopBody *ast.BlockStmt + var loopNode ast.Node + switch stmt := n.(type) { + case *ast.RangeStmt: + loopBody = stmt.Body + loopNode = stmt + case *ast.ForStmt: + loopBody = stmt.Body + loopNode = stmt + default: + return true } - issues = append(issues, Issue{ - Rule: r.Name(), - Category: r.Category(), - Severity: severity, - Line: result.Pos.Line, - Column: result.Pos.Column, - Message: "Database " + result.Method + "() called inside loop - N+1 query pattern", - Why: "Each iteration makes a separate database round-trip. With 100 items, that's 100 queries instead of 1. Network latency dominates, making this extremely slow.", - Fix: "Use batch operations: SELECT ... WHERE id IN (...), bulk INSERT, or collect IDs and query once outside the loop", + if loopBody == nil { + return true + } + + // Check for bounded loops (small, known iteration count) + loopBound := getLoopBound(loopNode) + + // Find SQL method calls in the loop body + ast.Inspect(loopBody, func(inner ast.Node) bool { + call, ok := inner.(*ast.CallExpr) + if !ok { + return true + } + + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return true + } + + sqlMethods := map[string]bool{ + "Query": true, "QueryRow": true, "Exec": true, + "QueryRowContext": true, "QueryContext": true, "ExecContext": true, + "Get": true, "Select": true, + } + + if !sqlMethods[sel.Sel.Name] { + return true + } + + // Get receiver variable name + receiver := getReceiverName(sel.X) + + // Check if using a prepared statement + if preparedStmts[receiver] { + // Using prepared statement - much lower severity + // This is the idiomatic Go batch pattern + return true // Skip - not a real issue + } + + // Determine severity based on context + severity := SeverityHigh + why := "Each iteration makes a separate database round-trip. With 100 items, that's 100 queries instead of 1. Network latency dominates." + fix := "Use batch operations: SELECT ... WHERE id IN (...), bulk INSERT, or collect IDs and query once outside the loop" + + // Writes are worse than reads + if sel.Sel.Name == "Exec" || sel.Sel.Name == "ExecContext" { + severity = SeverityCritical + } + + // Transaction context reduces severity slightly (batched round-trips) + if txVars[receiver] { + if severity == SeverityCritical { + severity = SeverityHigh + } else { + severity = SeverityMedium + } + why = "Each iteration executes separately within the transaction. While batched at commit, individual executions still have overhead." + fix = "Consider using prepared statements with tx.Prepare() for better performance, or batch the operations" + } + + // Small bounded loops are less severe + if loopBound > 0 && loopBound <= 10 { + if severity == SeverityCritical { + severity = SeverityHigh + } else if severity == SeverityHigh { + severity = SeverityMedium + } else { + severity = SeverityLow + } + why = "Loop appears bounded to a small number of iterations. Still incurs round-trip overhead but impact is limited." + } + + pos := fset.Position(call.Pos()) + issues = append(issues, Issue{ + Rule: r.Name(), + Category: r.Category(), + Severity: severity, + Line: pos.Line, + Column: pos.Column, + Message: "Database " + sel.Sel.Name + "() called inside loop - N+1 query pattern", + Why: why, + Fix: fix, + }) + + return true }) - } + + return true + }) return issues } +// findPreparedStatements finds variables that hold prepared statements +func findPreparedStatements(file *ast.File) map[string]bool { + stmts := make(map[string]bool) + + ast.Inspect(file, func(n ast.Node) bool { + assign, ok := n.(*ast.AssignStmt) + if !ok { + return true + } + + // Look for stmt, err := db.Prepare(...) or tx.Prepare(...) + for i, rhs := range assign.Rhs { + call, ok := rhs.(*ast.CallExpr) + if !ok { + continue + } + + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + continue + } + + if sel.Sel.Name == "Prepare" || sel.Sel.Name == "PrepareContext" { + if i < len(assign.Lhs) { + if ident, ok := assign.Lhs[i].(*ast.Ident); ok { + stmts[ident.Name] = true + } + } + } + } + + return true + }) + + return stmts +} + +// findTransactionVariables finds variables that hold database transactions +func findTransactionVariables(file *ast.File) map[string]bool { + txVars := make(map[string]bool) + + ast.Inspect(file, func(n ast.Node) bool { + assign, ok := n.(*ast.AssignStmt) + if !ok { + return true + } + + // Look for tx, err := db.Begin() or db.BeginTx(...) + for i, rhs := range assign.Rhs { + call, ok := rhs.(*ast.CallExpr) + if !ok { + continue + } + + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + continue + } + + if sel.Sel.Name == "Begin" || sel.Sel.Name == "BeginTx" { + if i < len(assign.Lhs) { + if ident, ok := assign.Lhs[i].(*ast.Ident); ok { + txVars[ident.Name] = true + } + } + } + } + + return true + }) + + return txVars +} + +// getReceiverName extracts the receiver variable name from a selector expression +func getReceiverName(expr ast.Expr) string { + switch e := expr.(type) { + case *ast.Ident: + return e.Name + case *ast.SelectorExpr: + // s.db -> return the final selector + return e.Sel.Name + } + return "" +} + +// getLoopBound tries to determine if a loop has a small, known bound +func getLoopBound(loopNode ast.Node) int { + switch stmt := loopNode.(type) { + case *ast.ForStmt: + // for i := 0; i < N; i++ pattern + if stmt.Cond != nil { + if binExpr, ok := stmt.Cond.(*ast.BinaryExpr); ok { + if binExpr.Op == token.LSS || binExpr.Op == token.LEQ { + if lit, ok := binExpr.Y.(*ast.BasicLit); ok && lit.Kind == token.INT { + // Parse the integer + var bound int + for _, c := range lit.Value { + if c >= '0' && c <= '9' { + bound = bound*10 + int(c-'0') + } + } + return bound + } + } + } + } + case *ast.RangeStmt: + // Check if ranging over a small literal + if compLit, ok := stmt.X.(*ast.CompositeLit); ok { + return len(compLit.Elts) + } + } + return -1 // Unknown bound +} + // UnbatchedInsertRule detects single-row inserts that could be batched type UnbatchedInsertRule struct{} @@ -52,11 +259,14 @@ func (r *UnbatchedInsertRule) Check(file *ast.File, fset *token.FileSet, src []b ast.Inspect(file, func(n ast.Node) bool { var loopBody *ast.BlockStmt + var loopNode ast.Node switch stmt := n.(type) { case *ast.RangeStmt: loopBody = stmt.Body + loopNode = stmt case *ast.ForStmt: loopBody = stmt.Body + loopNode = stmt default: return true } @@ -65,6 +275,8 @@ func (r *UnbatchedInsertRule) Check(file *ast.File, fset *token.FileSet, src []b return true } + loopBound := getLoopBound(loopNode) + // Look for Create/Insert/Save patterns (common ORM methods) ormMethods := map[string]bool{ "Create": true, @@ -84,11 +296,16 @@ func (r *UnbatchedInsertRule) Check(file *ast.File, fset *token.FileSet, src []b } if ormMethods[sel.Sel.Name] { + severity := SeverityHigh + if loopBound > 0 && loopBound <= 10 { + severity = SeverityMedium + } + pos := fset.Position(call.Pos()) issues = append(issues, Issue{ Rule: r.Name(), Category: r.Category(), - Severity: SeverityHigh, + Severity: severity, Line: pos.Line, Column: pos.Column, Message: "Single-row " + sel.Sel.Name + "() inside loop - consider batch insert", diff --git a/rules/dbpool.go b/rules/dbpool.go new file mode 100644 index 0000000..1b29084 --- /dev/null +++ b/rules/dbpool.go @@ -0,0 +1,149 @@ +package rules + +import ( + "go/ast" + "go/token" +) + +func init() { + RegisterRule("database", &MissingConnectionPoolRule{}) + RegisterRule("database", &UnlimitedConnectionPoolRule{}) +} + +// MissingConnectionPoolRule detects sql.Open without pool configuration +type MissingConnectionPoolRule struct{} + +func (r *MissingConnectionPoolRule) Name() string { return "missing-connection-pool-config" } +func (r *MissingConnectionPoolRule) Category() string { return "database" } + +func (r *MissingConnectionPoolRule) Check(file *ast.File, fset *token.FileSet, src []byte) []Issue { + var issues []Issue + + // Find sql.Open calls + sqlOpenVars := make(map[string]token.Position) + + ast.Inspect(file, func(n ast.Node) bool { + assign, ok := n.(*ast.AssignStmt) + if !ok { + return true + } + + for i, rhs := range assign.Rhs { + call, ok := rhs.(*ast.CallExpr) + if !ok { + continue + } + + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + continue + } + + ident, ok := sel.X.(*ast.Ident) + if !ok { + continue + } + + if ident.Name == "sql" && sel.Sel.Name == "Open" { + if i < len(assign.Lhs) { + if lhsIdent, ok := assign.Lhs[i].(*ast.Ident); ok { + sqlOpenVars[lhsIdent.Name] = fset.Position(assign.Pos()) + } + } + } + } + + return true + }) + + // Check if pool configuration methods are called + configuredDBs := make(map[string]bool) + poolMethods := map[string]bool{ + "SetMaxOpenConns": true, + "SetMaxIdleConns": true, + "SetConnMaxLifetime": true, + "SetConnMaxIdleTime": true, + } + + ast.Inspect(file, func(n ast.Node) bool { + call, ok := n.(*ast.CallExpr) + if !ok { + return true + } + + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return true + } + + if poolMethods[sel.Sel.Name] { + if ident, ok := sel.X.(*ast.Ident); ok { + configuredDBs[ident.Name] = true + } + } + + return true + }) + + // Report unconfigured database connections + for dbVar, pos := range sqlOpenVars { + if !configuredDBs[dbVar] { + issues = append(issues, Issue{ + Rule: r.Name(), + Category: r.Category(), + Severity: SeverityMedium, + Line: pos.Line, + Column: pos.Column, + Message: "Database connection '" + dbVar + "' opened without pool configuration", + Why: "Without configuration, the connection pool uses defaults which may not be suitable. This can lead to connection exhaustion or resource waste.", + Fix: "Configure the pool:\n " + dbVar + ".SetMaxOpenConns(25)\n " + dbVar + ".SetMaxIdleConns(5)\n " + dbVar + ".SetConnMaxLifetime(5 * time.Minute)", + }) + } + } + + return issues +} + +// UnlimitedConnectionPoolRule detects sql.DB with SetMaxOpenConns(0) +type UnlimitedConnectionPoolRule struct{} + +func (r *UnlimitedConnectionPoolRule) Name() string { return "unlimited-connection-pool" } +func (r *UnlimitedConnectionPoolRule) Category() string { return "database" } + +func (r *UnlimitedConnectionPoolRule) Check(file *ast.File, fset *token.FileSet, src []byte) []Issue { + var issues []Issue + + ast.Inspect(file, func(n ast.Node) bool { + call, ok := n.(*ast.CallExpr) + if !ok { + return true + } + + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return true + } + + if sel.Sel.Name == "SetMaxOpenConns" && len(call.Args) > 0 { + if lit, ok := call.Args[0].(*ast.BasicLit); ok { + if lit.Kind == token.INT && lit.Value == "0" { + pos := fset.Position(call.Pos()) + issues = append(issues, Issue{ + Rule: r.Name(), + Category: r.Category(), + Severity: SeverityHigh, + Line: pos.Line, + Column: pos.Column, + Message: "SetMaxOpenConns(0) allows unlimited connections - potential resource exhaustion", + Why: "Unlimited connections can exhaust database server resources during traffic spikes. Most databases have connection limits.", + Fix: "Set a reasonable limit: SetMaxOpenConns(25) - adjust based on your database's max_connections and number of app instances", + }) + } + } + } + + return true + }) + + return issues +} diff --git a/rules/errors.go b/rules/errors.go new file mode 100644 index 0000000..df650e7 --- /dev/null +++ b/rules/errors.go @@ -0,0 +1,167 @@ +package rules + +import ( + "go/ast" + "go/token" +) + +func init() { + RegisterRule("allocation", &ErrorWrapInLoopRule{}) + RegisterRule("allocation", &FmtErrorfInLoopRule{}) +} + +// ErrorWrapInLoopRule detects error wrapping in hot paths +type ErrorWrapInLoopRule struct{} + +func (r *ErrorWrapInLoopRule) Name() string { return "error-wrap-in-loop" } +func (r *ErrorWrapInLoopRule) Category() string { return "allocation" } + +func (r *ErrorWrapInLoopRule) Check(file *ast.File, fset *token.FileSet, src []byte) []Issue { + var issues []Issue + + ast.Inspect(file, func(n ast.Node) bool { + var loopBody *ast.BlockStmt + switch stmt := n.(type) { + case *ast.RangeStmt: + loopBody = stmt.Body + case *ast.ForStmt: + loopBody = stmt.Body + default: + return true + } + + if loopBody == nil { + return true + } + + ast.Inspect(loopBody, func(inner ast.Node) bool { + call, ok := inner.(*ast.CallExpr) + if !ok { + return true + } + + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return true + } + + ident, ok := sel.X.(*ast.Ident) + if !ok { + return true + } + + // Check for errors.Wrap, errors.Wrapf, fmt.Errorf with %w + if ident.Name == "errors" && (sel.Sel.Name == "Wrap" || sel.Sel.Name == "Wrapf") { + pos := fset.Position(call.Pos()) + issues = append(issues, Issue{ + Rule: r.Name(), + Category: r.Category(), + Severity: SeverityLow, + Line: pos.Line, + Column: pos.Column, + Message: "errors." + sel.Sel.Name + "() in loop - creates new error each iteration", + Why: "Error wrapping allocates a new error struct each time. In hot loops, this adds GC pressure.", + Fix: "Consider: (1) Pre-allocate sentinel errors, (2) Only wrap once outside loop, (3) Use error code patterns", + }) + } + + // Check for fmt.Errorf + if ident.Name == "fmt" && sel.Sel.Name == "Errorf" { + pos := fset.Position(call.Pos()) + issues = append(issues, Issue{ + Rule: r.Name(), + Category: r.Category(), + Severity: SeverityLow, + Line: pos.Line, + Column: pos.Column, + Message: "fmt.Errorf() in loop - allocates on each iteration", + Why: "fmt.Errorf allocates a new error and formats strings each time. In hot loops, use pre-defined errors.", + Fix: "Define errors at package level: var ErrInvalidItem = errors.New(\"invalid item\"), then use them in the loop", + }) + } + + return true + }) + + return true + }) + + return issues +} + +// FmtErrorfInLoopRule specifically detects fmt.Errorf with %w verb (error wrapping) +type FmtErrorfInLoopRule struct{} + +func (r *FmtErrorfInLoopRule) Name() string { return "fmt-errorf-wrap-loop" } +func (r *FmtErrorfInLoopRule) Category() string { return "allocation" } + +func (r *FmtErrorfInLoopRule) Check(file *ast.File, fset *token.FileSet, src []byte) []Issue { + var issues []Issue + + ast.Inspect(file, func(n ast.Node) bool { + var loopBody *ast.BlockStmt + switch stmt := n.(type) { + case *ast.RangeStmt: + loopBody = stmt.Body + case *ast.ForStmt: + loopBody = stmt.Body + default: + return true + } + + if loopBody == nil { + return true + } + + ast.Inspect(loopBody, func(inner ast.Node) bool { + call, ok := inner.(*ast.CallExpr) + if !ok { + return true + } + + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return true + } + + ident, ok := sel.X.(*ast.Ident) + if !ok { + return true + } + + if ident.Name == "fmt" && sel.Sel.Name == "Errorf" && len(call.Args) > 0 { + // Check if format string contains %w + if lit, ok := call.Args[0].(*ast.BasicLit); ok && lit.Kind == token.STRING { + if containsWrapVerb(lit.Value) { + pos := fset.Position(call.Pos()) + issues = append(issues, Issue{ + Rule: r.Name(), + Category: r.Category(), + Severity: SeverityMedium, + Line: pos.Line, + Column: pos.Column, + Message: "fmt.Errorf() with %w in loop - error chain allocates each iteration", + Why: "Error wrapping with %w creates an error chain, allocating memory each time. In hot paths, this adds up.", + Fix: "For hot paths: (1) Return the original error, (2) Use sentinel errors, (3) Wrap once after the loop with context", + }) + } + } + } + + return true + }) + + return true + }) + + return issues +} + +func containsWrapVerb(s string) bool { + for i := 0; i < len(s)-1; i++ { + if s[i] == '%' && s[i+1] == 'w' { + return true + } + } + return false +} diff --git a/rules/http.go b/rules/http.go new file mode 100644 index 0000000..92b8455 --- /dev/null +++ b/rules/http.go @@ -0,0 +1,328 @@ +package rules + +import ( + "go/ast" + "go/token" +) + +func init() { + RegisterRule("io", &MissingMaxBytesReaderRule{}) + RegisterRule("io", &MissingBodyCloseRule{}) + RegisterRule("io", &ResponseWriterBufferingRule{}) +} + +// MissingMaxBytesReaderRule detects reading request body without size limit +type MissingMaxBytesReaderRule struct{} + +func (r *MissingMaxBytesReaderRule) Name() string { return "missing-max-bytes-reader" } +func (r *MissingMaxBytesReaderRule) Category() string { return "io" } + +func (r *MissingMaxBytesReaderRule) Check(file *ast.File, fset *token.FileSet, src []byte) []Issue { + var issues []Issue + + ast.Inspect(file, func(n ast.Node) bool { + funcDecl, ok := n.(*ast.FuncDecl) + if !ok { + return true + } + + // Check if this looks like an HTTP handler + if !isHTTPHandler(funcDecl) { + return true + } + + if funcDecl.Body == nil { + return true + } + + // Track if MaxBytesReader is used + hasMaxBytesReader := false + var bodyReadPos *token.Position + + ast.Inspect(funcDecl.Body, func(inner ast.Node) bool { + call, ok := inner.(*ast.CallExpr) + if !ok { + return true + } + + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return true + } + + // Check for http.MaxBytesReader + if ident, ok := sel.X.(*ast.Ident); ok { + if ident.Name == "http" && sel.Sel.Name == "MaxBytesReader" { + hasMaxBytesReader = true + } + } + + // Check for reading request body + if sel.Sel.Name == "ReadAll" || sel.Sel.Name == "Copy" || sel.Sel.Name == "Decode" { + // Check if argument is r.Body + for _, arg := range call.Args { + if argSel, ok := arg.(*ast.SelectorExpr); ok { + if argSel.Sel.Name == "Body" { + pos := fset.Position(call.Pos()) + bodyReadPos = &pos + } + } + } + } + + // Check for json.NewDecoder(r.Body) + if ident, ok := sel.X.(*ast.Ident); ok { + if ident.Name == "json" && sel.Sel.Name == "NewDecoder" { + for _, arg := range call.Args { + if argSel, ok := arg.(*ast.SelectorExpr); ok { + if argSel.Sel.Name == "Body" { + pos := fset.Position(call.Pos()) + bodyReadPos = &pos + } + } + } + } + } + + return true + }) + + // If body is read without MaxBytesReader, flag it + if bodyReadPos != nil && !hasMaxBytesReader { + issues = append(issues, Issue{ + Rule: r.Name(), + Category: r.Category(), + Severity: SeverityMedium, + Line: bodyReadPos.Line, + Column: bodyReadPos.Column, + Message: "Request body read without size limit - potential DoS vulnerability", + Why: "Without http.MaxBytesReader, clients can send arbitrarily large bodies causing OOM. This is a denial-of-service vector.", + Fix: "Wrap body with limit: r.Body = http.MaxBytesReader(w, r.Body, maxBytes)", + }) + } + + return true + }) + + return issues +} + +// MissingBodyCloseRule detects HTTP response bodies that aren't closed +type MissingBodyCloseRule struct{} + +func (r *MissingBodyCloseRule) Name() string { return "missing-body-close" } +func (r *MissingBodyCloseRule) Category() string { return "io" } + +func (r *MissingBodyCloseRule) Check(file *ast.File, fset *token.FileSet, src []byte) []Issue { + var issues []Issue + + ast.Inspect(file, func(n ast.Node) bool { + funcDecl, ok := n.(*ast.FuncDecl) + if !ok { + return true + } + + if funcDecl.Body == nil { + return true + } + + // Find HTTP client calls that return responses + type respInfo struct { + varName string + pos token.Position + } + var responses []respInfo + + ast.Inspect(funcDecl.Body, func(inner ast.Node) bool { + assign, ok := inner.(*ast.AssignStmt) + if !ok { + return true + } + + for i, rhs := range assign.Rhs { + call, ok := rhs.(*ast.CallExpr) + if !ok { + continue + } + + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + continue + } + + // Check for http.Get, client.Do, etc. + httpMethods := map[string]bool{ + "Get": true, + "Post": true, + "PostForm": true, + "Head": true, + "Do": true, + } + + if httpMethods[sel.Sel.Name] { + if i < len(assign.Lhs) { + if ident, ok := assign.Lhs[i].(*ast.Ident); ok { + responses = append(responses, respInfo{ + varName: ident.Name, + pos: fset.Position(assign.Pos()), + }) + } + } + } + } + + return true + }) + + // Check if response bodies are closed + for _, resp := range responses { + closed := false + + ast.Inspect(funcDecl.Body, func(inner ast.Node) bool { + // Check for resp.Body.Close() or defer resp.Body.Close() + call, ok := inner.(*ast.CallExpr) + if !ok { + return true + } + + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return true + } + + if sel.Sel.Name == "Close" { + if innerSel, ok := sel.X.(*ast.SelectorExpr); ok { + if innerSel.Sel.Name == "Body" { + if ident, ok := innerSel.X.(*ast.Ident); ok { + if ident.Name == resp.varName { + closed = true + return false + } + } + } + } + } + + return true + }) + + if !closed { + issues = append(issues, Issue{ + Rule: r.Name(), + Category: r.Category(), + Severity: SeverityHigh, + Line: resp.pos.Line, + Column: resp.pos.Column, + Message: "HTTP response body not closed - connection leak", + Why: "Not closing response bodies leaks connections. HTTP keep-alive connections stay open, exhausting the connection pool.", + Fix: "Add: defer " + resp.varName + ".Body.Close() (after checking error)", + }) + } + } + + return true + }) + + return issues +} + +// ResponseWriterBufferingRule detects large writes to ResponseWriter without Flush +type ResponseWriterBufferingRule struct{} + +func (r *ResponseWriterBufferingRule) Name() string { return "response-writer-buffering" } +func (r *ResponseWriterBufferingRule) Category() string { return "io" } + +func (r *ResponseWriterBufferingRule) Check(file *ast.File, fset *token.FileSet, src []byte) []Issue { + var issues []Issue + + ast.Inspect(file, func(n ast.Node) bool { + funcDecl, ok := n.(*ast.FuncDecl) + if !ok { + return true + } + + // Check if this looks like an HTTP handler + if !isHTTPHandler(funcDecl) { + return true + } + + if funcDecl.Body == nil { + return true + } + + // Check for large data streaming without flush + var loopWritePos *token.Position + hasFlusher := false + + ast.Inspect(funcDecl.Body, func(inner ast.Node) bool { + // Check for type assertion to Flusher + typeAssert, ok := inner.(*ast.TypeAssertExpr) + if ok { + if sel, ok := typeAssert.Type.(*ast.SelectorExpr); ok { + if sel.Sel.Name == "Flusher" { + hasFlusher = true + } + } + } + + // Look for Write in loops + var loopBody *ast.BlockStmt + switch stmt := inner.(type) { + case *ast.RangeStmt: + loopBody = stmt.Body + case *ast.ForStmt: + loopBody = stmt.Body + default: + return true + } + + if loopBody == nil { + return true + } + + ast.Inspect(loopBody, func(stmt ast.Node) bool { + call, ok := stmt.(*ast.CallExpr) + if !ok { + return true + } + + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return true + } + + if sel.Sel.Name == "Write" || sel.Sel.Name == "WriteString" { + // Check if receiver is the response writer (typically w) + if ident, ok := sel.X.(*ast.Ident); ok { + if ident.Name == "w" { + pos := fset.Position(call.Pos()) + loopWritePos = &pos + } + } + } + + return true + }) + + return true + }) + + // If writing in loop without flusher, suggest it + if loopWritePos != nil && !hasFlusher { + issues = append(issues, Issue{ + Rule: r.Name(), + Category: r.Category(), + Severity: SeverityLow, + Line: loopWritePos.Line, + Column: loopWritePos.Column, + Message: "Streaming response without Flush - data may buffer until handler returns", + Why: "ResponseWriter buffers data. For streaming/SSE, data won't reach the client until the buffer fills or handler returns.", + Fix: "For streaming: if f, ok := w.(http.Flusher); ok { f.Flush() } after each chunk", + }) + } + + return true + }) + + return issues +} diff --git a/rules/interface.go b/rules/interface.go new file mode 100644 index 0000000..da77be9 --- /dev/null +++ b/rules/interface.go @@ -0,0 +1,295 @@ +package rules + +import ( + "go/ast" + "go/token" +) + +func init() { + RegisterRule("allocation", &InterfaceBoxingInLoopRule{}) + RegisterRule("allocation", &VariadicInterfaceRule{}) + RegisterRule("allocation", &TypeAssertionInLoopRule{}) +} + +// InterfaceBoxingInLoopRule detects interface{} assignments in loops +type InterfaceBoxingInLoopRule struct{} + +func (r *InterfaceBoxingInLoopRule) Name() string { return "interface-boxing-loop" } +func (r *InterfaceBoxingInLoopRule) Category() string { return "allocation" } + +func (r *InterfaceBoxingInLoopRule) Check(file *ast.File, fset *token.FileSet, src []byte) []Issue { + var issues []Issue + + // Find functions that take interface{} or any parameters + interfaceFuncs := findInterfaceFunctions(file) + + ast.Inspect(file, func(n ast.Node) bool { + var loopBody *ast.BlockStmt + switch stmt := n.(type) { + case *ast.RangeStmt: + loopBody = stmt.Body + case *ast.ForStmt: + loopBody = stmt.Body + default: + return true + } + + if loopBody == nil { + return true + } + + ast.Inspect(loopBody, func(inner ast.Node) bool { + call, ok := inner.(*ast.CallExpr) + if !ok { + return true + } + + // Check if calling a function that takes interface{} + funcName := getFuncName(call.Fun) + if interfaceFuncs[funcName] && len(call.Args) > 0 { + // Check if passing concrete types (will be boxed) + for _, arg := range call.Args { + if !isInterfaceExpr(arg) { + // Concrete type being boxed + pos := fset.Position(call.Pos()) + issues = append(issues, Issue{ + Rule: r.Name(), + Category: r.Category(), + Severity: SeverityLow, + Line: pos.Line, + Column: pos.Column, + Message: "Concrete type boxed to interface{} in loop - causes allocation", + Why: "Converting concrete types to interface{} allocates memory for the interface header. In hot loops, this adds GC pressure.", + Fix: "Consider: (1) Type-specific function overloads, (2) Generics (Go 1.18+), (3) Code generation for hot paths", + }) + break // Only report once per call + } + } + } + + return true + }) + + return true + }) + + return issues +} + +// VariadicInterfaceRule detects slice passed to ...interface{} causing per-element allocation +type VariadicInterfaceRule struct{} + +func (r *VariadicInterfaceRule) Name() string { return "variadic-interface" } +func (r *VariadicInterfaceRule) Category() string { return "allocation" } + +func (r *VariadicInterfaceRule) Check(file *ast.File, fset *token.FileSet, src []byte) []Issue { + var issues []Issue + + // Find functions with variadic interface{} parameters + variadicFuncs := map[string]bool{ + "Printf": true, + "Sprintf": true, + "Fprintf": true, + "Errorf": true, + "Fatalf": true, + "Panicf": true, + "Logf": true, + "Debugf": true, + "Infof": true, + "Warnf": true, + } + + ast.Inspect(file, func(n ast.Node) bool { + var loopBody *ast.BlockStmt + switch stmt := n.(type) { + case *ast.RangeStmt: + loopBody = stmt.Body + case *ast.ForStmt: + loopBody = stmt.Body + default: + return true + } + + if loopBody == nil { + return true + } + + ast.Inspect(loopBody, func(inner ast.Node) bool { + call, ok := inner.(*ast.CallExpr) + if !ok { + return true + } + + // Get function name + var funcName string + switch fn := call.Fun.(type) { + case *ast.Ident: + funcName = fn.Name + case *ast.SelectorExpr: + funcName = fn.Sel.Name + } + + if variadicFuncs[funcName] { + // Check for complex arguments that will allocate + complexArgs := 0 + for _, arg := range call.Args { + // Skip the format string + if lit, ok := arg.(*ast.BasicLit); ok && lit.Kind == token.STRING { + continue + } + // Check for complex expressions + switch arg.(type) { + case *ast.CallExpr: // Function calls allocate return value + complexArgs++ + case *ast.BinaryExpr: // Operations may allocate + complexArgs++ + case *ast.CompositeLit: // Struct/slice literals allocate + complexArgs++ + } + } + + if complexArgs > 2 { + pos := fset.Position(call.Pos()) + issues = append(issues, Issue{ + Rule: r.Name(), + Category: r.Category(), + Severity: SeverityLow, + Line: pos.Line, + Column: pos.Column, + Message: funcName + "() with many arguments in loop - each arg boxes to interface{}", + Why: "Each argument to Printf-style functions is boxed to interface{}, causing allocations. Complex expressions allocate twice.", + Fix: "For hot paths: (1) Use structured logging, (2) Pre-format strings, (3) Use log level checks to skip entirely", + }) + } + } + + return true + }) + + return true + }) + + return issues +} + +// TypeAssertionInLoopRule detects type assertions in loops +type TypeAssertionInLoopRule struct{} + +func (r *TypeAssertionInLoopRule) Name() string { return "type-assertion-loop" } +func (r *TypeAssertionInLoopRule) Category() string { return "allocation" } + +func (r *TypeAssertionInLoopRule) Check(file *ast.File, fset *token.FileSet, src []byte) []Issue { + var issues []Issue + + ast.Inspect(file, func(n ast.Node) bool { + var loopBody *ast.BlockStmt + switch stmt := n.(type) { + case *ast.RangeStmt: + loopBody = stmt.Body + case *ast.ForStmt: + loopBody = stmt.Body + default: + return true + } + + if loopBody == nil { + return true + } + + // Count type assertions in loop + assertions := 0 + var lastPos token.Position + + ast.Inspect(loopBody, func(inner ast.Node) bool { + switch node := inner.(type) { + case *ast.TypeAssertExpr: + assertions++ + lastPos = fset.Position(node.Pos()) + case *ast.TypeSwitchStmt: + assertions++ + lastPos = fset.Position(node.Pos()) + } + return true + }) + + // Only flag if multiple assertions + if assertions > 2 { + issues = append(issues, Issue{ + Rule: r.Name(), + Category: r.Category(), + Severity: SeverityLow, + Line: lastPos.Line, + Column: lastPos.Column, + Message: "Multiple type assertions in loop - consider type-specific code paths", + Why: "Type assertions have a small overhead. Multiple assertions per iteration suggest the code might benefit from type-specific handling.", + Fix: "Consider: (1) Type switch outside loop, (2) Generic functions, (3) Interface with specific methods instead of type assertions", + }) + } + + return true + }) + + return issues +} + +// Helper functions + +func findInterfaceFunctions(file *ast.File) map[string]bool { + funcs := make(map[string]bool) + + ast.Inspect(file, func(n ast.Node) bool { + funcDecl, ok := n.(*ast.FuncDecl) + if !ok { + return true + } + + if funcDecl.Type.Params == nil { + return true + } + + for _, param := range funcDecl.Type.Params.List { + if isInterfaceType(param.Type) { + funcs[funcDecl.Name.Name] = true + break + } + } + + return true + }) + + return funcs +} + +func isInterfaceType(expr ast.Expr) bool { + switch t := expr.(type) { + case *ast.InterfaceType: + return true + case *ast.Ident: + return t.Name == "any" || t.Name == "error" + case *ast.Ellipsis: + return isInterfaceType(t.Elt) + } + return false +} + +func isInterfaceExpr(expr ast.Expr) bool { + // This is a heuristic - we can't know types without type checking + // Assume identifiers ending in "err" or named "any" are interfaces + if ident, ok := expr.(*ast.Ident); ok { + name := ident.Name + if name == "err" || name == "error" || name == "any" { + return true + } + } + return false +} + +func getFuncName(expr ast.Expr) string { + switch fn := expr.(type) { + case *ast.Ident: + return fn.Name + case *ast.SelectorExpr: + return fn.Sel.Name + } + return "" +} diff --git a/rules/io.go b/rules/io.go index c73e4bc..2ccbeb8 100644 --- a/rules/io.go +++ b/rules/io.go @@ -12,6 +12,7 @@ func init() { } // JSONInLoopRule detects JSON marshal/unmarshal in loops +// Now smarter: recognizes json.Encoder which reuses reflection cache type JSONInLoopRule struct{} func (r *JSONInLoopRule) Name() string { return "json-in-loop" } @@ -20,24 +21,213 @@ func (r *JSONInLoopRule) Category() string { return "io" } func (r *JSONInLoopRule) Check(file *ast.File, fset *token.FileSet, src []byte) []Issue { var issues []Issue - results := FindJSONInLoop(file, fset) - for _, result := range results { - issues = append(issues, Issue{ - Rule: r.Name(), - Category: r.Category(), - Severity: SeverityMedium, - Line: result.Pos.Line, - Column: result.Pos.Column, - Message: "json." + result.Method + "() inside loop - reflection overhead", - Why: "JSON encoding uses reflection, which is slow. In a loop, this overhead multiplies. Each call also allocates memory.", - Fix: "Consider: (1) Processing in batches, (2) Using code-generated encoders (easyjson, ffjson), (3) Streaming with json.Encoder for large datasets", + ast.Inspect(file, func(n ast.Node) bool { + funcDecl, ok := n.(*ast.FuncDecl) + if !ok { + return true + } + + if funcDecl.Body == nil { + return true + } + + // Find json.Encoder variables (more efficient than Marshal in loops) + encoders := findJSONEncoders(funcDecl.Body) + decoders := findJSONDecoders(funcDecl.Body) + + // Find loops and check for JSON operations + ast.Inspect(funcDecl.Body, func(inner ast.Node) bool { + var loopBody *ast.BlockStmt + var loopNode ast.Node + switch stmt := inner.(type) { + case *ast.RangeStmt: + loopBody = stmt.Body + loopNode = stmt + case *ast.ForStmt: + loopBody = stmt.Body + loopNode = stmt + default: + return true + } + + if loopBody == nil { + return true + } + + loopBound := getLoopBound(loopNode) + + ast.Inspect(loopBody, func(stmt ast.Node) bool { + call, ok := stmt.(*ast.CallExpr) + if !ok { + return true + } + + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return true + } + + // Check for json.Marshal/Unmarshal + if ident, ok := sel.X.(*ast.Ident); ok { + if ident.Name == "json" && (sel.Sel.Name == "Marshal" || sel.Sel.Name == "Unmarshal") { + severity := SeverityMedium + if loopBound > 0 && loopBound <= 10 { + severity = SeverityLow + } + + pos := fset.Position(call.Pos()) + issues = append(issues, Issue{ + Rule: r.Name(), + Category: r.Category(), + Severity: severity, + Line: pos.Line, + Column: pos.Column, + Message: "json." + sel.Sel.Name + "() inside loop - reflection overhead", + Why: "JSON encoding uses reflection, which is slow. In a loop, this overhead multiplies. Each call also allocates memory.", + Fix: "Consider: (1) Use json.Encoder/json.Decoder which cache reflection, (2) Process in batches, (3) Use code-generated encoders (easyjson, ffjson)", + }) + } + } + + // Check for Encode/Decode on json.Encoder/Decoder - this is actually fine + // We just want to make sure they're not creating new encoders in the loop + if sel.Sel.Name == "Encode" { + receiver := getReceiverName(sel.X) + if !encoders[receiver] { + // Could be creating encoder inside loop + // Check if the receiver is json.NewEncoder call + if isNewEncoderCall(sel.X) { + pos := fset.Position(call.Pos()) + issues = append(issues, Issue{ + Rule: r.Name(), + Category: r.Category(), + Severity: SeverityMedium, + Line: pos.Line, + Column: pos.Column, + Message: "json.NewEncoder().Encode() inside loop - create encoder once outside", + Why: "Creating a new encoder for each iteration wastes the reflection caching benefit. Create the encoder once before the loop.", + Fix: "Move encoder creation outside the loop: enc := json.NewEncoder(w); for ... { enc.Encode(item) }", + }) + } + } + } + + if sel.Sel.Name == "Decode" { + receiver := getReceiverName(sel.X) + if !decoders[receiver] { + if isNewDecoderCall(sel.X) { + pos := fset.Position(call.Pos()) + issues = append(issues, Issue{ + Rule: r.Name(), + Category: r.Category(), + Severity: SeverityMedium, + Line: pos.Line, + Column: pos.Column, + Message: "json.NewDecoder().Decode() inside loop - create decoder once outside", + Why: "Creating a new decoder for each iteration wastes the reflection caching benefit.", + Fix: "Move decoder creation outside the loop if reading from a single source", + }) + } + } + } + + return true + }) + + return true }) - } + + return true + }) return issues } -// HTTPClientCreationRule detects http.Client{} created inside loops or functions +// findJSONEncoders finds variables that hold json.Encoder +func findJSONEncoders(body *ast.BlockStmt) map[string]bool { + encoders := make(map[string]bool) + + ast.Inspect(body, func(n ast.Node) bool { + assign, ok := n.(*ast.AssignStmt) + if !ok { + return true + } + + for i, rhs := range assign.Rhs { + if isNewEncoderCall(rhs) { + if i < len(assign.Lhs) { + if ident, ok := assign.Lhs[i].(*ast.Ident); ok { + encoders[ident.Name] = true + } + } + } + } + + return true + }) + + return encoders +} + +// findJSONDecoders finds variables that hold json.Decoder +func findJSONDecoders(body *ast.BlockStmt) map[string]bool { + decoders := make(map[string]bool) + + ast.Inspect(body, func(n ast.Node) bool { + assign, ok := n.(*ast.AssignStmt) + if !ok { + return true + } + + for i, rhs := range assign.Rhs { + if isNewDecoderCall(rhs) { + if i < len(assign.Lhs) { + if ident, ok := assign.Lhs[i].(*ast.Ident); ok { + decoders[ident.Name] = true + } + } + } + } + + return true + }) + + return decoders +} + +func isNewEncoderCall(expr ast.Expr) bool { + call, ok := expr.(*ast.CallExpr) + if !ok { + return false + } + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return false + } + ident, ok := sel.X.(*ast.Ident) + if !ok { + return false + } + return ident.Name == "json" && sel.Sel.Name == "NewEncoder" +} + +func isNewDecoderCall(expr ast.Expr) bool { + call, ok := expr.(*ast.CallExpr) + if !ok { + return false + } + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return false + } + ident, ok := sel.X.(*ast.Ident) + if !ok { + return false + } + return ident.Name == "json" && sel.Sel.Name == "NewDecoder" +} + +// HTTPClientCreationRule detects http.Client{} created inside functions (not reused) type HTTPClientCreationRule struct{} func (r *HTTPClientCreationRule) Name() string { return "http-client-creation" } @@ -46,35 +236,71 @@ func (r *HTTPClientCreationRule) Category() string { return "io" } func (r *HTTPClientCreationRule) Check(file *ast.File, fset *token.FileSet, src []byte) []Issue { var issues []Issue + // Track package-level http.Client declarations (those are fine) + packageLevelClients := findPackageLevelHTTPClients(file) + ast.Inspect(file, func(n ast.Node) bool { - // Look for &http.Client{} or http.Client{} - compLit, ok := n.(*ast.CompositeLit) + funcDecl, ok := n.(*ast.FuncDecl) if !ok { return true } - sel, ok := compLit.Type.(*ast.SelectorExpr) - if !ok { + if funcDecl.Body == nil { return true } - ident, ok := sel.X.(*ast.Ident) - if !ok || ident.Name != "http" || sel.Sel.Name != "Client" { - return true - } + // Look for http.Client creation inside functions + ast.Inspect(funcDecl.Body, func(inner ast.Node) bool { + // Look for &http.Client{} or http.Client{} + compLit, ok := inner.(*ast.CompositeLit) + if !ok { + return true + } + + sel, ok := compLit.Type.(*ast.SelectorExpr) + if !ok { + return true + } + + ident, ok := sel.X.(*ast.Ident) + if !ok || ident.Name != "http" || sel.Sel.Name != "Client" { + return true + } + + // Check if this is assigned to a package-level variable (acceptable) + // This is a simplified check - we mainly care about function-local creation + + pos := fset.Position(compLit.Pos()) - // Check if this is inside a function (not package level) - // This is a simplified check - ideally we'd track scope - pos := fset.Position(compLit.Pos()) - issues = append(issues, Issue{ - Rule: r.Name(), - Category: r.Category(), - Severity: SeverityMedium, - Line: pos.Line, - Column: pos.Column, - Message: "http.Client created - ensure reuse across requests", - Why: "Creating new http.Client for each request wastes connection pooling benefits. Each client maintains its own connection pool and transport.", - Fix: "Create http.Client once at package level or in init, then reuse. Configure Transport for connection pooling.", + // Determine severity based on context + severity := SeverityMedium + message := "http.Client created inside function - ensure reuse across requests" + fix := "Create http.Client once at package level or in init, then reuse. Configure Transport for connection pooling." + + // Check if this is inside a loop - that's worse + if isInsideLoop(funcDecl.Body, compLit) { + severity = SeverityHigh + message = "http.Client created inside loop - significant overhead" + fix = "Move http.Client creation outside the loop. Create once and reuse for all requests." + } + + // Skip if there are package-level clients (user probably knows what they're doing) + if len(packageLevelClients) > 0 { + severity = SeverityLow + } + + issues = append(issues, Issue{ + Rule: r.Name(), + Category: r.Category(), + Severity: severity, + Line: pos.Line, + Column: pos.Column, + Message: message, + Why: "Creating new http.Client for each request wastes connection pooling benefits. Each client maintains its own connection pool and transport.", + Fix: fix, + }) + + return true }) return true @@ -83,6 +309,89 @@ func (r *HTTPClientCreationRule) Check(file *ast.File, fset *token.FileSet, src return issues } +func findPackageLevelHTTPClients(file *ast.File) []string { + var clients []string + + for _, decl := range file.Decls { + genDecl, ok := decl.(*ast.GenDecl) + if !ok || genDecl.Tok != token.VAR { + continue + } + + for _, spec := range genDecl.Specs { + valueSpec, ok := spec.(*ast.ValueSpec) + if !ok { + continue + } + + // Check type + if sel, ok := valueSpec.Type.(*ast.SelectorExpr); ok { + if ident, ok := sel.X.(*ast.Ident); ok { + if ident.Name == "http" && sel.Sel.Name == "Client" { + for _, name := range valueSpec.Names { + clients = append(clients, name.Name) + } + } + } + } + + // Check values (for var client = &http.Client{}) + for i, val := range valueSpec.Values { + if isHTTPClientLiteral(val) && i < len(valueSpec.Names) { + clients = append(clients, valueSpec.Names[i].Name) + } + } + } + } + + return clients +} + +func isHTTPClientLiteral(expr ast.Expr) bool { + // Handle &http.Client{} + if unary, ok := expr.(*ast.UnaryExpr); ok && unary.Op == token.AND { + expr = unary.X + } + + compLit, ok := expr.(*ast.CompositeLit) + if !ok { + return false + } + + sel, ok := compLit.Type.(*ast.SelectorExpr) + if !ok { + return false + } + + ident, ok := sel.X.(*ast.Ident) + if !ok { + return false + } + + return ident.Name == "http" && sel.Sel.Name == "Client" +} + +func isInsideLoop(body *ast.BlockStmt, target ast.Node) bool { + inside := false + + ast.Inspect(body, func(n ast.Node) bool { + switch n.(type) { + case *ast.RangeStmt, *ast.ForStmt: + // Check if target is inside this loop + ast.Inspect(n, func(inner ast.Node) bool { + if inner == target { + inside = true + return false + } + return true + }) + } + return !inside + }) + + return inside +} + // ReadAllRule detects ioutil.ReadAll/io.ReadAll that could use streaming type ReadAllRule struct{} @@ -110,16 +419,32 @@ func (r *ReadAllRule) Check(file *ast.File, fset *token.FileSet, src []byte) []I } if (ident.Name == "io" || ident.Name == "ioutil") && sel.Sel.Name == "ReadAll" { + // Check context - is this reading from http response body? + severity := SeverityLow + why := "For large files or responses, ReadAll allocates potentially huge buffers. This can cause OOM for large inputs." + fix := "Consider streaming: io.Copy(), bufio.Scanner, or json.Decoder for JSON. Process data in chunks when possible." + + // Check if reading HTTP response (common pattern) + if len(call.Args) > 0 { + if sel, ok := call.Args[0].(*ast.SelectorExpr); ok { + if sel.Sel.Name == "Body" { + severity = SeverityMedium + why = "Reading entire HTTP response body into memory. For large responses, this can cause memory issues." + fix = "Consider: (1) Setting Content-Length limits, (2) Using io.LimitReader, (3) Streaming with json.Decoder for JSON responses" + } + } + } + pos := fset.Position(call.Pos()) issues = append(issues, Issue{ Rule: r.Name(), Category: r.Category(), - Severity: SeverityLow, + Severity: severity, Line: pos.Line, Column: pos.Column, Message: "ReadAll() loads entire content into memory", - Why: "For large files or responses, ReadAll allocates potentially huge buffers. This can cause OOM for large inputs.", - Fix: "Consider streaming: io.Copy(), bufio.Scanner, or json.Decoder for JSON. Process data in chunks when possible.", + Why: why, + Fix: fix, }) } diff --git a/rules/memory.go b/rules/memory.go new file mode 100644 index 0000000..e3b124b --- /dev/null +++ b/rules/memory.go @@ -0,0 +1,386 @@ +package rules + +import ( + "go/ast" + "go/token" +) + +func init() { + RegisterRule("memory", &PprofInHotPathRule{}) + RegisterRule("memory", &LargeStructCopyRule{}) + RegisterRule("memory", &EscapeToHeapRule{}) +} + +// PprofInHotPathRule detects pprof calls in hot paths +type PprofInHotPathRule struct{} + +func (r *PprofInHotPathRule) Name() string { return "pprof-in-hot-path" } +func (r *PprofInHotPathRule) Category() string { return "memory" } + +func (r *PprofInHotPathRule) Check(file *ast.File, fset *token.FileSet, src []byte) []Issue { + var issues []Issue + + pprofFuncs := map[string]bool{ + "WriteHeapProfile": true, + "StartCPUProfile": true, + "StopCPUProfile": true, + "Lookup": true, + "WriteTo": true, + } + + ast.Inspect(file, func(n ast.Node) bool { + // Check for pprof calls in loops + var loopBody *ast.BlockStmt + switch stmt := n.(type) { + case *ast.RangeStmt: + loopBody = stmt.Body + case *ast.ForStmt: + loopBody = stmt.Body + default: + return true + } + + if loopBody == nil { + return true + } + + ast.Inspect(loopBody, func(inner ast.Node) bool { + call, ok := inner.(*ast.CallExpr) + if !ok { + return true + } + + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return true + } + + // Check for pprof.X or runtime/pprof calls + if ident, ok := sel.X.(*ast.Ident); ok { + if ident.Name == "pprof" && pprofFuncs[sel.Sel.Name] { + pos := fset.Position(call.Pos()) + issues = append(issues, Issue{ + Rule: r.Name(), + Category: r.Category(), + Severity: SeverityHigh, + Line: pos.Line, + Column: pos.Column, + Message: "pprof." + sel.Sel.Name + "() called in loop - significant overhead", + Why: "Profiling operations are expensive and should not be called repeatedly. They're meant for sampling, not continuous collection.", + Fix: "Move profiling outside the loop, or use sampling: if rand.Intn(1000) == 0 { profile() }", + }) + } + } + + return true + }) + + return true + }) + + // Also check for pprof in HTTP handlers (common mistake) + ast.Inspect(file, func(n ast.Node) bool { + funcDecl, ok := n.(*ast.FuncDecl) + if !ok { + return true + } + + // Check if this looks like an HTTP handler + if !isHTTPHandler(funcDecl) { + return true + } + + ast.Inspect(funcDecl.Body, func(inner ast.Node) bool { + call, ok := inner.(*ast.CallExpr) + if !ok { + return true + } + + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return true + } + + if ident, ok := sel.X.(*ast.Ident); ok { + if ident.Name == "pprof" && pprofFuncs[sel.Sel.Name] { + pos := fset.Position(call.Pos()) + issues = append(issues, Issue{ + Rule: r.Name(), + Category: r.Category(), + Severity: SeverityMedium, + Line: pos.Line, + Column: pos.Column, + Message: "pprof." + sel.Sel.Name + "() in HTTP handler - consider using /debug/pprof endpoints", + Why: "Manual pprof calls in handlers add latency. Use net/http/pprof endpoints for on-demand profiling.", + Fix: "Import _ \"net/http/pprof\" and use /debug/pprof/* endpoints instead", + }) + } + } + + return true + }) + + return true + }) + + return issues +} + +// LargeStructCopyRule detects passing large structs by value +type LargeStructCopyRule struct{} + +func (r *LargeStructCopyRule) Name() string { return "large-struct-copy" } +func (r *LargeStructCopyRule) Category() string { return "memory" } + +func (r *LargeStructCopyRule) Check(file *ast.File, fset *token.FileSet, src []byte) []Issue { + var issues []Issue + + // Find struct definitions and estimate their size + structSizes := estimateStructSizes(file) + + ast.Inspect(file, func(n ast.Node) bool { + // Check function parameters + funcDecl, ok := n.(*ast.FuncDecl) + if !ok { + return true + } + + if funcDecl.Type.Params == nil { + return true + } + + for _, param := range funcDecl.Type.Params.List { + // Skip pointer types + if _, isPtr := param.Type.(*ast.StarExpr); isPtr { + continue + } + + typeName := getTypeName(param.Type) + if size, ok := structSizes[typeName]; ok && size > 64 { + // Large struct passed by value + pos := fset.Position(param.Pos()) + issues = append(issues, Issue{ + Rule: r.Name(), + Category: r.Category(), + Severity: SeverityMedium, + Line: pos.Line, + Column: pos.Column, + Message: "Large struct '" + typeName + "' (~" + itoa(size) + " bytes) passed by value", + Why: "Passing large structs by value copies all fields on each call. This wastes CPU and memory bandwidth.", + Fix: "Pass by pointer: func f(s *" + typeName + ") instead of func f(s " + typeName + ")", + }) + } + } + + // Check for large struct copies in loops + if funcDecl.Body == nil { + return true + } + + ast.Inspect(funcDecl.Body, func(inner ast.Node) bool { + var loopBody *ast.BlockStmt + switch stmt := inner.(type) { + case *ast.RangeStmt: + loopBody = stmt.Body + // Check range value - copying struct on each iteration + if stmt.Value != nil { + if ident, ok := stmt.Value.(*ast.Ident); ok && ident.Name != "_" { + // This copies the value on each iteration + // We'd need type info to know if it's a large struct + } + } + case *ast.ForStmt: + loopBody = stmt.Body + default: + return true + } + + if loopBody == nil { + return true + } + + // Check for assignments that copy large structs + ast.Inspect(loopBody, func(stmt ast.Node) bool { + assign, ok := stmt.(*ast.AssignStmt) + if !ok { + return true + } + + for _, rhs := range assign.Rhs { + typeName := getTypeName(rhs) + if size, ok := structSizes[typeName]; ok && size > 64 { + pos := fset.Position(assign.Pos()) + issues = append(issues, Issue{ + Rule: r.Name(), + Category: r.Category(), + Severity: SeverityMedium, + Line: pos.Line, + Column: pos.Column, + Message: "Large struct copy in loop - consider using pointer", + Why: "Copying a ~" + itoa(size) + " byte struct on each iteration adds significant overhead.", + Fix: "Use pointer to avoid copy, or access fields directly without intermediate variable", + }) + } + } + + return true + }) + + return true + }) + + return true + }) + + return issues +} + +// EscapeToHeapRule detects patterns that likely cause heap escapes +type EscapeToHeapRule struct{} + +func (r *EscapeToHeapRule) Name() string { return "escape-to-heap" } +func (r *EscapeToHeapRule) Category() string { return "memory" } + +func (r *EscapeToHeapRule) Check(file *ast.File, fset *token.FileSet, src []byte) []Issue { + var issues []Issue + + ast.Inspect(file, func(n ast.Node) bool { + var loopBody *ast.BlockStmt + switch stmt := n.(type) { + case *ast.RangeStmt: + loopBody = stmt.Body + case *ast.ForStmt: + loopBody = stmt.Body + default: + return true + } + + if loopBody == nil { + return true + } + + // Find &x patterns in loops that likely escape + ast.Inspect(loopBody, func(inner ast.Node) bool { + // Check for &localVar being stored or passed + unary, ok := inner.(*ast.UnaryExpr) + if !ok || unary.Op != token.AND { + return true + } + + // Check if this is part of an append or map assignment + // These typically cause the pointed value to escape + + // For now, flag pointer creation in loops as informational + // since we can't do full escape analysis without type info + + return true + }) + + return true + }) + + return issues +} + +// Helper functions + +func isHTTPHandler(funcDecl *ast.FuncDecl) bool { + if funcDecl.Type.Params == nil || len(funcDecl.Type.Params.List) < 2 { + return false + } + + params := funcDecl.Type.Params.List + + // Check for (w http.ResponseWriter, r *http.Request) pattern + for _, param := range params { + typeName := getTypeName(param.Type) + if typeName == "ResponseWriter" || typeName == "Request" { + return true + } + } + + return false +} + +func estimateStructSizes(file *ast.File) map[string]int { + sizes := make(map[string]int) + + // Common known types + knownSizes := map[string]int{ + "string": 16, // ptr + len + "int": 8, + "int64": 8, + "int32": 4, + "float64": 8, + "bool": 1, + "byte": 1, + "time.Time": 24, + "sync.Mutex": 8, + } + + for _, decl := range file.Decls { + genDecl, ok := decl.(*ast.GenDecl) + if !ok || genDecl.Tok != token.TYPE { + continue + } + + for _, spec := range genDecl.Specs { + typeSpec, ok := spec.(*ast.TypeSpec) + if !ok { + continue + } + + structType, ok := typeSpec.Type.(*ast.StructType) + if !ok { + continue + } + + size := 0 + if structType.Fields != nil { + for _, field := range structType.Fields.List { + fieldSize := 8 // default assumption + fieldTypeName := getTypeName(field.Type) + if known, ok := knownSizes[fieldTypeName]; ok { + fieldSize = known + } + // Multiply by number of names (e.g., "a, b int") + count := len(field.Names) + if count == 0 { + count = 1 // embedded field + } + size += fieldSize * count + } + } + + sizes[typeSpec.Name.Name] = size + } + } + + return sizes +} + +func getTypeName(expr ast.Expr) string { + switch t := expr.(type) { + case *ast.Ident: + return t.Name + case *ast.SelectorExpr: + return t.Sel.Name + case *ast.StarExpr: + return "*" + getTypeName(t.X) + case *ast.ArrayType: + return "[]" + getTypeName(t.Elt) + } + return "" +} + +func itoa(n int) string { + if n == 0 { + return "0" + } + s := "" + for n > 0 { + s = string(rune('0'+n%10)) + s + n /= 10 + } + return s +} diff --git a/rules/time.go b/rules/time.go new file mode 100644 index 0000000..b14d568 --- /dev/null +++ b/rules/time.go @@ -0,0 +1,215 @@ +package rules + +import ( + "go/ast" + "go/token" +) + +func init() { + RegisterRule("allocation", &TimeParseInLoopRule{}) + RegisterRule("allocation", &TimeLocationInLoopRule{}) + RegisterRule("io", &TimeFormatInLoopRule{}) +} + +// TimeParseInLoopRule detects time.Parse in loops +type TimeParseInLoopRule struct{} + +func (r *TimeParseInLoopRule) Name() string { return "time-parse-in-loop" } +func (r *TimeParseInLoopRule) Category() string { return "allocation" } + +func (r *TimeParseInLoopRule) Check(file *ast.File, fset *token.FileSet, src []byte) []Issue { + var issues []Issue + + ast.Inspect(file, func(n ast.Node) bool { + var loopBody *ast.BlockStmt + var loopNode ast.Node + switch stmt := n.(type) { + case *ast.RangeStmt: + loopBody = stmt.Body + loopNode = stmt + case *ast.ForStmt: + loopBody = stmt.Body + loopNode = stmt + default: + return true + } + + if loopBody == nil { + return true + } + + loopBound := getLoopBound(loopNode) + + ast.Inspect(loopBody, func(inner ast.Node) bool { + call, ok := inner.(*ast.CallExpr) + if !ok { + return true + } + + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return true + } + + ident, ok := sel.X.(*ast.Ident) + if !ok { + return true + } + + if ident.Name == "time" && sel.Sel.Name == "Parse" { + severity := SeverityMedium + if loopBound > 0 && loopBound <= 10 { + severity = SeverityLow + } + + pos := fset.Position(call.Pos()) + issues = append(issues, Issue{ + Rule: r.Name(), + Category: r.Category(), + Severity: severity, + Line: pos.Line, + Column: pos.Column, + Message: "time.Parse() in loop - repeated parsing overhead", + Why: "time.Parse parses the layout string on each call. While not as expensive as regexp.Compile, it still adds up in hot loops.", + Fix: "If parsing the same format, consider: (1) Caching parsed time.Location, (2) Using time.ParseInLocation with cached location", + }) + } + + return true + }) + + return true + }) + + return issues +} + +// TimeLocationInLoopRule detects time.LoadLocation in loops +type TimeLocationInLoopRule struct{} + +func (r *TimeLocationInLoopRule) Name() string { return "time-location-in-loop" } +func (r *TimeLocationInLoopRule) Category() string { return "allocation" } + +func (r *TimeLocationInLoopRule) Check(file *ast.File, fset *token.FileSet, src []byte) []Issue { + var issues []Issue + + ast.Inspect(file, func(n ast.Node) bool { + var loopBody *ast.BlockStmt + switch stmt := n.(type) { + case *ast.RangeStmt: + loopBody = stmt.Body + case *ast.ForStmt: + loopBody = stmt.Body + default: + return true + } + + if loopBody == nil { + return true + } + + ast.Inspect(loopBody, func(inner ast.Node) bool { + call, ok := inner.(*ast.CallExpr) + if !ok { + return true + } + + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return true + } + + ident, ok := sel.X.(*ast.Ident) + if !ok { + return true + } + + if ident.Name == "time" && sel.Sel.Name == "LoadLocation" { + pos := fset.Position(call.Pos()) + issues = append(issues, Issue{ + Rule: r.Name(), + Category: r.Category(), + Severity: SeverityMedium, + Line: pos.Line, + Column: pos.Column, + Message: "time.LoadLocation() in loop - expensive I/O operation", + Why: "LoadLocation reads timezone data from disk or tzdata. This is slow and should be cached.", + Fix: "Cache the location: var loc, _ = time.LoadLocation(\"America/New_York\") at package level or function start", + }) + } + + return true + }) + + return true + }) + + return issues +} + +// TimeFormatInLoopRule detects time.Time.Format with complex layouts in loops +type TimeFormatInLoopRule struct{} + +func (r *TimeFormatInLoopRule) Name() string { return "time-format-loop" } +func (r *TimeFormatInLoopRule) Category() string { return "io" } + +func (r *TimeFormatInLoopRule) Check(file *ast.File, fset *token.FileSet, src []byte) []Issue { + var issues []Issue + + ast.Inspect(file, func(n ast.Node) bool { + var loopBody *ast.BlockStmt + switch stmt := n.(type) { + case *ast.RangeStmt: + loopBody = stmt.Body + case *ast.ForStmt: + loopBody = stmt.Body + default: + return true + } + + if loopBody == nil { + return true + } + + // Count Format calls in loop + formatCount := 0 + var lastFormatPos token.Position + + ast.Inspect(loopBody, func(inner ast.Node) bool { + call, ok := inner.(*ast.CallExpr) + if !ok { + return true + } + + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return true + } + + if sel.Sel.Name == "Format" { + formatCount++ + lastFormatPos = fset.Position(call.Pos()) + } + + return true + }) + + // Only flag if multiple format calls in same loop + if formatCount > 1 { + issues = append(issues, Issue{ + Rule: r.Name(), + Category: r.Category(), + Severity: SeverityLow, + Line: lastFormatPos.Line, + Column: lastFormatPos.Column, + Message: "Multiple time.Format() calls in loop - consider caching or batching", + Why: "time.Format allocates strings. Multiple calls per iteration multiply allocations.", + Fix: "Consider: (1) Format once with all needed data, (2) Use strconv for simple number conversions, (3) Use a strings.Builder", + }) + } + + return true + }) + + return issues +} diff --git a/testdata/test_new_rules.go b/testdata/test_new_rules.go new file mode 100644 index 0000000..8b9ba89 --- /dev/null +++ b/testdata/test_new_rules.go @@ -0,0 +1,135 @@ +package testdata + +import ( + "context" + "database/sql" + "fmt" + "net/http" + "regexp" + "runtime/pprof" + "time" +) + +// Test context rules +func handleRequest(w http.ResponseWriter, r *http.Request) { + // BAD: context.Background in handler + ctx := context.Background() + doWork(ctx) + + // BAD: missing timeout + ctx2, cancel := context.WithCancel(context.Background()) + _ = ctx2 + _ = cancel + // No defer cancel() - context leak +} + +// Test database rules +func initDB() { + // BAD: no pool configuration + db, _ := sql.Open("postgres", "connection-string") + _ = db +} + +func badPoolConfig() { + db, _ := sql.Open("postgres", "connection-string") + // BAD: unlimited connections + db.SetMaxOpenConns(0) +} + +// Test memory rules +func processItems(items []LargeStruct) { + for _, item := range items { + // BAD: large struct passed by value + processLargeStruct(item) + } +} + +type LargeStruct struct { + Field1 [100]byte + Field2 [100]byte + Field3 [100]byte + Field4 [100]byte + Field5 string + Field6 string + Field7 string + Field8 string + Field9 string + Field10 string +} + +func processLargeStruct(s LargeStruct) {} + +func hotPath() { + // BAD: pprof in hot path + for i := 0; i < 1000; i++ { + pprof.Lookup("heap") + } +} + +// Test time rules +func parseTimesInLoop(dates []string) { + for _, d := range dates { + // BAD: time.Parse in loop + t, _ := time.Parse("2006-01-02", d) + _ = t + } +} + +func loadLocationInLoop(zones []string) { + for _, zone := range zones { + // BAD: time.LoadLocation in loop + loc, _ := time.LoadLocation(zone) + _ = loc + } +} + +// Test HTTP rules +func handleUpload(w http.ResponseWriter, r *http.Request) { + // BAD: no MaxBytesReader + body := r.Body + _ = body +} + +func makeRequest() { + // BAD: missing body close + resp, _ := http.Get("http://example.com") + _ = resp + // No defer resp.Body.Close() +} + +// Test cache rules +func validateInLoop(items []string) { + for _, item := range items { + // BAD: regexp.MatchString in loop + matched, _ := regexp.MatchString(`^\d+$`, item) + _ = matched + } +} + +func compileInFunc() { + // BAD: regexp.MustCompile inside function + re := regexp.MustCompile(`\d+`) + _ = re +} + +// Test error wrapping rules +func processWithErrors(items []string) error { + for _, item := range items { + // BAD: fmt.Errorf in loop + err := fmt.Errorf("failed to process %s", item) + _ = err + } + return nil +} + +// Test interface boxing rules +func logInLoop(items []int) { + for _, item := range items { + // BAD: Printf with many args in loop + fmt.Printf("item: %d, squared: %d, cubed: %d, fourth: %d\n", + item, item*item, item*item*item, item*item*item*item) + } +} + +// Helper to avoid unused warnings +func doWork(ctx context.Context) {}