Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 83 additions & 6 deletions apps/desktop/src/store/zustand/listener/batch.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import type {
} from "@hypr/plugin-transcription";

import type { BatchPersistCallback } from "./transcript";
import { transformWordEntries } from "./utils";
import { transformWordEntries, type WordEntry } from "./utils";

import { type RuntimeSpeakerHint, type WordLike } from "~/stt/segment";

Expand Down Expand Up @@ -39,7 +39,7 @@ export type BatchState = {
export type BatchActions = {
handleBatchStarted: (sessionId: string, phase?: BatchPhase) => void;
handleBatchCompleted: (sessionId: string) => void;
handleBatchResponse: (sessionId: string, response: BatchResponse) => void;
handleBatchResponse: (sessionId: string, response: BatchResponse) => boolean;
handleBatchResponseStreamed: (
sessionId: string,
event: BatchStreamEvent,
Expand All @@ -57,6 +57,9 @@ export type BatchActions = {
clearBatchPersist: (sessionId: string) => void;
};

export const EMPTY_BATCH_TRANSCRIPT_ERROR =
"No speech was detected in the audio.";

export const createBatchSlice = <T extends BatchState>(
set: StoreApi<T>["setState"],
get: StoreApi<T>["getState"],
Expand Down Expand Up @@ -112,7 +115,7 @@ export const createBatchSlice = <T extends BatchState>(

const [words, hints] = transformBatch(response);
if (!words.length) {
return;
return false;
}

persist?.(words, hints, { mode: "replace" });
Expand All @@ -130,6 +133,8 @@ export const createBatchSlice = <T extends BatchState>(
batchPreview: restPreview,
};
});

return true;
},

handleBatchResponseStreamed: (sessionId, event) => {
Expand Down Expand Up @@ -284,13 +289,22 @@ function transformBatch(

response.results.channels.forEach((channel, channelIndex) => {
const alternative = channel.alternatives[0];
if (!alternative || !alternative.words || !alternative.words.length) {
if (!alternative) {
return;
}

const [words, hints] = transformWordEntries(
const wordEntries = wordEntriesFromTranscript(
alternative.words,
alternative.transcript,
{
channel: channelIndex,
durationSeconds: getBatchDurationSeconds(response),
},
);

const [words, hints] = transformWordEntries(
wordEntries,
alternative.transcript,
channelIndex,
);

Expand Down Expand Up @@ -357,9 +371,19 @@ function mergeBatchPreview(
return preview;
}

const [incomingWords, incomingHints] = transformWordEntries(
const wordEntries = wordEntriesFromTranscript(
alternative.words,
alternative.transcript,
{
channel: channelIndex,
startSeconds: response.start,
durationSeconds: response.duration,
},
);

const [incomingWords, incomingHints] = transformWordEntries(
wordEntries,
alternative.transcript,
channelIndex,
);
if (incomingWords.length === 0) {
Expand Down Expand Up @@ -440,3 +464,56 @@ function getBatchStreamPercentage(event: BatchStreamEvent): number {
return 0;
}
}

function wordEntriesFromTranscript(
entries: WordEntry[] | null | undefined,
transcript: string,
{
channel,
startSeconds = 0,
durationSeconds,
}: {
channel: number;
startSeconds?: number;
durationSeconds?: number;
},
): WordEntry[] {
if (entries?.length || !transcript.trim()) {
return entries ?? [];
}

const tokens = transcript.trim().split(/\s+/).filter(Boolean);
if (!tokens.length) {
return [];
}

const duration = Math.max(
durationSeconds && Number.isFinite(durationSeconds)
? durationSeconds
: tokens.length * 0.4,
tokens.length * 0.05,
);

return tokens.map((token, index) => ({
word: token,
punctuated_word: token,
start: startSeconds + (index / tokens.length) * duration,
end: startSeconds + ((index + 1) / tokens.length) * duration,
channel,
speaker: null,
}));
}

function getBatchDurationSeconds(response: BatchResponse): number | undefined {
const metadata = response.metadata;
if (!metadata || typeof metadata !== "object" || Array.isArray(metadata)) {
return undefined;
}

const duration = (metadata as Record<string, unknown>).duration;
return typeof duration === "number" &&
Number.isFinite(duration) &&
duration > 0
? duration
: undefined;
}
86 changes: 86 additions & 0 deletions apps/desktop/src/store/zustand/listener/general-batch.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { beforeEach, describe, expect, test, vi } from "vitest";

import { EMPTY_BATCH_TRANSCRIPT_ERROR } from "./batch";
import { runBatchSession } from "./general-batch";

const { listenMock, startTranscriptionMock } = vi.hoisted(() => ({
Expand Down Expand Up @@ -213,6 +214,91 @@ describe("runBatchSession", () => {
expect(handleBatchFailed).not.toHaveBeenCalled();
});

test("rejects completed responses that have no transcribed words", async () => {
const handleBatchStarted = vi.fn();
const handleBatchResponse = vi.fn(() => false);
const handleBatchCompleted = vi.fn();
const clearBatchPersist = vi.fn();
const clearBatchSession = vi.fn();
const handleBatchResponseStreamed = vi.fn();
const handleBatchFailed = vi.fn();
const handleBatchStopped = vi.fn();
const updateBatchProgress = vi.fn();
const setBatchPersist = vi.fn();

let handler:
| ((event: {
payload: {
type: string;
session_id: string;
response?: unknown;
mode?: "direct" | "streamed";
};
}) => void)
| undefined;

listenMock.mockImplementation(async (cb) => {
handler = cb;
return vi.fn();
});

startTranscriptionMock.mockImplementation(async () => {
queueMicrotask(() => {
handler?.({
payload: {
type: "completed",
session_id: "session-1",
mode: "direct",
response: {
metadata: null,
results: { channels: [] },
},
},
});
});

return {
status: "ok",
data: null,
};
});

await expect(
runBatchSession(
() => ({
batch: {},
batchPreview: {},
batchPersist: {},
handleBatchStarted,
handleBatchResponse,
handleBatchCompleted,
clearBatchPersist,
clearBatchSession,
handleBatchResponseStreamed,
handleBatchFailed,
handleBatchStopped,
updateBatchProgress,
setBatchPersist,
}),
"session-1",
{
session_id: "session-1",
provider: "hyprnote",
file_path: "/tmp/session.wav",
base_url: "",
api_key: "",
},
),
).rejects.toThrow(EMPTY_BATCH_TRANSCRIPT_ERROR);

expect(handleBatchFailed).toHaveBeenCalledWith(
"session-1",
EMPTY_BATCH_TRANSCRIPT_ERROR,
);
expect(clearBatchPersist).toHaveBeenCalledWith("session-1");
expect(clearBatchSession).not.toHaveBeenCalled();
});

test("rejects when the transcription is stopped", async () => {
const handleBatchStarted = vi.fn();
const handleBatchResponse = vi.fn();
Expand Down
16 changes: 13 additions & 3 deletions apps/desktop/src/store/zustand/listener/general-batch.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@ import {
events as transcriptionEvents,
} from "@hypr/plugin-transcription";

import type { BatchActions, BatchState } from "./batch";
import {
EMPTY_BATCH_TRANSCRIPT_ERROR,
type BatchActions,
type BatchState,
} from "./batch";

type BatchStore = BatchActions & BatchState;

Expand Down Expand Up @@ -39,6 +43,7 @@ export const runBatchSession = async <T extends BatchStore>(
response: Parameters<BatchStore["handleBatchResponse"]>[1];
},
resolve: () => void,
reject: (reason?: unknown) => void,
) => {
if (settled) {
return;
Expand All @@ -47,15 +52,19 @@ export const runBatchSession = async <T extends BatchStore>(
settled = true;

try {
get().handleBatchResponse(sessionId, output.response);
const handled = get().handleBatchResponse(sessionId, output.response);
if (handled === false) {
throw new Error(EMPTY_BATCH_TRANSCRIPT_ERROR);
}
cleanup();
} catch (error) {
console.error("[runBatch] error handling batch response", error);
const errorMessage =
error instanceof Error ? error.message : String(error);
get().handleBatchFailed(sessionId, errorMessage);
cleanup(false);
throw error;
reject(error);
return;
}

resolve();
Expand Down Expand Up @@ -121,6 +130,7 @@ export const runBatchSession = async <T extends BatchStore>(
response: payload.response,
},
resolve,
reject,
);
return;
}
Expand Down
Loading
Loading