Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { useMemo, useRef } from "react";

import type { Segment } from "~/stt/live-segment";
import { getTranscriptTimingSource } from "~/stt/timing";

export function useStableSegments(segments: Segment[]): Segment[] {
const cacheRef = useRef<Map<string, Segment>>(new Map());
Expand Down Expand Up @@ -54,7 +55,8 @@ function segmentsEqual(a: Segment, b: Segment) {
aw.start_ms !== bw.start_ms ||
aw.end_ms !== bw.end_ms ||
aw.channel !== bw.channel ||
aw.is_final !== bw.is_final
aw.is_final !== bw.is_final ||
getTranscriptTimingSource(aw) !== getTranscriptTimingSource(bw)
) {
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import {
defaultRenderLabelContext,
SpeakerLabelManager,
} from "~/stt/segment/shared";
import { isTranscriptWordSeekable } from "~/stt/timing";

export function RenderTranscript({
scrollElement,
Expand Down Expand Up @@ -105,7 +106,7 @@ const SegmentsList = memo(

const seekAndPlay = useCallback(
(word: SegmentWord) => {
if (audioExists) {
if (audioExists && isTranscriptWordSeekable(word)) {
seek((offsetMs + word.start_ms) / 1000);
startPlayback();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import type { HighlightSegment } from "./utils";
import { useSearch } from "~/session/components/note-input/search/context";
import { createHighlightSegments } from "~/session/components/note-input/search/matching";
import type { SegmentWord } from "~/stt/live-segment";
import { isTranscriptWordSeekable } from "~/stt/timing";

interface WordSpanProps {
word: SegmentWord;
Expand All @@ -29,18 +30,19 @@ export function WordSpan(props: WordSpanProps) {
highlights.segments,
highlights.isActive,
);
const canSeek = props.audioExists && isTranscriptWordSeekable(props.word);
const className = useMemo(
() =>
cn([
props.audioExists && "cursor-pointer hover:bg-neutral-200/60",
canSeek && "cursor-pointer hover:bg-neutral-200/60",
!props.word.is_final && ["opacity-60", "italic"],
]),
[props.audioExists, props.word.is_final],
[canSeek, props.word.is_final],
);

return (
<span
onClick={() => props.onClickWord(props.word)}
onClick={() => canSeek && props.onClickWord(props.word)}
className={className}
data-word-id={props.word.id}
>
Expand Down
67 changes: 65 additions & 2 deletions apps/desktop/src/store/zustand/listener/batch.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ import type { BatchPersistCallback } from "./transcript";
import { transformWordEntries, type WordEntry } from "./utils";

import { type RuntimeSpeakerHint, type WordLike } from "~/stt/segment";
import {
createTranscriptTimingMetadata,
getValidTimingSource,
type TranscriptTimingSource,
} from "~/stt/timing";

export type BatchPhase = "importing" | "transcribing";
export type BatchTerminalReason = "failed" | "timed_out" | "stopped";
Expand Down Expand Up @@ -293,19 +298,26 @@ function transformBatch(
return;
}

const timingSource = getWordTimingSourceForBatchResponse(
response,
Boolean(alternative.words?.length),
"synthetic_text",
);
const wordEntries = wordEntriesFromTranscript(
alternative.words,
alternative.transcript,
{
channel: channelIndex,
durationSeconds: getBatchDurationSeconds(response),
timingSource,
},
);

const [words, hints] = transformWordEntries(
wordEntries,
alternative.transcript,
channelIndex,
{ timingSource },
);

hints.forEach((hint) => {
Expand Down Expand Up @@ -371,20 +383,27 @@ function mergeBatchPreview(
return preview;
}

const timingSource = getWordTimingSourceForBatchResponse(
response,
Boolean(alternative.words?.length),
"provider_segment_interpolated",
);
const wordEntries = wordEntriesFromTranscript(
alternative.words,
alternative.transcript,
{
channel: channelIndex,
startSeconds: response.start,
durationSeconds: response.duration,
timingSource,
},
);

const [incomingWords, incomingHints] = transformWordEntries(
wordEntries,
alternative.transcript,
channelIndex,
{ timingSource },
);
if (incomingWords.length === 0) {
return preview;
Expand Down Expand Up @@ -472,14 +491,23 @@ function wordEntriesFromTranscript(
channel,
startSeconds = 0,
durationSeconds,
timingSource,
}: {
channel: number;
startSeconds?: number;
durationSeconds?: number;
timingSource: TranscriptTimingSource;
},
): WordEntry[] {
if (entries?.length || !transcript.trim()) {
return entries ?? [];
if (entries?.length) {
return entries.map((entry) => ({
...entry,
metadata: createTranscriptTimingMetadata(timingSource, entry.metadata),
}));
}

if (!transcript.trim()) {
return [];
}

const tokens = transcript.trim().split(/\s+/).filter(Boolean);
Expand All @@ -501,9 +529,44 @@ function wordEntriesFromTranscript(
end: startSeconds + ((index + 1) / tokens.length) * duration,
channel,
speaker: null,
metadata: createTranscriptTimingMetadata(timingSource),
}));
}

function getWordTimingSourceForBatchResponse(
response: { metadata?: unknown },
hasProviderWords: boolean,
fallbackWithoutWords: TranscriptTimingSource,
): TranscriptTimingSource {
if (!hasProviderWords) {
return fallbackWithoutWords;
}

const explicitSource = getBatchResponseTimingSource(response);
if (explicitSource) {
return explicitSource;
}

return "provider_word";
}
Comment thread
cursor[bot] marked this conversation as resolved.

function getBatchResponseTimingSource(response: {
metadata?: unknown;
}): TranscriptTimingSource | undefined {
const metadata = response.metadata;
if (!metadata || typeof metadata !== "object" || Array.isArray(metadata)) {
return undefined;
}

const record = metadata as Record<string, unknown>;
const timing = record.timing;
if (timing && typeof timing === "object" && !Array.isArray(timing)) {
return getValidTimingSource((timing as Record<string, unknown>).source);
}

return getValidTimingSource(record.timing_source);
}

function getBatchDurationSeconds(response: BatchResponse): number | undefined {
const metadata = response.metadata;
if (!metadata || typeof metadata !== "object" || Array.isArray(metadata)) {
Expand Down
32 changes: 31 additions & 1 deletion apps/desktop/src/store/zustand/listener/general.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ describe("General Listener Slice", () => {
start_ms: 0,
end_ms: 500,
channel: 0,
metadata: {
timing: {
source: "provider_word",
},
},
},
]);

Expand Down Expand Up @@ -201,6 +206,11 @@ describe("General Listener Slice", () => {
start_ms: 0,
end_ms: 500,
channel: 0,
metadata: {
timing: {
source: "provider_word",
},
},
},
],
[
Expand All @@ -227,7 +237,7 @@ describe("General Listener Slice", () => {

expect(
handleBatchResponse(sessionId, {
metadata: { duration: 2 },
metadata: { duration: 2, timing_source: "provider_word" },
results: {
channels: [
{
Expand All @@ -251,12 +261,22 @@ describe("General Listener Slice", () => {
start_ms: 0,
end_ms: 1000,
channel: 0,
metadata: {
timing: {
source: "synthetic_text",
},
},
},
{
text: " world",
start_ms: 1000,
end_ms: 2000,
channel: 0,
metadata: {
timing: {
source: "synthetic_text",
},
},
},
],
[],
Expand Down Expand Up @@ -312,12 +332,22 @@ describe("General Listener Slice", () => {
start_ms: 4000,
end_ms: 5000,
channel: 1,
metadata: {
timing: {
source: "provider_segment_interpolated",
},
},
},
{
text: " world",
start_ms: 5000,
end_ms: 6000,
channel: 1,
metadata: {
timing: {
source: "provider_segment_interpolated",
},
},
},
],
[],
Expand Down
13 changes: 13 additions & 0 deletions apps/desktop/src/store/zustand/listener/utils.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import type { RuntimeSpeakerHint, WordLike } from "~/stt/segment";
import {
createTranscriptTimingMetadata,
type TranscriptTimingSource,
type TranscriptWordMetadata,
} from "~/stt/timing";

export function fixSpacingForWords(
words: string[],
Expand Down Expand Up @@ -36,12 +41,16 @@ export type WordEntry = {
end: number;
channel?: number;
speaker?: number | null;
metadata?: TranscriptWordMetadata | null;
};

export function transformWordEntries(
wordEntries: WordEntry[] | null | undefined,
transcript: string,
channel: number,
options: {
timingSource?: TranscriptTimingSource;
} = {},
): [WordLike[], RuntimeSpeakerHint[]] {
const words: WordLike[] = [];
const hints: RuntimeSpeakerHint[] = [];
Expand All @@ -61,6 +70,10 @@ export function transformWordEntries(
start_ms: Math.round(word.start * 1000),
end_ms: Math.round(word.end * 1000),
channel: typeof word.channel === "number" ? word.channel : channel,
metadata: createTranscriptTimingMetadata(
options.timingSource ?? "provider_word",
word.metadata,
),
});

if (typeof word.speaker === "number") {
Expand Down
17 changes: 15 additions & 2 deletions apps/desktop/src/stt/live-segment.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import type {
SegmentWord as BoundSegmentWord,
} from "@hypr/plugin-transcription";

import type { TranscriptWordMetadata } from "~/stt/timing";

export enum ChannelProfile {
DirectMic = 0,
RemoteParty = 1,
Expand All @@ -17,6 +19,7 @@ export type WordLike = {
start_ms: number;
end_ms: number;
channel: ChannelProfile;
metadata?: TranscriptWordMetadata | null;
};

export type PartialWord = WordLike;
Expand All @@ -41,8 +44,18 @@ export type RenderLabelContext = {
};

export type SegmentKey = BoundSegmentKey;
export type SegmentWord = BoundSegmentWord;
export type Segment = LiveTranscriptSegment | RenderedTranscriptSegment;
export type SegmentWord = BoundSegmentWord & {
metadata?: TranscriptWordMetadata | null;
};
type SegmentWithWordMetadata<T extends { words: BoundSegmentWord[] }> = Omit<
T,
"words"
> & {
words: SegmentWord[];
};
export type Segment =
| SegmentWithWordMetadata<LiveTranscriptSegment>
| SegmentWithWordMetadata<RenderedTranscriptSegment>;
export type SegmentChannelProfile = BoundChannelProfile;

export class SpeakerLabelManager {
Expand Down
Loading
Loading