Skip to content

Commit 1431e21

Browse files
cdxkerdensumesh
authored andcommitted
feat(server): skip_search tool
1 parent 196a266 commit 1431e21

File tree

9 files changed

+226
-11
lines changed

9 files changed

+226
-11
lines changed

clients/search-component/src/utils/hooks/chat-context.tsx

Lines changed: 79 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import React, { createContext, useContext, useRef, useState } from "react";
33
import {
44
defaultPriceToolCallOptions,
55
defaultRelevanceToolCallOptions,
6+
defaultSearchToolCallOptions,
67
useModalState,
78
} from "./modal-context";
89
import { Chunk } from "../types";
@@ -267,6 +268,7 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
267268

268269
const handleReader = async (
269270
reader: ReadableStreamDefaultReader<Uint8Array>,
271+
skipSearch: boolean,
270272
queryId: string | null,
271273
) => {
272274
setIsLoading(true);
@@ -303,7 +305,7 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
303305
json = null;
304306
}
305307

306-
if (json && props.analytics && !calledAnalytics) {
308+
if (json && props.analytics && !calledAnalytics && !skipSearch) {
307309
calledAnalytics = true;
308310
const ecommerceChunks = (json as unknown as Chunk[]).filter(
309311
(chunk) =>
@@ -368,7 +370,7 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
368370
{
369371
type: "system",
370372
text: outputBuffer,
371-
additional: json ? json : null,
373+
additional: json && !skipSearch ? json : null,
372374
queryId,
373375
},
374376
]);
@@ -484,6 +486,8 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
484486
}
485487
}
486488

489+
let skipSearch = false;
490+
487491
if (
488492
props.recommendOptions?.filter &&
489493
props.recommendOptions?.queriesToTriggerRecommendations.includes(
@@ -568,6 +572,53 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
568572
}
569573
});
570574

