From ef6d7a48495c3c03f2bb28f7122da8da097e7818 Mon Sep 17 00:00:00 2001 From: Val Alexander Date: Thu, 9 Apr 2026 14:48:07 -0500 Subject: [PATCH] Preserve local draft threads across project switches - Keep composer drafts when remapping or clearing project draft threads - Merge draft-only threads into the sidebar and route views - Fix new-thread flow so switching projects creates a fresh draft thread --- apps/web/src/components/ChatView.browser.tsx | 9 +- apps/web/src/components/ChatView.logic.ts | 38 +------ apps/web/src/components/Sidebar.logic.test.ts | 58 ++++++++++ apps/web/src/components/Sidebar.logic.ts | 25 +++++ apps/web/src/components/Sidebar.tsx | 105 +++++++++++++----- apps/web/src/composerDraftStore.test.ts | 23 ++-- apps/web/src/composerDraftStore.ts | 44 +------- apps/web/src/draftThreads.ts | 34 ++++++ apps/web/src/hooks/useHandleNewThread.ts | 8 +- apps/web/src/routes/_chat.$threadId.tsx | 10 +- 10 files changed, 237 insertions(+), 117 deletions(-) create mode 100644 apps/web/src/draftThreads.ts diff --git a/apps/web/src/components/ChatView.browser.tsx b/apps/web/src/components/ChatView.browser.tsx index 426f266de..84d08360c 100644 --- a/apps/web/src/components/ChatView.browser.tsx +++ b/apps/web/src/components/ChatView.browser.tsx @@ -1776,11 +1776,12 @@ describe("ChatView timeline estimator parity (full app)", () => { await newThreadButton.click(); - await waitForURL( + const secondThreadPath = await waitForURL( mounted.router, - (path) => path === threadPath, - "New-thread should reuse the existing project draft thread.", + (path) => UUID_ROUTE_RE.test(path) && path !== threadPath, + "New-thread should create a fresh draft thread UUID.", ); + const secondThreadId = secondThreadPath.slice(1) as ThreadId; expect(useComposerDraftStore.getState().draftsByThreadId[threadId]).toMatchObject({ model: "gpt-5.4", modelOptions: { @@ -1790,6 +1791,8 @@ describe("ChatView timeline estimator parity (full app)", () => { }, }, }); + expect(useComposerDraftStore.getState().draftThreadsByThreadId[threadId]).toBeDefined(); + expect(useComposerDraftStore.getState().draftThreadsByThreadId[secondThreadId]).toBeDefined(); } finally { await mounted.cleanup(); } diff --git a/apps/web/src/components/ChatView.logic.ts b/apps/web/src/components/ChatView.logic.ts index 5d78f3ad9..23f88d51f 100644 --- a/apps/web/src/components/ChatView.logic.ts +++ b/apps/web/src/components/ChatView.logic.ts @@ -1,7 +1,7 @@ -import { type MessageId, ProjectId, type ThreadId } from "@okcode/contracts"; -import { type ChatMessage, type Thread } from "../types"; +import { type MessageId, ProjectId } from "@okcode/contracts"; +import { type ChatMessage } from "../types"; import { randomUUID } from "~/lib/utils"; -import { type ComposerAttachment, type DraftThreadState } from "../composerDraftStore"; +import { type ComposerAttachment } from "../composerDraftStore"; import { Schema } from "effect"; import { filterTerminalContextsWithText, @@ -9,43 +9,13 @@ import { type TerminalContextDraft, } from "../lib/terminalContext"; import { type PromptEnhancementId } from "../promptEnhancement"; -import { normalizeThreadTitle } from "../threadTitle"; +export { buildLocalDraftThread } from "../draftThreads"; export const LAST_INVOKED_SCRIPT_BY_PROJECT_KEY = "okcode:last-invoked-script-by-project"; const WORKTREE_BRANCH_PREFIX = "okcode"; export const LastInvokedScriptByProjectSchema = Schema.Record(ProjectId, Schema.String); -export function buildLocalDraftThread( - threadId: ThreadId, - draftThread: DraftThreadState, - fallbackModel: string, - error: string | null, -): Thread { - return { - id: threadId, - codexThreadId: null, - projectId: draftThread.projectId, - title: normalizeThreadTitle(draftThread.title), - model: fallbackModel, - runtimeMode: draftThread.runtimeMode, - interactionMode: draftThread.interactionMode, - session: null, - messages: [], - error, - createdAt: draftThread.createdAt, - latestTurn: null, - lastVisitedAt: draftThread.createdAt, - branch: draftThread.branch, - worktreePath: draftThread.worktreePath, - worktreeBaseBranch: null, - ...(draftThread.githubRef ? { githubRef: draftThread.githubRef } : {}), - turnDiffSummaries: [], - activities: [], - proposedPlans: [], - }; -} - export function revokeBlobPreviewUrl(previewUrl: string | undefined): void { if (!previewUrl || typeof URL === "undefined" || !previewUrl.startsWith("blob:")) { return; diff --git a/apps/web/src/components/Sidebar.logic.test.ts b/apps/web/src/components/Sidebar.logic.test.ts index 531d811c7..3c77fd2bd 100644 --- a/apps/web/src/components/Sidebar.logic.test.ts +++ b/apps/web/src/components/Sidebar.logic.test.ts @@ -5,6 +5,7 @@ import { getVisibleThreadsForProject, getProjectSortTimestamp, hasUnseenCompletion, + mergeDraftThreadsIntoSidebarThreads, resolveProjectStatusIndicator, resolveProjectNameTone, resolveSidebarNewThreadEnvMode, @@ -16,6 +17,7 @@ import { sortThreadsForSidebar, } from "./Sidebar.logic"; import { ProjectId, ThreadId } from "@okcode/contracts"; +import type { DraftThreadState } from "../composerDraftStore"; import { DEFAULT_INTERACTION_MODE, DEFAULT_RUNTIME_MODE, @@ -99,6 +101,62 @@ describe("resolveSidebarNewThreadEnvMode", () => { }); }); +describe("mergeDraftThreadsIntoSidebarThreads", () => { + it("includes preserved local drafts alongside server threads for the same project", () => { + const projectId = ProjectId.makeUnsafe("project-1"); + const merged = mergeDraftThreadsIntoSidebarThreads({ + serverThreads: [makeThread({ id: ThreadId.makeUnsafe("thread-server"), projectId })], + draftThreadsByThreadId: { + [ThreadId.makeUnsafe("thread-draft")]: { + projectId, + createdAt: "2026-03-09T11:00:00.000Z", + title: "Draft thread", + runtimeMode: DEFAULT_RUNTIME_MODE, + interactionMode: DEFAULT_INTERACTION_MODE, + branch: null, + worktreePath: null, + envMode: "local", + } satisfies DraftThreadState, + }, + projectModelByProjectId: new Map([[projectId, "gpt-5.4"]]), + }); + + expect(merged.map((thread) => thread.id)).toEqual([ + ThreadId.makeUnsafe("thread-server"), + ThreadId.makeUnsafe("thread-draft"), + ]); + expect(merged[1]).toMatchObject({ + projectId, + title: "Draft thread", + model: "gpt-5.4", + }); + }); + + it("skips draft entries once a server thread with the same id exists", () => { + const projectId = ProjectId.makeUnsafe("project-1"); + const threadId = ThreadId.makeUnsafe("thread-shared"); + const merged = mergeDraftThreadsIntoSidebarThreads({ + serverThreads: [makeThread({ id: threadId, projectId })], + draftThreadsByThreadId: { + [threadId]: { + projectId, + createdAt: "2026-03-09T11:00:00.000Z", + title: "Promoted draft", + runtimeMode: DEFAULT_RUNTIME_MODE, + interactionMode: DEFAULT_INTERACTION_MODE, + branch: null, + worktreePath: null, + envMode: "local", + } satisfies DraftThreadState, + }, + projectModelByProjectId: new Map([[projectId, "gpt-5.4"]]), + }); + + expect(merged).toHaveLength(1); + expect(merged[0]?.id).toBe(threadId); + }); +}); + describe("resolveThreadStatusPill", () => { const baseThread = { interactionMode: "plan" as const, diff --git a/apps/web/src/components/Sidebar.logic.ts b/apps/web/src/components/Sidebar.logic.ts index ac84a36ea..11ddbeccc 100644 --- a/apps/web/src/components/Sidebar.logic.ts +++ b/apps/web/src/components/Sidebar.logic.ts @@ -1,4 +1,7 @@ +import { DEFAULT_MODEL_BY_PROVIDER } from "@okcode/contracts"; import type { SidebarProjectSortOrder, SidebarThreadSortOrder } from "../appSettings"; +import type { DraftThreadState } from "../composerDraftStore"; +import { buildLocalDraftThread } from "../draftThreads"; import type { Thread } from "../types"; import { cn } from "../lib/utils"; import { @@ -283,6 +286,28 @@ export function groupThreadsByProjectId( return threadsByProjectId; } +export function mergeDraftThreadsIntoSidebarThreads(input: { + serverThreads: readonly Thread[]; + draftThreadsByThreadId: Readonly>; + projectModelByProjectId: ReadonlyMap; +}): Thread[] { + const serverThreadIds = new Set(input.serverThreads.map((thread) => thread.id)); + const mergedThreads = [...input.serverThreads]; + + for (const [threadId, draftThread] of Object.entries(input.draftThreadsByThreadId)) { + if (serverThreadIds.has(threadId as Thread["id"])) { + continue; + } + const fallbackModel = + input.projectModelByProjectId.get(draftThread.projectId) ?? DEFAULT_MODEL_BY_PROVIDER.codex; + mergedThreads.push( + buildLocalDraftThread(threadId as Thread["id"], draftThread, fallbackModel, null), + ); + } + + return mergedThreads; +} + function toSortableTimestamp(iso: string | undefined): number | null { if (!iso) return null; const ms = Date.parse(iso); diff --git a/apps/web/src/components/Sidebar.tsx b/apps/web/src/components/Sidebar.tsx index 1c3956aa2..a71a56228 100644 --- a/apps/web/src/components/Sidebar.tsx +++ b/apps/web/src/components/Sidebar.tsx @@ -96,6 +96,7 @@ import { OkCodeMark } from "./OkCodeMark"; import { getVisibleThreadsForProject, isActionableThreadStatus, + mergeDraftThreadsIntoSidebarThreads, resolveProjectNameTone, resolveSidebarNewThreadEnvMode, resolveThreadStatusPill, @@ -292,6 +293,7 @@ function SortableProjectItem({ interface MemoizedThreadRowProps { thread: Thread; + isDraft: boolean; isActive: boolean; isSelected: boolean; prByThreadId: Map; @@ -300,7 +302,7 @@ interface MemoizedThreadRowProps { editingThreadId: ThreadIdType | null; editingThreadTitle: string; bindInputRef: (node: HTMLInputElement | null) => void; - startEditing: (opts: { threadId: ThreadIdType; title: string }) => void; + startEditing: (opts: { threadId: ThreadIdType; title: string; isDraft?: boolean }) => void; setDraftTitle: (title: string) => void; commitEditing: () => Promise | void; cancelEditing: () => void; @@ -322,6 +324,7 @@ interface MemoizedThreadRowProps { const MemoizedThreadRow = memo( function ThreadRow({ thread, + isDraft, isActive, isSelected, prByThreadId, @@ -430,6 +433,7 @@ const MemoizedThreadRow = memo( startEditing({ threadId: thread.id, title: thread.title, + isDraft, }); }} onDraftTitleChange={setDraftTitle} @@ -443,6 +447,7 @@ const MemoizedThreadRow = memo( ); }, (prev, next) => { + if (prev.isDraft !== next.isDraft) return false; if (prev.isActive !== next.isActive) return false; if (prev.isSelected !== next.isSelected) return false; if (prev.thread.title !== next.thread.title) return false; @@ -466,12 +471,16 @@ export default function Sidebar() { const threads = useStore((store) => store.threads); const markThreadUnread = useStore((store) => store.markThreadUnread); const toggleProject = useStore((store) => store.toggleProject); + const setProjectExpanded = useStore((store) => store.setProjectExpanded); const setAllProjectsExpanded = useStore((store) => store.setAllProjectsExpanded); const reorderProjects = useStore((store) => store.reorderProjects); const clearComposerDraftForThread = useComposerDraftStore((store) => store.clearDraftThread); + const clearDraftThread = useComposerDraftStore((store) => store.clearDraftThread); + const draftThreadsByThreadId = useComposerDraftStore((store) => store.draftThreadsByThreadId); const getDraftThreadByProjectId = useComposerDraftStore( (store) => store.getDraftThreadByProjectId, ); + const setDraftThreadTitle = useComposerDraftStore((store) => store.setDraftThreadTitle); const clearTerminalState = useTerminalStateStore((state) => state.clearTerminalState); const clearProjectDraftThreadId = useComposerDraftStore( (store) => store.clearProjectDraftThreadId, @@ -539,7 +548,11 @@ export default function Sidebar() { commitEditing, setDraftTitle, startEditing, - } = useThreadTitleEditor(); + } = useThreadTitleEditor({ + onRenameDraftThread: (threadId, title) => { + setDraftThreadTitle(threadId, title); + }, + }); const { editingProjectId, draftProjectTitle, @@ -557,14 +570,29 @@ export default function Sidebar() { () => new Map(projects.map((project) => [project.id, project.cwd] as const)), [projects], ); + const projectModelById = useMemo( + () => new Map(projects.map((project) => [project.id, project.model] as const)), + [projects], + ); + const sidebarThreads = useMemo( + () => + mergeDraftThreadsIntoSidebarThreads({ + serverThreads: threads, + draftThreadsByThreadId, + projectModelByProjectId: projectModelById, + }), + [draftThreadsByThreadId, projectModelById, threads], + ); const threadById = useMemo( - () => new Map(threads.map((thread) => [thread.id, thread] as const)), - [threads], + () => new Map(sidebarThreads.map((thread) => [thread.id, thread] as const)), + [sidebarThreads], ); + const serverThreadIds = useMemo(() => new Set(threads.map((thread) => thread.id)), [threads]); const activeProjectId = routeThreadId ? (threadById.get(routeThreadId)?.projectId ?? null) : null; + const lastAutoExpandedThreadIdRef = useRef(null); const sortedThreadsByProjectId = useMemo( - () => sortThreadsByProjectIdForSidebar(threads, appSettings.sidebarThreadSortOrder), - [appSettings.sidebarThreadSortOrder, threads], + () => sortThreadsByProjectIdForSidebar(sidebarThreads, appSettings.sidebarThreadSortOrder), + [appSettings.sidebarThreadSortOrder, sidebarThreads], ); const orderedThreadIdsByProjectId = useMemo(() => { const orderedThreadIds = new Map(); @@ -586,14 +614,25 @@ export default function Sidebar() { } return latestThreads; }, [sortedThreadsByProjectId]); + + useEffect(() => { + if (!routeThreadId || !activeProjectId) { + return; + } + if (lastAutoExpandedThreadIdRef.current === routeThreadId) { + return; + } + lastAutoExpandedThreadIdRef.current = routeThreadId; + setProjectExpanded(activeProjectId, true); + }, [activeProjectId, routeThreadId, setProjectExpanded]); const threadGitTargets = useMemo( () => - threads.map((thread) => ({ + sidebarThreads.map((thread) => ({ threadId: thread.id, branch: thread.branch, cwd: thread.worktreePath ?? projectCwdById.get(thread.projectId) ?? null, })), - [projectCwdById, threads], + [projectCwdById, sidebarThreads], ); const threadGitStatusCwds = useMemo( () => [ @@ -789,10 +828,10 @@ export default function Sidebar() { threadId: ThreadId, opts: { deletedThreadIds?: ReadonlySet } = {}, ): Promise => { - const api = readNativeApi(); - if (!api) return; const thread = threadById.get(threadId); if (!thread) return; + const api = readNativeApi(); + const isDraftThread = !serverThreadIds.has(threadId); const threadProject = projectById.get(thread.projectId); // When bulk-deleting, exclude the other threads being deleted so // getOrphanedWorktreePathForThread correctly detects that no surviving @@ -800,14 +839,15 @@ export default function Sidebar() { const deletedIds = opts.deletedThreadIds; const survivingThreads = deletedIds && deletedIds.size > 0 - ? threads.filter((t) => t.id === threadId || !deletedIds.has(t.id)) - : threads; + ? sidebarThreads.filter((t) => t.id === threadId || !deletedIds.has(t.id)) + : sidebarThreads; const orphanedWorktreePath = getOrphanedWorktreePathForThread(survivingThreads, threadId); const displayWorktreePath = orphanedWorktreePath ? formatWorktreePathForDisplay(orphanedWorktreePath) : null; const canDeleteWorktree = orphanedWorktreePath !== null && threadProject !== undefined; const shouldDeleteWorktree = + api && canDeleteWorktree && (await api.dialogs.confirm( [ @@ -818,7 +858,7 @@ export default function Sidebar() { ].join("\n"), )); - if (thread.session && thread.session.status !== "closed") { + if (!isDraftThread && api && thread.session && thread.session.status !== "closed") { await api.orchestration .dispatchCommand({ type: "thread.session.stop", @@ -829,22 +869,28 @@ export default function Sidebar() { .catch(() => undefined); } - try { - await api.terminal.close({ threadId, deleteHistory: true }); - } catch { - // Terminal may already be closed + if (api) { + try { + await api.terminal.close({ threadId, deleteHistory: true }); + } catch { + // Terminal may already be closed + } } const allDeletedIds = deletedIds ?? new Set(); const shouldNavigateToFallback = routeThreadId === threadId; const fallbackThreadId = - threads.find((entry) => entry.id !== threadId && !allDeletedIds.has(entry.id))?.id ?? null; - await api.orchestration.dispatchCommand({ - type: "thread.delete", - commandId: newCommandId(), - threadId, - }); - clearComposerDraftForThread(threadId); + sidebarThreads.find((entry) => entry.id !== threadId && !allDeletedIds.has(entry.id))?.id ?? + null; + if (!isDraftThread) { + if (!api) return; + await api.orchestration.dispatchCommand({ + type: "thread.delete", + commandId: newCommandId(), + threadId, + }); + } + clearDraftThread(threadId); clearProjectDraftThreadById(thread.projectId, thread.id); clearTerminalState(threadId); if (shouldNavigateToFallback) { @@ -885,15 +931,16 @@ export default function Sidebar() { } }, [ - clearComposerDraftForThread, + clearDraftThread, clearProjectDraftThreadById, clearTerminalState, navigate, projectById, removeWorktreeMutation, routeThreadId, + serverThreadIds, + sidebarThreads, threadById, - threads, ], ); @@ -938,12 +985,13 @@ export default function Sidebar() { if (!api) return; const thread = threadById.get(threadId); if (!thread) return; + const isDraftThread = !serverThreadIds.has(threadId); const threadWorkspacePath = thread.worktreePath ?? projectCwdById.get(thread.projectId) ?? null; const clicked = await api.contextMenu.show( [ { id: "rename", label: "Rename thread" }, - { id: "mark-unread", label: "Mark unread" }, + ...(!isDraftThread ? [{ id: "mark-unread", label: "Mark unread" }] : []), { id: "copy-path", label: "Copy Path" }, { id: "copy-thread-id", label: "Copy Thread ID" }, { id: "delete", label: "Delete", destructive: true }, @@ -955,6 +1003,7 @@ export default function Sidebar() { startEditing({ threadId, title: thread.title, + isDraft: isDraftThread, }); return; } @@ -1000,6 +1049,7 @@ export default function Sidebar() { deleteThread, markThreadUnread, projectCwdById, + serverThreadIds, startEditing, threadById, ], @@ -1384,6 +1434,7 @@ export default function Sidebar() { { store.clearProjectDraftThreadById(projectId, threadId); expect(useComposerDraftStore.getState().getDraftThreadByProjectId(projectId)).toBeNull(); - expect(useComposerDraftStore.getState().getDraftThread(threadId)).toBeNull(); - expect(useComposerDraftStore.getState().draftsByThreadId[threadId]).toBeUndefined(); + expect(useComposerDraftStore.getState().getDraftThread(threadId)).toMatchObject({ + projectId, + title: "New thread", + }); + expect(useComposerDraftStore.getState().draftsByThreadId[threadId]?.prompt).toBe("hello"); }); it("clears project draft mapping by project id", () => { @@ -520,11 +523,14 @@ describe("composerDraftStore project draft thread mapping", () => { store.setPrompt(threadId, "hello"); store.clearProjectDraftThreadId(projectId); expect(useComposerDraftStore.getState().getDraftThreadByProjectId(projectId)).toBeNull(); - expect(useComposerDraftStore.getState().getDraftThread(threadId)).toBeNull(); - expect(useComposerDraftStore.getState().draftsByThreadId[threadId]).toBeUndefined(); + expect(useComposerDraftStore.getState().getDraftThread(threadId)).toMatchObject({ + projectId, + title: "New thread", + }); + expect(useComposerDraftStore.getState().draftsByThreadId[threadId]?.prompt).toBe("hello"); }); - it("clears orphaned composer drafts when remapping a project to a new draft thread", () => { + it("preserves older drafts when remapping a project to a new active draft thread", () => { const store = useComposerDraftStore.getState(); store.setProjectDraftThreadId(projectId, threadId); store.setPrompt(threadId, "orphan me"); @@ -534,8 +540,11 @@ describe("composerDraftStore project draft thread mapping", () => { expect(useComposerDraftStore.getState().getDraftThreadByProjectId(projectId)?.threadId).toBe( otherThreadId, ); - expect(useComposerDraftStore.getState().getDraftThread(threadId)).toBeNull(); - expect(useComposerDraftStore.getState().draftsByThreadId[threadId]).toBeUndefined(); + expect(useComposerDraftStore.getState().getDraftThread(threadId)).toMatchObject({ + projectId, + title: "New thread", + }); + expect(useComposerDraftStore.getState().draftsByThreadId[threadId]?.prompt).toBe("orphan me"); }); it("keeps composer drafts when the thread is still mapped by another project", () => { diff --git a/apps/web/src/composerDraftStore.ts b/apps/web/src/composerDraftStore.ts index 0a5b47b61..40b0b1424 100644 --- a/apps/web/src/composerDraftStore.ts +++ b/apps/web/src/composerDraftStore.ts @@ -1156,20 +1156,8 @@ export const useComposerDraftStore = create()( ...state.draftThreadsByThreadId, [threadId]: nextDraftThread, }; - let nextDraftsByThreadId = state.draftsByThreadId; - if ( - previousThreadIdForProject && - previousThreadIdForProject !== threadId && - !Object.values(nextProjectDraftThreadIdByProjectId).includes(previousThreadIdForProject) - ) { - delete nextDraftThreadsByThreadId[previousThreadIdForProject]; - if (state.draftsByThreadId[previousThreadIdForProject] !== undefined) { - nextDraftsByThreadId = { ...state.draftsByThreadId }; - delete nextDraftsByThreadId[previousThreadIdForProject]; - } - } return { - draftsByThreadId: nextDraftsByThreadId, + draftsByThreadId: state.draftsByThreadId, draftThreadsByThreadId: nextDraftThreadsByThreadId, projectDraftThreadIdByProjectId: nextProjectDraftThreadIdByProjectId, }; @@ -1277,20 +1265,9 @@ export const useComposerDraftStore = create()( const { [projectId]: _removed, ...restProjectMappingsRaw } = state.projectDraftThreadIdByProjectId; const restProjectMappings = restProjectMappingsRaw as Record; - const nextDraftThreadsByThreadId: Record = { - ...state.draftThreadsByThreadId, - }; - let nextDraftsByThreadId = state.draftsByThreadId; - if (!Object.values(restProjectMappings).includes(threadId)) { - delete nextDraftThreadsByThreadId[threadId]; - if (state.draftsByThreadId[threadId] !== undefined) { - nextDraftsByThreadId = { ...state.draftsByThreadId }; - delete nextDraftsByThreadId[threadId]; - } - } return { - draftsByThreadId: nextDraftsByThreadId, - draftThreadsByThreadId: nextDraftThreadsByThreadId, + draftsByThreadId: state.draftsByThreadId, + draftThreadsByThreadId: state.draftThreadsByThreadId, projectDraftThreadIdByProjectId: restProjectMappings, }; }); @@ -1306,20 +1283,9 @@ export const useComposerDraftStore = create()( const { [projectId]: _removed, ...restProjectMappingsRaw } = state.projectDraftThreadIdByProjectId; const restProjectMappings = restProjectMappingsRaw as Record; - const nextDraftThreadsByThreadId: Record = { - ...state.draftThreadsByThreadId, - }; - let nextDraftsByThreadId = state.draftsByThreadId; - if (!Object.values(restProjectMappings).includes(threadId)) { - delete nextDraftThreadsByThreadId[threadId]; - if (state.draftsByThreadId[threadId] !== undefined) { - nextDraftsByThreadId = { ...state.draftsByThreadId }; - delete nextDraftsByThreadId[threadId]; - } - } return { - draftsByThreadId: nextDraftsByThreadId, - draftThreadsByThreadId: nextDraftThreadsByThreadId, + draftsByThreadId: state.draftsByThreadId, + draftThreadsByThreadId: state.draftThreadsByThreadId, projectDraftThreadIdByProjectId: restProjectMappings, }; }); diff --git a/apps/web/src/draftThreads.ts b/apps/web/src/draftThreads.ts new file mode 100644 index 000000000..073eec927 --- /dev/null +++ b/apps/web/src/draftThreads.ts @@ -0,0 +1,34 @@ +import type { ThreadId } from "@okcode/contracts"; +import type { DraftThreadState } from "./composerDraftStore"; +import type { Thread } from "./types"; +import { normalizeThreadTitle } from "./threadTitle"; + +export function buildLocalDraftThread( + threadId: ThreadId, + draftThread: DraftThreadState, + fallbackModel: string, + error: string | null, +): Thread { + return { + id: threadId, + codexThreadId: null, + projectId: draftThread.projectId, + title: normalizeThreadTitle(draftThread.title), + model: fallbackModel, + runtimeMode: draftThread.runtimeMode, + interactionMode: draftThread.interactionMode, + session: null, + messages: [], + error, + createdAt: draftThread.createdAt, + latestTurn: null, + lastVisitedAt: draftThread.createdAt, + branch: draftThread.branch, + worktreePath: draftThread.worktreePath, + worktreeBaseBranch: null, + ...(draftThread.githubRef ? { githubRef: draftThread.githubRef } : {}), + turnDiffSummaries: [], + activities: [], + proposedPlans: [], + }; +} diff --git a/apps/web/src/hooks/useHandleNewThread.ts b/apps/web/src/hooks/useHandleNewThread.ts index d7392a9c5..97e8660f4 100644 --- a/apps/web/src/hooks/useHandleNewThread.ts +++ b/apps/web/src/hooks/useHandleNewThread.ts @@ -37,10 +37,10 @@ export function useHandleNewThread() { branch?: string | null; worktreePath?: string | null; envMode?: DraftThreadEnvMode; + reuseExistingDraft?: boolean; }, ): Promise => { const { - clearProjectDraftThreadId, getDraftThread, getDraftThreadByProjectId, setModel, @@ -52,11 +52,12 @@ export function useHandleNewThread() { const hasBranchOption = options?.branch !== undefined; const hasWorktreePathOption = options?.worktreePath !== undefined; const hasEnvModeOption = options?.envMode !== undefined; + const shouldReuseExistingDraft = options?.reuseExistingDraft === true; const storedDraftThread = getDraftThreadByProjectId(projectId); const latestActiveDraftThread: DraftThreadState | null = routeThreadId ? getDraftThread(routeThreadId) : null; - if (storedDraftThread) { + if (shouldReuseExistingDraft && storedDraftThread) { return (async () => { if (hasBranchOption || hasWorktreePathOption || hasEnvModeOption) { setDraftThreadContext(storedDraftThread.threadId, { @@ -76,9 +77,8 @@ export function useHandleNewThread() { })(); } - clearProjectDraftThreadId(projectId); - if ( + shouldReuseExistingDraft && latestActiveDraftThread && routeThreadId && latestActiveDraftThread.projectId === projectId diff --git a/apps/web/src/routes/_chat.$threadId.tsx b/apps/web/src/routes/_chat.$threadId.tsx index e97a07dd2..882be464c 100644 --- a/apps/web/src/routes/_chat.$threadId.tsx +++ b/apps/web/src/routes/_chat.$threadId.tsx @@ -176,6 +176,9 @@ function ChatThreadRouteView() { const draftThreadExists = useComposerDraftStore((store) => Object.hasOwn(store.draftThreadsByThreadId, threadId), ); + const draftThread = useComposerDraftStore( + (store) => store.draftThreadsByThreadId[threadId] ?? null, + ); const routeThreadExists = threadExists || draftThreadExists; const clientMode = useClientMode(); const shouldUseSheet = clientMode === "mobile"; @@ -205,10 +208,11 @@ function ChatThreadRouteView() { // ── Active workspace CWD for file tree ──────────────────────────── const activeWorkspaceCwd = useStore((store) => { const thread = store.threads.find((t) => t.id === threadId); - if (!thread) return null; - const project = store.projects.find((p) => p.id === thread.projectId); + const threadProjectId = thread?.projectId ?? draftThread?.projectId ?? null; + if (!threadProjectId) return null; + const project = store.projects.find((p) => p.id === threadProjectId); if (!project) return null; - return thread.worktreePath ?? project.cwd; + return thread?.worktreePath ?? draftThread?.worktreePath ?? project.cwd; }); // ── Keep-alive flags so lazy content doesn't unmount on tab switch ─