Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions apps/server/src/wsServer.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2142,6 +2142,60 @@ describe("WebSocket Server", () => {
);
});

it("stops a pending git action for the initiating websocket", async () => {
const runStackedAction: GitManagerShape["runStackedAction"] = (input, options) =>
Effect.gen(function* () {
if (options?.progressReporter) {
yield* options.progressReporter.publish({
actionId: options.actionId ?? input.actionId,
cwd: input.cwd,
action: input.action,
kind: "phase_started",
phase: "commit",
label: "Committing...",
});
}
return yield* Effect.never;
});
const gitManager: GitManagerShape = {
status: vi.fn(() => Effect.void as any),
resolvePullRequest: vi.fn(() => Effect.void as any),
preparePullRequestThread: vi.fn(() => Effect.void as any),
runStackedAction,
listPullRequests: vi.fn(() => Effect.succeed({ pullRequests: [] })),
};

const { cwd } = makeWorkspaceFixture("test");
server = await createTestServer({ cwd, gitManager });
const addr = server.address();
const port = typeof addr === "object" && addr !== null ? addr.port : 0;

const [ws] = await connectAndAwaitWelcome(port);
connections.push(ws);

const actionResponsePromise = sendRequest(ws, WS_METHODS.gitRunStackedAction, {
actionId: "client-action-stop",
cwd,
action: "commit",
});
await waitForPush(ws, WS_CHANNELS.gitActionProgress);

const stopResponse = await sendRequest(ws, WS_METHODS.gitStopAction, {
cwd,
actionId: "client-action-stop",
});

expect(stopResponse.error).toBeUndefined();
await expect(actionResponsePromise).resolves.toEqual(
expect.objectContaining({
error: expect.objectContaining({
code: "git_action_stopped",
message: "Git action stopped.",
}),
}),
);
});

