diff --git a/apps/server/src/wsServer.test.ts b/apps/server/src/wsServer.test.ts index 8e99a65aa..54d55ce04 100644 --- a/apps/server/src/wsServer.test.ts +++ b/apps/server/src/wsServer.test.ts @@ -2142,6 +2142,60 @@ describe("WebSocket Server", () => { ); }); + it("stops a pending git action for the initiating websocket", async () => { + const runStackedAction: GitManagerShape["runStackedAction"] = (input, options) => + Effect.gen(function* () { + if (options?.progressReporter) { + yield* options.progressReporter.publish({ + actionId: options.actionId ?? input.actionId, + cwd: input.cwd, + action: input.action, + kind: "phase_started", + phase: "commit", + label: "Committing...", + }); + } + return yield* Effect.never; + }); + const gitManager: GitManagerShape = { + status: vi.fn(() => Effect.void as any), + resolvePullRequest: vi.fn(() => Effect.void as any), + preparePullRequestThread: vi.fn(() => Effect.void as any), + runStackedAction, + listPullRequests: vi.fn(() => Effect.succeed({ pullRequests: [] })), + }; + + const { cwd } = makeWorkspaceFixture("test"); + server = await createTestServer({ cwd, gitManager }); + const addr = server.address(); + const port = typeof addr === "object" && addr !== null ? addr.port : 0; + + const [ws] = await connectAndAwaitWelcome(port); + connections.push(ws); + + const actionResponsePromise = sendRequest(ws, WS_METHODS.gitRunStackedAction, { + actionId: "client-action-stop", + cwd, + action: "commit", + }); + await waitForPush(ws, WS_CHANNELS.gitActionProgress); + + const stopResponse = await sendRequest(ws, WS_METHODS.gitStopAction, { + cwd, + actionId: "client-action-stop", + }); + + expect(stopResponse.error).toBeUndefined(); + await expect(actionResponsePromise).resolves.toEqual( + expect.objectContaining({ + error: expect.objectContaining({ + code: "git_action_stopped", + message: "Git action stopped.", + }), + }), + ); + }); + it("rejects websocket connections without a valid auth token", async () => { const { cwd } = makeWorkspaceFixture("test"); server = await createTestServer({ cwd, authToken: "secret-token" }); diff --git a/apps/server/src/wsServer.ts b/apps/server/src/wsServer.ts index 1be33c80f..8eafdfc4a 100644 --- a/apps/server/src/wsServer.ts +++ b/apps/server/src/wsServer.ts @@ -38,6 +38,7 @@ import { Effect, Exit, FileSystem, + Fiber, Layer, Path, Ref, @@ -331,6 +332,13 @@ class RouteRequestError extends Schema.TaggedErrorClass()("Ro message: Schema.String, }) {} +class GitActionStoppedError extends Schema.TaggedErrorClass()( + "GitActionStoppedError", + { + message: Schema.String, + }, +) {} + export const createServer = Effect.fn(function* (): Effect.fn.Return< http.Server, ServerLifecycleError, @@ -374,6 +382,89 @@ export const createServer = Effect.fn(function* (): Effect.fn.Return< const clients = yield* Ref.make(new Set()); const logger = createLogger("ws"); const readiness = yield* makeServerReadiness; + type ActiveGitRequestKind = "pull" | "stacked_action"; + type ActiveGitRequestHandle = { + readonly kind: ActiveGitRequestKind; + readonly cwd: string; + readonly actionId: string | null; + readonly fiber: Fiber.Fiber; + }; + const activeGitRequests = new WeakMap>(); + + const registerActiveGitRequest = (ws: WebSocket, handle: ActiveGitRequestHandle) => + Effect.sync(() => { + const handles = activeGitRequests.get(ws) ?? new Set(); + handles.add(handle); + activeGitRequests.set(ws, handles); + }); + + const unregisterActiveGitRequest = (ws: WebSocket, handle: ActiveGitRequestHandle) => + Effect.sync(() => { + const handles = activeGitRequests.get(ws); + if (!handles) { + return; + } + handles.delete(handle); + if (handles.size === 0) { + activeGitRequests.delete(ws); + } + }); + + const interruptActiveGitRequests = (ws: WebSocket) => + Effect.gen(function* () { + const handles = Array.from(activeGitRequests.get(ws) ?? []); + activeGitRequests.delete(ws); + for (const handle of handles) { + yield* Fiber.interrupt(handle.fiber).pipe(Effect.ignore); + } + }); + + const stopActiveGitRequest = ( + ws: WebSocket, + input: { cwd: string; actionId?: string | undefined }, + ) => + Effect.gen(function* () { + const handles = Array.from(activeGitRequests.get(ws) ?? []); + const handle = + input.actionId != null + ? handles.find( + (candidate) => candidate.cwd === input.cwd && candidate.actionId === input.actionId, + ) + : handles.find((candidate) => candidate.cwd === input.cwd); + + if (!handle) { + return; + } + + yield* Fiber.interrupt(handle.fiber); + }); + + const runTrackedGitRequest = ( + ws: WebSocket, + meta: { kind: ActiveGitRequestKind; cwd: string; actionId?: string | undefined }, + effect: Effect.Effect, + interruptedMessage: string, + ): Effect.Effect => + Effect.gen(function* () { + const fiber = yield* Effect.forkScoped(effect); + const handle: ActiveGitRequestHandle = { + kind: meta.kind, + cwd: meta.cwd, + actionId: meta.actionId ?? null, + fiber, + }; + yield* registerActiveGitRequest(ws, handle); + const exit = yield* Fiber.await(fiber).pipe( + Effect.ensuring(unregisterActiveGitRequest(ws, handle)), + ); + if (Exit.isSuccess(exit)) { + return exit.value; + } + if (Cause.hasInterruptsOnly(exit.cause)) { + return yield* new GitActionStoppedError({ message: interruptedMessage }); + } + return yield* Effect.failCause(exit.cause as Cause.Cause); + }) as Effect.Effect; function logOutgoingPush(push: WsPushEnvelopeBase, recipients: number) { if (!logWebSocketEvents) return; @@ -1117,24 +1208,40 @@ export const createServer = Effect.fn(function* (): Effect.fn.Return< const body = stripRequestTag(request.body); const snapshot = yield* projectionReadModelQuery.getSnapshot(); const gitEnv = yield* resolveRuntimeEnvironment({ cwd: body.cwd, readModel: snapshot }); - return yield* git - .syncCurrentBranch(body.cwd) - .pipe(Effect.provideService(RuntimeEnv, gitEnv)); + return yield* runTrackedGitRequest( + ws, + { kind: "pull", cwd: body.cwd }, + git.syncCurrentBranch(body.cwd).pipe(Effect.provideService(RuntimeEnv, gitEnv)), + "Git pull stopped.", + ); + } + + case WS_METHODS.gitStopAction: { + const body = stripRequestTag(request.body); + yield* stopActiveGitRequest(ws, body); + return {}; } case WS_METHODS.gitRunStackedAction: { const body = stripRequestTag(request.body); const snapshot = yield* projectionReadModelQuery.getSnapshot(); const gitEnv = yield* resolveRuntimeEnvironment({ cwd: body.cwd, readModel: snapshot }); - return yield* gitManager - .runStackedAction(body, { - actionId: body.actionId, - progressReporter: { - publish: (event) => - pushBus.publishClient(ws, WS_CHANNELS.gitActionProgress, event).pipe(Effect.asVoid), - }, - }) - .pipe(Effect.provideService(RuntimeEnv, gitEnv)); + return yield* runTrackedGitRequest( + ws, + { kind: "stacked_action", cwd: body.cwd, actionId: body.actionId }, + gitManager + .runStackedAction(body, { + actionId: body.actionId, + progressReporter: { + publish: (event) => + pushBus + .publishClient(ws, WS_CHANNELS.gitActionProgress, event) + .pipe(Effect.asVoid), + }, + }) + .pipe(Effect.provideService(RuntimeEnv, gitEnv)), + "Git action stopped.", + ); } case WS_METHODS.gitResolvePullRequest: { @@ -1702,6 +1809,17 @@ export const createServer = Effect.fn(function* (): Effect.fn.Return< }; } + if ( + (request.body._tag === WS_METHODS.gitRunStackedAction || + request.body._tag === WS_METHODS.gitPull) && + Schema.is(GitActionStoppedError)(squashed) + ) { + return { + message: redactSensitiveText(squashed.message), + code: "git_action_stopped", + }; + } + if (squashed instanceof Error) { return { message: redactSensitiveText(squashed.message) }; } @@ -1798,19 +1916,25 @@ export const createServer = Effect.fn(function* (): Effect.fn.Return< ws.on("close", () => { void runPromise( - Ref.update(clients, (clients) => { - clients.delete(ws); - return clients; - }), + Effect.all([ + interruptActiveGitRequests(ws), + Ref.update(clients, (clients) => { + clients.delete(ws); + return clients; + }), + ]).pipe(Effect.asVoid), ); }); ws.on("error", () => { void runPromise( - Ref.update(clients, (clients) => { - clients.delete(ws); - return clients; - }), + Effect.all([ + interruptActiveGitRequests(ws), + Ref.update(clients, (clients) => { + clients.delete(ws); + return clients; + }), + ]).pipe(Effect.asVoid), ); }); }); diff --git a/apps/web/src/components/BranchToolbar.tsx b/apps/web/src/components/BranchToolbar.tsx index d3be9e54f..b6935fec5 100644 --- a/apps/web/src/components/BranchToolbar.tsx +++ b/apps/web/src/components/BranchToolbar.tsx @@ -5,6 +5,7 @@ import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"; import { gitPullMutationOptions, + gitStopActionMutationOptions, gitQueryKeys, gitStatusQueryOptions, invalidateGitQueries, @@ -13,6 +14,7 @@ import { newCommandId } from "../lib/utils"; import { readNativeApi } from "../nativeApi"; import { useComposerDraftStore } from "../composerDraftStore"; import { useStore } from "../store"; +import { isWsRequestError } from "../wsTransport"; import { EnvMode, resolveDraftEnvModeAfterBranchChange, @@ -127,6 +129,7 @@ export default function BranchToolbar({ const isDiverged = aheadCount > 0 && behindCount > 0; const needsSync = behindCount > 0 && !hasServerThread; const pullMutation = useMutation(gitPullMutationOptions({ cwd: gitCwd, queryClient })); + const stopPullMutation = useMutation(gitStopActionMutationOptions({ cwd: gitCwd, queryClient })); // Force a fresh git-status fetch when a draft thread mounts so we catch // upstream changes immediately instead of waiting for the next poll cycle. @@ -164,8 +167,11 @@ export default function BranchToolbar({ }) .catch((error) => { toastManager.add({ - type: "error", - title: "Pull failed", + type: isWsRequestError(error) && error.code === "git_action_stopped" ? "info" : "error", + title: + isWsRequestError(error) && error.code === "git_action_stopped" + ? "Pull stopped" + : "Pull failed", description: error instanceof Error ? error.message : "An error occurred.", }); }) @@ -174,6 +180,17 @@ export default function BranchToolbar({ }); }, [pullMutation, queryClient]); + const handleStopPull = useCallback(() => { + if (!pullMutation.isPending || stopPullMutation.isPending) return; + void stopPullMutation.mutateAsync({}).catch((error) => { + toastManager.add({ + type: "error", + title: "Unable to stop pull", + description: error instanceof Error ? error.message : "An error occurred.", + }); + }); + }, [pullMutation.isPending, stopPullMutation]); + if (!activeThreadId || !activeProject) return null; return ( @@ -259,6 +276,16 @@ export default function BranchToolbar({ ) : null} + {pullMutation.isPending ? ( + + ) : null} 0; const isPullRunning = useIsMutating({ mutationKey: gitMutationKeys.pull(gitCwd) }) > 0; const isGitActionRunning = isRunStackedActionRunning || isPullRunning; + const activeGitActionId = activeGitActionProgressRef.current?.actionId; const isDefaultBranch = useMemo(() => { const branchName = gitStatusForActions?.branch; if (!branchName) return false; @@ -830,8 +836,11 @@ export default function GitActionsControl({ gitCwd, activeThreadId }: GitActions return; } toastManager.update(resolvedProgressToastId, { - type: "error", - title: "Action failed", + type: isWsRequestError(err) && err.code === "git_action_stopped" ? "info" : "error", + title: + isWsRequestError(err) && err.code === "git_action_stopped" + ? "Git action stopped" + : "Action failed", description: err instanceof Error ? err.message : "An error occurred.", data: threadToastData, }); @@ -1078,8 +1087,11 @@ export default function GitActionsControl({ gitCwd, activeThreadId }: GitActions }) .catch((err) => { toastManager.update(loadingToastId, { - type: "error", - title: messages?.errorTitle ?? "Pull failed", + type: isWsRequestError(err) && err.code === "git_action_stopped" ? "info" : "error", + title: + isWsRequestError(err) && err.code === "git_action_stopped" + ? "Pull stopped" + : (messages?.errorTitle ?? "Pull failed"), description: err instanceof Error ? err.message : "An error occurred.", data: threadToastData, }); @@ -1088,6 +1100,22 @@ export default function GitActionsControl({ gitCwd, activeThreadId }: GitActions [pullMutation, threadToastData], ); + const stopPendingGitAction = useCallback(() => { + if (!gitCwd || !isGitActionRunning || stopGitActionMutation.isPending) { + return; + } + void stopGitActionMutation + .mutateAsync(activeGitActionId ? { actionId: activeGitActionId } : {}) + .catch((error) => { + toastManager.add({ + type: "error", + title: "Unable to stop git action", + description: error instanceof Error ? error.message : "An error occurred.", + data: threadToastData, + }); + }); + }, [activeGitActionId, gitCwd, isGitActionRunning, stopGitActionMutation, threadToastData]); + const runSyncAction = useCallback(() => { if (!syncAction || syncAction.disabled) { return; @@ -1234,6 +1262,20 @@ export default function GitActionsControl({ gitCwd, activeThreadId }: GitActions ) : ( + {isGitActionRunning ? ( + <> + + + + ) : null} {syncAction ? ( <> {syncActionDisabledReason ? ( diff --git a/apps/web/src/lib/gitReactQuery.ts b/apps/web/src/lib/gitReactQuery.ts index 13e4c0d15..8711289f4 100644 --- a/apps/web/src/lib/gitReactQuery.ts +++ b/apps/web/src/lib/gitReactQuery.ts @@ -21,6 +21,7 @@ export const gitMutationKeys = { checkout: (cwd: string | null) => ["git", "mutation", "checkout", cwd] as const, runStackedAction: (cwd: string | null) => ["git", "mutation", "run-stacked-action", cwd] as const, pull: (cwd: string | null) => ["git", "mutation", "pull", cwd] as const, + stopAction: (cwd: string | null) => ["git", "mutation", "stop-action", cwd] as const, preparePullRequestThread: (cwd: string | null) => ["git", "mutation", "prepare-pull-request-thread", cwd] as const, }; @@ -202,6 +203,26 @@ export function gitPullMutationOptions(input: { cwd: string | null; queryClient: }); } +export function gitStopActionMutationOptions(input: { + cwd: string | null; + queryClient: QueryClient; +}) { + return mutationOptions({ + mutationKey: gitMutationKeys.stopAction(input.cwd), + mutationFn: async ({ actionId }: { actionId?: string } = {}) => { + const api = ensureNativeApi(); + if (!input.cwd) throw new Error("Stopping git actions is unavailable."); + return api.git.stopAction({ + cwd: input.cwd, + ...(actionId ? { actionId } : {}), + }); + }, + onSettled: async () => { + await invalidateGitQueries(input.queryClient); + }, + }); +} + export function gitCreateWorktreeMutationOptions(input: { queryClient: QueryClient }) { return mutationOptions({ mutationFn: async ({ diff --git a/apps/web/src/wsNativeApi.ts b/apps/web/src/wsNativeApi.ts index cc8897eec..c7fde368a 100644 --- a/apps/web/src/wsNativeApi.ts +++ b/apps/web/src/wsNativeApi.ts @@ -285,6 +285,7 @@ export function createWsNativeApi(): NativeApi { cloneRepository: (input) => transport.request(WS_METHODS.gitCloneRepository, input, { timeoutMs: null }), pull: (input) => transport.request(WS_METHODS.gitPull, input), + stopAction: (input) => transport.request(WS_METHODS.gitStopAction, input), status: (input) => transport.request(WS_METHODS.gitStatus, input), runStackedAction: (input) => transport.request(WS_METHODS.gitRunStackedAction, input, { timeoutMs: null }), diff --git a/packages/contracts/src/git.ts b/packages/contracts/src/git.ts index 2730ed0ea..b9d31ae9a 100644 --- a/packages/contracts/src/git.ts +++ b/packages/contracts/src/git.ts @@ -120,6 +120,12 @@ export const GitPullInput = Schema.Struct({ }); export type GitPullInput = typeof GitPullInput.Type; +export const GitStopActionInput = Schema.Struct({ + cwd: TrimmedNonEmptyStringSchema, + actionId: Schema.optional(TrimmedNonEmptyStringSchema), +}); +export type GitStopActionInput = typeof GitStopActionInput.Type; + export const GitRunStackedActionInput = Schema.Struct({ actionId: TrimmedNonEmptyStringSchema, cwd: TrimmedNonEmptyStringSchema, diff --git a/packages/contracts/src/ipc.ts b/packages/contracts/src/ipc.ts index d533f0f64..cb10d73ed 100644 --- a/packages/contracts/src/ipc.ts +++ b/packages/contracts/src/ipc.ts @@ -22,6 +22,7 @@ import type { GitResolvePullRequestResult, GitRunStackedActionInput, GitRunStackedActionResult, + GitStopActionInput, GitWorktreeCleanupCandidate, GitStatusInput, GitStatusResult, @@ -388,6 +389,7 @@ export interface NativeApi { ) => Promise; // Stacked action API pull: (input: GitPullInput) => Promise; + stopAction: (input: GitStopActionInput) => Promise; status: (input: GitStatusInput) => Promise; runStackedAction: (input: GitRunStackedActionInput) => Promise; onActionProgress: (callback: (event: GitActionProgressEvent) => void) => () => void; diff --git a/packages/contracts/src/ws.ts b/packages/contracts/src/ws.ts index 185d03a03..55a7e760f 100644 --- a/packages/contracts/src/ws.ts +++ b/packages/contracts/src/ws.ts @@ -27,6 +27,7 @@ import { GitPruneWorktreesInput, GitRemoveWorktreeInput, GitRunStackedActionInput, + GitStopActionInput, GitStatusInput, } from "./git"; import { @@ -123,6 +124,7 @@ export const WS_METHODS = { // Git methods gitPull: "git.pull", + gitStopAction: "git.stopAction", gitStatus: "git.status", gitRunStackedAction: "git.runStackedAction", gitListBranches: "git.listBranches", @@ -263,6 +265,7 @@ const WebSocketRequestBody = Schema.Union([ // Git methods tagRequestBody(WS_METHODS.gitPull, GitPullInput), + tagRequestBody(WS_METHODS.gitStopAction, GitStopActionInput), tagRequestBody(WS_METHODS.gitStatus, GitStatusInput), tagRequestBody(WS_METHODS.gitRunStackedAction, GitRunStackedActionInput), tagRequestBody(WS_METHODS.gitListBranches, GitListBranchesInput),