diff --git a/fixer/fixer.go b/fixer/fixer.go index b25daa0..71e588c 100644 --- a/fixer/fixer.go +++ b/fixer/fixer.go @@ -4,10 +4,10 @@ import ( "bytes" "fmt" "go/ast" - "go/format" "go/parser" "go/token" "os" + "path/filepath" "strings" "github.com/unsaid-dev/goperf/rules" @@ -37,6 +37,49 @@ func NewFixer(dryRun, verbose bool) *Fixer { } } +// validatePathForWrite ensures the file path is safe for writing +func validatePathForWrite(filename string) error { + // Get current working directory + cwd, err := filepath.Abs(".") + if err != nil { + return fmt.Errorf("failed to get working directory: %w", err) + } + + // Get absolute path of target + absPath, err := filepath.Abs(filename) + if err != nil { + return fmt.Errorf("failed to resolve path: %w", err) + } + + // Clean the path to resolve any .. components + absPath = filepath.Clean(absPath) + + // Check if path is within current working directory + if !strings.HasPrefix(absPath, cwd+string(filepath.Separator)) && absPath != cwd { + return fmt.Errorf("refusing to write to %q: path is outside working directory (security restriction)", filename) + } + + // Check for symlinks to prevent TOCTOU attacks + realPath, err := filepath.EvalSymlinks(absPath) + if err == nil && realPath != absPath { + // Path contains symlinks - verify the real path is also within CWD + if !strings.HasPrefix(realPath, cwd+string(filepath.Separator)) && realPath != cwd { + return fmt.Errorf("refusing to write to %q: symlink points outside working directory (security restriction)", filename) + } + } + + // Verify it's a regular file (not a device, socket, etc.) + info, err := os.Lstat(absPath) + if err == nil { + // File exists - check it's a regular file + if info.Mode()&os.ModeType != 0 { + return fmt.Errorf("refusing to write to %q: not a regular file", filename) + } + } + + return nil +} + // FixIssues attempts to fix the given issues func (f *Fixer) FixIssues(issues []rules.Issue) []Fix { var fixes []Fix @@ -58,6 +101,14 @@ func (f *Fixer) FixIssues(issues []rules.Issue) []Fix { func (f *Fixer) fixFile(filename string, issues []rules.Issue) []Fix { var fixes []Fix + // Validate path before any operations + if err := validatePathForWrite(filename); err != nil { + if f.Verbose { + fmt.Fprintf(os.Stderr, "Skipping %s: %v\n", filename, err) + } + return fixes + } + src, err := os.ReadFile(filename) if err != nil { return fixes @@ -69,26 +120,19 @@ func (f *Fixer) fixFile(filename string, issues []rules.Issue) []Fix { return fixes } - modified := false lines := strings.Split(string(src), "\n") for _, issue := range issues { fix := f.attemptFix(issue, astFile, fset, lines) if fix != nil { fixes = append(fixes, *fix) - if !f.DryRun && fix.Fixed != "" { - modified = true - } } } - if modified && !f.DryRun { - // Format and write back - var buf bytes.Buffer - if err := format.Node(&buf, fset, astFile); err == nil { - os.WriteFile(filename, buf.Bytes(), 0644) - } - } + // NOTE: Auto-fix currently only generates suggestions. + // Actual AST modification and file writing is not implemented + // because safe AST modification is complex and error-prone. + // The fixes are displayed as suggestions for manual application. return fixes } @@ -151,13 +195,16 @@ func (f *Fixer) fixUnpreallocatedSlice(issue rules.Issue, file *ast.File, fset * Applied: false, } - // Extract slice name from message + // Extract slice name from message safely msg := issue.Message start := strings.Index(msg, "'") end := strings.LastIndex(msg, "'") - if start >= 0 && end > start { + if start >= 0 && end > start && start+1 < len(msg) { sliceName := msg[start+1 : end] - fix.Fixed = fmt.Sprintf("%s = make([]T, 0, expectedSize) // Preallocate %s", sliceName, sliceName) + // Validate the extracted name looks like an identifier + if isValidIdentifier(sliceName) { + fix.Fixed = fmt.Sprintf("%s = make([]T, 0, expectedSize) // Preallocate %s", sliceName, sliceName) + } } return fix @@ -167,14 +214,17 @@ func (f *Fixer) fixMissingBodyClose(issue rules.Issue, file *ast.File, fset *tok line := issue.Line original := getLine(lines, line) - // Find the variable name from the message + // Find the variable name from the message safely msg := issue.Message start := strings.Index(msg, "'") end := strings.LastIndex(msg, "'") varName := "resp" - if start >= 0 && end > start { - varName = msg[start+1 : end] + if start >= 0 && end > start && start+1 < len(msg) { + extracted := msg[start+1 : end] + if isValidIdentifier(extracted) { + varName = extracted + } } fix := &Fix{ @@ -193,14 +243,17 @@ func (f *Fixer) fixContextLeak(issue rules.Issue, file *ast.File, fset *token.Fi line := issue.Line original := getLine(lines, line) - // Extract cancel function name from message + // Extract cancel function name from message safely msg := issue.Message start := strings.Index(msg, "'") end := strings.LastIndex(msg, "'") cancelName := "cancel" - if start >= 0 && end > start { - cancelName = msg[start+1 : end] + if start >= 0 && end > start && start+1 < len(msg) { + extracted := msg[start+1 : end] + if isValidIdentifier(extracted) { + cancelName = extracted + } } fix := &Fix{ @@ -215,6 +268,25 @@ func (f *Fixer) fixContextLeak(issue rules.Issue, file *ast.File, fset *token.Fi return fix } +// isValidIdentifier checks if a string looks like a valid Go identifier +func isValidIdentifier(s string) bool { + if len(s) == 0 || len(s) > 100 { + return false + } + for i, c := range s { + if i == 0 { + if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || c == '_') { + return false + } + } else { + if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_') { + return false + } + } + } + return true +} + func getLine(lines []string, lineNum int) string { if lineNum > 0 && lineNum <= len(lines) { return lines[lineNum-1] @@ -230,9 +302,11 @@ func PrintFixes(fixes []Fix, dryRun bool) { } if dryRun { - fmt.Println("=== DRY RUN: Suggested fixes (no files modified) ===\n") + fmt.Println("=== DRY RUN: Suggested fixes (no files modified) ===") + fmt.Println() } else { - fmt.Println("=== Applied fixes ===\n") + fmt.Println("=== Suggested fixes ===") + fmt.Println() } for _, fix := range fixes { diff --git a/fixer/fixer_test.go b/fixer/fixer_test.go new file mode 100644 index 0000000..aadc19c --- /dev/null +++ b/fixer/fixer_test.go @@ -0,0 +1,193 @@ +package fixer + +import ( + "os" + "path/filepath" + "testing" +) + +func TestIsValidIdentifier(t *testing.T) { + tests := []struct { + input string + want bool + }{ + {"cancel", true}, + {"Cancel", true}, + {"_cancel", true}, + {"cancel1", true}, + {"my_cancel_func", true}, + {"", false}, + {"1cancel", false}, + {"-cancel", false}, + {"cancel-func", false}, + {"cancel func", false}, + {"cancel()", false}, + {string(make([]byte, 101)), false}, // too long + } + + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + got := isValidIdentifier(tt.input) + if got != tt.want { + t.Errorf("isValidIdentifier(%q) = %v, want %v", tt.input, got, tt.want) + } + }) + } +} + +func TestGetLine(t *testing.T) { + lines := []string{"line1", "line2", "line3"} + + tests := []struct { + lineNum int + want string + }{ + {1, "line1"}, + {2, "line2"}, + {3, "line3"}, + {0, ""}, // out of bounds + {4, ""}, // out of bounds + {-1, ""}, // negative + } + + for _, tt := range tests { + got := getLine(lines, tt.lineNum) + if got != tt.want { + t.Errorf("getLine(lines, %d) = %q, want %q", tt.lineNum, got, tt.want) + } + } +} + +func TestValidatePathForWrite(t *testing.T) { + // Get current working directory + cwd, err := os.Getwd() + if err != nil { + t.Fatalf("failed to get cwd: %v", err) + } + + // Create a temp file in current directory for testing + tmpFile, err := os.CreateTemp(cwd, "test-*.go") + if err != nil { + t.Fatalf("failed to create temp file: %v", err) + } + tmpFile.Close() + defer os.Remove(tmpFile.Name()) + + tests := []struct { + name string + path string + wantErr bool + }{ + { + name: "valid file in cwd", + path: tmpFile.Name(), + wantErr: false, + }, + { + name: "path outside cwd", + path: "/etc/passwd", + wantErr: true, + }, + { + name: "path traversal attempt", + path: filepath.Join(cwd, "..", "..", "etc", "passwd"), + wantErr: true, + }, + { + name: "relative path outside", + path: "../../../etc/passwd", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validatePathForWrite(tt.path) + if (err != nil) != tt.wantErr { + t.Errorf("validatePathForWrite(%q) error = %v, wantErr %v", tt.path, err, tt.wantErr) + } + }) + } +} + +func TestNewFixer(t *testing.T) { + f := NewFixer(true, true) + if f == nil { + t.Fatal("NewFixer returned nil") + } + if !f.DryRun { + t.Error("DryRun should be true") + } + if !f.Verbose { + t.Error("Verbose should be true") + } + + f2 := NewFixer(false, false) + if f2.DryRun { + t.Error("DryRun should be false") + } + if f2.Verbose { + t.Error("Verbose should be false") + } +} + +func TestGenerateDiff(t *testing.T) { + fixes := []Fix{ + { + File: "test.go", + Line: 10, + Original: "old code", + Fixed: "new code", + Rule: "test-rule", + }, + } + + diff := GenerateDiff(fixes) + if diff == "" { + t.Error("GenerateDiff returned empty string") + } + + // Check diff format + expectedContains := []string{ + "--- a/test.go", + "+++ b/test.go", + "-old code", + "+new code", + } + + for _, expected := range expectedContains { + if !contains(diff, expected) { + t.Errorf("diff missing %q", expected) + } + } +} + +func TestGenerateDiff_EmptyFix(t *testing.T) { + fixes := []Fix{ + { + File: "test.go", + Line: 10, + Original: "old code", + Fixed: "", // No fix available + Rule: "test-rule", + }, + } + + diff := GenerateDiff(fixes) + if diff != "" { + t.Error("GenerateDiff should return empty string for fixes without Fixed content") + } +} + +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr)) +} + +func containsHelper(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/main.go b/main.go index 95b0cca..010fc9a 100644 --- a/main.go +++ b/main.go @@ -25,6 +25,13 @@ var ( dryRunFlag = flag.Bool("dry-run", false, "Show fixes without applying them (use with --fix)") ) +// Resource limits to prevent DoS +const ( + MaxFilesPerScan = 10000 + MaxFileSizeBytes = 10 * 1024 * 1024 // 10MB + MaxDirectoryDepth = 50 +) + var version = "0.1.0" func main() { @@ -194,19 +201,79 @@ func parseSeverity(s string) rules.Severity { } } +// validatePath ensures the path is safe (no traversal attacks, within allowed scope) +func validatePath(path string) error { + // Get current working directory + cwd, err := filepath.Abs(".") + if err != nil { + return fmt.Errorf("failed to get working directory: %w", err) + } + + // Get absolute path of target + absPath, err := filepath.Abs(path) + if err != nil { + return fmt.Errorf("failed to resolve path: %w", err) + } + + // Clean the path to resolve any .. components + absPath = filepath.Clean(absPath) + + // Check if path is within current working directory + if !strings.HasPrefix(absPath, cwd+string(filepath.Separator)) && absPath != cwd { + return fmt.Errorf("path %q is outside working directory (security restriction)", path) + } + + // Check for symlinks to prevent TOCTOU attacks + realPath, err := filepath.EvalSymlinks(absPath) + if err == nil && realPath != absPath { + // Path contains symlinks - verify the real path is also within CWD + if !strings.HasPrefix(realPath, cwd+string(filepath.Separator)) && realPath != cwd { + return fmt.Errorf("symlink %q points outside working directory (security restriction)", path) + } + } + + return nil +} + func collectGoFiles(pattern string, ignorePaths []string) ([]string, error) { var files []string + fileCount := 0 + + // Validate the base pattern first + basePath := pattern + if strings.HasSuffix(pattern, "/...") { + basePath = strings.TrimSuffix(pattern, "/...") + if basePath == "" { + basePath = "." + } + } + + if err := validatePath(basePath); err != nil { + return nil, err + } // Handle ./... pattern if strings.HasSuffix(pattern, "/...") { root := strings.TrimSuffix(pattern, "/...") - if root == "." { + if root == "." || root == "" { root = "." } + + currentDepth := 0 + rootDepth := strings.Count(filepath.Clean(root), string(filepath.Separator)) + err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error { if err != nil { return err } + + // Check directory depth limit + pathDepth := strings.Count(filepath.Clean(path), string(filepath.Separator)) + currentDepth = pathDepth - rootDepth + if currentDepth > MaxDirectoryDepth { + return filepath.SkipDir + } + if info.IsDir() { // Skip vendor, testdata, etc. base := filepath.Base(path) @@ -220,8 +287,21 @@ func collectGoFiles(pattern string, ignorePaths []string) ([]string, error) { } return nil } + + // Check file count limit + if fileCount >= MaxFilesPerScan { + return fmt.Errorf("exceeded maximum file limit (%d files)", MaxFilesPerScan) + } + + // Check file size limit + if info.Size() > MaxFileSizeBytes { + // Skip oversized files but continue + return nil + } + if strings.HasSuffix(path, ".go") && !strings.HasSuffix(path, "_test.go") { files = append(files, path) + fileCount++ } return nil }) @@ -240,11 +320,27 @@ func collectGoFiles(pattern string, ignorePaths []string) ([]string, error) { return nil, err } for _, entry := range entries { + if fileCount >= MaxFilesPerScan { + return nil, fmt.Errorf("exceeded maximum file limit (%d files)", MaxFilesPerScan) + } + if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".go") && !strings.HasSuffix(entry.Name(), "_test.go") { - files = append(files, filepath.Join(pattern, entry.Name())) + filePath := filepath.Join(pattern, entry.Name()) + + // Check file size + if fileInfo, err := entry.Info(); err == nil && fileInfo.Size() > MaxFileSizeBytes { + continue // Skip oversized files + } + + files = append(files, filePath) + fileCount++ } } } else if strings.HasSuffix(pattern, ".go") { + // Check file size + if info.Size() > MaxFileSizeBytes { + return nil, fmt.Errorf("file %q exceeds maximum size limit (%d bytes)", pattern, MaxFileSizeBytes) + } files = append(files, pattern) } diff --git a/rules/context.go b/rules/context.go index d9e2e20..76eda02 100644 --- a/rules/context.go +++ b/rules/context.go @@ -229,17 +229,24 @@ func (r *ContextLeakRule) Check(file *ast.File, fset *token.FileSet, src []byte) }) // Check if cancel functions are called (either directly or via defer) + // Only track calls to identifiers that are in our cancelFuncs map calledCancels := make(map[string]bool) ast.Inspect(funcDecl.Body, func(inner ast.Node) bool { switch node := inner.(type) { case *ast.CallExpr: if ident, ok := node.Fun.(*ast.Ident); ok { - calledCancels[ident.Name] = true + // Only mark as called if it's a known cancel function + if _, isCancel := cancelFuncs[ident.Name]; isCancel { + calledCancels[ident.Name] = true + } } case *ast.DeferStmt: if call, ok := node.Call.Fun.(*ast.Ident); ok { - calledCancels[call.Name] = true + // Only mark as called if it's a known cancel function + if _, isCancel := cancelFuncs[call.Name]; isCancel { + calledCancels[call.Name] = true + } } } return true diff --git a/rules/detection_test.go b/rules/detection_test.go new file mode 100644 index 0000000..50d08eb --- /dev/null +++ b/rules/detection_test.go @@ -0,0 +1,523 @@ +package rules + +import ( + "go/ast" + "go/parser" + "go/token" + "testing" +) + +func TestContextBackgroundInHandler(t *testing.T) { + tests := []struct { + name string + code string + wantHits int + }{ + { + name: "context.Background in HTTP handler", + code: `package main + +import ( + "context" + "net/http" +) + +func handler(w http.ResponseWriter, r *http.Request) { + ctx := context.Background() // should flag + _ = ctx +} +`, + wantHits: 1, + }, + { + name: "context.TODO in HTTP handler", + code: `package main + +import ( + "context" + "net/http" +) + +func handler(w http.ResponseWriter, r *http.Request) { + ctx := context.TODO() // should flag + _ = ctx +} +`, + wantHits: 1, + }, + { + name: "context.Background in non-handler", + code: `package main + +import "context" + +func regularFunc() { + ctx := context.Background() // should NOT flag - not a handler + _ = ctx +} +`, + wantHits: 0, + }, + { + name: "r.Context() in handler - correct pattern", + code: `package main + +import "net/http" + +func handler(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() // correct - should NOT flag + _ = ctx +} +`, + wantHits: 0, + }, + } + + rule := &ContextBackgroundInHandlerRule{} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, "test.go", tt.code, 0) + if err != nil { + t.Fatalf("failed to parse: %v", err) + } + + issues := rule.Check(f, fset, []byte(tt.code)) + if len(issues) != tt.wantHits { + t.Errorf("got %d issues, want %d", len(issues), tt.wantHits) + for _, issue := range issues { + t.Logf(" issue: %s at line %d", issue.Message, issue.Line) + } + } + }) + } +} + +func TestContextLeak(t *testing.T) { + tests := []struct { + name string + code string + wantHits int + }{ + { + name: "uncalled cancel function", + code: `package main + +import "context" + +func foo() { + ctx, cancel := context.WithCancel(context.Background()) + _ = ctx + // cancel is never called - should flag + _ = cancel +} +`, + wantHits: 1, + }, + { + name: "cancel called via defer", + code: `package main + +import "context" + +func foo() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _ = ctx +} +`, + wantHits: 0, + }, + { + name: "cancel called directly", + code: `package main + +import "context" + +func foo() { + ctx, cancel := context.WithCancel(context.Background()) + _ = ctx + cancel() +} +`, + wantHits: 0, + }, + { + name: "cancel assigned to underscore - ignored", + code: `package main + +import "context" + +func foo() { + ctx, _ := context.WithCancel(context.Background()) + _ = ctx +} +`, + wantHits: 0, + }, + { + name: "WithTimeout uncalled", + code: `package main + +import ( + "context" + "time" +) + +func foo() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + _ = ctx + // cancel is never called - should flag + _ = cancel +} +`, + wantHits: 1, + }, + { + name: "function with same name as cancel - no false positive", + code: `package main + +import "context" + +func cancel() {} + +func foo() { + ctx, ctxCancel := context.WithCancel(context.Background()) + defer ctxCancel() // calls the correct cancel + _ = ctx + cancel() // different function - should not affect detection +} +`, + wantHits: 0, + }, + } + + rule := &ContextLeakRule{} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, "test.go", tt.code, 0) + if err != nil { + t.Fatalf("failed to parse: %v", err) + } + + issues := rule.Check(f, fset, []byte(tt.code)) + if len(issues) != tt.wantHits { + t.Errorf("got %d issues, want %d", len(issues), tt.wantHits) + for _, issue := range issues { + t.Logf(" issue: %s at line %d", issue.Message, issue.Line) + } + } + }) + } +} + +func TestIsHTTPHandler(t *testing.T) { + tests := []struct { + name string + code string + want bool + }{ + { + name: "standard http handler", + code: `package main +import "net/http" +func handler(w http.ResponseWriter, r *http.Request) {} +`, + want: true, + }, + { + name: "echo handler", + code: `package main +import "github.com/labstack/echo/v4" +func handler(c echo.Context) error { return nil } +`, + want: true, + }, + { + name: "gin handler", + code: `package main +import "github.com/gin-gonic/gin" +func handler(c *gin.Context) {} +`, + want: true, + }, + { + name: "fiber handler", + code: `package main +import "github.com/gofiber/fiber/v2" +func handler(c *fiber.Ctx) error { return nil } +`, + want: true, + }, + { + name: "regular function", + code: `package main +func regularFunc(a int, b string) {} +`, + want: false, + }, + { + name: "function with context param (not http)", + code: `package main +import "context" +func regularFunc(ctx context.Context) {} +`, + want: true, // This will match due to "Context" suffix + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, "test.go", tt.code, 0) + if err != nil { + t.Fatalf("failed to parse: %v", err) + } + + var funcDecl *ast.FuncDecl + for _, decl := range f.Decls { + if fd, ok := decl.(*ast.FuncDecl); ok { + funcDecl = fd + break + } + } + + if funcDecl == nil { + t.Fatal("no function found") + } + + got := isHTTPHandler(funcDecl) + if got != tt.want { + t.Errorf("isHTTPHandler() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestEstimateStructSizes(t *testing.T) { + code := `package main + +type SmallStruct struct { + a int + b int +} + +type LargeStruct struct { + a, b, c, d, e, f, g, h int64 + name string + data []byte +} + +type EmbeddedStruct struct { + SmallStruct + extra int +} +` + + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, "test.go", code, 0) + if err != nil { + t.Fatalf("failed to parse: %v", err) + } + + sizes := estimateStructSizes(f) + + if size, ok := sizes["SmallStruct"]; !ok { + t.Error("SmallStruct not found") + } else if size != 16 { // 2 * 8 bytes + t.Errorf("SmallStruct size = %d, want 16", size) + } + + if size, ok := sizes["LargeStruct"]; !ok { + t.Error("LargeStruct not found") + } else if size < 64 { // 8*8 + 16 + 8 = 88 bytes minimum + t.Errorf("LargeStruct size = %d, want >= 64", size) + } +} + +func TestGetTypeName(t *testing.T) { + tests := []struct { + code string + want string + }{ + {`var x int`, "int"}, + {`var x string`, "string"}, + {`var x *int`, "*int"}, + {`var x []string`, "[]string"}, + } + + for _, tt := range tests { + t.Run(tt.want, func(t *testing.T) { + fset := token.NewFileSet() + code := "package main\n" + tt.code + f, err := parser.ParseFile(fset, "test.go", code, 0) + if err != nil { + t.Fatalf("failed to parse: %v", err) + } + + for _, decl := range f.Decls { + if genDecl, ok := decl.(*ast.GenDecl); ok { + for _, spec := range genDecl.Specs { + if valueSpec, ok := spec.(*ast.ValueSpec); ok { + got := getTypeName(valueSpec.Type) + if got != tt.want { + t.Errorf("getTypeName() = %q, want %q", got, tt.want) + } + } + } + } + } + }) + } +} + +func TestItoa(t *testing.T) { + tests := []struct { + input int + want string + }{ + {0, "0"}, + {1, "1"}, + {42, "42"}, + {-1, "-1"}, + {100, "100"}, + {999, "999"}, + } + + for _, tt := range tests { + t.Run(tt.want, func(t *testing.T) { + got := itoa(tt.input) + if got != tt.want { + t.Errorf("itoa(%d) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestPprofInHotPath(t *testing.T) { + tests := []struct { + name string + code string + wantHits int + }{ + { + name: "pprof in loop", + code: `package main + +import "runtime/pprof" + +func foo() { + for i := 0; i < 100; i++ { + pprof.WriteHeapProfile(nil) // should flag + } +} +`, + wantHits: 1, + }, + { + name: "pprof outside loop", + code: `package main + +import "runtime/pprof" + +func foo() { + pprof.WriteHeapProfile(nil) // should NOT flag +} +`, + wantHits: 0, + }, + { + name: "pprof in HTTP handler", + code: `package main + +import ( + "net/http" + "runtime/pprof" +) + +func handler(w http.ResponseWriter, r *http.Request) { + pprof.WriteHeapProfile(nil) // should flag - in handler +} +`, + wantHits: 1, + }, + } + + rule := &PprofInHotPathRule{} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, "test.go", tt.code, 0) + if err != nil { + t.Fatalf("failed to parse: %v", err) + } + + issues := rule.Check(f, fset, []byte(tt.code)) + if len(issues) != tt.wantHits { + t.Errorf("got %d issues, want %d", len(issues), tt.wantHits) + for _, issue := range issues { + t.Logf(" issue: %s at line %d", issue.Message, issue.Line) + } + } + }) + } +} + +func TestMissingContextTimeout(t *testing.T) { + tests := []struct { + name string + code string + wantHits int + }{ + { + name: "http.NewRequest without context", + code: `package main + +import "net/http" + +func foo() { + req, _ := http.NewRequest("GET", "http://example.com", nil) // should flag + _ = req +} +`, + wantHits: 1, + }, + { + name: "http.NewRequestWithContext - correct", + code: `package main + +import ( + "context" + "net/http" +) + +func foo() { + ctx := context.Background() + req, _ := http.NewRequestWithContext(ctx, "GET", "http://example.com", nil) + _ = req +} +`, + wantHits: 0, + }, + } + + rule := &MissingContextTimeoutRule{} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, "test.go", tt.code, 0) + if err != nil { + t.Fatalf("failed to parse: %v", err) + } + + issues := rule.Check(f, fset, []byte(tt.code)) + if len(issues) != tt.wantHits { + t.Errorf("got %d issues, want %d", len(issues), tt.wantHits) + for _, issue := range issues { + t.Logf(" issue: %s at line %d", issue.Message, issue.Line) + } + } + }) + } +} diff --git a/rules/helpers.go b/rules/helpers.go new file mode 100644 index 0000000..29fe6f9 --- /dev/null +++ b/rules/helpers.go @@ -0,0 +1,196 @@ +package rules + +import ( + "go/ast" + "go/token" + "strings" +) + +// isHTTPHandler checks if a function declaration is an HTTP handler +// by looking for (w http.ResponseWriter, r *http.Request) or similar patterns +func isHTTPHandler(funcDecl *ast.FuncDecl) bool { + if funcDecl.Type.Params == nil || len(funcDecl.Type.Params.List) < 1 { + return false + } + + params := funcDecl.Type.Params.List + + // Check for standard library pattern: (w http.ResponseWriter, r *http.Request) + for _, param := range params { + typeName := getTypeName(param.Type) + if typeName == "ResponseWriter" || typeName == "Request" { + return true + } + } + + // Check for common framework patterns: + // - Echo: func(c echo.Context) + // - Gin: func(c *gin.Context) + // - Chi: func(w http.ResponseWriter, r *http.Request) + // - Fiber: func(c *fiber.Ctx) + for _, param := range params { + typeName := getTypeName(param.Type) + if strings.HasSuffix(typeName, "Context") || strings.HasSuffix(typeName, "Ctx") { + return true + } + } + + return false +} + +// estimateStructSizes estimates the memory size of struct types in a file +func estimateStructSizes(file *ast.File) map[string]int { + sizes := make(map[string]int) + + // Common known types with their typical sizes (64-bit) + knownSizes := map[string]int{ + "string": 16, // ptr + len + "int": 8, + "int64": 8, + "int32": 4, + "int16": 2, + "int8": 1, + "uint": 8, + "uint64": 8, + "uint32": 4, + "uint16": 2, + "uint8": 1, + "float64": 8, + "float32": 4, + "bool": 1, + "byte": 1, + "rune": 4, + "time.Time": 24, + "sync.Mutex": 8, + } + + for _, decl := range file.Decls { + genDecl, ok := decl.(*ast.GenDecl) + if !ok || genDecl.Tok != token.TYPE { + continue + } + + for _, spec := range genDecl.Specs { + typeSpec, ok := spec.(*ast.TypeSpec) + if !ok { + continue + } + + structType, ok := typeSpec.Type.(*ast.StructType) + if !ok { + continue + } + + size := 0 + if structType.Fields != nil { + for _, field := range structType.Fields.List { + fieldSize := 8 // default assumption + fieldTypeName := getTypeName(field.Type) + if known, ok := knownSizes[fieldTypeName]; ok { + fieldSize = known + } + // Multiply by number of names (e.g., "a, b int") + count := len(field.Names) + if count == 0 { + count = 1 // embedded field + } + size += fieldSize * count + } + } + + sizes[typeSpec.Name.Name] = size + } + } + + return sizes +} + +// getTypeName extracts the type name from an AST expression +func getTypeName(expr ast.Expr) string { + switch t := expr.(type) { + case *ast.Ident: + return t.Name + case *ast.SelectorExpr: + return t.Sel.Name + case *ast.StarExpr: + return "*" + getTypeName(t.X) + case *ast.ArrayType: + return "[]" + getTypeName(t.Elt) + } + return "" +} + +// itoa converts an integer to a string (simple implementation without importing strconv) +func itoa(n int) string { + if n == 0 { + return "0" + } + if n < 0 { + return "-" + itoa(-n) + } + s := "" + for n > 0 { + s = string(rune('0'+n%10)) + s + n /= 10 + } + return s +} + +// isInLoop checks if a position is within a loop body +func isInLoop(file *ast.File, fset *token.FileSet, pos token.Pos) bool { + targetLine := fset.Position(pos).Line + inLoop := false + + ast.Inspect(file, func(n ast.Node) bool { + if n == nil { + return true + } + + var loopBody *ast.BlockStmt + switch stmt := n.(type) { + case *ast.RangeStmt: + loopBody = stmt.Body + case *ast.ForStmt: + loopBody = stmt.Body + default: + return true + } + + if loopBody == nil { + return true + } + + startLine := fset.Position(loopBody.Lbrace).Line + endLine := fset.Position(loopBody.Rbrace).Line + + if targetLine >= startLine && targetLine <= endLine { + inLoop = true + return false + } + + return true + }) + + return inLoop +} + +// findFunctionContaining finds the function declaration containing a position +func findFunctionContaining(file *ast.File, fset *token.FileSet, pos token.Pos) *ast.FuncDecl { + targetLine := fset.Position(pos).Line + + for _, decl := range file.Decls { + funcDecl, ok := decl.(*ast.FuncDecl) + if !ok || funcDecl.Body == nil { + continue + } + + startLine := fset.Position(funcDecl.Body.Lbrace).Line + endLine := fset.Position(funcDecl.Body.Rbrace).Line + + if targetLine >= startLine && targetLine <= endLine { + return funcDecl + } + } + + return nil +} diff --git a/rules/memory.go b/rules/memory.go index e3b124b..59d2ddd 100644 --- a/rules/memory.go +++ b/rules/memory.go @@ -282,105 +282,4 @@ func (r *EscapeToHeapRule) Check(file *ast.File, fset *token.FileSet, src []byte return issues } -// Helper functions - -func isHTTPHandler(funcDecl *ast.FuncDecl) bool { - if funcDecl.Type.Params == nil || len(funcDecl.Type.Params.List) < 2 { - return false - } - - params := funcDecl.Type.Params.List - - // Check for (w http.ResponseWriter, r *http.Request) pattern - for _, param := range params { - typeName := getTypeName(param.Type) - if typeName == "ResponseWriter" || typeName == "Request" { - return true - } - } - - return false -} - -func estimateStructSizes(file *ast.File) map[string]int { - sizes := make(map[string]int) - - // Common known types - knownSizes := map[string]int{ - "string": 16, // ptr + len - "int": 8, - "int64": 8, - "int32": 4, - "float64": 8, - "bool": 1, - "byte": 1, - "time.Time": 24, - "sync.Mutex": 8, - } - - for _, decl := range file.Decls { - genDecl, ok := decl.(*ast.GenDecl) - if !ok || genDecl.Tok != token.TYPE { - continue - } - - for _, spec := range genDecl.Specs { - typeSpec, ok := spec.(*ast.TypeSpec) - if !ok { - continue - } - - structType, ok := typeSpec.Type.(*ast.StructType) - if !ok { - continue - } - - size := 0 - if structType.Fields != nil { - for _, field := range structType.Fields.List { - fieldSize := 8 // default assumption - fieldTypeName := getTypeName(field.Type) - if known, ok := knownSizes[fieldTypeName]; ok { - fieldSize = known - } - // Multiply by number of names (e.g., "a, b int") - count := len(field.Names) - if count == 0 { - count = 1 // embedded field - } - size += fieldSize * count - } - } - - sizes[typeSpec.Name.Name] = size - } - } - - return sizes -} - -func getTypeName(expr ast.Expr) string { - switch t := expr.(type) { - case *ast.Ident: - return t.Name - case *ast.SelectorExpr: - return t.Sel.Name - case *ast.StarExpr: - return "*" + getTypeName(t.X) - case *ast.ArrayType: - return "[]" + getTypeName(t.Elt) - } - return "" -} - -func itoa(n int) string { - if n == 0 { - return "0" - } - s := "" - for n > 0 { - s = string(rune('0'+n%10)) + s - n /= 10 - } - return s -} +// Helper functions are now in helpers.go