|
| 1 | +package agent |
| 2 | + |
| 3 | +import ( |
| 4 | + "bytes" |
| 5 | + "context" |
| 6 | + "errors" |
| 7 | + "fmt" |
| 8 | + "io" |
| 9 | + "os/exec" |
| 10 | + "strings" |
| 11 | + "time" |
| 12 | + |
| 13 | + acp "github.com/coder/acp-go-sdk" |
| 14 | + "github.com/roborev-dev/roborev/internal/config" |
| 15 | +) |
| 16 | + |
| 17 | +// Security error for path traversal attempts |
| 18 | +var ErrPathTraversal = errors.New("path traversal attempt detected") |
| 19 | + |
| 20 | +const ( |
| 21 | + defaultACPName = "acp" |
| 22 | + defaultACPCommand = "acp-agent" |
| 23 | + defaultACPReadOnlyMode = "plan" |
| 24 | + defaultACPAutoApproveMode = "auto-approve" |
| 25 | + defaultACPTimeoutSeconds = 600 |
| 26 | + maxACPTextFileBytes = 10_000_000 |
| 27 | +) |
| 28 | + |
| 29 | +// ACPAgent runs code reviews using the Agent Client Protocol via acp-go-sdk |
| 30 | +type ACPAgent struct { |
| 31 | + agentName string // Agent name (from configuration) |
| 32 | + Command string // ACP agent command (configured via TOML) |
| 33 | + Args []string // Additional arguments for the agent |
| 34 | + Model string // Model to use |
| 35 | + Mode string // Mode to use |
| 36 | + ReadOnlyMode string |
| 37 | + AutoApproveMode string |
| 38 | + Reasoning ReasoningLevel // Reasoning level |
| 39 | + Agentic bool // Agentic mode |
| 40 | + Timeout time.Duration // Command timeout |
| 41 | + SessionID string // Current ACP session ID |
| 42 | + repoRoot string // Repository root for path validation |
| 43 | +} |
| 44 | + |
| 45 | +func NewACPAgent(command string) *ACPAgent { |
| 46 | + if command == "" { |
| 47 | + command = defaultACPCommand |
| 48 | + } |
| 49 | + |
| 50 | + return &ACPAgent{ |
| 51 | + agentName: defaultACPName, |
| 52 | + Command: command, |
| 53 | + Args: []string{}, |
| 54 | + Model: "", |
| 55 | + Mode: defaultACPReadOnlyMode, |
| 56 | + ReadOnlyMode: defaultACPReadOnlyMode, |
| 57 | + AutoApproveMode: defaultACPAutoApproveMode, |
| 58 | + Timeout: time.Duration(defaultACPTimeoutSeconds) * time.Second, |
| 59 | + Reasoning: ReasoningStandard, |
| 60 | + SessionID: "", // Initialize with empty session ID |
| 61 | + } |
| 62 | +} |
| 63 | + |
| 64 | +func NewACPAgentFromConfig(config *config.ACPAgentConfig) *ACPAgent { |
| 65 | + if config == nil { |
| 66 | + return NewACPAgent("") |
| 67 | + } |
| 68 | + |
| 69 | + agent := NewACPAgent(config.Command) |
| 70 | + if agentName := strings.TrimSpace(config.Name); agentName != "" { |
| 71 | + agent.agentName = agentName |
| 72 | + } |
| 73 | + if len(config.Args) > 0 { |
| 74 | + agent.Args = append([]string(nil), config.Args...) |
| 75 | + } |
| 76 | + if model := strings.TrimSpace(config.Model); model != "" { |
| 77 | + agent.Model = model |
| 78 | + } |
| 79 | + if readOnlyMode := strings.TrimSpace(config.ReadOnlyMode); readOnlyMode != "" { |
| 80 | + agent.ReadOnlyMode = readOnlyMode |
| 81 | + } |
| 82 | + if autoApproveMode := strings.TrimSpace(config.AutoApproveMode); autoApproveMode != "" { |
| 83 | + agent.AutoApproveMode = autoApproveMode |
| 84 | + } |
| 85 | + if config.DisableModeNegotiation { |
| 86 | + agent.Mode = "" |
| 87 | + } else if mode := strings.TrimSpace(config.Mode); mode != "" { |
| 88 | + agent.Mode = mode |
| 89 | + } else { |
| 90 | + agent.Mode = agent.ReadOnlyMode |
| 91 | + } |
| 92 | + |
| 93 | + timeout := time.Duration(defaultACPTimeoutSeconds) * time.Second |
| 94 | + if config.Timeout > 0 { |
| 95 | + timeout = time.Duration(config.Timeout) * time.Second |
| 96 | + } |
| 97 | + agent.Timeout = timeout |
| 98 | + |
| 99 | + return agent |
| 100 | +} |
| 101 | + |
| 102 | +func (a *ACPAgent) Name() string { |
| 103 | + return a.agentName |
| 104 | +} |
| 105 | + |
| 106 | +func (a *ACPAgent) CommandName() string { |
| 107 | + return a.Command |
| 108 | +} |
| 109 | + |
| 110 | +func (a *ACPAgent) CommandLine() string { |
| 111 | + return a.Command + " " + strings.Join(a.Args, " ") |
| 112 | +} |
| 113 | + |
| 114 | +func (a *ACPAgent) WithReasoning(level ReasoningLevel) Agent { |
| 115 | + return &ACPAgent{ |
| 116 | + agentName: a.agentName, |
| 117 | + Command: a.Command, |
| 118 | + Args: a.Args, |
| 119 | + Model: a.Model, |
| 120 | + ReadOnlyMode: a.ReadOnlyMode, |
| 121 | + AutoApproveMode: a.AutoApproveMode, |
| 122 | + Mode: a.Mode, |
| 123 | + Reasoning: level, // Use the provided level parameter |
| 124 | + Agentic: a.Agentic, // Preserve Agentic field |
| 125 | + Timeout: a.Timeout, |
| 126 | + SessionID: a.SessionID, // Preserve SessionID |
| 127 | + } |
| 128 | +} |
| 129 | + |
| 130 | +func (a *ACPAgent) WithAgentic(agentic bool) Agent { |
| 131 | + |
| 132 | + // Set the appropriate mode based on agentic flag |
| 133 | + mode := a.ReadOnlyMode |
| 134 | + if agentic && a.AutoApproveMode != "" { |
| 135 | + mode = a.AutoApproveMode |
| 136 | + } |
| 137 | + if strings.TrimSpace(a.Mode) == "" { |
| 138 | + mode = "" |
| 139 | + } |
| 140 | + |
| 141 | + return &ACPAgent{ |
| 142 | + agentName: a.agentName, |
| 143 | + Command: a.Command, |
| 144 | + Args: a.Args, |
| 145 | + Model: a.Model, |
| 146 | + ReadOnlyMode: a.ReadOnlyMode, |
| 147 | + AutoApproveMode: a.AutoApproveMode, |
| 148 | + Mode: mode, |
| 149 | + Reasoning: a.Reasoning, |
| 150 | + Agentic: agentic, |
| 151 | + Timeout: a.Timeout, |
| 152 | + SessionID: a.SessionID, // Preserve SessionID |
| 153 | + } |
| 154 | +} |
| 155 | + |
| 156 | +func (a *ACPAgent) WithModel(model string) Agent { |
| 157 | + if model == "" { |
| 158 | + return a |
| 159 | + } |
| 160 | + |
| 161 | + return &ACPAgent{ |
| 162 | + agentName: a.agentName, |
| 163 | + Command: a.Command, |
| 164 | + Args: a.Args, |
| 165 | + Model: model, |
| 166 | + ReadOnlyMode: a.ReadOnlyMode, |
| 167 | + AutoApproveMode: a.AutoApproveMode, |
| 168 | + Mode: a.Mode, |
| 169 | + Reasoning: a.Reasoning, |
| 170 | + Agentic: a.Agentic, // Preserve Agentic field |
| 171 | + Timeout: a.Timeout, |
| 172 | + SessionID: a.SessionID, // Preserve SessionID |
| 173 | + } |
| 174 | +} |
| 175 | + |
| 176 | +// Review implements the main review functionality using ACP SDK |
| 177 | +func (a *ACPAgent) Review(ctx context.Context, repoPath, commitSHA, prompt string, output io.Writer) (string, error) { |
| 178 | + // Set timeout context |
| 179 | + var cancel context.CancelFunc |
| 180 | + var err error |
| 181 | + ctx, cancel = context.WithTimeout(ctx, a.Timeout) |
| 182 | + defer cancel() |
| 183 | + |
| 184 | + // Build the command with arguments |
| 185 | + cmd := exec.CommandContext(ctx, a.Command, a.Args...) |
| 186 | + |
| 187 | + // Set up stdio pipes for communication with the agent |
| 188 | + var stdinPipe io.WriteCloser |
| 189 | + var stdoutPipe io.ReadCloser |
| 190 | + var pipesCleanup func() error |
| 191 | + |
| 192 | + // Initialize pipes with proper cleanup |
| 193 | + pipeInit := func() error { |
| 194 | + var err error |
| 195 | + stdinPipe, err = cmd.StdinPipe() |
| 196 | + if err != nil { |
| 197 | + return fmt.Errorf("failed to create stdin pipe: %w", err) |
| 198 | + } |
| 199 | + stdoutPipe, err = cmd.StdoutPipe() |
| 200 | + if err != nil { |
| 201 | + _ = stdinPipe.Close() |
| 202 | + return fmt.Errorf("failed to create stdout pipe: %w", err) |
| 203 | + } |
| 204 | + |
| 205 | + // Set up cleanup function that will be called in reverse order |
| 206 | + pipesCleanup = func() error { |
| 207 | + var pipeErrors []error |
| 208 | + if closeErr := stdoutPipe.Close(); closeErr != nil { |
| 209 | + pipeErrors = append(pipeErrors, closeErr) |
| 210 | + } |
| 211 | + if closeErr := stdinPipe.Close(); closeErr != nil { |
| 212 | + pipeErrors = append(pipeErrors, closeErr) |
| 213 | + } |
| 214 | + if len(pipeErrors) > 0 { |
| 215 | + return fmt.Errorf("pipe cleanup errors: %v", pipeErrors) |
| 216 | + } |
| 217 | + return nil |
| 218 | + } |
| 219 | + return nil |
| 220 | + } |
| 221 | + |
| 222 | + if err := pipeInit(); err != nil { |
| 223 | + return "", err |
| 224 | + } |
| 225 | + |
| 226 | + // Start the agent process |
| 227 | + if err := cmd.Start(); err != nil { |
| 228 | + _ = pipesCleanup() |
| 229 | + return "", fmt.Errorf("failed to start ACP agent: %w", err) |
| 230 | + } |
| 231 | + |
| 232 | + // Defer cleanup in proper order: terminals -> pipes -> process |
| 233 | + // Create a client that handles agent responses |
| 234 | + client := &acpClient{ |
| 235 | + agent: a, |
| 236 | + output: output, |
| 237 | + result: &bytes.Buffer{}, |
| 238 | + sessionID: "", |
| 239 | + repoRoot: repoPath, |
| 240 | + terminals: make(map[string]*acpTerminal), |
| 241 | + nextTerminalID: 1, |
| 242 | + } |
| 243 | + |
| 244 | + // Deferred cleanup to ensure no orphaned terminal processes |
| 245 | + defer func() { |
| 246 | + // Cancel all active terminals first |
| 247 | + client.terminalsMutex.Lock() |
| 248 | + for _, terminal := range client.terminals { |
| 249 | + terminal.cancel() |
| 250 | + } |
| 251 | + client.terminalsMutex.Unlock() |
| 252 | + |
| 253 | + // Then clean up pipes |
| 254 | + if pipesCleanup != nil { |
| 255 | + _ = pipesCleanup() |
| 256 | + } |
| 257 | + |
| 258 | + // Finally clean up process resources |
| 259 | + if cmd.Process != nil { |
| 260 | + if cmd.ProcessState == nil || !cmd.ProcessState.Exited() { |
| 261 | + _ = cmd.Process.Kill() |
| 262 | + } |
| 263 | + _ = cmd.Wait() |
| 264 | + } |
| 265 | + }() |
| 266 | + |
| 267 | + // Create the ACP connection |
| 268 | + conn := acp.NewClientSideConnection(client, stdinPipe, stdoutPipe) |
| 269 | + |
| 270 | + _, err = conn.Initialize(ctx, acp.InitializeRequest{ |
| 271 | + ProtocolVersion: acp.ProtocolVersionNumber, |
| 272 | + ClientCapabilities: acp.ClientCapabilities{ |
| 273 | + Fs: acp.FileSystemCapability{ |
| 274 | + ReadTextFile: true, |
| 275 | + WriteTextFile: true, |
| 276 | + }, |
| 277 | + Terminal: true, |
| 278 | + }, |
| 279 | + }) |
| 280 | + if err != nil { |
| 281 | + // Check process state to provide better error context |
| 282 | + if cmd.ProcessState != nil && cmd.ProcessState.Exited() { |
| 283 | + return "", fmt.Errorf("failed to initialize ACP connection (agent exited with code %d): %w", |
| 284 | + cmd.ProcessState.ExitCode(), err) |
| 285 | + } |
| 286 | + return "", fmt.Errorf("failed to initialize ACP connection: %w", err) |
| 287 | + } |
| 288 | + |
| 289 | + // Create a new session |
| 290 | + sessionResp, err := conn.NewSession(ctx, acp.NewSessionRequest{ |
| 291 | + Cwd: repoPath, |
| 292 | + McpServers: []acp.McpServer{}, |
| 293 | + }) |
| 294 | + if err != nil { |
| 295 | + return "", fmt.Errorf("failed to create session: %w", err) |
| 296 | + } |
| 297 | + |
| 298 | + // Store the session ID for request-scoped validation. |
| 299 | + client.sessionID = string(sessionResp.SessionId) |
| 300 | + |
| 301 | + if a.Mode != "" { |
| 302 | + if err := validateConfiguredMode(a.Mode, sessionResp.Modes); err != nil { |
| 303 | + return "", err |
| 304 | + } |
| 305 | + |
| 306 | + _, err = conn.SetSessionMode(ctx, acp.SetSessionModeRequest{SessionId: sessionResp.SessionId, ModeId: acp.SessionModeId(a.Mode)}) |
| 307 | + if err != nil { |
| 308 | + return "", fmt.Errorf("failed to set session mode: %w", err) |
| 309 | + } |
| 310 | + } |
| 311 | + |
| 312 | + if a.Model != "" { |
| 313 | + if err := validateConfiguredModel(a.Model, sessionResp.Models); err != nil { |
| 314 | + return "", err |
| 315 | + } |
| 316 | + |
| 317 | + _, err = conn.SetSessionModel(ctx, acp.SetSessionModelRequest{SessionId: sessionResp.SessionId, ModelId: acp.ModelId(a.Model)}) |
| 318 | + if err != nil { |
| 319 | + return "", fmt.Errorf("failed to set session model: %w", err) |
| 320 | + } |
| 321 | + } |
| 322 | + |
| 323 | + // Send the prompt request |
| 324 | + promptRequest := acp.PromptRequest{ |
| 325 | + SessionId: sessionResp.SessionId, |
| 326 | + Prompt: []acp.ContentBlock{ |
| 327 | + acp.TextBlock(fmt.Sprintf("Review the code changes in commit %s.\n\nRepository: %s\n\nPrompt: %s", |
| 328 | + commitSHA, repoPath, prompt)), |
| 329 | + }, |
| 330 | + } |
| 331 | + |
| 332 | + promptResponse, err := conn.Prompt(ctx, promptRequest) |
| 333 | + if err != nil { |
| 334 | + return "", fmt.Errorf("failed to send prompt: %w", err) |
| 335 | + } |
| 336 | + |
| 337 | + // Wait for the agent to finish processing |
| 338 | + if promptResponse.StopReason != acp.StopReasonEndTurn { |
| 339 | + return "", fmt.Errorf("agent did not complete processing: %s", promptResponse.StopReason) |
| 340 | + } |
| 341 | + |
| 342 | + return client.resultString(), nil |
| 343 | +} |
0 commit comments