Skip to content

Commit d3db700

Browse files
Always send prefill before audio streaming; fix bfloat16 audio output
Signed-off-by: Elena Rastorgueva <erastorgueva@nvidia.com>
1 parent 8b849c1 commit d3db700

2 files changed

Lines changed: 37 additions & 34 deletions

File tree

examples/speechlm2/nemo_inference_pipelines/triton/client_streaming.py

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -134,42 +134,43 @@ def send_sequence_end(client, sequence_id):
134134
sequence_id = random.randint(1, 2**63 - 1) # Generate random uint64 value
135135

136136
try:
137-
# If a system prompt is provided, send a separate prefill request first:
138-
# zero-length audio + system_prompt, with sequence_start=True.
139-
prefill_sent = False
140-
if args.system_prompt is not None:
141-
logger.info(f"Sending prefill request with system_prompt ({len(args.system_prompt)} chars)")
142-
empty_audio = np.zeros((1, 0), dtype=np.float32)
143-
prefill_inputs = [
144-
grpcclient.InferInput(
145-
"audio_signal", empty_audio.shape, np_to_triton_dtype(empty_audio.dtype)
146-
),
147-
]
148-
prefill_inputs[0].set_data_from_numpy(empty_audio)
137+
# Always send a prefill request first (zero-length audio, sequence_start=True).
138+
# This initializes the TTS speaker embedding and system prompt for the session.
139+
# If --system_prompt is provided, it is included; otherwise the server uses
140+
# its configured default.
141+
logger.info("Sending prefill request%s",
142+
f" with system_prompt ({len(args.system_prompt)} chars)" if args.system_prompt else "")
143+
empty_audio = np.zeros((1, 0), dtype=np.float32)
144+
prefill_inputs = [
145+
grpcclient.InferInput(
146+
"audio_signal", empty_audio.shape, np_to_triton_dtype(empty_audio.dtype)
147+
),
148+
]
149+
prefill_inputs[0].set_data_from_numpy(empty_audio)
149150

151+
if args.system_prompt is not None:
150152
prompt_np = np.array([args.system_prompt.encode("utf-8")], dtype=object)
151153
prompt_input = grpcclient.InferInput("system_prompt", prompt_np.shape, "BYTES")
152154
prompt_input.set_data_from_numpy(prompt_np)
153155
prefill_inputs.append(prompt_input)
154156

155-
prefill_outputs = [
156-
grpcclient.InferRequestedOutput("output_text"),
157-
grpcclient.InferRequestedOutput("output_asr_text"),
158-
grpcclient.InferRequestedOutput("output_audio"),
159-
]
157+
prefill_outputs = [
158+
grpcclient.InferRequestedOutput("output_text"),
159+
grpcclient.InferRequestedOutput("output_asr_text"),
160+
grpcclient.InferRequestedOutput("output_audio"),
161+
]
160162

161-
prefill_start = time.time()
162-
client.infer(
163-
model_name,
164-
prefill_inputs,
165-
request_id=str(uuid.uuid4()),
166-
outputs=prefill_outputs,
167-
sequence_id=sequence_id,
168-
sequence_start=True,
169-
sequence_end=False,
170-
)
171-
logger.info(f"Prefill completed in {time.time() - prefill_start:.3f}s")
172-
prefill_sent = True
163+
prefill_start = time.time()
164+
client.infer(
165+
model_name,
166+
prefill_inputs,
167+
request_id=str(uuid.uuid4()),
168+
outputs=prefill_outputs,
169+
sequence_id=sequence_id,
170+
sequence_start=True,
171+
sequence_end=False,
172+
)
173+
logger.info(f"Prefill completed in {time.time() - prefill_start:.3f}s")
173174

174175
for idx, audio_chunk in tqdm(enumerate(audio_signal_chunks)):
175176
inputs = [
@@ -193,7 +194,7 @@ def send_sequence_end(client, sequence_id):
193194
request_id=str(uuid.uuid4()),
194195
outputs=outputs,
195196
sequence_id=sequence_id,
196-
sequence_start=(idx == 0 and not prefill_sent),
197+
sequence_start=False,
197198
sequence_end=idx == len(audio_signal_chunks) - 1,
198199
)
199200
end_time = time.time()

examples/speechlm2/nemo_inference_pipelines/triton/model_repo_s2s/voicechat/1/infer_streaming.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -307,9 +307,11 @@ def get_generations(self, frames: List[Frame]) -> List[Tuple]:
307307
def execute(self, requests: Iterable) -> List[pb_utils.InferenceResponse]:
308308
"""Execute the model and return the responses.
309309
310-
Zero-length audio with ``sequence_start=True`` and a ``system_prompt``
311-
is treated as a prefill-only request by the pipeline (no fake audio
312-
needed). All other requests are normal audio generation.
310+
Clients MUST send a prefill request (zero-length audio with
311+
``sequence_start=True``) before streaming audio. The prefill
312+
initializes the TTS speaker embedding and system prompt for the
313+
session. Sending audio on the first request without a prefill
314+
will produce degraded speaker voice quality.
313315
314316
Returns:
315317
- output_audio: float32 array of generated audio samples
@@ -329,7 +331,7 @@ def execute(self, requests: Iterable) -> List[pb_utils.InferenceResponse]:
329331
responses = []
330332
for audio, text, asr_text in generations:
331333
if isinstance(audio, torch.Tensor):
332-
audio_np = audio.detach().cpu().numpy().astype(np.float32)
334+
audio_np = audio.detach().cpu().float().numpy()
333335
if audio_np.ndim == 1:
334336
audio_np = audio_np.reshape(1, -1)
335337
else:

0 commit comments

Comments
 (0)