@@ -545,6 +545,14 @@ def _remap_layer_key(self, layer_idx: str, suffix: str) -> str:
545545 else :
546546 return f"{ prefix } .{ suffix } "
547547
548+ @staticmethod
549+ def _format_duration (seconds : float ) -> str :
550+ """Format a duration in seconds as 00:MM:SS.mmm."""
551+ mins = int (seconds // 60 )
552+ secs = int (seconds % 60 )
553+ ms = int ((seconds % 1 ) * 1000 )
554+ return f"00:{ mins :02d} :{ secs :02d} .{ ms :03d} "
555+
548556 def generate (
549557 self ,
550558 text : str ,
@@ -554,6 +562,8 @@ def generate(
554562 top_p : float = 0.95 ,
555563 max_tokens : int = 4096 ,
556564 verbose : bool = False ,
565+ stream : bool = False ,
566+ streaming_interval : float = 2.0 ,
557567 ** kwargs ,
558568 ) -> GenerationResult :
559569 """Generate speech from text.
@@ -566,9 +576,16 @@ def generate(
566576 top_p: Nucleus sampling threshold.
567577 max_tokens: Maximum number of audio tokens to generate.
568578 verbose: Show progress bar.
579+ stream: Enable streaming output. When True, intermediate audio
580+ chunks are yielded during generation for lower latency.
581+ streaming_interval: Approximate seconds of audio per streaming
582+ chunk. Each frame is 80 ms, so the interval is converted to
583+ a frame count (``max(1, int(streaming_interval / 0.08))``).
569584
570585 Yields:
571- GenerationResult with audio waveform.
586+ GenerationResult with audio waveform. When *stream* is True,
587+ intermediate results have ``is_streaming_chunk=True`` and the
588+ last result additionally has ``is_final_chunk=True``.
572589 """
573590 from mlx_lm .models .cache import make_prompt_cache
574591
@@ -605,6 +622,17 @@ def generate(
605622 )
606623
607624 all_codes = []
625+ yielded_frames = 0
626+ chunk_idx = 0
627+ # Convert streaming_interval (seconds) to frames (1 frame = 80 ms)
628+ frames_per_chunk = max (1 , int (streaming_interval / 0.08 ))
629+ # Context frames for overlap-add decoding. The codec decoder uses
630+ # sliding-window attention with windows up to 16 (at the final stage).
631+ # Including context frames from the previous chunk ensures smooth
632+ # transitions without boundary artifacts.
633+ context_frames = 16
634+ # Each codec frame produces 1920 samples (8x upsample × 240 patch)
635+ samples_per_frame = 1920
608636
609637 for i in tqdm (range (max_tokens ), disable = not verbose ):
610638 h = hidden [:, - 1 , :] # (1, dim)
@@ -635,53 +663,118 @@ def generate(
635663 if i % 50 == 0 :
636664 mx .clear_cache ()
637665
666+ # Streaming: yield chunk when buffer is full
667+ if stream and len (all_codes ) - yielded_frames >= frames_per_chunk :
668+ # Include context frames from earlier in the sequence so the
669+ # codec decoder's sliding-window attention has proper context,
670+ # avoiding boundary artifacts between chunks.
671+ ctx_start = max (0 , yielded_frames - context_frames )
672+ chunk_codes = mx .concatenate (all_codes [ctx_start :], axis = 1 )
673+ full_waveform = self .audio_tokenizer .decode (chunk_codes )
674+ full_waveform = full_waveform .squeeze (0 )
675+
676+ # Trim the context portion — keep only new audio
677+ ctx_used = yielded_frames - ctx_start
678+ trim_samples = ctx_used * samples_per_frame
679+ chunk_waveform = full_waveform [trim_samples :]
680+
681+ chunk_samples = chunk_waveform .shape [0 ]
682+ chunk_duration = chunk_samples / self .config .sample_rate
683+ chunk_time = time .time () - time_start
684+ chunk_token_count = len (all_codes ) - yielded_frames
685+
686+ yield GenerationResult (
687+ audio = chunk_waveform ,
688+ sample_rate = self .config .sample_rate ,
689+ samples = chunk_samples ,
690+ segment_idx = chunk_idx ,
691+ token_count = chunk_token_count ,
692+ audio_samples = {
693+ "samples" : chunk_samples ,
694+ "samples-per-sec" : self .config .sample_rate ,
695+ },
696+ audio_duration = self ._format_duration (chunk_duration ),
697+ real_time_factor = (
698+ chunk_duration / chunk_time if chunk_time > 0 else 0
699+ ),
700+ prompt = {
701+ "tokens" : chunk_token_count ,
702+ "tokens-per-sec" : (
703+ round (chunk_token_count / chunk_time , 2 )
704+ if chunk_time > 0
705+ else 0
706+ ),
707+ },
708+ processing_time_seconds = chunk_time ,
709+ peak_memory_usage = mx .get_peak_memory () / 1e9 ,
710+ is_streaming_chunk = True ,
711+ is_final_chunk = False ,
712+ )
713+ yielded_frames = len (all_codes )
714+ chunk_idx += 1
715+ time_start = time .time ()
716+
638717 if not all_codes :
639718 raise RuntimeError ("No audio frames generated" )
640719
641- audio_codes = mx .concatenate (all_codes , axis = 1 ) # (1, N_frames, 37)
642-
643- time_acoustic = time .time ()
644-
645- # Decode to waveform
646- waveform = self .audio_tokenizer .decode (audio_codes ) # (1, samples)
647- waveform = waveform .squeeze (0 ) # (samples,)
720+ # Final chunk: decode remaining frames (or all frames if not streaming)
721+ remaining = len (all_codes ) - yielded_frames
722+ if stream and yielded_frames > 0 and remaining > 0 :
723+ # Decode remainder with context for smooth transition
724+ ctx_start = max (0 , yielded_frames - context_frames )
725+ final_codes = mx .concatenate (all_codes [ctx_start :], axis = 1 )
726+ full_waveform = self .audio_tokenizer .decode (final_codes ).squeeze (0 )
727+ ctx_used = yielded_frames - ctx_start
728+ trim_samples = ctx_used * samples_per_frame
729+ waveform = full_waveform [trim_samples :]
730+ elif stream and yielded_frames > 0 and remaining == 0 :
731+ # Everything already yielded — emit a zero-length final marker
732+ waveform = mx .zeros ((0 ,))
733+ else :
734+ # Non-streaming (or no intermediate chunks were yielded):
735+ # decode everything at once — identical to the original path
736+ audio_codes = mx .concatenate (all_codes , axis = 1 )
737+ waveform = self .audio_tokenizer .decode (audio_codes ).squeeze (0 )
648738
649739 time_end = time .time ()
650740
651741 audio_samples = waveform .shape [0 ]
652742 audio_duration = audio_samples / self .config .sample_rate
653743
654- duration_mins = int (audio_duration // 60 )
655- duration_secs = int (audio_duration % 60 )
656- duration_ms = int ((audio_duration % 1 ) * 1000 )
657- duration_str = f"00:{ duration_mins :02d} :{ duration_secs :02d} .{ duration_ms :03d} "
658-
659744 processing_time = time_end - time_start
660745
661746 yield GenerationResult (
662747 audio = waveform ,
663748 sample_rate = self .config .sample_rate ,
664749 samples = audio_samples ,
665- segment_idx = 0 ,
666- token_count = audio_codes . shape [ 1 ] ,
750+ segment_idx = chunk_idx if stream else 0 ,
751+ token_count = remaining if stream and yielded_frames > 0 else len ( all_codes ) ,
667752 audio_samples = {
668753 "samples" : audio_samples ,
669754 "samples-per-sec" : self .config .sample_rate ,
670755 },
671- audio_duration = duration_str ,
756+ audio_duration = self . _format_duration ( audio_duration ) ,
672757 real_time_factor = (
673758 audio_duration / processing_time if processing_time > 0 else 0
674759 ),
675760 prompt = {
676- "tokens" : audio_codes .shape [1 ],
761+ "tokens" : (
762+ remaining if stream and yielded_frames > 0 else len (all_codes )
763+ ),
677764 "tokens-per-sec" : (
678- round (audio_codes .shape [1 ] / processing_time , 2 )
765+ round (
766+ (remaining if stream and yielded_frames > 0 else len (all_codes ))
767+ / processing_time ,
768+ 2 ,
769+ )
679770 if processing_time > 0
680771 else 0
681772 ),
682773 },
683774 processing_time_seconds = processing_time ,
684775 peak_memory_usage = mx .get_peak_memory () / 1e9 ,
776+ is_streaming_chunk = stream ,
777+ is_final_chunk = stream ,
685778 )
686779 mx .clear_cache ()
687780
0 commit comments