Skip to content

Commit 3f92bb2

Browse files
feat(voxtral_tts): add streaming support with overlap-add decoding (#618)
* feat(voxtral_tts): add streaming support for chunked audio generation Add stream and streaming_interval parameters to Voxtral TTS generate(). When streaming is enabled, audio is decoded and yielded in chunks during generation rather than waiting for all tokens to complete. This reduces time-to-first-audio (TTFA) significantly for longer inputs, enabling real-time conversational use cases. Benchmark results (M-series Mac, 4-bit quantized model): - Short text: ~2.4x TTFA speedup (974ms -> 402ms) - Medium text: ~4.2x TTFA speedup (1856ms -> 440ms) - Long text: ~11.7x TTFA speedup (5188ms -> 443ms) Follows the existing streaming patterns from Orpheus (#384) and Qwen3-TTS (#435). * fix(voxtral_tts): use overlap-add decoding to eliminate streaming boundary artifacts Each streaming chunk is now decoded with 16 frames of context from the preceding chunk, then the context portion is trimmed from the output. This gives the codec decoder's sliding-window attention (up to window size 16) proper context at chunk boundaries, eliminating audible clicks and noise between chunks. The first chunk has no context (decoded as-is). Subsequent chunks prepend up to 16 previous frames, decode the combined sequence, then return only the new audio samples. * Batch conditional and unconditional velocity predictions * format --------- Co-authored-by: Prince Canuma <prince.gdt@gmail.com>
1 parent c3667bd commit 3f92bb2

2 files changed

Lines changed: 121 additions & 20 deletions

File tree

mlx_audio/tts/models/voxtral_tts/acoustic_head.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,12 +209,20 @@ def decode_one_frame(self, llm_output: mx.array) -> mx.array:
209209
n_steps = args.n_denoising_steps
210210
timesteps = [i / (n_steps - 1) for i in range(n_steps)]
211211

212+
llm_uncond = mx.zeros_like(llm_output)
212213
for step in range(n_steps - 1):
213214
t_val = timesteps[step]
214215
dt = timesteps[step + 1] - t_val
215216
t = mx.full((B,), t_val)
216-
v_cond = self._predict_velocity(x_t, t, llm_output)
217-
v_uncond = self._predict_velocity(x_t, t, mx.zeros_like(llm_output))
217+
218+
# Batch conditional and unconditional velocity predictions
219+
# into a single forward pass (B=2) through the acoustic transformer
220+
x_t_batch = mx.concatenate([x_t, x_t], axis=0)
221+
t_batch = mx.concatenate([t, t], axis=0)
222+
llm_batch = mx.concatenate([llm_output, llm_uncond], axis=0)
223+
v_both = self._predict_velocity(x_t_batch, t_batch, llm_batch)
224+
v_cond, v_uncond = v_both[:B], v_both[B:]
225+
218226
v = args.cfg_alpha * v_cond + (1.0 - args.cfg_alpha) * v_uncond
219227
x_t = x_t + v * dt
220228

mlx_audio/tts/models/voxtral_tts/voxtral_tts.py

Lines changed: 111 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,14 @@ def _remap_layer_key(self, layer_idx: str, suffix: str) -> str:
545545
else:
546546
return f"{prefix}.{suffix}"
547547

548+
@staticmethod
549+
def _format_duration(seconds: float) -> str:
550+
"""Format a duration in seconds as 00:MM:SS.mmm."""
551+
mins = int(seconds // 60)
552+
secs = int(seconds % 60)
553+
ms = int((seconds % 1) * 1000)
554+
return f"00:{mins:02d}:{secs:02d}.{ms:03d}"
555+
548556
def generate(
549557
self,
550558
text: str,
@@ -554,6 +562,8 @@ def generate(
554562
top_p: float = 0.95,
555563
max_tokens: int = 4096,
556564
verbose: bool = False,
565+
stream: bool = False,
566+
streaming_interval: float = 2.0,
557567
**kwargs,
558568
) -> GenerationResult:
559569
"""Generate speech from text.
@@ -566,9 +576,16 @@ def generate(
566576
top_p: Nucleus sampling threshold.
567577
max_tokens: Maximum number of audio tokens to generate.
568578
verbose: Show progress bar.
579+
stream: Enable streaming output. When True, intermediate audio
580+
chunks are yielded during generation for lower latency.
581+
streaming_interval: Approximate seconds of audio per streaming
582+
chunk. Each frame is 80 ms, so the interval is converted to
583+
a frame count (``max(1, int(streaming_interval / 0.08))``).
569584
570585
Yields:
571-
GenerationResult with audio waveform.
586+
GenerationResult with audio waveform. When *stream* is True,
587+
intermediate results have ``is_streaming_chunk=True`` and the
588+
last result additionally has ``is_final_chunk=True``.
572589
"""
573590
from mlx_lm.models.cache import make_prompt_cache
574591

@@ -605,6 +622,17 @@ def generate(
605622
)
606623

607624
all_codes = []
625+
yielded_frames = 0
626+
chunk_idx = 0
627+
# Convert streaming_interval (seconds) to frames (1 frame = 80 ms)
628+
frames_per_chunk = max(1, int(streaming_interval / 0.08))
629+
# Context frames for overlap-add decoding. The codec decoder uses
630+
# sliding-window attention with windows up to 16 (at the final stage).
631+
# Including context frames from the previous chunk ensures smooth
632+
# transitions without boundary artifacts.
633+
context_frames = 16
634+
# Each codec frame produces 1920 samples (8x upsample × 240 patch)
635+
samples_per_frame = 1920
608636

609637
for i in tqdm(range(max_tokens), disable=not verbose):
610638
h = hidden[:, -1, :] # (1, dim)
@@ -635,53 +663,118 @@ def generate(
635663
if i % 50 == 0:
636664
mx.clear_cache()
637665

666+
# Streaming: yield chunk when buffer is full
667+
if stream and len(all_codes) - yielded_frames >= frames_per_chunk:
668+
# Include context frames from earlier in the sequence so the
669+
# codec decoder's sliding-window attention has proper context,
670+
# avoiding boundary artifacts between chunks.
671+
ctx_start = max(0, yielded_frames - context_frames)
672+
chunk_codes = mx.concatenate(all_codes[ctx_start:], axis=1)
673+
full_waveform = self.audio_tokenizer.decode(chunk_codes)
674+
full_waveform = full_waveform.squeeze(0)
675+
676+
# Trim the context portion — keep only new audio
677+
ctx_used = yielded_frames - ctx_start
678+
trim_samples = ctx_used * samples_per_frame
679+
chunk_waveform = full_waveform[trim_samples:]
680+
681+
chunk_samples = chunk_waveform.shape[0]
682+
chunk_duration = chunk_samples / self.config.sample_rate
683+
chunk_time = time.time() - time_start
684+
chunk_token_count = len(all_codes) - yielded_frames
685+
686+
yield GenerationResult(
687+
audio=chunk_waveform,
688+
sample_rate=self.config.sample_rate,
689+
samples=chunk_samples,
690+
segment_idx=chunk_idx,
691+
token_count=chunk_token_count,
692+
audio_samples={
693+
"samples": chunk_samples,
694+
"samples-per-sec": self.config.sample_rate,
695+
},
696+
audio_duration=self._format_duration(chunk_duration),
697+
real_time_factor=(
698+
chunk_duration / chunk_time if chunk_time > 0 else 0
699+
),
700+
prompt={
701+
"tokens": chunk_token_count,
702+
"tokens-per-sec": (
703+
round(chunk_token_count / chunk_time, 2)
704+
if chunk_time > 0
705+
else 0
706+
),
707+
},
708+
processing_time_seconds=chunk_time,
709+
peak_memory_usage=mx.get_peak_memory() / 1e9,
710+
is_streaming_chunk=True,
711+
is_final_chunk=False,
712+
)
713+
yielded_frames = len(all_codes)
714+
chunk_idx += 1
715+
time_start = time.time()
716+
638717
if not all_codes:
639718
raise RuntimeError("No audio frames generated")
640719

641-
audio_codes = mx.concatenate(all_codes, axis=1) # (1, N_frames, 37)
642-
643-
time_acoustic = time.time()
644-
645-
# Decode to waveform
646-
waveform = self.audio_tokenizer.decode(audio_codes) # (1, samples)
647-
waveform = waveform.squeeze(0) # (samples,)
720+
# Final chunk: decode remaining frames (or all frames if not streaming)
721+
remaining = len(all_codes) - yielded_frames
722+
if stream and yielded_frames > 0 and remaining > 0:
723+
# Decode remainder with context for smooth transition
724+
ctx_start = max(0, yielded_frames - context_frames)
725+
final_codes = mx.concatenate(all_codes[ctx_start:], axis=1)
726+
full_waveform = self.audio_tokenizer.decode(final_codes).squeeze(0)
727+
ctx_used = yielded_frames - ctx_start
728+
trim_samples = ctx_used * samples_per_frame
729+
waveform = full_waveform[trim_samples:]
730+
elif stream and yielded_frames > 0 and remaining == 0:
731+
# Everything already yielded — emit a zero-length final marker
732+
waveform = mx.zeros((0,))
733+
else:
734+
# Non-streaming (or no intermediate chunks were yielded):
735+
# decode everything at once — identical to the original path
736+
audio_codes = mx.concatenate(all_codes, axis=1)
737+
waveform = self.audio_tokenizer.decode(audio_codes).squeeze(0)
648738

649739
time_end = time.time()
650740

651741
audio_samples = waveform.shape[0]
652742
audio_duration = audio_samples / self.config.sample_rate
653743

654-
duration_mins = int(audio_duration // 60)
655-
duration_secs = int(audio_duration % 60)
656-
duration_ms = int((audio_duration % 1) * 1000)
657-
duration_str = f"00:{duration_mins:02d}:{duration_secs:02d}.{duration_ms:03d}"
658-
659744
processing_time = time_end - time_start
660745

661746
yield GenerationResult(
662747
audio=waveform,
663748
sample_rate=self.config.sample_rate,
664749
samples=audio_samples,
665-
segment_idx=0,
666-
token_count=audio_codes.shape[1],
750+
segment_idx=chunk_idx if stream else 0,
751+
token_count=remaining if stream and yielded_frames > 0 else len(all_codes),
667752
audio_samples={
668753
"samples": audio_samples,
669754
"samples-per-sec": self.config.sample_rate,
670755
},
671-
audio_duration=duration_str,
756+
audio_duration=self._format_duration(audio_duration),
672757
real_time_factor=(
673758
audio_duration / processing_time if processing_time > 0 else 0
674759
),
675760
prompt={
676-
"tokens": audio_codes.shape[1],
761+
"tokens": (
762+
remaining if stream and yielded_frames > 0 else len(all_codes)
763+
),
677764
"tokens-per-sec": (
678-
round(audio_codes.shape[1] / processing_time, 2)
765+
round(
766+
(remaining if stream and yielded_frames > 0 else len(all_codes))
767+
/ processing_time,
768+
2,
769+
)
679770
if processing_time > 0
680771
else 0
681772
),
682773
},
683774
processing_time_seconds=processing_time,
684775
peak_memory_usage=mx.get_peak_memory() / 1e9,
776+
is_streaming_chunk=stream,
777+
is_final_chunk=stream,
685778
)
686779
mx.clear_cache()
687780

0 commit comments

Comments
 (0)