Skip to content

Commit 07442bf

Browse files
vid277cdxker
authored andcommitted
feature: use transcribe audio route in search component
1 parent 7eede7c commit 07442bf

File tree

10 files changed

+278
-125
lines changed

10 files changed

+278
-125
lines changed

clients/search-component/src/TrieveModal/Chat/UserMessage.tsx

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ export const UserMessage = ({
1212
message: Message;
1313
idx: number;
1414
}) => {
15-
const { props } = useModalState();
15+
const { props, transcribedQuery } = useModalState();
1616

1717
return (
1818
<motion.div
@@ -31,15 +31,18 @@ export const UserMessage = ({
3131
{message.imageUrl && (
3232
<ImagePreview isUploading={false} imageUrl={message.imageUrl} />
3333
)}
34-
{message.text === "Loading..." ? (
34+
{message.text === "Loading..." && !transcribedQuery ? (
3535
<span className={`user-text ${props.type}`}>
3636
<LoadingIcon className="loading" />
3737
</span>
3838
) : null}
39-
{message.text != "" &&
40-
message.text != "Loading..." &&
41-
message.text != props.defaultImageQuestion ? (
42-
<span className={`user-text ${props.type}`}> {message.text}</span>
39+
{(message.text !== "" &&
40+
message.text !== "Loading..." &&
41+
message.text !== props.defaultImageQuestion) ||
42+
transcribedQuery ? (
43+
<span className={`user-text ${props.type}`}>
44+
{transcribedQuery || message.text}
45+
</span>
4346
) : null}
4447
</div>
4548
</div>

clients/search-component/src/TrieveModal/Search/UploadAudio.tsx

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,16 @@ import { StopSquareIcon, MicIcon } from "../icons";
55
import { motion } from "motion/react";
66

77
export const UploadAudio = () => {
8-
const { props, mode, setAudioBase64, isRecording, setIsRecording } =
9-
useModalState();
8+
const {
9+
props,
10+
mode,
11+
setAudioBase64,
12+
isRecording,
13+
setIsRecording,
14+
trieveSDK,
15+
setTranscribedQuery,
16+
setQuery,
17+
} = useModalState();
1018

1119
const [mediaRecorder, setMediaRecorder] = useState<MediaRecorder | null>(
1220
null,
@@ -31,15 +39,25 @@ export const UploadAudio = () => {
3139

3240
recorder.onstop = () => {
3341
stream.getTracks().forEach((track) => track.stop());
34-
setAudioBase64("");
3542
const audioBlob = new Blob(audioChunks, {
3643
type: isFirefox ? "audio/webm" : "audio/mp4",
3744
});
3845
const reader = new FileReader();
3946
reader.readAsDataURL(audioBlob);
40-
reader.onloadend = () => {
47+
reader.onloadend = async () => {
4148
let base64data = reader.result as string;
4249
base64data = base64data?.split(",")[1];
50+
51+
const transcribedAudio = await trieveSDK.transcribeAudio(
52+
{
53+
base64_audio: base64data,
54+
},
55+
new AbortController().signal,
56+
);
57+
58+
setTranscribedQuery(transcribedAudio);
59+
60+
setQuery(transcribedAudio);
4361
setAudioBase64(base64data);
4462
};
4563
};

clients/search-component/src/utils/hooks/chat-context.tsx

Lines changed: 84 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
122122
currentGroup,
123123
props,
124124
abTreatment,
125+
transcribedQuery,
126+
setTranscribedQuery,
125127
} = useModalState();
126128
const [currentQuestion, setCurrentQuestion] = useState(query);
127129
const [currentTopic, setCurrentTopic] = useState("");
@@ -402,9 +404,7 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
402404
let curAudioBase64 = audioBase64;
403405
let questionProp = question;
404406
const curGroup = group || currentGroup;
405-
let transcribedQuery: string | null = null;
406407

407-
// This only works w/ shopify rn
408408
const recommendOptions = props.recommendOptions;
409409
if (
410410
recommendOptions &&
@@ -428,7 +428,6 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
428428
}
429429
}
430430

431-
// Use group search
432431
let filters: ChunkFilter | null = {
433432
must: null,
434433
must_not: null,
@@ -543,7 +542,6 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
543542
return await trieveSDK.getToolCallFunctionParams({
544543
user_message_text: questionProp || currentQuestion,
545544
image_url: localImageUrl ? localImageUrl : null,
546-
audio_input: curAudioBase64 ? curAudioBase64 : null,
547545
tool_function: {
548546
name: "get_price_filters",
549547
description:
@@ -603,7 +601,6 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
603601
},
604602
)} \n\n${props.searchToolCallOptions?.userMessageTextPrefix ?? defaultSearchToolCallOptions.userMessageTextPrefix}: ${questionProp || currentQuestion}.`,
605603
image_url: localImageUrl ? localImageUrl : null,
606-
audio_input: curAudioBase64 ? curAudioBase64 : null,
607604
tool_function: {
608605
name: "skip_search",
609606
description:
@@ -620,7 +617,7 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
620617
},
621618
});
622619
}
623-
})
620+
});
624621

625622
const imageFiltersPromise = retryOperation(async () => {
626623
if (localImageUrl) {
@@ -659,16 +656,15 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
659656
user_message_text:
660657
questionProp || currentQuestion
661658
? `Get filters from the following messages: ${messages
662-
.slice(0, -1)
663-
.filter((message) => {
664-
return message.type == "user";
665-
})
666-
.map(
667-
(message) => `\n\n${message.text}`,
668-
)} \n\n ${questionProp || currentQuestion}`
659+
.slice(0, -1)
660+
.filter((message) => {
661+
return message.type == "user";
662+
})
663+
.map(
664+
(message) => `\n\n${message.text}`,
665+
)} \n\n ${questionProp || currentQuestion}`
669666
: null,
670667
image_url: localImageUrl ? localImageUrl : null,
671-
audio_input: curAudioBase64 ? curAudioBase64 : null,
672668
tool_function: {
673669
name: "get_filters",
674670
description:
@@ -684,11 +680,6 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
684680
},
685681
},
686682
chatMessageAbortController.current.signal,
687-
(headers: Record<string, string>) => {
688-
if (headers["x-tr-query"] && curAudioBase64) {
689-
transcribedQuery = headers["x-tr-query"];
690-
}
691-
},
692683
);
693684
} else {
694685
return {
@@ -697,13 +688,17 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
697688
}
698689
});
699690

700-
const [priceFiltersResp, imageFiltersResp, tagFiltersResp, skipSearchResp] =
701-
await Promise.all([
702-
priceFiltersPromise,
703-
imageFiltersPromise,
704-
tagFiltersPromise,
705-
skipSearchPromise,
706-
]);
691+
const [
692+
priceFiltersResp,
693+
imageFiltersResp,
694+
tagFiltersResp,
695+
skipSearchResp,
696+
] = await Promise.all([
697+
priceFiltersPromise,
698+
imageFiltersPromise,
699+
tagFiltersPromise,
700+
skipSearchPromise,
701+
]);
707702

708703
if (transcribedQuery && curAudioBase64) {
709704
questionProp = transcribedQuery;
@@ -918,13 +913,13 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
918913
: null;
919914
const imageUrls = props.relevanceToolCallOptions?.includeImages
920915
? (
921-
(firstChunk?.image_urls?.filter(
922-
(stringOrNull): stringOrNull is string =>
923-
Boolean(stringOrNull),
924-
) ||
925-
[]) ??
926-
[]
927-
).splice(0, 1)
916+
(firstChunk?.image_urls?.filter(
917+
(stringOrNull): stringOrNull is string =>
918+
Boolean(stringOrNull),
919+
) ||
920+
[]) ??
921+
[]
922+
).splice(0, 1)
928923
: undefined;
929924
const jsonOfFirstChunk = {
930925
title: (firstChunk?.metadata as any)?.title ?? "",
@@ -1116,47 +1111,47 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
11161111

11171112
if (referenceImageUrls.length > 0 || curGroup) {
11181113
if (referenceImageUrls.length == 0 && curGroup) {
1119-
const fulltextSearchPromise = trieveSDK.searchInGroup(
1120-
{
1121-
query: questionProp || currentQuestion,
1122-
search_type: "fulltext",
1123-
page_size: 10,
1124-
group_id: curGroup.id,
1125-
user_id: fingerprint,
1126-
},
1127-
searchAbortController.current.signal,
1128-
);
1114+
const fulltextSearchPromise = trieveSDK.searchInGroup(
1115+
{
1116+
query: questionProp || currentQuestion,
1117+
search_type: "fulltext",
1118+
page_size: 10,
1119+
group_id: curGroup.id,
1120+
user_id: fingerprint,
1121+
},
1122+
searchAbortController.current.signal,
1123+
);
11291124

1130-
const chunksInGroupPromise = trieveSDK.getChunksInGroup(
1131-
{
1132-
groupId: curGroup.id,
1133-
page: 1,
1134-
},
1135-
searchAbortController.current.signal,
1136-
);
1125+
const chunksInGroupPromise = trieveSDK.getChunksInGroup(
1126+
{
1127+
groupId: curGroup.id,
1128+
page: 1,
1129+
},
1130+
searchAbortController.current.signal,
1131+
);
11371132

1138-
const [fulltextSearchResp, chunksInGroupResp] = await Promise.all([
1139-
fulltextSearchPromise,
1140-
chunksInGroupPromise,
1141-
]);
1133+
const [fulltextSearchResp, chunksInGroupResp] = await Promise.all([
1134+
fulltextSearchPromise,
1135+
chunksInGroupPromise,
1136+
]);
11421137

1143-
const chunkIds = fulltextSearchResp.chunks.map(
1144-
(score_chunk) => score_chunk.chunk.id,
1145-
);
1146-
1147-
chunksInGroupResp.chunks.filter((chunk) => chunkIds.includes(chunk.id));
1148-
1149-
const topChunk = chunksInGroupResp.chunks[0];
1150-
1151-
if (topChunk) {
1152-
topChunk.image_urls?.forEach((url) => {
1153-
if (url) {
1154-
referenceImageUrls.push(url);
1155-
}
1156-
});
1157-
}
1138+
const chunkIds = fulltextSearchResp.chunks.map(
1139+
(score_chunk) => score_chunk.chunk.id,
1140+
);
1141+
1142+
chunksInGroupResp.chunks.filter((chunk) => chunkIds.includes(chunk.id));
11581143

1159-
referenceImageUrls.slice(0, 3);
1144+
const topChunk = chunksInGroupResp.chunks[0];
1145+
1146+
if (topChunk) {
1147+
topChunk.image_urls?.forEach((url) => {
1148+
if (url) {
1149+
referenceImageUrls.push(url);
1150+
}
1151+
});
1152+
}
1153+
1154+
referenceImageUrls.slice(0, 3);
11601155
}
11611156

11621157
if (await handleImageEdit()) {
@@ -1200,23 +1195,23 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
12001195
if (skipSearch) {
12011196
createMessageFilters = props.useGroupSearch
12021197
? {
1203-
must: [
1204-
{
1205-
field: "group_ids",
1206-
match_any: groupIdsInChat,
1207-
},
1208-
],
1209-
}
1198+
must: [
1199+
{
1200+
field: "group_ids",
1201+
match_any: groupIdsInChat,
1202+
},
1203+
],
1204+
}
12101205
: {
1211-
must: [
1212-
{
1213-
field: "ids",
1214-
match_any: messages
1215-
.flatMap((m) => m.additional ?? [])
1216-
.map((chunk) => chunk.id),
1217-
},
1218-
],
1219-
};
1206+
must: [
1207+
{
1208+
field: "ids",
1209+
match_any: messages
1210+
.flatMap((m) => m.additional ?? [])
1211+
.map((chunk) => chunk.id),
1212+
},
1213+
],
1214+
};
12201215
}
12211216
const systemPromptToUse =
12221217
props.systemPrompt && props.systemPrompt !== ""
@@ -1229,10 +1224,6 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
12291224
{
12301225
topic_id: id || currentTopic,
12311226
new_message_content: questionProp || currentQuestion,
1232-
audio_input:
1233-
curAudioBase64 && curAudioBase64?.length > 0
1234-
? curAudioBase64
1235-
: undefined,
12361227
image_urls: imageUrl ? [imageUrl] : [],
12371228
llm_options: {
12381229
completion_first: false,
@@ -1258,11 +1249,7 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
12581249
only_include_docs_used: false,
12591250
},
12601251
chatMessageAbortController.current.signal,
1261-
(headers: Record<string, string>) => {
1262-
if (headers["x-tr-query"] && curAudioBase64) {
1263-
transcribedQuery = headers["x-tr-query"];
1264-
}
1265-
},
1252+
undefined,
12661253
props.overrideFetch ?? false,
12671254
);
12681255
reader = createMessageResp.reader;
@@ -1309,6 +1296,7 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
13091296
}
13101297
if (audioBase64) {
13111298
setAudioBase64("");
1299+
setTranscribedQuery("");
13121300
}
13131301
};
13141302

@@ -1364,7 +1352,7 @@ function ChatProvider({ children }: { children: React.ReactNode }) {
13641352
setImageUrl(imageUrl);
13651353
}
13661354

1367-
const questionProp = question;
1355+
const questionProp = transcribedQuery || question;
13681356
setIsDoneReading(false);
13691357
setCurrentQuestion("");
13701358

0 commit comments

Comments
 (0)