it("rejects websocket connections without a valid auth token", async () => {
const { cwd } = makeWorkspaceFixture("test");
server = await createTestServer({ cwd, authToken: "secret-token" });
Expand Down
164 changes: 144 additions & 20 deletions apps/server/src/wsServer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import {
Effect,
Exit,
FileSystem,
Fiber,
Layer,
Path,
Ref,
Expand Down Expand Up @@ -331,6 +332,13 @@ class RouteRequestError extends Schema.TaggedErrorClass<RouteRequestError>()("Ro
message: Schema.String,
}) {}

class GitActionStoppedError extends Schema.TaggedErrorClass<GitActionStoppedError>()(
"GitActionStoppedError",
{
message: Schema.String,
},
) {}

export const createServer = Effect.fn(function* (): Effect.fn.Return<
http.Server,
ServerLifecycleError,
Expand Down Expand Up @@ -374,6 +382,89 @@ export const createServer = Effect.fn(function* (): Effect.fn.Return<
const clients = yield* Ref.make(new Set<WebSocket>());
const logger = createLogger("ws");
const readiness = yield* makeServerReadiness;
type ActiveGitRequestKind = "pull" | "stacked_action";
type ActiveGitRequestHandle = {
readonly kind: ActiveGitRequestKind;
readonly cwd: string;
readonly actionId: string | null;
readonly fiber: Fiber.Fiber<unknown, unknown>;
};
const activeGitRequests = new WeakMap<WebSocket, Set<ActiveGitRequestHandle>>();

const registerActiveGitRequest = (ws: WebSocket, handle: ActiveGitRequestHandle) =>
Effect.sync(() => {
const handles = activeGitRequests.get(ws) ?? new Set<ActiveGitRequestHandle>();
handles.add(handle);
activeGitRequests.set(ws, handles);
});

const unregisterActiveGitRequest = (ws: WebSocket, handle: ActiveGitRequestHandle) =>
Effect.sync(() => {
const handles = activeGitRequests.get(ws);
if (!handles) {
return;
}
handles.delete(handle);
if (handles.size === 0) {
activeGitRequests.delete(ws);
}
});

const interruptActiveGitRequests = (ws: WebSocket) =>
Effect.gen(function* () {
const handles = Array.from(activeGitRequests.get(ws) ?? []);
activeGitRequests.delete(ws);
for (const handle of handles) {
yield* Fiber.interrupt(handle.fiber).pipe(Effect.ignore);
}
});

const stopActiveGitRequest = (
ws: WebSocket,
input: { cwd: string; actionId?: string | undefined },
) =>
Effect.gen(function* () {
const handles = Array.from(activeGitRequests.get(ws) ?? []);
const handle =
input.actionId != null
? handles.find(
(candidate) => candidate.cwd === input.cwd && candidate.actionId === input.actionId,
)
: handles.find((candidate) => candidate.cwd === input.cwd);

if (!handle) {
return;
}

yield* Fiber.interrupt(handle.fiber);
});

const runTrackedGitRequest = <A, E>(
ws: WebSocket,
meta: { kind: ActiveGitRequestKind; cwd: string; actionId?: string | undefined },
effect: Effect.Effect<A, E, never>,
interruptedMessage: string,
): Effect.Effect<A, E | GitActionStoppedError> =>
Effect.gen(function* () {
const fiber = yield* Effect.forkScoped(effect);
const handle: ActiveGitRequestHandle = {
kind: meta.kind,
cwd: meta.cwd,
actionId: meta.actionId ?? null,
fiber,
};
yield* registerActiveGitRequest(ws, handle);
const exit = yield* Fiber.await(fiber).pipe(
Effect.ensuring(unregisterActiveGitRequest(ws, handle)),
);
if (Exit.isSuccess(exit)) {
return exit.value;
}
if (Cause.hasInterruptsOnly(exit.cause)) {
return yield* new GitActionStoppedError({ message: interruptedMessage });
}
return yield* Effect.failCause(exit.cause as Cause.Cause<E>);
}) as Effect.Effect<A, E | GitActionStoppedError, never>;

function logOutgoingPush(push: WsPushEnvelopeBase, recipients: number) {
if (!logWebSocketEvents) return;
Expand Down Expand Up @@ -1117,24 +1208,40 @@ export const createServer = Effect.fn(function* (): Effect.fn.Return<
const body = stripRequestTag(request.body);
const snapshot = yield* projectionReadModelQuery.getSnapshot();
const gitEnv = yield* resolveRuntimeEnvironment({ cwd: body.cwd, readModel: snapshot });
return yield* git
.syncCurrentBranch(body.cwd)
.pipe(Effect.provideService(RuntimeEnv, gitEnv));
return yield* runTrackedGitRequest(
ws,
{ kind: "pull", cwd: body.cwd },
git.syncCurrentBranch(body.cwd).pipe(Effect.provideService(RuntimeEnv, gitEnv)),
"Git pull stopped.",
);
}

case WS_METHODS.gitStopAction: {
const body = stripRequestTag(request.body);
yield* stopActiveGitRequest(ws, body);
return {};
}

case WS_METHODS.gitRunStackedAction: {
const body = stripRequestTag(request.body);
const snapshot = yield* projectionReadModelQuery.getSnapshot();
const gitEnv = yield* resolveRuntimeEnvironment({ cwd: body.cwd, readModel: snapshot });
return yield* gitManager
.runStackedAction(body, {
actionId: body.actionId,
progressReporter: {
publish: (event) =>
pushBus.publishClient(ws, WS_CHANNELS.gitActionProgress, event).pipe(Effect.asVoid),
},
})
.pipe(Effect.provideService(RuntimeEnv, gitEnv));
return yield* runTrackedGitRequest(
ws,
{ kind: "stacked_action", cwd: body.cwd, actionId: body.actionId },
gitManager
.runStackedAction(body, {
actionId: body.actionId,
progressReporter: {
publish: (event) =>
pushBus
.publishClient(ws, WS_CHANNELS.gitActionProgress, event)
.pipe(Effect.asVoid),
},
})
.pipe(Effect.provideService(RuntimeEnv, gitEnv)),
"Git action stopped.",
);
}

case WS_METHODS.gitResolvePullRequest: {
Expand Down Expand Up @@ -1702,6 +1809,17 @@ export const createServer = Effect.fn(function* (): Effect.fn.Return<
};
}

if (
(request.body._tag === WS_METHODS.gitRunStackedAction ||
request.body._tag === WS_METHODS.gitPull) &&
Schema.is(GitActionStoppedError)(squashed)
) {
return {
message: redactSensitiveText(squashed.message),
code: "git_action_stopped",
};
}

if (squashed instanceof Error) {
return { message: redactSensitiveText(squashed.message) };
}
Expand Down Expand Up @@ -1798,19 +1916,25 @@ export const createServer = Effect.fn(function* (): Effect.fn.Return<

ws.on("close", () => {
void runPromise(
Ref.update(clients, (clients) => {
clients.delete(ws);
return clients;
}),
Effect.all([
interruptActiveGitRequests(ws),
Ref.update(clients, (clients) => {
clients.delete(ws);
return clients;
}),
]).pipe(Effect.asVoid),
);
});

ws.on("error", () => {
void runPromise(
Ref.update(clients, (clients) => {
clients.delete(ws);
return clients;
}),
Effect.all([
interruptActiveGitRequests(ws),
Ref.update(clients, (clients) => {
clients.delete(ws);
return clients;
}),
]).pipe(Effect.asVoid),
);
});
});
Expand Down
31 changes: 29 additions & 2 deletions apps/web/src/components/BranchToolbar.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query";

import {
gitPullMutationOptions,
gitStopActionMutationOptions,
gitQueryKeys,
gitStatusQueryOptions,
invalidateGitQueries,
Expand All @@ -13,6 +14,7 @@ import { newCommandId } from "../lib/utils";
import { readNativeApi } from "../nativeApi";
import { useComposerDraftStore } from "../composerDraftStore";
import { useStore } from "../store";
import { isWsRequestError } from "../wsTransport";
import {
EnvMode,
resolveDraftEnvModeAfterBranchChange,
Expand Down Expand Up @@ -127,6 +129,7 @@ export default function BranchToolbar({
const isDiverged = aheadCount > 0 && behindCount > 0;
const needsSync = behindCount > 0 && !hasServerThread;
const pullMutation = useMutation(gitPullMutationOptions({ cwd: gitCwd, queryClient }));
const stopPullMutation = useMutation(gitStopActionMutationOptions({ cwd: gitCwd, queryClient }));

// Force a fresh git-status fetch when a draft thread mounts so we catch
// upstream changes immediately instead of waiting for the next poll cycle.
Expand Down Expand Up @@ -164,8 +167,11 @@ export default function BranchToolbar({
})
.catch((error) => {
toastManager.add({
type: "error",
title: "Pull failed",
type: isWsRequestError(error) && error.code === "git_action_stopped" ? "info" : "error",
title:
isWsRequestError(error) && error.code === "git_action_stopped"
? "Pull stopped"
: "Pull failed",
description: error instanceof Error ? error.message : "An error occurred.",
});
})
Expand All @@ -174,6 +180,17 @@ export default function BranchToolbar({
});
}, [pullMutation, queryClient]);

const handleStopPull = useCallback(() => {
if (!pullMutation.isPending || stopPullMutation.isPending) return;
void stopPullMutation.mutateAsync({}).catch((error) => {
toastManager.add({
type: "error",
title: "Unable to stop pull",
description: error instanceof Error ? error.message : "An error occurred.",
});
});
}, [pullMutation.isPending, stopPullMutation]);

if (!activeThreadId || !activeProject) return null;

return (
Expand Down Expand Up @@ -259,6 +276,16 @@ export default function BranchToolbar({
</TooltipPopup>
</Tooltip>
) : null}
{pullMutation.isPending ? (
<Button
variant="destructive-outline"
size="xs"
disabled={stopPullMutation.isPending}
onClick={handleStopPull}
>
Stop
</Button>
) : null}
<BranchToolbarBranchSelector
activeProjectCwd={activeProject.cwd}
activeThreadBranch={activeThreadBranch}
Expand Down
Loading
Loading