diff --git a/apps/desktop/src/session/components/note-input/transcript/renderer/segment-hooks.ts b/apps/desktop/src/session/components/note-input/transcript/renderer/segment-hooks.ts index 589c1722c1..cf18dd7ae3 100644 --- a/apps/desktop/src/session/components/note-input/transcript/renderer/segment-hooks.ts +++ b/apps/desktop/src/session/components/note-input/transcript/renderer/segment-hooks.ts @@ -1,6 +1,7 @@ import { useMemo, useRef } from "react"; import type { Segment } from "~/stt/live-segment"; +import { getTranscriptTimingSource } from "~/stt/timing"; export function useStableSegments(segments: Segment[]): Segment[] { const cacheRef = useRef>(new Map()); @@ -54,7 +55,8 @@ function segmentsEqual(a: Segment, b: Segment) { aw.start_ms !== bw.start_ms || aw.end_ms !== bw.end_ms || aw.channel !== bw.channel || - aw.is_final !== bw.is_final + aw.is_final !== bw.is_final || + getTranscriptTimingSource(aw) !== getTranscriptTimingSource(bw) ) { return false; } diff --git a/apps/desktop/src/session/components/note-input/transcript/renderer/transcript.tsx b/apps/desktop/src/session/components/note-input/transcript/renderer/transcript.tsx index 0a01274c90..162f1ccdd3 100644 --- a/apps/desktop/src/session/components/note-input/transcript/renderer/transcript.tsx +++ b/apps/desktop/src/session/components/note-input/transcript/renderer/transcript.tsx @@ -23,6 +23,7 @@ import { defaultRenderLabelContext, SpeakerLabelManager, } from "~/stt/segment/shared"; +import { isTranscriptWordSeekable } from "~/stt/timing"; export function RenderTranscript({ scrollElement, @@ -105,7 +106,7 @@ const SegmentsList = memo( const seekAndPlay = useCallback( (word: SegmentWord) => { - if (audioExists) { + if (audioExists && isTranscriptWordSeekable(word)) { seek((offsetMs + word.start_ms) / 1000); startPlayback(); } diff --git a/apps/desktop/src/session/components/note-input/transcript/renderer/word-span.tsx b/apps/desktop/src/session/components/note-input/transcript/renderer/word-span.tsx index 7e2491b433..49b18ddbf3 100644 --- a/apps/desktop/src/session/components/note-input/transcript/renderer/word-span.tsx +++ b/apps/desktop/src/session/components/note-input/transcript/renderer/word-span.tsx @@ -7,6 +7,7 @@ import type { HighlightSegment } from "./utils"; import { useSearch } from "~/session/components/note-input/search/context"; import { createHighlightSegments } from "~/session/components/note-input/search/matching"; import type { SegmentWord } from "~/stt/live-segment"; +import { isTranscriptWordSeekable } from "~/stt/timing"; interface WordSpanProps { word: SegmentWord; @@ -29,18 +30,19 @@ export function WordSpan(props: WordSpanProps) { highlights.segments, highlights.isActive, ); + const canSeek = props.audioExists && isTranscriptWordSeekable(props.word); const className = useMemo( () => cn([ - props.audioExists && "cursor-pointer hover:bg-neutral-200/60", + canSeek && "cursor-pointer hover:bg-neutral-200/60", !props.word.is_final && ["opacity-60", "italic"], ]), - [props.audioExists, props.word.is_final], + [canSeek, props.word.is_final], ); return ( props.onClickWord(props.word)} + onClick={() => canSeek && props.onClickWord(props.word)} className={className} data-word-id={props.word.id} > diff --git a/apps/desktop/src/store/zustand/listener/batch.ts b/apps/desktop/src/store/zustand/listener/batch.ts index c417b8ffde..87dc56654c 100644 --- a/apps/desktop/src/store/zustand/listener/batch.ts +++ b/apps/desktop/src/store/zustand/listener/batch.ts @@ -10,6 +10,11 @@ import type { BatchPersistCallback } from "./transcript"; import { transformWordEntries, type WordEntry } from "./utils"; import { type RuntimeSpeakerHint, type WordLike } from "~/stt/segment"; +import { + createTranscriptTimingMetadata, + getValidTimingSource, + type TranscriptTimingSource, +} from "~/stt/timing"; export type BatchPhase = "importing" | "transcribing"; export type BatchTerminalReason = "failed" | "timed_out" | "stopped"; @@ -293,12 +298,18 @@ function transformBatch( return; } + const timingSource = getWordTimingSourceForBatchResponse( + response, + Boolean(alternative.words?.length), + "synthetic_text", + ); const wordEntries = wordEntriesFromTranscript( alternative.words, alternative.transcript, { channel: channelIndex, durationSeconds: getBatchDurationSeconds(response), + timingSource, }, ); @@ -306,6 +317,7 @@ function transformBatch( wordEntries, alternative.transcript, channelIndex, + { timingSource }, ); hints.forEach((hint) => { @@ -371,6 +383,11 @@ function mergeBatchPreview( return preview; } + const timingSource = getWordTimingSourceForBatchResponse( + response, + Boolean(alternative.words?.length), + "provider_segment_interpolated", + ); const wordEntries = wordEntriesFromTranscript( alternative.words, alternative.transcript, @@ -378,6 +395,7 @@ function mergeBatchPreview( channel: channelIndex, startSeconds: response.start, durationSeconds: response.duration, + timingSource, }, ); @@ -385,6 +403,7 @@ function mergeBatchPreview( wordEntries, alternative.transcript, channelIndex, + { timingSource }, ); if (incomingWords.length === 0) { return preview; @@ -472,14 +491,23 @@ function wordEntriesFromTranscript( channel, startSeconds = 0, durationSeconds, + timingSource, }: { channel: number; startSeconds?: number; durationSeconds?: number; + timingSource: TranscriptTimingSource; }, ): WordEntry[] { - if (entries?.length || !transcript.trim()) { - return entries ?? []; + if (entries?.length) { + return entries.map((entry) => ({ + ...entry, + metadata: createTranscriptTimingMetadata(timingSource, entry.metadata), + })); + } + + if (!transcript.trim()) { + return []; } const tokens = transcript.trim().split(/\s+/).filter(Boolean); @@ -501,9 +529,44 @@ function wordEntriesFromTranscript( end: startSeconds + ((index + 1) / tokens.length) * duration, channel, speaker: null, + metadata: createTranscriptTimingMetadata(timingSource), })); } +function getWordTimingSourceForBatchResponse( + response: { metadata?: unknown }, + hasProviderWords: boolean, + fallbackWithoutWords: TranscriptTimingSource, +): TranscriptTimingSource { + if (!hasProviderWords) { + return fallbackWithoutWords; + } + + const explicitSource = getBatchResponseTimingSource(response); + if (explicitSource) { + return explicitSource; + } + + return "provider_word"; +} + +function getBatchResponseTimingSource(response: { + metadata?: unknown; +}): TranscriptTimingSource | undefined { + const metadata = response.metadata; + if (!metadata || typeof metadata !== "object" || Array.isArray(metadata)) { + return undefined; + } + + const record = metadata as Record; + const timing = record.timing; + if (timing && typeof timing === "object" && !Array.isArray(timing)) { + return getValidTimingSource((timing as Record).source); + } + + return getValidTimingSource(record.timing_source); +} + function getBatchDurationSeconds(response: BatchResponse): number | undefined { const metadata = response.metadata; if (!metadata || typeof metadata !== "object" || Array.isArray(metadata)) { diff --git a/apps/desktop/src/store/zustand/listener/general.test.ts b/apps/desktop/src/store/zustand/listener/general.test.ts index 3abfde6ea0..0561f00c04 100644 --- a/apps/desktop/src/store/zustand/listener/general.test.ts +++ b/apps/desktop/src/store/zustand/listener/general.test.ts @@ -136,6 +136,11 @@ describe("General Listener Slice", () => { start_ms: 0, end_ms: 500, channel: 0, + metadata: { + timing: { + source: "provider_word", + }, + }, }, ]); @@ -201,6 +206,11 @@ describe("General Listener Slice", () => { start_ms: 0, end_ms: 500, channel: 0, + metadata: { + timing: { + source: "provider_word", + }, + }, }, ], [ @@ -227,7 +237,7 @@ describe("General Listener Slice", () => { expect( handleBatchResponse(sessionId, { - metadata: { duration: 2 }, + metadata: { duration: 2, timing_source: "provider_word" }, results: { channels: [ { @@ -251,12 +261,22 @@ describe("General Listener Slice", () => { start_ms: 0, end_ms: 1000, channel: 0, + metadata: { + timing: { + source: "synthetic_text", + }, + }, }, { text: " world", start_ms: 1000, end_ms: 2000, channel: 0, + metadata: { + timing: { + source: "synthetic_text", + }, + }, }, ], [], @@ -312,12 +332,22 @@ describe("General Listener Slice", () => { start_ms: 4000, end_ms: 5000, channel: 1, + metadata: { + timing: { + source: "provider_segment_interpolated", + }, + }, }, { text: " world", start_ms: 5000, end_ms: 6000, channel: 1, + metadata: { + timing: { + source: "provider_segment_interpolated", + }, + }, }, ], [], diff --git a/apps/desktop/src/store/zustand/listener/utils.ts b/apps/desktop/src/store/zustand/listener/utils.ts index 010f98845f..e91884584e 100644 --- a/apps/desktop/src/store/zustand/listener/utils.ts +++ b/apps/desktop/src/store/zustand/listener/utils.ts @@ -1,4 +1,9 @@ import type { RuntimeSpeakerHint, WordLike } from "~/stt/segment"; +import { + createTranscriptTimingMetadata, + type TranscriptTimingSource, + type TranscriptWordMetadata, +} from "~/stt/timing"; export function fixSpacingForWords( words: string[], @@ -36,12 +41,16 @@ export type WordEntry = { end: number; channel?: number; speaker?: number | null; + metadata?: TranscriptWordMetadata | null; }; export function transformWordEntries( wordEntries: WordEntry[] | null | undefined, transcript: string, channel: number, + options: { + timingSource?: TranscriptTimingSource; + } = {}, ): [WordLike[], RuntimeSpeakerHint[]] { const words: WordLike[] = []; const hints: RuntimeSpeakerHint[] = []; @@ -61,6 +70,10 @@ export function transformWordEntries( start_ms: Math.round(word.start * 1000), end_ms: Math.round(word.end * 1000), channel: typeof word.channel === "number" ? word.channel : channel, + metadata: createTranscriptTimingMetadata( + options.timingSource ?? "provider_word", + word.metadata, + ), }); if (typeof word.speaker === "number") { diff --git a/apps/desktop/src/stt/live-segment.ts b/apps/desktop/src/stt/live-segment.ts index 303f3b9751..4d232d6dca 100644 --- a/apps/desktop/src/stt/live-segment.ts +++ b/apps/desktop/src/stt/live-segment.ts @@ -6,6 +6,8 @@ import type { SegmentWord as BoundSegmentWord, } from "@hypr/plugin-transcription"; +import type { TranscriptWordMetadata } from "~/stt/timing"; + export enum ChannelProfile { DirectMic = 0, RemoteParty = 1, @@ -17,6 +19,7 @@ export type WordLike = { start_ms: number; end_ms: number; channel: ChannelProfile; + metadata?: TranscriptWordMetadata | null; }; export type PartialWord = WordLike; @@ -41,8 +44,18 @@ export type RenderLabelContext = { }; export type SegmentKey = BoundSegmentKey; -export type SegmentWord = BoundSegmentWord; -export type Segment = LiveTranscriptSegment | RenderedTranscriptSegment; +export type SegmentWord = BoundSegmentWord & { + metadata?: TranscriptWordMetadata | null; +}; +type SegmentWithWordMetadata = Omit< + T, + "words" +> & { + words: SegmentWord[]; +}; +export type Segment = + | SegmentWithWordMetadata + | SegmentWithWordMetadata; export type SegmentChannelProfile = BoundChannelProfile; export class SpeakerLabelManager { diff --git a/apps/desktop/src/stt/render-transcript.test.ts b/apps/desktop/src/stt/render-transcript.test.ts index 2cd945ffde..fa180fe323 100644 --- a/apps/desktop/src/stt/render-transcript.test.ts +++ b/apps/desktop/src/stt/render-transcript.test.ts @@ -205,4 +205,67 @@ describe("buildRenderTranscriptRequestFromStore", () => { humans: [], }); }); + + it("reattaches word metadata after Rust renders transcript segments", async () => { + renderTranscriptSegmentsCommand.mockResolvedValue({ + status: "ok", + data: [ + { + id: "segment-1", + key: { + channel: "DirectMic", + speaker_index: null, + speaker_human_id: null, + }, + speaker_label: "You", + start_ms: 10, + end_ms: 20, + text: "hello", + words: [ + { + id: "word-1", + text: "hello", + start_ms: 10, + end_ms: 20, + channel: "DirectMic", + is_final: true, + }, + ], + }, + ], + }); + + const segments = await renderTranscriptSegments({ + transcripts: [ + { + started_at: 1_000, + words: [ + { + id: "word-1", + text: " hello", + start_ms: 10, + end_ms: 20, + channel: 0, + speaker_index: null, + metadata: { + timing: { + source: "synthetic_text", + }, + }, + } as never, + ], + assignments: [], + }, + ], + participant_human_ids: [], + self_human_id: null, + humans: [], + }); + + expect(segments[0]?.words[0]?.metadata).toEqual({ + timing: { + source: "synthetic_text", + }, + }); + }); }); diff --git a/apps/desktop/src/stt/render-transcript.ts b/apps/desktop/src/stt/render-transcript.ts index 0bd7dd0762..a1358bb80c 100644 --- a/apps/desktop/src/stt/render-transcript.ts +++ b/apps/desktop/src/stt/render-transcript.ts @@ -12,8 +12,17 @@ import type { } from "@hypr/plugin-transcription"; import type * as main from "~/store/tinybase/store/main"; +import type { SegmentWord } from "~/stt/live-segment"; +import type { TranscriptWordMetadata } from "~/stt/timing"; import { parseTranscriptHints, parseTranscriptWords } from "~/stt/utils"; +export type RenderedTranscriptSegmentWithWordMetadata = Omit< + RenderedTranscriptSegment, + "words" +> & { + words: SegmentWord[]; +}; + type TranscriptRow = { started_at?: number | null; words?: Array<{ @@ -22,6 +31,7 @@ type TranscriptRow = { start_ms?: number | null; end_ms?: number | null; channel?: number | null; + metadata?: unknown; }> | null; speaker_hints?: Array< TranscriptSpeakerHint | { word_id?: string; type?: string; value?: unknown } @@ -35,15 +45,16 @@ type RenderTranscriptRequestHumans = { export async function renderTranscriptSegments( request: RenderTranscriptRequest, -): Promise { - const result = await listenerCommands.renderTranscriptSegments( - normalizeRenderTranscriptRequest(request), - ); +): Promise { + const normalizedRequest = normalizeRenderTranscriptRequest(request); + const metadataByWordId = collectWordMetadataById(normalizedRequest); + const result = + await listenerCommands.renderTranscriptSegments(normalizedRequest); if (result.status === "error") { throw new Error(result.error); } - return result.data; + return attachWordMetadata(result.data, metadataByWordId); } export function buildRenderTranscriptRequestFromStore( @@ -105,14 +116,19 @@ function buildRenderTranscriptRequest( } wordIndexById.set(word.id, words.length); - words.push({ + const metadata = normalizeWordMetadata(word.metadata); + const renderWord: RenderTranscriptInput["words"][number] & { + metadata?: TranscriptWordMetadata; + } = { id: word.id, text: word.text, start_ms: word.start_ms, end_ms: word.end_ms, channel: typeof word.channel === "number" ? word.channel : 0, speaker_index: null, - }); + ...(metadata ? { metadata } : {}), + }; + words.push(renderWord); } for (const hint of transcript.speaker_hints ?? []) { @@ -325,6 +341,67 @@ function normalizeRenderTranscriptRequest( }; } +function collectWordMetadataById( + request: RenderTranscriptRequest, +): Map { + const metadataByWordId = new Map(); + + for (const transcript of request.transcripts) { + for (const word of transcript.words) { + const metadata = normalizeWordMetadata( + (word as { metadata?: unknown }).metadata, + ); + if (metadata) { + metadataByWordId.set(word.id, metadata); + } + } + } + + return metadataByWordId; +} + +function attachWordMetadata( + segments: RenderedTranscriptSegment[], + metadataByWordId: Map, +): RenderedTranscriptSegmentWithWordMetadata[] { + if (metadataByWordId.size === 0) { + return segments as RenderedTranscriptSegmentWithWordMetadata[]; + } + + return segments.map((segment) => ({ + ...segment, + words: segment.words.map((word) => + attachMetadataToWord(word, metadataByWordId), + ), + })); +} + +function attachMetadataToWord( + word: RenderedTranscriptSegment["words"][number], + metadataByWordId: Map, +): SegmentWord { + if (!word.id) { + return word; + } + + const metadata = metadataByWordId.get(word.id); + return metadata ? { ...word, metadata } : word; +} + +function normalizeWordMetadata(value: unknown): TranscriptWordMetadata | null { + if (typeof value === "string") { + try { + return normalizeWordMetadata(JSON.parse(value)); + } catch { + return null; + } + } + + return value && typeof value === "object" && !Array.isArray(value) + ? (value as TranscriptWordMetadata) + : null; +} + function normalizeTranscriptMs(value: number): number { return Number.isFinite(value) ? Math.round(value) : value; } diff --git a/apps/desktop/src/stt/timing.ts b/apps/desktop/src/stt/timing.ts new file mode 100644 index 0000000000..44b648c32d --- /dev/null +++ b/apps/desktop/src/stt/timing.ts @@ -0,0 +1,60 @@ +export type TranscriptTimingSource = + | "provider_word" + | "provider_segment_interpolated" + | "synthetic_text"; + +export type TranscriptWordMetadata = Record; + +export function createTranscriptTimingMetadata( + source: TranscriptTimingSource, + metadata?: unknown, +): TranscriptWordMetadata { + const base = isRecord(metadata) ? metadata : {}; + const timing = isRecord(base.timing) ? base.timing : {}; + + return { + ...base, + timing: { + ...timing, + source, + }, + }; +} + +export function getTranscriptTimingSource(word: { + metadata?: unknown; +}): TranscriptTimingSource { + const metadata = word.metadata; + if (!isRecord(metadata)) { + return "provider_word"; + } + + const timing = metadata.timing; + if (!isRecord(timing)) { + return getValidTimingSource(metadata.timing_source) ?? "provider_word"; + } + + return ( + getValidTimingSource(timing.source) ?? + getValidTimingSource(metadata.timing_source) ?? + "provider_word" + ); +} + +export function isTranscriptWordSeekable(word: { metadata?: unknown }) { + return getTranscriptTimingSource(word) !== "synthetic_text"; +} + +export function getValidTimingSource( + source: unknown, +): TranscriptTimingSource | undefined { + return source === "provider_word" || + source === "provider_segment_interpolated" || + source === "synthetic_text" + ? source + : undefined; +} + +function isRecord(value: unknown): value is Record { + return Boolean(value) && typeof value === "object" && !Array.isArray(value); +} diff --git a/apps/desktop/src/stt/useRunBatch.ts b/apps/desktop/src/stt/useRunBatch.ts index c262a188cc..1f25995979 100644 --- a/apps/desktop/src/stt/useRunBatch.ts +++ b/apps/desktop/src/stt/useRunBatch.ts @@ -219,6 +219,9 @@ export const useRunBatch = (sessionId: string) => { start_ms: word.start_ms, end_ms: word.end_ms, channel: word.channel, + metadata: word.metadata + ? JSON.stringify(word.metadata) + : undefined, }); newWordIds.push(wordId); diff --git a/crates/owhisper-client/src/adapter/mistral/batch.rs b/crates/owhisper-client/src/adapter/mistral/batch.rs index ef61962d80..e022c9193f 100644 --- a/crates/owhisper-client/src/adapter/mistral/batch.rs +++ b/crates/owhisper-client/src/adapter/mistral/batch.rs @@ -150,66 +150,72 @@ fn strip_punctuation(s: &str) -> String { } fn convert_response(response: MistralBatchResponse) -> BatchResponse { - let words: Vec = if !response.words.is_empty() { - response - .words - .into_iter() - .map(|w| { - let normalized = strip_punctuation(&w.word); - Word { - word: if normalized.is_empty() { - w.word.clone() - } else { - normalized - }, - start: w.start, - end: w.end, - confidence: 1.0, - channel: 0, - speaker: None, - punctuated_word: Some(w.word), - } - }) - .collect() + let (words, timing_source): (Vec, &str) = if !response.words.is_empty() { + ( + response + .words + .into_iter() + .map(|w| { + let normalized = strip_punctuation(&w.word); + Word { + word: if normalized.is_empty() { + w.word.clone() + } else { + normalized + }, + start: w.start, + end: w.end, + confidence: 1.0, + channel: 0, + speaker: None, + punctuated_word: Some(w.word), + } + }) + .collect(), + "provider_word", + ) } else if !response.segments.is_empty() { - response - .segments - .iter() - .flat_map(|segment| { - let seg_duration = segment.end - segment.start; - let segment_words: Vec<&str> = segment.text.split_whitespace().collect(); - let word_count = segment_words.len(); - if word_count == 0 { - return vec![]; - } - let word_duration = seg_duration / word_count as f64; - - segment_words - .into_iter() - .enumerate() - .map(|(i, w)| { - let word_start = segment.start + (i as f64 * word_duration); - let word_end = word_start + word_duration; - let normalized = strip_punctuation(w); - Word { - word: if normalized.is_empty() { - w.to_string() - } else { - normalized - }, - start: word_start, - end: word_end, - confidence: 1.0, - channel: 0, - speaker: None, - punctuated_word: Some(w.to_string()), - } - }) - .collect::>() - }) - .collect() + ( + response + .segments + .iter() + .flat_map(|segment| { + let seg_duration = segment.end - segment.start; + let segment_words: Vec<&str> = segment.text.split_whitespace().collect(); + let word_count = segment_words.len(); + if word_count == 0 { + return vec![]; + } + let word_duration = seg_duration / word_count as f64; + + segment_words + .into_iter() + .enumerate() + .map(|(i, w)| { + let word_start = segment.start + (i as f64 * word_duration); + let word_end = word_start + word_duration; + let normalized = strip_punctuation(w); + Word { + word: if normalized.is_empty() { + w.to_string() + } else { + normalized + }, + start: word_start, + end: word_end, + confidence: 1.0, + channel: 0, + speaker: None, + punctuated_word: Some(w.to_string()), + } + }) + .collect::>() + }) + .collect(), + "provider_segment_interpolated", + ) } else { - Vec::new() + (Vec::new(), "synthetic_text") }; let alternatives = Alternatives { @@ -224,6 +230,7 @@ fn convert_response(response: MistralBatchResponse) -> BatchResponse { let metadata = serde_json::json!({ "language": response.language, + "timing_source": timing_source, }); BatchResponse { @@ -240,6 +247,29 @@ mod tests { use crate::adapter::BatchSttAdapter; use crate::http_client::create_client; + #[test] + fn convert_response_marks_segment_interpolated_words() { + let response = convert_response(MistralBatchResponse { + model: Some("voxtral-mini-latest".to_string()), + language: Some("en".to_string()), + text: "hello world".to_string(), + words: Vec::new(), + segments: vec![MistralSegment { + text: "hello world".to_string(), + start: 1.0, + end: 3.0, + }], + }); + + let alternative = &response.results.channels[0].alternatives[0]; + + assert_eq!(alternative.words.len(), 2); + assert_eq!( + response.metadata["timing_source"], + "provider_segment_interpolated" + ); + } + #[tokio::test] #[ignore] async fn test_mistral_transcribe() { diff --git a/crates/owhisper-client/src/adapter/openai/batch.rs b/crates/owhisper-client/src/adapter/openai/batch.rs index 3df87d173c..0c86a71405 100644 --- a/crates/owhisper-client/src/adapter/openai/batch.rs +++ b/crates/owhisper-client/src/adapter/openai/batch.rs @@ -3,7 +3,8 @@ use std::path::{Path, PathBuf}; use futures_util::StreamExt; use openai_transcription::batch::{ CreateTranscriptionOptions, CreateTranscriptionResponse, DiarizedTranscriptionResponse, - ParsedTranscriptionStreamEvent, TranscriptionStreamEventParser, TranscriptionUsage, + ParsedTranscriptionStreamEvent, TimestampGranularity, TranscriptionStreamEventParser, + TranscriptionUsage, }; use owhisper_interface::ListenParams; use owhisper_interface::batch::{Alternatives, Channel, Response as BatchResponse, Results, Word}; @@ -20,7 +21,6 @@ use super::OpenAIAdapter; const DEFAULT_API_BASE: &str = "https://api.openai.com/v1"; const OPENAI_PROGRESS_CAP: f64 = 0.99; -const SYNTHETIC_WORD_SECONDS: f64 = 0.4; impl BatchSttAdapter for OpenAIAdapter { fn provider_name(&self) -> &'static str { @@ -181,6 +181,14 @@ fn build_transcription_options( let mut options = CreateTranscriptionOptions::for_model(model, use_response_format, enable_streaming); + if let CreateTranscriptionOptions::Whisper(options) = &mut options { + if options.response_format.is_some() { + options + .timestamp_granularities + .push(TimestampGranularity::Word); + } + } + if let Some(lang) = params.languages.first() { options.push_language(lang.iso639().code().to_string()); } @@ -356,60 +364,12 @@ fn usage_duration_seconds(usage: Option<&TranscriptionUsage>) -> Option { seconds.is_finite().then_some(seconds).filter(|s| *s > 0.0) } -fn estimate_text_duration(transcript: &str) -> f64 { - let word_count = transcript.split_whitespace().count(); - word_count as f64 * SYNTHETIC_WORD_SECONDS -} - -fn synthesize_words(transcript: &str, duration: f64, channel: i32) -> Vec { - let tokens = transcript.split_whitespace().collect::>(); - if tokens.is_empty() { - return Vec::new(); - } - - let duration = if duration.is_finite() && duration > 0.0 { - duration - } else { - estimate_text_duration(transcript) - }; - let word_duration = duration / tokens.len() as f64; - - tokens - .iter() - .enumerate() - .map(|(index, token)| { - let normalized = strip_punctuation(token); - let start = word_duration * index as f64; - let end = if index + 1 == tokens.len() { - duration - } else { - word_duration * (index + 1) as f64 - }; - - Word { - word: if normalized.is_empty() { - (*token).to_string() - } else { - normalized - }, - start, - end, - confidence: 1.0, - channel, - speaker: None, - punctuated_word: Some((*token).to_string()), - } - }) - .collect() -} - fn convert_text_response(transcript: String, usage: Option) -> BatchResponse { let usage_duration = usage_duration_seconds(usage.as_ref()); - let duration = usage_duration.unwrap_or_else(|| estimate_text_duration(&transcript)); - let words = synthesize_words(&transcript, duration, 0); - let metadata = text_response_metadata(usage, usage_duration); + let mut metadata = text_response_metadata(usage, usage_duration); + insert_timing_source(&mut metadata, "synthetic_text"); - build_batch_response(transcript, words, metadata) + build_batch_response(transcript, Vec::new(), metadata) } fn convert_response(response: CreateTranscriptionResponse) -> BatchResponse { @@ -445,6 +405,7 @@ fn convert_response(response: CreateTranscriptionResponse) -> BatchResponse { serde_json::json!({ "language": response.language, "duration": response.duration, + "timing_source": "provider_word", }), ) } @@ -472,12 +433,19 @@ fn convert_response(response: CreateTranscriptionResponse) -> BatchResponse { "duration": response.duration, "speaker_labels": speaker_labels, "speaker_segments": speaker_segments, + "timing_source": "provider_segment_interpolated", }), ) } } } +fn insert_timing_source(metadata: &mut serde_json::Value, source: &'static str) { + if let Some(object) = metadata.as_object_mut() { + object.insert("timing_source".to_string(), serde_json::json!(source)); + } +} + fn convert_diarized_words(response: &DiarizedTranscriptionResponse) -> (Vec, Vec) { let mut speaker_labels = Vec::new(); let mut words = Vec::new(); @@ -587,6 +555,32 @@ mod tests { .expect("serialize multipart"); assert!(matches!(options, CreateTranscriptionOptions::Whisper(_))); assert!(!fields.iter().any(|field| field.name == "stream")); + assert!( + !fields + .iter() + .any(|field| field.name == "timestamp_granularities[]") + ); + } + + #[test] + fn build_transcription_options_requests_word_timestamps_for_whisper_batch() { + let options = build_transcription_options( + &ListenParams { + model: Some("whisper-1".to_string()), + ..Default::default() + }, + true, + false, + ); + + let fields = options + .multipart_text_fields() + .expect("serialize multipart"); + assert!( + fields.iter().any(|field| { + field.name == "timestamp_granularities[]" && field.value == "word" + }) + ); } #[test] @@ -632,12 +626,9 @@ mod tests { "hello world" ); let words = &response.results.channels[0].alternatives[0].words; - assert_eq!(words.len(), 2); - assert_eq!(words[0].word, "hello"); - assert_eq!(words[1].word, "world"); - assert_eq!(words[0].start, 0.0); - assert!(words[1].end > words[0].end); + assert!(words.is_empty()); assert_eq!(response.metadata["usage"]["type"], "tokens"); + assert_eq!(response.metadata["timing_source"], "synthetic_text"); } #[test] @@ -678,7 +669,7 @@ mod tests { } #[test] - fn convert_standard_response_synthesizes_words_from_text() { + fn convert_standard_response_preserves_text_without_words() { let response: CreateTranscriptionResponse = serde_json::from_str( r#"{ "text": " hello, world! ", @@ -695,17 +686,10 @@ mod tests { let words = &alternative.words; assert_eq!(alternative.transcript, "hello, world!"); - assert_eq!(words.len(), 2); - assert_eq!(words[0].word, "hello"); - assert_eq!(words[0].punctuated_word.as_deref(), Some("hello,")); - assert_eq!(words[0].start, 0.0); - assert_eq!(words[0].end, 1.0); - assert_eq!(words[1].word, "world"); - assert_eq!(words[1].punctuated_word.as_deref(), Some("world!")); - assert_eq!(words[1].start, 1.0); - assert_eq!(words[1].end, 2.0); + assert!(words.is_empty()); assert_eq!(batch.metadata["usage"]["type"], "duration"); assert_eq!(batch.metadata["duration"], 2.0); + assert_eq!(batch.metadata["timing_source"], "synthetic_text"); } #[test] @@ -744,6 +728,10 @@ mod tests { assert_eq!(words[0].channel, MIXED_CAPTURE_CHANNEL); assert_eq!(words[0].speaker, Some(0)); assert_eq!(words[2].speaker, Some(1)); + assert_eq!( + batch.metadata["timing_source"], + "provider_segment_interpolated" + ); assert_eq!(batch.metadata["speaker_labels"][0], "agent"); assert_eq!( batch.metadata["speaker_segments"].as_array().map(Vec::len),