Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 97 additions & 23 deletions fixer/fixer.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ import (
"bytes"
"fmt"
"go/ast"
"go/format"
"go/parser"
"go/token"
"os"
"path/filepath"
"strings"

"github.com/unsaid-dev/goperf/rules"
Expand Down Expand Up @@ -37,6 +37,49 @@ func NewFixer(dryRun, verbose bool) *Fixer {
}
}

// validatePathForWrite ensures the file path is safe for writing
func validatePathForWrite(filename string) error {
// Get current working directory
cwd, err := filepath.Abs(".")
if err != nil {
return fmt.Errorf("failed to get working directory: %w", err)
}

// Get absolute path of target
absPath, err := filepath.Abs(filename)
if err != nil {
return fmt.Errorf("failed to resolve path: %w", err)
}

// Clean the path to resolve any .. components
absPath = filepath.Clean(absPath)

// Check if path is within current working directory
if !strings.HasPrefix(absPath, cwd+string(filepath.Separator)) && absPath != cwd {
return fmt.Errorf("refusing to write to %q: path is outside working directory (security restriction)", filename)
}

// Check for symlinks to prevent TOCTOU attacks
realPath, err := filepath.EvalSymlinks(absPath)
if err == nil && realPath != absPath {
// Path contains symlinks - verify the real path is also within CWD
if !strings.HasPrefix(realPath, cwd+string(filepath.Separator)) && realPath != cwd {
return fmt.Errorf("refusing to write to %q: symlink points outside working directory (security restriction)", filename)
}
}

// Verify it's a regular file (not a device, socket, etc.)
info, err := os.Lstat(absPath)
if err == nil {
// File exists - check it's a regular file
if info.Mode()&os.ModeType != 0 {
return fmt.Errorf("refusing to write to %q: not a regular file", filename)
}
}

return nil
}

// FixIssues attempts to fix the given issues
func (f *Fixer) FixIssues(issues []rules.Issue) []Fix {
var fixes []Fix
Expand All @@ -58,6 +101,14 @@ func (f *Fixer) FixIssues(issues []rules.Issue) []Fix {
func (f *Fixer) fixFile(filename string, issues []rules.Issue) []Fix {
var fixes []Fix

// Validate path before any operations
if err := validatePathForWrite(filename); err != nil {
if f.Verbose {
fmt.Fprintf(os.Stderr, "Skipping %s: %v\n", filename, err)
}
return fixes
}

src, err := os.ReadFile(filename)
if err != nil {
return fixes
Expand All @@ -69,26 +120,19 @@ func (f *Fixer) fixFile(filename string, issues []rules.Issue) []Fix {
return fixes
}

modified := false
lines := strings.Split(string(src), "\n")

for _, issue := range issues {
fix := f.attemptFix(issue, astFile, fset, lines)
if fix != nil {
fixes = append(fixes, *fix)
if !f.DryRun && fix.Fixed != "" {
modified = true
}
}
}

if modified && !f.DryRun {
// Format and write back
var buf bytes.Buffer
if err := format.Node(&buf, fset, astFile); err == nil {
os.WriteFile(filename, buf.Bytes(), 0644)
}
}
// NOTE: Auto-fix currently only generates suggestions.
// Actual AST modification and file writing is not implemented
// because safe AST modification is complex and error-prone.
// The fixes are displayed as suggestions for manual application.

return fixes
}
Expand Down Expand Up @@ -151,13 +195,16 @@ func (f *Fixer) fixUnpreallocatedSlice(issue rules.Issue, file *ast.File, fset *
Applied: false,
}

// Extract slice name from message
// Extract slice name from message safely
msg := issue.Message
start := strings.Index(msg, "'")
end := strings.LastIndex(msg, "'")
if start >= 0 && end > start {
if start >= 0 && end > start && start+1 < len(msg) {
sliceName := msg[start+1 : end]
fix.Fixed = fmt.Sprintf("%s = make([]T, 0, expectedSize) // Preallocate %s", sliceName, sliceName)
// Validate the extracted name looks like an identifier
if isValidIdentifier(sliceName) {
fix.Fixed = fmt.Sprintf("%s = make([]T, 0, expectedSize) // Preallocate %s", sliceName, sliceName)
}
}

return fix
Expand All @@ -167,14 +214,17 @@ func (f *Fixer) fixMissingBodyClose(issue rules.Issue, file *ast.File, fset *tok
line := issue.Line
original := getLine(lines, line)

// Find the variable name from the message
// Find the variable name from the message safely
msg := issue.Message
start := strings.Index(msg, "'")
end := strings.LastIndex(msg, "'")

varName := "resp"
if start >= 0 && end > start {
varName = msg[start+1 : end]
if start >= 0 && end > start && start+1 < len(msg) {
extracted := msg[start+1 : end]
if isValidIdentifier(extracted) {
varName = extracted
}
}

fix := &Fix{
Expand All @@ -193,14 +243,17 @@ func (f *Fixer) fixContextLeak(issue rules.Issue, file *ast.File, fset *token.Fi
line := issue.Line
original := getLine(lines, line)

// Extract cancel function name from message
// Extract cancel function name from message safely
msg := issue.Message
start := strings.Index(msg, "'")
end := strings.LastIndex(msg, "'")

cancelName := "cancel"
if start >= 0 && end > start {
cancelName = msg[start+1 : end]
if start >= 0 && end > start && start+1 < len(msg) {
extracted := msg[start+1 : end]
if isValidIdentifier(extracted) {
cancelName = extracted
}
}

fix := &Fix{
Expand All @@ -215,6 +268,25 @@ func (f *Fixer) fixContextLeak(issue rules.Issue, file *ast.File, fset *token.Fi
return fix
}

// isValidIdentifier checks if a string looks like a valid Go identifier
func isValidIdentifier(s string) bool {
if len(s) == 0 || len(s) > 100 {
return false
}
for i, c := range s {
if i == 0 {
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || c == '_') {
return false
}
} else {
if !((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_') {
return false
}
}
}
return true
}

func getLine(lines []string, lineNum int) string {
if lineNum > 0 && lineNum <= len(lines) {
return lines[lineNum-1]
Expand All @@ -230,9 +302,11 @@ func PrintFixes(fixes []Fix, dryRun bool) {
}

if dryRun {
fmt.Println("=== DRY RUN: Suggested fixes (no files modified) ===\n")
fmt.Println("=== DRY RUN: Suggested fixes (no files modified) ===")
fmt.Println()
} else {
fmt.Println("=== Applied fixes ===\n")
fmt.Println("=== Suggested fixes ===")
fmt.Println()
}

for _, fix := range fixes {
Expand Down
Loading
Loading