Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions packages/react-headless/src/adapters/_defaultStorage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ export function createDefaultInMemoryStorage(): ChatStorage {
async getMessages(threadId: string) {
return messagesByThread.get(threadId) ?? [];
},
async cacheMessages(threadId: string, messages: Message[]) {
messagesByThread.set(threadId, messages);
},
async updateThread(thread: Thread) {
threads = threads.map((t) => (t.id === thread.id ? thread : t));
return thread;
Expand Down
2 changes: 2 additions & 0 deletions packages/react-headless/src/adapters/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ export interface ThreadStorage {
listThreads(cursor?: string): Promise<{ threads: Thread[]; nextCursor?: string }>;
createThread(firstMessage: UserMessage): Promise<Thread>;
getMessages(threadId: string): Promise<Message[]>;
/** Optional — cache current messages for a thread. Not all storages support it. */
cacheMessages?(threadId: string, messages: Message[]): Promise<void>;
updateThread(thread: Thread): Promise<Thread>;
deleteThread(id: string): Promise<void>;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ export function makeStore(overrides: MakeStoreOverrides = {}) {
}),
getMessages: vi.fn().mockResolvedValue([]),
updateThread: vi.fn(async (t) => t),
cacheMessages: vi.fn().mockResolvedValue(undefined),
deleteThread: vi.fn().mockResolvedValue(undefined),
...threadOverrides,
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,41 @@ describe("createChatStore", () => {
expect(store.getState().messages).toEqual([]);
expect(store.getState().threadError).toBeNull();
});

it("persists current messages via cacheMessages before clearing", () => {
const cacheMessages = vi.fn().mockResolvedValue(undefined);
const store = makeStore({ cacheMessages });

const msgs = [makeMessage("m1"), makeMessage("m2", "assistant")];
store.setState({
selectedThreadId: "t1",
messages: msgs,
});

store.getState().switchToNewThread();

expect(cacheMessages).toHaveBeenCalledWith("t1", msgs);
});

it("does not call cacheMessages when no thread is selected", () => {
const cacheMessages = vi.fn().mockResolvedValue(undefined);
const store = makeStore({ cacheMessages });

store.setState({ messages: [makeMessage("m1")] });
store.getState().switchToNewThread();

expect(cacheMessages).not.toHaveBeenCalled();
});

it("does not call cacheMessages when messages are empty", () => {
const cacheMessages = vi.fn().mockResolvedValue(undefined);
const store = makeStore({ cacheMessages });

store.setState({ selectedThreadId: "t1", messages: [] });
store.getState().switchToNewThread();

expect(cacheMessages).not.toHaveBeenCalled();
});
});

describe("createThread", () => {
Expand Down Expand Up @@ -324,6 +359,24 @@ describe("createChatStore", () => {
expect(store.getState().selectedThreadId).toBe("t-auto");
});

it("persists messages via cacheMessages after streaming completes", async () => {
const cacheMessages = vi.fn().mockResolvedValue(undefined);
const send = vi.fn().mockResolvedValue(new Response("", { status: 200 }));

const store = makeStore({
cacheMessages,
send,
streamProtocol: { parse: async function* () {} },
});
store.setState({ selectedThreadId: "t1" });

await store.getState().processMessage({ role: "user", content: "hello" });

const finalMessages = store.getState().messages;
expect(finalMessages.length).toBeGreaterThan(0);
expect(cacheMessages).toHaveBeenCalledWith("t1", finalMessages);
});

it("no-ops when already running", async () => {
const send = vi.fn().mockResolvedValue(new Response("", { status: 200 }));

Expand Down
7 changes: 7 additions & 0 deletions packages/react-headless/src/store/createChatStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ export const createChatStore = (config: CreateChatStoreConfig) => {

switchToNewThread: () => {
get().cancelMessage();
const { selectedThreadId, messages } = get();
if (selectedThreadId && messages.length > 0) {
threadStorage.cacheMessages?.(selectedThreadId, messages).catch(() => {});
}
set({
selectedThreadId: null,
messages: [],
Expand Down Expand Up @@ -201,6 +205,9 @@ export const createChatStore = (config: CreateChatStoreConfig) => {
}),
adapter: llm.streamProtocol,
});

// Persist messages after successful streaming so they survive thread switches.
await threadStorage.cacheMessages?.(threadId, get().messages);
} catch (e) {
if (!abortController.signal.aborted) {
set({ threadError: e instanceof Error ? e : new Error(String(e)) });
Expand Down
Loading