diff --git a/apps/web/src/composerDraftStore.test.ts b/apps/web/src/composerDraftStore.test.ts index 927a16060..e3fdadfa9 100644 --- a/apps/web/src/composerDraftStore.test.ts +++ b/apps/web/src/composerDraftStore.test.ts @@ -372,6 +372,8 @@ describe("composerDraftStore setModel", () => { draftsByThreadId: {}, draftThreadsByThreadId: {}, projectDraftThreadIdByProjectId: {}, + lastSelectedModel: null, + lastSelectedProvider: null, }); }); @@ -384,6 +386,23 @@ describe("composerDraftStore setModel", () => { "gpt-5.3-codex", ); }); + + it("updates lastSelectedModel when a non-null model is set", () => { + const store = useComposerDraftStore.getState(); + + store.setModel(threadId, "gpt-5.3-codex"); + + expect(useComposerDraftStore.getState().lastSelectedModel).toBe("gpt-5.3-codex"); + }); + + it("does not update lastSelectedModel when model is set to null", () => { + const store = useComposerDraftStore.getState(); + + store.setModel(threadId, "gpt-5.3-codex"); + store.setModel(threadId, null); + + expect(useComposerDraftStore.getState().lastSelectedModel).toBe("gpt-5.3-codex"); + }); }); describe("composerDraftStore setProvider", () => { @@ -394,6 +413,8 @@ describe("composerDraftStore setProvider", () => { draftsByThreadId: {}, draftThreadsByThreadId: {}, projectDraftThreadIdByProjectId: {}, + lastSelectedModel: null, + lastSelectedProvider: null, }); }); @@ -413,6 +434,14 @@ describe("composerDraftStore setProvider", () => { expect(useComposerDraftStore.getState().draftsByThreadId[threadId]).toBeUndefined(); }); + + it("updates lastSelectedProvider when a non-null provider is set", () => { + const store = useComposerDraftStore.getState(); + + store.setProvider(threadId, "codex"); + + expect(useComposerDraftStore.getState().lastSelectedProvider).toBe("codex"); + }); }); describe("composerDraftStore runtime and interaction settings", () => { diff --git a/apps/web/src/composerDraftStore.ts b/apps/web/src/composerDraftStore.ts index 2af920527..e89d18064 100644 --- a/apps/web/src/composerDraftStore.ts +++ b/apps/web/src/composerDraftStore.ts @@ -97,6 +97,8 @@ interface PersistedComposerDraftStoreState { draftsByThreadId: Record; draftThreadsByThreadId: Record; projectDraftThreadIdByProjectId: Record; + lastSelectedModel: string | null; + lastSelectedProvider: ProviderKind | null; } interface ComposerThreadDraftState { @@ -130,6 +132,8 @@ interface ComposerDraftStoreState { draftsByThreadId: Record; draftThreadsByThreadId: Record; projectDraftThreadIdByProjectId: Record; + lastSelectedModel: string | null; + lastSelectedProvider: ProviderKind | null; getDraftThreadByProjectId: (projectId: ProjectId) => ProjectDraftThread | null; getDraftThread: (threadId: ThreadId) => DraftThreadState | null; setProjectDraftThreadId: ( @@ -185,6 +189,8 @@ const EMPTY_PERSISTED_DRAFT_STORE_STATE: PersistedComposerDraftStoreState = { draftsByThreadId: {}, draftThreadsByThreadId: {}, projectDraftThreadIdByProjectId: {}, + lastSelectedModel: null, + lastSelectedProvider: null, }; const EMPTY_IMAGES: ComposerImageAttachment[] = []; @@ -386,7 +392,13 @@ function normalizePersistedComposerDraftState(value: unknown): PersistedComposer } } if (!rawDraftMap || typeof rawDraftMap !== "object") { - return { draftsByThreadId: {}, draftThreadsByThreadId, projectDraftThreadIdByProjectId }; + return { + draftsByThreadId: {}, + draftThreadsByThreadId, + projectDraftThreadIdByProjectId, + lastSelectedModel: null, + lastSelectedProvider: null, + }; } const nextDraftsByThreadId: PersistedComposerDraftStoreState["draftsByThreadId"] = {}; for (const [threadId, draftValue] of Object.entries(rawDraftMap as Record)) { @@ -450,10 +462,17 @@ function normalizePersistedComposerDraftState(value: unknown): PersistedComposer ...(codexFastMode ? { codexFastMode } : {}), }; } + const lastSelectedModel = + typeof candidate.lastSelectedModel === "string" + ? (normalizeModelSlug(candidate.lastSelectedModel) ?? null) + : null; + const lastSelectedProvider = normalizeProviderKind(candidate.lastSelectedProvider); return { draftsByThreadId: nextDraftsByThreadId, draftThreadsByThreadId, projectDraftThreadIdByProjectId, + lastSelectedModel, + lastSelectedProvider, }; } @@ -563,6 +582,8 @@ export const useComposerDraftStore = create()( draftsByThreadId: {}, draftThreadsByThreadId: {}, projectDraftThreadIdByProjectId: {}, + lastSelectedModel: null, + lastSelectedProvider: null, getDraftThreadByProjectId: (projectId) => { if (projectId.length === 0) { return null; @@ -841,7 +862,13 @@ export const useComposerDraftStore = create()( } else { nextDraftsByThreadId[threadId] = nextDraft; } - return { draftsByThreadId: nextDraftsByThreadId }; + const nextState: Partial = { + draftsByThreadId: nextDraftsByThreadId, + }; + if (normalizedProvider !== null) { + nextState.lastSelectedProvider = normalizedProvider; + } + return nextState; }); }, setModel: (threadId, model) => { @@ -868,7 +895,13 @@ export const useComposerDraftStore = create()( } else { nextDraftsByThreadId[threadId] = nextDraft; } - return { draftsByThreadId: nextDraftsByThreadId }; + const nextState: Partial = { + draftsByThreadId: nextDraftsByThreadId, + }; + if (normalizedModel !== null) { + nextState.lastSelectedModel = normalizedModel; + } + return nextState; }); }, setRuntimeMode: (threadId, runtimeMode) => { @@ -1255,6 +1288,8 @@ export const useComposerDraftStore = create()( draftsByThreadId: persistedDraftsByThreadId, draftThreadsByThreadId: state.draftThreadsByThreadId, projectDraftThreadIdByProjectId: state.projectDraftThreadIdByProjectId, + lastSelectedModel: state.lastSelectedModel, + lastSelectedProvider: state.lastSelectedProvider, }; }, merge: (persistedState, currentState) => { @@ -1270,6 +1305,8 @@ export const useComposerDraftStore = create()( draftsByThreadId, draftThreadsByThreadId: normalizedPersisted.draftThreadsByThreadId, projectDraftThreadIdByProjectId: normalizedPersisted.projectDraftThreadIdByProjectId, + lastSelectedModel: normalizedPersisted.lastSelectedModel, + lastSelectedProvider: normalizedPersisted.lastSelectedProvider, }; }, }, diff --git a/apps/web/src/hooks/useHandleNewThread.ts b/apps/web/src/hooks/useHandleNewThread.ts index 35f92d98e..1b9a25193 100644 --- a/apps/web/src/hooks/useHandleNewThread.ts +++ b/apps/web/src/hooks/useHandleNewThread.ts @@ -89,6 +89,9 @@ export function useHandleNewThread() { const threadId = newThreadId(); const createdAt = new Date().toISOString(); return (async () => { + const { lastSelectedModel, lastSelectedProvider, setModel, setProvider } = + useComposerDraftStore.getState(); + setProjectDraftThreadId(projectId, threadId, { createdAt, branch: options?.branch ?? null, @@ -97,6 +100,13 @@ export function useHandleNewThread() { runtimeMode: DEFAULT_RUNTIME_MODE, }); + if (lastSelectedModel) { + setModel(threadId, lastSelectedModel); + } + if (lastSelectedProvider) { + setProvider(threadId, lastSelectedProvider); + } + await navigate({ to: "/$threadId", params: { threadId },