575+
const skipSearchPromise = retryOperation(async () => {
576+
if (props.type === "ecommerce" && !curGroup && messages.length > 1) {
577+
return await trieveSDK.getToolCallFunctionParams({
578+
user_message_text: `Here's the previous message thread so far: ${messages.map(
579+
(message) => {
580+
if (
581+
message.type === "system" &&
582+
message.additional?.length &&
583+
props.type === "ecommerce"
584+
) {
585+
const chunks = message.additional
586+
.map((chunk) => {
587+
return JSON.stringify({
588+
title: chunk.metadata?.title || "",
589+
description: chunk.chunk_html || "",
590+
price: chunk.num_value
591+
? `${props.defaultCurrency || ""} ${chunk.num_value}`
592+
: "",
593+
link: chunk.link || "",
594+
});
595+
})
596+
.join("\n\n");
597+
return `\n\n${chunks}${message.text}`;
598+
} else {
599+
return `\n\n${message.text}`;
600+
}
601+
},
602+
)} \n\n${props.searchToolCallOptions?.userMessageTextPrefix ?? defaultSearchToolCallOptions.userMessageTextPrefix}: ${questionProp || currentQuestion}.`,
603+
image_url: imageUrl ? imageUrl : null,
604+
audio_input: curAudioBase64 ? curAudioBase64 : null,
605+
tool_function: {
606+
name: "skip_search",
607+
description:
608+
props.searchToolCallOptions?.toolDescription ??
609+
(defaultSearchToolCallOptions.toolDescription as string),
610+
parameters: [
611+
{
612+
name: "skip_search",
613+
parameter_type: "boolean",
614+
description:
615+
"Set to true if the query is asking about products which were shown to them previously in the message thread only incldue if they are referenced by name. Set to false if the query is asking about the general catalog products or for different/other products differing from the ones shown previously. Only set this to true if the query contains a title that was in the previous messages",
616+
},
617+
],
618+
},
619+
});
620+
}})
621+
571622
const imageFiltersPromise = retryOperation(async () => {
572623
if (imageUrl) {
573624
return await trieveSDK.getToolCallFunctionParams({
@@ -643,11 +694,12 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
643694
}
644695
});
645696

646-
const [priceFiltersResp, imageFiltersResp, tagFiltersResp] =
697+
const [priceFiltersResp, imageFiltersResp, tagFiltersResp, skipSearchResp] =
647698
await Promise.all([
648699
priceFiltersPromise,
649700
imageFiltersPromise,
650701
tagFiltersPromise,
702+
skipSearchPromise,
651703
]);
652704

653705
if (transcribedQuery && curAudioBase64) {
@@ -727,6 +779,15 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
727779
}
728780
}
729781

782+
if (skipSearchResp?.parameters) {
783+
const needsSearchParam = (skipSearchResp.parameters as any)[
784+
"skip_search"
785+
];
786+
if (typeof needsSearchParam === "boolean" && needsSearchParam) {
787+
skipSearch = true;
788+
}
789+
}
790+
730791
clearTimeout(toolCallTimeout);
731792
} catch (e) {
732793
console.error("error getting getToolCallFunctionParams", e);
@@ -818,7 +879,7 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
818879
},
819880
],
820881
};
821-
} else {
882+
} else if (!skipSearch) {
822883
try {
823884
setLoadingText("Searching for relevant products...");
824885
const searchOverGroupsResp = await retryOperation(async () => {
@@ -1088,6 +1149,19 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
10881149
if (createMessageFilters == null) {
10891150
createMessageFilters = filters;
10901151
}
1152+
if (skipSearch) {
1153+
createMessageFilters = {
1154+
must: [
1155+
{
1156+
field: "ids",
1157+
match_any: messages
1158+
.filter((msg) => msg.type == "system")
1159+
.flatMap((msg) => msg.additional ?? [])
1160+
.map((chunk) => chunk.id),
1161+
},
1162+
],
1163+
};
1164+
}
10911165
const systemPromptToUse =
10921166
props.systemPrompt && props.systemPrompt !== ""
10931167
? props.systemPrompt
@@ -1172,7 +1246,7 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
11721246
]);
11731247
}
11741248

1175-
if (reader) handleReader(reader, queryId);
1249+
if (reader) handleReader(reader, skipSearch, queryId);
11761250

11771251
if (imageUrl) {
11781252
setImageUrl("");

clients/search-component/src/utils/hooks/modal-context.tsx

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,17 @@ export interface RelevanceToolCallOptions {
8383
lowDescription?: string;
8484
}
8585

86+
export interface SearchToolCallOptions {
87+
userMessageTextPrefix?: string;
88+
toolDescription: string;
89+
}
90+
91+
export const defaultSearchToolCallOptions: SearchToolCallOptions = {
92+
userMessageTextPrefix: "Here is the user query:",
93+
toolDescription:
94+
"Call this tool anytime it seems like we need to skip the search step. This tool tells our system that the user is asking about what they were previously shown.",
95+
};
96+
8697
export const defaultPriceToolCallOptions: PriceToolCallOptions = {
8798
toolDescription:
8899
"Only call this function if the query includes details about a price. Decide on which price filters to apply to the available catalog being used within the knowledge base to respond. If the question is slightly like a product name, respond with no filters (all false).",
@@ -163,6 +174,7 @@ export type ModalProps = {
163174
tags?: TagProp[];
164175
relevanceToolCallOptions?: RelevanceToolCallOptions;
165176
priceToolCallOptions?: PriceToolCallOptions;
177+
searchToolCallOptions?: SearchToolCallOptions;
166178
defaultSearchMode?: SearchModes;
167179
usePagefind?: boolean;
168180
type?: ModalTypes;
@@ -245,6 +257,7 @@ const defaultProps = {
245257
},
246258
} as searchOptions,
247259
chatFilters: undefined,
260+
searchToolCallOptions: defaultSearchToolCallOptions,
248261
analytics: true,
249262
chat: true,
250263
suggestedQueries: true,

clients/ts-sdk/openapi.json

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16583,6 +16583,14 @@
1658316583
],
1658416584
"nullable": true
1658516585
},
16586+
"searchToolCallOptions": {
16587+
"allOf": [
16588+
{
16589+
"$ref": "#/components/schemas/SearchToolCallOptions"
16590+
}
16591+
],
16592+
"nullable": true
16593+
},
1658616594
"showFloatingButton": {
1658716595
"type": "boolean",
1658816596
"nullable": true
@@ -20558,6 +20566,23 @@
2055820566
"top_score"
2055920567
]
2056020568
},
20569+
"SearchToolCallOptions": {
20570+
"type": "object",
20571+
"properties": {
20572+
"noSearchRagContext": {
20573+
"type": "string",
20574+
"nullable": true
20575+
},
20576+
"toolDescription": {
20577+
"type": "string",
20578+
"nullable": true
20579+
},
20580+
"userMessageTextPrefix": {
20581+
"type": "string",
20582+
"nullable": true
20583+
}
20584+
}
20585+
},
2056120586
"SearchType": {
2056220587
"type": "string",
2056320588
"enum": [

clients/ts-sdk/src/types.gen.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3088,6 +3088,7 @@ export type PublicPageParameters = {
30883088
searchBar?: (boolean) | null;
30893089
searchOptions?: ((PublicPageSearchOptions) | null);
30903090
searchPageProps?: ((SearchPageProps) | null);
3091+
searchToolCallOptions?: ((SearchToolCallOptions) | null);
30913092
showFloatingButton?: (boolean) | null;
30923093
showFloatingInput?: (boolean) | null;
30933094
showFloatingSearchIcon?: (boolean) | null;
@@ -4032,6 +4033,12 @@ export type SearchRevenueResponse = {
40324033

40334034
export type SearchSortBy = 'created_at' | 'latency' | 'top_score';
40344035

4036+
export type SearchToolCallOptions = {
4037+
noSearchRagContext?: (string) | null;
4038+
toolDescription?: (string) | null;
4039+
userMessageTextPrefix?: (string) | null;
4040+
};
4041+
40354042
export type SearchType = 'search' | 'autocomplete' | 'search_over_groups' | 'search_within_groups';
40364043

40374044
export type SearchTypeCount = {

frontends/dashboard/src/hooks/usePublicPageSettings.tsx

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import {
2222
defaultOpenGraphMetadata,
2323
defaultPriceToolCallOptions,
2424
defaultRelevanceToolCallOptions,
25+
defaultSearchToolCallOptions,
2526
} from "../pages/dataset/PublicPageSettings";
2627

2728
export type DatasetWithPublicPage = Dataset & {
@@ -93,6 +94,12 @@ export const { use: usePublicPage, provider: PublicPageProvider } =
9394
});
9495
}
9596

97+
if (!extraParams.searchToolCallOptions) {
98+
setExtraParams("searchToolCallOptions", {
99+
...defaultSearchToolCallOptions,
100+
});
101+
}
102+
96103
if (!extraParams.openGraphMetadata) {
97104
setExtraParams("openGraphMetadata", {
98105
...defaultOpenGraphMetadata,

frontends/dashboard/src/pages/dataset/PublicPageSettings.tsx

Lines changed: 80 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,12 @@ export const defaultRelevanceToolCallOptions: RelevanceToolCallOptions = {
7777
"Not relevant and not a good fit for the given query taking all details of both the query and the product into account",
7878
};
7979

80+
export const defaultSearchToolCallOptions: SearchToolCallOptions = {
81+
userMessageTextPrefix: "Here is the user query:",
82+
toolDescription:
83+
"Call this tool anytime it seems like we need to skip the search step. This tool tells our system that the user is asking about what they were previously shown.",
84+
};
85+
8086
export const defaultPriceToolCallOptions: PriceToolCallOptions = {
8187
toolDescription:
8288
"Only call this function if the query includes details about a price. Decide on which price filters to apply to the available catalog being used within the knowledge base to respond. If the question is slightly like a product name, respond with no filters (all false).",
@@ -649,9 +655,8 @@ const PublicPageControls = () => {
649655
/>
650656
</div>
651657
<MultiStringInput
652-
placeholder={`What is ${
653-
extraParams["brandName"] || "Trieve"
654-
}?...`}
658+
placeholder={`What is ${extraParams["brandName"] || "Trieve"
659+
}?...`}
655660
value={extraParams.defaultSearchQueries || []}
656661
onChange={(e) => {
657662
setExtraParams("defaultSearchQueries", e);
@@ -672,9 +677,8 @@ const PublicPageControls = () => {
672677
/>
673678
</div>
674679
<MultiStringInput
675-
placeholder={`What is ${
676-
extraParams["brandName"] || "Trieve"
677-
}?...`}
680+
placeholder={`What is ${extraParams["brandName"] || "Trieve"
681+
}?...`}
678682
value={extraParams.defaultAiQuestions || []}
679683
onChange={(e) => {
680684
setExtraParams("defaultAiQuestions", e);
@@ -1662,6 +1666,76 @@ const PublicPageControls = () => {
16621666
</div>
16631667
</details>
16641668

1669+
<details class="my-4">
1670+
<summary class="cursor-pointer text-sm font-medium">
1671+
Search Tool Options
1672+
</summary>
1673+
<div class="mt-4 space-y-4">
1674+
<div class="grid grid-cols-2 gap-4">
1675+
<div class="grow">
1676+
<div class="flex items-center gap-1">
1677+
<label class="block" for="">
1678+
Search Tool Description
1679+
</label>
1680+
<Tooltip
1681+
tooltipText="Description of the search tool provided to the model."
1682+
body={
1683+
<FaRegularCircleQuestion class="h-3 w-3 text-black" />
1684+
}
1685+
/>
1686+
</div>
1687+
<textarea
1688+
value={
1689+
extraParams.searchToolCallOptions?.toolDescription ||
1690+
defaultPriceToolCallOptions.toolDescription
1691+
}
1692+
onInput={(e) =>
1693+
setExtraParams(
1694+
"searchToolCallOptions",
1695+
"toolDescription",
1696+
e.currentTarget.value,
1697+
)
1698+
}
1699+
rows="4"
1700+
name="messageToQueryPrompt"
1701+
id="messageToQueryPrompt"
1702+
class="block w-full rounded-md border-[0.5px] border-neutral-300 px-3 py-1.5 shadow-sm placeholder:text-neutral-400 focus:outline-magenta-500 sm:text-sm sm:leading-6"
1703+
/>
1704+
</div>
1705+
<div class="grow">
1706+
<div class="flex items-center gap-1">
1707+
<label class="block" for="">
1708+
Search Prompt
1709+
</label>
1710+
<Tooltip
1711+
tooltipText="Prompt to the model to use the search tool."
1712+
body={
1713+
<FaRegularCircleQuestion class="h-3 w-3 text-black" />
1714+
}
1715+
/>
1716+
</div>
1717+
<textarea
1718+
value={
1719+
(extraParams.searchToolCallOptions?.userMessageTextPrefix ||
1720+
defaultSearchToolCallOptions.userMessageTextPrefix) as string
1721+
}
1722+
onInput={(e) =>
1723+
setExtraParams(
1724+
"searchToolCallOptions",
1725+
"userMessageTextPrefix",
1726+
e.currentTarget.value,
1727+
)
1728+
}
1729+
rows="4"
1730+
name="messageToQueryPrompt"
1731+
id="messageToQueryPrompt"
1732+
class="block w-full rounded-md border-[0.5px] border-neutral-300 px-3 py-1.5 shadow-sm placeholder:text-neutral-400 focus:outline-magenta-500 sm:text-sm sm:leading-6"
1733+
/>
1734+
</div>
1735+
</div>
1736+
</div>
1737+
</details>
1738+
16651739
<div class="space-x-1.5 pt-8">
16661740
<button
16671741
class="inline-flex justify-center rounded-md bg-magenta-500 px-3 py-2 text-sm font-semibold text-white shadow-sm hover:bg-magenta-700 focus-visible:outline focus-visible:outline-2 focus-visible:outline-offset-2 focus-visible:outline-magenta-900 disabled:opacity-40"

0 commit comments

Comments
 (0)