diff --git a/src/core/checkpoints/index.ts b/src/core/checkpoints/index.ts index 26a137b939..0f99344e74 100644 --- a/src/core/checkpoints/index.ts +++ b/src/core/checkpoints/index.ts @@ -4,8 +4,6 @@ import * as vscode from "vscode" import type { ClineApiReqInfo } from "@roo-code/types" import { TelemetryService } from "@roo-code/telemetry" -import { Task } from "../task/Task" - import { getWorkspacePath } from "../../utils/path" import { checkGitInstalled } from "../../utils/git" import { t } from "../../i18n" @@ -18,14 +16,97 @@ import { CheckpointServiceOptions, RepoPerTaskCheckpointService } from "../../se const WARNING_THRESHOLD_MS = 5000 -function sendCheckpointInitWarn(task: Task, type?: "WAIT_TIMEOUT" | "INIT_TIMEOUT", timeout?: number) { +export type CheckpointTaskContext = { + taskId: string + cwd?: string + providerRef: WeakRef + clineMessages: any[] + messageManager: { + rewindToTimestamp: (ts: number, options: { includeTargetMessage: boolean }) => Promise + } + say: (...args: any[]) => Promise + combineMessages: (messages: any[]) => any +} + +export type CheckpointRuntime = CheckpointTaskContext & { + enableCheckpoints: boolean + checkpointTimeout: number + checkpointService?: RepoPerTaskCheckpointService + checkpointServiceInitializing: boolean +} + +export type CheckpointManagerOptions = { + enableCheckpoints: boolean + checkpointTimeout: number +} + +export class CheckpointManager implements CheckpointRuntime { + public enableCheckpoints: boolean + public checkpointTimeout: number + public checkpointService?: RepoPerTaskCheckpointService + public checkpointServiceInitializing = false + + constructor( + private readonly task: CheckpointTaskContext, + { enableCheckpoints, checkpointTimeout }: CheckpointManagerOptions, + ) { + this.enableCheckpoints = enableCheckpoints + this.checkpointTimeout = checkpointTimeout + } + + get taskId() { + return this.task.taskId + } + + get cwd() { + return this.task.cwd + } + + get providerRef() { + return this.task.providerRef + } + + get clineMessages() { + return this.task.clineMessages + } + + get messageManager() { + return this.task.messageManager + } + + public say(...args: any[]) { + return this.task.say(...args) + } + + public combineMessages(messages: any[]) { + return this.task.combineMessages(messages) + } + + public getService(options?: { interval?: number }) { + return getCheckpointService(this, options) + } + + public save(force = false, suppressMessage = false) { + return checkpointSave(this, force, suppressMessage) + } + + public restore(options: CheckpointRestoreOptions) { + return checkpointRestore(this, options) + } + + public diff(options: CheckpointDiffOptions) { + return checkpointDiff(this, options) + } +} + +function sendCheckpointInitWarn(task: CheckpointRuntime, type?: "WAIT_TIMEOUT" | "INIT_TIMEOUT", timeout?: number) { task.providerRef.deref()?.postMessageToWebview({ type: "checkpointInitWarning", checkpointWarning: type && timeout ? { type, timeout } : undefined, }) } -export async function getCheckpointService(task: Task, { interval = 250 }: { interval?: number } = {}) { +export async function getCheckpointService(task: CheckpointRuntime, { interval = 250 }: { interval?: number } = {}) { if (!task.enableCheckpoints) { return undefined } @@ -130,7 +211,7 @@ export async function getCheckpointService(task: Task, { interval = 250 }: { int } async function checkGitInstallation( - task: Task, + task: CheckpointRuntime, service: RepoPerTaskCheckpointService, log: (message: string) => void, provider: any, @@ -209,7 +290,7 @@ async function checkGitInstallation( } } -export async function checkpointSave(task: Task, force = false, suppressMessage = false) { +export async function checkpointSave(task: CheckpointRuntime, force = false, suppressMessage = false) { const service = await getCheckpointService(task) if (!service) { @@ -235,7 +316,7 @@ export type CheckpointRestoreOptions = { } export async function checkpointRestore( - task: Task, + task: CheckpointRuntime, { ts, commitHash, mode, operation = "delete" }: CheckpointRestoreOptions, ) { const service = await getCheckpointService(task) @@ -314,7 +395,10 @@ export type CheckpointDiffOptions = { mode: "from-init" | "checkpoint" | "to-current" | "full" } -export async function checkpointDiff(task: Task, { ts, previousCommitHash, commitHash, mode }: CheckpointDiffOptions) { +export async function checkpointDiff( + task: CheckpointRuntime, + { ts, previousCommitHash, commitHash, mode }: CheckpointDiffOptions, +) { const service = await getCheckpointService(task) if (!service) { diff --git a/src/core/task/Task.ts b/src/core/task/Task.ts index 26b7295729..3ae49d5358 100644 --- a/src/core/task/Task.ts +++ b/src/core/task/Task.ts @@ -76,7 +76,6 @@ import { getModelMaxOutputTokens } from "../../shared/api" // services import { McpHub } from "../../services/mcp/McpHub" import { McpServerManager } from "../../services/mcp/McpServerManager" -import { RepoPerTaskCheckpointService } from "../../services/checkpoints" // integrations import { DiffViewProvider } from "../../integrations/editor/DiffViewProvider" @@ -118,12 +117,9 @@ import { import { getEnvironmentDetails } from "../environment/getEnvironmentDetails" import { checkContextWindowExceededError } from "../context/context-management/context-error-handling" import { + CheckpointManager, type CheckpointDiffOptions, type CheckpointRestoreOptions, - getCheckpointService, - checkpointSave, - checkpointRestore, - checkpointDiff, } from "../checkpoints" import { processUserContentMentions } from "../mentions/processUserContentMentions" import { getMessagesSinceLastSummary, summarizeConversation, getEffectiveApiHistory } from "../condense" @@ -326,10 +322,7 @@ export class Task extends EventEmitter implements TaskLike { toolUsage: ToolUsage = {} // Checkpoints - enableCheckpoints: boolean - checkpointTimeout: number - checkpointService?: RepoPerTaskCheckpointService - checkpointServiceInitializing = false + private readonly checkpointManager: CheckpointManager // Message Queue Service public readonly messageQueueService: MessageQueueService @@ -491,8 +484,7 @@ export class Task extends EventEmitter implements TaskLike { this.providerRef = new WeakRef(provider) this.globalStoragePath = provider.context.globalStorageUri.fsPath this.diffViewProvider = new DiffViewProvider(this.cwd, this) - this.enableCheckpoints = enableCheckpoints - this.checkpointTimeout = checkpointTimeout + this.checkpointManager = new CheckpointManager(this, { enableCheckpoints, checkpointTimeout }) this.parentTask = parentTask this.taskNumber = taskNumber @@ -2471,7 +2463,7 @@ export class Task extends EventEmitter implements TaskLike { private async initiateTaskLoop(userContent: Anthropic.Messages.ContentBlockParam[]): Promise { // Kicks off the checkpoints initialization process in the background. - getCheckpointService(this) + this.checkpointManager.getService() let nextUserContent = userContent let includeFileDetails = true @@ -4482,8 +4474,40 @@ export class Task extends EventEmitter implements TaskLike { // Checkpoints + public get enableCheckpoints() { + return this.checkpointManager.enableCheckpoints + } + + public set enableCheckpoints(value: boolean) { + this.checkpointManager.enableCheckpoints = value + } + + public get checkpointTimeout() { + return this.checkpointManager.checkpointTimeout + } + + public set checkpointTimeout(value: number) { + this.checkpointManager.checkpointTimeout = value + } + + public get checkpointService(): CheckpointManager["checkpointService"] { + return this.checkpointManager.checkpointService + } + + public set checkpointService(value: CheckpointManager["checkpointService"]) { + this.checkpointManager.checkpointService = value + } + + public get checkpointServiceInitializing() { + return this.checkpointManager.checkpointServiceInitializing + } + + public set checkpointServiceInitializing(value: boolean) { + this.checkpointManager.checkpointServiceInitializing = value + } + public async checkpointSave(force: boolean = false, suppressMessage: boolean = false) { - return checkpointSave(this, force, suppressMessage) + return this.checkpointManager.save(force, suppressMessage) } private buildCleanConversationHistory( @@ -4628,11 +4652,11 @@ export class Task extends EventEmitter implements TaskLike { return cleanConversationHistory } public async checkpointRestore(options: CheckpointRestoreOptions) { - return checkpointRestore(this, options) + return this.checkpointManager.restore(options) } public async checkpointDiff(options: CheckpointDiffOptions) { - return checkpointDiff(this, options) + return this.checkpointManager.diff(options) } // Metrics