Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions packages/gitbook/src/components/Search/SearchAskAnswer.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -167,15 +167,18 @@ function AnswerBody(props: { query: string; answer: AskAnswerResult }) {
<>
<div data-testid="search-ask-answer" className="animate-fade-in-slow text-tint-strong">
{answer.body ?? t(language, 'search_ask_no_answer')}
{answer.sources.length > 0 ? (
{answer.error ? (
<div className="mt-4 text-sm text-tint">{t(language, 'search_ask_error')}</div>
) : null}
{!answer.error && answer.sources.length > 0 ? (
// @TODO: Add responseId once search uses new AI endpoint
<AIResponseFeedback query={query} className="-ml-1 mt-2" responseId="" />
) : null}
{answer.followupQuestions.length > 0 ? (
{!answer.error && answer.followupQuestions.length > 0 ? (
<AnswerFollowupQuestions followupQuestions={answer.followupQuestions} />
) : null}
</div>
{answer.sources.length > 0 ? (
{!answer.error && answer.sources.length > 0 ? (
<AnswerSources
sources={answer.sources}
language={language}
Expand Down
207 changes: 100 additions & 107 deletions packages/gitbook/src/components/Search/server-actions.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

import type { GitBookSiteContext } from '@/lib/context';
import { resolvePageId } from '@/lib/pages';
import { fetchServerActionSiteContext, getServerActionBaseContext } from '@/lib/server-actions';
import {
fetchServerActionSiteContext,
getServerActionBaseContext,
runStreamableServerAction,
} from '@/lib/server-actions';
import { findSiteSpaceBy } from '@/lib/sites';
import { filterOutNullable } from '@/lib/typescript';
import type {
Expand All @@ -11,7 +15,6 @@ import type {
SearchAIAnswer,
SearchAIRecommendedQuestionStream,
} from '@gitbook/api';
import { createStreamableValue } from 'ai/rsc';
import type * as React from 'react';

import { throwIfDataError } from '@/lib/data';
Expand All @@ -32,6 +35,8 @@ export interface AskAnswerResult {
body?: React.ReactNode;
followupQuestions: string[];
sources: AskAnswerSource[];
/** Set when an error occurred mid-stream, after partial content was already delivered. */
error?: boolean;
}

/**
Expand All @@ -45,91 +50,86 @@ export async function streamAskQuestion({
question: string;
}) {
return traceErrorOnly('Search.streamAskQuestion', async () => {
const responseStream = createStreamableValue<AskAnswerResult | undefined>();

(async () => {
const context = await fetchServerActionSiteContext(
await getServerActionBaseContext({ isEmbeddable: asEmbeddable })
);

const apiClient = await context.dataFetcher.api();
return runStreamableServerAction<AskAnswerResult | undefined>({
onError: (_, lastValue) => ({
...(lastValue ?? { followupQuestions: [], sources: [] }),
error: true,
}),
run: async (push) => {
const context = await fetchServerActionSiteContext(
await getServerActionBaseContext({ isEmbeddable: asEmbeddable })
);

const stream = apiClient.orgs.streamAskInSite(
context.organizationId,
context.site.id,
{
question,
context: {
siteSpaceId: context.siteSpace.id,
},
scope: {
mode: 'default',
currentSiteSpace: context.siteSpace.id,
const apiClient = await context.dataFetcher.api();

const stream = apiClient.orgs.streamAskInSite(
context.organizationId,
context.site.id,
{
question,
context: {
siteSpaceId: context.siteSpace.id,
},
scope: {
mode: 'default',
currentSiteSpace: context.siteSpace.id,
},
},
},
{ format: 'document' }
);

const spacePromises = new Map<string, Promise<Revision>>();
for await (const chunk of stream) {
const answer = chunk.answer;

// Register the space of each page source into the promise queue.
const spaces = answer.sources
.map((source) => {
if (source.type !== 'page') {
return null;
}

if (!spacePromises.has(source.space)) {
spacePromises.set(
source.space,
throwIfDataError(
context.dataFetcher.getRevision({
spaceId: source.space,
revisionId: source.revision,
})
)
);
}

return source.space;
})
.filter(filterOutNullable);

// Get the pages for all spaces referenced by this answer.
const pages = await Promise.all(
spaces.map(async (space) => {
const revision = await spacePromises.get(space);
return { space, pages: revision?.pages };
})
).then((results) => {
return results.reduce((map, result) => {
if (result.pages) {
map.set(result.space, result.pages);
}
return map;
}, new Map<string, RevisionPage[]>());
});
responseStream.update(
await transformAnswer(context, {
answer: chunk.answer,
asEmbeddable: Boolean(asEmbeddable),
spacePages: pages,
})
{ format: 'document' }
);
}
})()
.then(() => {
responseStream.done();
})
.catch((error) => {
responseStream.error(error);
});

return {
stream: responseStream.value,
};
const spacePromises = new Map<string, Promise<Revision>>();
for await (const chunk of stream) {
const answer = chunk.answer;

// Register the space of each page source into the promise queue.
const spaces = answer.sources
.map((source) => {
if (source.type !== 'page') {
return null;
}

if (!spacePromises.has(source.space)) {
spacePromises.set(
source.space,
throwIfDataError(
context.dataFetcher.getRevision({
spaceId: source.space,
revisionId: source.revision,
})
)
);
}

return source.space;
})
.filter(filterOutNullable);

// Get the pages for all spaces referenced by this answer.
const pages = await Promise.all(
spaces.map(async (space) => {
const revision = await spacePromises.get(space);
return { space, pages: revision?.pages };
})
).then((results) => {
return results.reduce((map, result) => {
if (result.pages) {
map.set(result.space, result.pages);
}
return map;
}, new Map<string, RevisionPage[]>());
});

push(
await transformAnswer(context, {
answer: chunk.answer,
asEmbeddable: Boolean(asEmbeddable),
spacePages: pages,
})
);
}
},
});
});
}

Expand All @@ -142,32 +142,25 @@ export async function streamRecommendedQuestions(args: { siteSpaceId?: string })
const siteURLData = await getSiteURLDataFromMiddleware();
const context = await getServerActionBaseContext();

const responseStream = createStreamableValue<
SearchAIRecommendedQuestionStream | undefined
>();
return runStreamableServerAction<SearchAIRecommendedQuestionStream | undefined>({
// On mid-stream error, pass the last value through to stop cleanly without a throw.
// On pre-stream error, fail() is called so the existing silent catch in the client handles it.
onError: (_, lastValue) => lastValue,
run: async (push) => {
const apiClient = await context.dataFetcher.api();
const apiStream = apiClient.orgs.streamRecommendedQuestionsInSite(
siteURLData.organization,
siteURLData.site,
{
siteSpaceId: args.siteSpaceId,
}
);

(async () => {
const apiClient = await context.dataFetcher.api();
const apiStream = apiClient.orgs.streamRecommendedQuestionsInSite(
siteURLData.organization,
siteURLData.site,
{
siteSpaceId: args.siteSpaceId,
for await (const chunk of apiStream) {
push(chunk);
}
);

for await (const chunk of apiStream) {
responseStream.update(chunk);
}
})()
.then(() => {
responseStream.done();
})
.catch((error) => {
responseStream.error(error);
});

return { stream: responseStream.value };
},
});
});
}

Expand Down
43 changes: 43 additions & 0 deletions packages/gitbook/src/lib/graceful-stream.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/**
* Creates a mid-stream error handler that gracefully downgrades errors to values
* when the stream has already started delivering content.
*
* - If the stream has started: calls `update(onError(error, lastValue))` (when non-null)
* then `done()`, preserving partial content on the client.
* - If the stream has not started: calls `fail(error)`, propagating the error normally
* so the client receives a full error state.
*/
export function createMidStreamErrorHandler<T>(
onError: (error: unknown, lastValue: T | undefined) => T | undefined
): {
track: (value: T) => void;
handleError: (
error: unknown,
callbacks: {
update: (value: T) => void;
done: () => void;
fail: (error: unknown) => void;
}
) => void;
} {
let hasStarted = false;
let lastValue: T | undefined;

return {
track(value) {
hasStarted = true;
lastValue = value;
},
handleError(error, { update, done, fail }) {
if (hasStarted) {
const errorValue = onError(error, lastValue);
if (errorValue !== undefined) {
update(errorValue);
}
done();
} else {
fail(error);
}
},
};
}
41 changes: 41 additions & 0 deletions packages/gitbook/src/lib/server-actions.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import { createStreamableValue } from 'ai/rsc';
import type { StreamableValue } from 'ai/rsc';
import { type GitBookBaseContext, fetchSiteContextByURLLookup, getBaseContext } from './context';
import { getEmbeddableLinker } from './embeddable-linker';
import { createMidStreamErrorHandler } from './graceful-stream';
import {
getSiteURLDataFromMiddleware,
getSiteURLFromMiddleware,
Expand Down Expand Up @@ -39,3 +42,41 @@ export async function fetchServerActionSiteContext(baseContext: GitBookBaseConte
const siteURLData = await getSiteURLDataFromMiddleware();
return fetchSiteContextByURLLookup(baseContext, siteURLData);
}

/**
* Run a server action that streams values to the client using `createStreamableValue`.
*
* When an error occurs after the stream has started delivering content, it is
* converted into a final value via `onError` and the stream closes cleanly —
* preserving partial content on the client.
*
* When an error occurs before any value has been pushed, it is propagated
* normally so the client receives a full error state.
*/
export function runStreamableServerAction<T>({
onError,
run,
}: {
onError: (error: unknown, lastValue: T | undefined) => T | undefined;
run: (push: (value: T) => void) => Promise<void>;
}): { stream: StreamableValue<T> } {
const responseStream = createStreamableValue<T>();
const errorHandler = createMidStreamErrorHandler<T>(onError);

run((value) => {
errorHandler.track(value);
responseStream.update(value);
})
.then(() => {
responseStream.done();
})
.catch((error) => {
errorHandler.handleError(error, {
update: (value) => responseStream.update(value),
done: () => responseStream.done(),
fail: (err) => responseStream.error(err),
});
});

return { stream: responseStream.value };
}
Loading