From 8bf68817ddf9471f7993f857b5d091060888a6e1 Mon Sep 17 00:00:00 2001 From: Sahith Date: Tue, 7 Apr 2026 22:35:53 -0500 Subject: [PATCH 1/2] whisper: add carry_initial_prompt to maintain context over sliding windows --- whisper/mlx_whisper/transcribe.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/whisper/mlx_whisper/transcribe.py b/whisper/mlx_whisper/transcribe.py index bced16a58..f82ff575d 100644 --- a/whisper/mlx_whisper/transcribe.py +++ b/whisper/mlx_whisper/transcribe.py @@ -70,6 +70,7 @@ def transcribe( no_speech_threshold: Optional[float] = 0.6, condition_on_previous_text: bool = True, initial_prompt: Optional[str] = None, + carry_initial_prompt: bool = False, word_timestamps: bool = False, prepend_punctuations: str = "\"'“¿([{-", append_punctuations: str = "\"'.。,,!!??::”)]}、", @@ -126,6 +127,11 @@ def transcribe( "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns to make it more likely to predict those word correctly. + carry_initial_prompt: bool + If True, the `initial_prompt` is forcefully carried forward into the context window for all + subsequent text segments. This ensures the model does not forget prompt-engineered context + over long audio files. + decode_options: dict Keyword arguments to construct `DecodingOptions` instances @@ -293,7 +299,27 @@ def new_segment( segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE mel_segment = pad_or_trim(mel_segment, N_FRAMES, axis=-2).astype(dtype) - decode_options["prompt"] = all_tokens[prompt_reset_since:] + prompt_tokens = all_tokens[prompt_reset_since:] + if carry_initial_prompt and initial_prompt_tokens: + # Extract previous text tokens, removing the initial prompt if it's currently at the beginning + prev_tokens = ( + prompt_tokens[len(initial_prompt_tokens) :] + if prompt_reset_since == 0 + else prompt_tokens + ) + # Calculate available space across context window + max_prompt_length = model.dims.n_text_ctx // 2 - 1 + max_prev_length = max( + 0, max_prompt_length - len(initial_prompt_tokens) + ) + # Retain initial prompt and latest previous tokens + prompt_tokens = ( + initial_prompt_tokens + prev_tokens[-max_prev_length:] + if max_prev_length > 0 + else initial_prompt_tokens + ) + + decode_options["prompt"] = prompt_tokens result: DecodingResult = decode_with_fallback(mel_segment) tokens = np.array(result.tokens) From 2985e38d75e0a45e4cca44375eaf558b261cd585 Mon Sep 17 00:00:00 2001 From: Sahith Date: Tue, 7 Apr 2026 22:48:06 -0500 Subject: [PATCH 2/2] whisper: add unit test for carry_initial_prompt logic --- whisper/test.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/whisper/test.py b/whisper/test.py index f0acb3cd9..17fd27690 100644 --- a/whisper/test.py +++ b/whisper/test.py @@ -197,6 +197,17 @@ def test_transcribe(self): ), ) + def test_carry_initial_prompt(self): + result = mlx_whisper.transcribe( + TEST_AUDIO, + path_or_hf_repo=MLX_FP32_MODEL_PATH, + fp16=False, + initial_prompt="A test prompt.", + carry_initial_prompt=True, + ) + self.assertIn("text", result) + self.assertGreater(len(result["text"]), 0) + def test_transcribe_alice(self): audio_file = os.path.join( os.path.expanduser("~"),