Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 92 additions & 8 deletions src/core/checkpoints/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ import * as vscode from "vscode"
import type { ClineApiReqInfo } from "@roo-code/types"
import { TelemetryService } from "@roo-code/telemetry"

import { Task } from "../task/Task"

import { getWorkspacePath } from "../../utils/path"
import { checkGitInstalled } from "../../utils/git"
import { t } from "../../i18n"
Expand All @@ -18,14 +16,97 @@ import { CheckpointServiceOptions, RepoPerTaskCheckpointService } from "../../se

const WARNING_THRESHOLD_MS = 5000

function sendCheckpointInitWarn(task: Task, type?: "WAIT_TIMEOUT" | "INIT_TIMEOUT", timeout?: number) {
export type CheckpointTaskContext = {
taskId: string
cwd?: string
providerRef: WeakRef<any>
clineMessages: any[]
messageManager: {
rewindToTimestamp: (ts: number, options: { includeTargetMessage: boolean }) => Promise<void>
}
say: (...args: any[]) => Promise<any>
combineMessages: (messages: any[]) => any
}

export type CheckpointRuntime = CheckpointTaskContext & {
enableCheckpoints: boolean
checkpointTimeout: number
checkpointService?: RepoPerTaskCheckpointService
checkpointServiceInitializing: boolean
}

export type CheckpointManagerOptions = {
enableCheckpoints: boolean
checkpointTimeout: number
}

export class CheckpointManager implements CheckpointRuntime {
public enableCheckpoints: boolean
public checkpointTimeout: number
public checkpointService?: RepoPerTaskCheckpointService
public checkpointServiceInitializing = false

constructor(
private readonly task: CheckpointTaskContext,
{ enableCheckpoints, checkpointTimeout }: CheckpointManagerOptions,
) {
this.enableCheckpoints = enableCheckpoints
this.checkpointTimeout = checkpointTimeout
}

get taskId() {
return this.task.taskId
}

get cwd() {
return this.task.cwd
}

get providerRef() {
return this.task.providerRef
}

get clineMessages() {
return this.task.clineMessages
}

get messageManager() {
return this.task.messageManager
}

public say(...args: any[]) {
return this.task.say(...args)
}

public combineMessages(messages: any[]) {
return this.task.combineMessages(messages)
}

public getService(options?: { interval?: number }) {
return getCheckpointService(this, options)
}

public save(force = false, suppressMessage = false) {
return checkpointSave(this, force, suppressMessage)
}

public restore(options: CheckpointRestoreOptions) {
return checkpointRestore(this, options)
}

public diff(options: CheckpointDiffOptions) {
return checkpointDiff(this, options)
}
}

function sendCheckpointInitWarn(task: CheckpointRuntime, type?: "WAIT_TIMEOUT" | "INIT_TIMEOUT", timeout?: number) {
task.providerRef.deref()?.postMessageToWebview({
type: "checkpointInitWarning",
checkpointWarning: type && timeout ? { type, timeout } : undefined,
})
}

export async function getCheckpointService(task: Task, { interval = 250 }: { interval?: number } = {}) {
export async function getCheckpointService(task: CheckpointRuntime, { interval = 250 }: { interval?: number } = {}) {
if (!task.enableCheckpoints) {
return undefined
}
Expand Down Expand Up @@ -130,7 +211,7 @@ export async function getCheckpointService(task: Task, { interval = 250 }: { int
}

async function checkGitInstallation(
task: Task,
task: CheckpointRuntime,
service: RepoPerTaskCheckpointService,
log: (message: string) => void,
provider: any,
Expand Down Expand Up @@ -209,7 +290,7 @@ async function checkGitInstallation(
}
}

export async function checkpointSave(task: Task, force = false, suppressMessage = false) {
export async function checkpointSave(task: CheckpointRuntime, force = false, suppressMessage = false) {
const service = await getCheckpointService(task)

if (!service) {
Expand All @@ -235,7 +316,7 @@ export type CheckpointRestoreOptions = {
}

export async function checkpointRestore(
task: Task,
task: CheckpointRuntime,
{ ts, commitHash, mode, operation = "delete" }: CheckpointRestoreOptions,
) {
const service = await getCheckpointService(task)
Expand Down Expand Up @@ -314,7 +395,10 @@ export type CheckpointDiffOptions = {
mode: "from-init" | "checkpoint" | "to-current" | "full"
}

export async function checkpointDiff(task: Task, { ts, previousCommitHash, commitHash, mode }: CheckpointDiffOptions) {
export async function checkpointDiff(
task: CheckpointRuntime,
{ ts, previousCommitHash, commitHash, mode }: CheckpointDiffOptions,
) {
const service = await getCheckpointService(task)

if (!service) {
Expand Down
54 changes: 39 additions & 15 deletions src/core/task/Task.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ import { getModelMaxOutputTokens } from "../../shared/api"
// services
import { McpHub } from "../../services/mcp/McpHub"
import { McpServerManager } from "../../services/mcp/McpServerManager"
import { RepoPerTaskCheckpointService } from "../../services/checkpoints"

// integrations
import { DiffViewProvider } from "../../integrations/editor/DiffViewProvider"
Expand Down Expand Up @@ -118,12 +117,9 @@ import {
import { getEnvironmentDetails } from "../environment/getEnvironmentDetails"
import { checkContextWindowExceededError } from "../context/context-management/context-error-handling"
import {
CheckpointManager,
type CheckpointDiffOptions,
type CheckpointRestoreOptions,
getCheckpointService,
checkpointSave,
checkpointRestore,
checkpointDiff,
} from "../checkpoints"
import { processUserContentMentions } from "../mentions/processUserContentMentions"
import { getMessagesSinceLastSummary, summarizeConversation, getEffectiveApiHistory } from "../condense"
Expand Down Expand Up @@ -326,10 +322,7 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
toolUsage: ToolUsage = {}

// Checkpoints
enableCheckpoints: boolean
checkpointTimeout: number
checkpointService?: RepoPerTaskCheckpointService
checkpointServiceInitializing = false
private readonly checkpointManager: CheckpointManager

// Message Queue Service
public readonly messageQueueService: MessageQueueService
Expand Down Expand Up @@ -491,8 +484,7 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
this.providerRef = new WeakRef(provider)
this.globalStoragePath = provider.context.globalStorageUri.fsPath
this.diffViewProvider = new DiffViewProvider(this.cwd, this)
this.enableCheckpoints = enableCheckpoints
this.checkpointTimeout = checkpointTimeout
this.checkpointManager = new CheckpointManager(this, { enableCheckpoints, checkpointTimeout })

this.parentTask = parentTask
this.taskNumber = taskNumber
Expand Down Expand Up @@ -2471,7 +2463,7 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {

private async initiateTaskLoop(userContent: Anthropic.Messages.ContentBlockParam[]): Promise<void> {
// Kicks off the checkpoints initialization process in the background.
getCheckpointService(this)
this.checkpointManager.getService()

let nextUserContent = userContent
let includeFileDetails = true
Expand Down Expand Up @@ -4482,8 +4474,40 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {

// Checkpoints

public get enableCheckpoints() {
return this.checkpointManager.enableCheckpoints
}

public set enableCheckpoints(value: boolean) {
this.checkpointManager.enableCheckpoints = value
}

public get checkpointTimeout() {
return this.checkpointManager.checkpointTimeout
}

public set checkpointTimeout(value: number) {
this.checkpointManager.checkpointTimeout = value
}

public get checkpointService(): CheckpointManager["checkpointService"] {
return this.checkpointManager.checkpointService
}

public set checkpointService(value: CheckpointManager["checkpointService"]) {
this.checkpointManager.checkpointService = value
}

public get checkpointServiceInitializing() {
return this.checkpointManager.checkpointServiceInitializing
}

public set checkpointServiceInitializing(value: boolean) {
this.checkpointManager.checkpointServiceInitializing = value
}

public async checkpointSave(force: boolean = false, suppressMessage: boolean = false) {
return checkpointSave(this, force, suppressMessage)
return this.checkpointManager.save(force, suppressMessage)
}

private buildCleanConversationHistory(
Expand Down Expand Up @@ -4628,11 +4652,11 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
return cleanConversationHistory
}
public async checkpointRestore(options: CheckpointRestoreOptions) {
return checkpointRestore(this, options)
return this.checkpointManager.restore(options)
}

public async checkpointDiff(options: CheckpointDiffOptions) {
return checkpointDiff(this, options)
return this.checkpointManager.diff(options)
}

// Metrics
Expand Down
Loading