Skip to content

Commit 8b849c1

Browse files
update triton code; bugfix for vllm dtype/device
Signed-off-by: Elena Rastorgueva <erastorgueva@nvidia.com>
1 parent de230d9 commit 8b849c1

5 files changed

Lines changed: 65 additions & 36 deletions

File tree

examples/speechlm2/nemo_inference_pipelines/triton/client_streaming.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def send_sequence_end(client, sequence_id):
105105

106106
outputs = [
107107
grpcclient.InferRequestedOutput("output_text"),
108+
grpcclient.InferRequestedOutput("output_asr_text"),
108109
grpcclient.InferRequestedOutput("output_audio"),
109110
]
110111

@@ -115,7 +116,7 @@ def send_sequence_end(client, sequence_id):
115116
outputs=outputs,
116117
sequence_id=sequence_id,
117118
sequence_start=False,
118-
sequence_end=True, # This is the key - properly end the sequence
119+
sequence_end=True,
119120
)
120121
logger.info("Sequence ended successfully")
121122

examples/speechlm2/nemo_inference_pipelines/triton/model_repo_s2s/voicechat/1/infer_streaming.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -43,35 +43,45 @@ def _resolve_env_overrides(self, cfg):
4343
env vars, while sharing the same s2s_streaming.yaml used by the CLI.
4444
4545
Env var mapping (cfg key -> env var, default):
46-
s2s.model_path -> S2S_MODEL_PATH (required)
47-
s2s.llm_checkpoint_path -> S2S_LLM_CHECKPOINT_PATH (required)
48-
s2s.speaker_reference -> S2S_SPEAKER_REFERENCE (required)
49-
s2s.engine_type -> S2S_ENGINE_TYPE (default: native)
50-
s2s.system_prompt -> S2S_SYSTEM_PROMPT (default: none)
51-
s2s.tts_system_prompt -> S2S_TTS_SYSTEM_PROMPT (default: none)
46+
s2s.model_path -> S2S_MODEL_PATH (required)
47+
s2s.speaker_reference -> S2S_SPEAKER_REFERENCE (optional)
48+
s2s.speaker_name -> S2S_SPEAKER_NAME (optional)
49+
s2s.engine_type -> S2S_ENGINE_TYPE (default: native)
50+
s2s.deterministic -> S2S_DETERMINISTIC (default: false)
51+
s2s.use_llm_cache -> S2S_USE_LLM_CACHE (default: true)
52+
s2s.use_tts_subword_cache -> S2S_USE_TTS_SUBWORD_CACHE (default: false)
53+
s2s.system_prompt -> S2S_SYSTEM_PROMPT (optional)
54+
s2s.tts_system_prompt -> S2S_TTS_SYSTEM_PROMPT (optional)
5255
streaming.chunk_size_in_secs -> S2S_CHUNK_SIZE_IN_SECS (default: 0.08)
5356
streaming.buffer_size_in_secs -> S2S_BUFFER_SIZE_IN_SECS (default: 5.6)
5457
"""
5558
env_overrides = {
5659
# Required
57-
"s2s.model_path": ("S2S_MODEL_PATH", None),
58-
"s2s.llm_checkpoint_path": ("S2S_LLM_CHECKPOINT_PATH", None),
59-
"s2s.speaker_reference": ("S2S_SPEAKER_REFERENCE", None),
60-
# Optional (with defaults)
61-
"s2s.engine_type": ("S2S_ENGINE_TYPE", "native"),
62-
"s2s.system_prompt": ("S2S_SYSTEM_PROMPT", None),
63-
"s2s.tts_system_prompt": ("S2S_TTS_SYSTEM_PROMPT", None),
60+
"s2s.model_path": ("S2S_MODEL_PATH", None),
61+
# Speaker identity (set one or both)
62+
"s2s.speaker_reference": ("S2S_SPEAKER_REFERENCE", None),
63+
"s2s.speaker_name": ("S2S_SPEAKER_NAME", None),
64+
# Engine & precision
65+
"s2s.engine_type": ("S2S_ENGINE_TYPE", "native"),
66+
"s2s.deterministic": ("S2S_DETERMINISTIC", False),
67+
# Cache / speedup flags
68+
"s2s.use_llm_cache": ("S2S_USE_LLM_CACHE", True),
69+
"s2s.use_tts_subword_cache": ("S2S_USE_TTS_SUBWORD_CACHE", False),
70+
# Prompts
71+
"s2s.system_prompt": ("S2S_SYSTEM_PROMPT", None),
72+
"s2s.tts_system_prompt": ("S2S_TTS_SYSTEM_PROMPT", None),
73+
# Streaming
6474
"streaming.chunk_size_in_secs": ("S2S_CHUNK_SIZE_IN_SECS", 0.08),
65-
"streaming.buffer_size_in_secs": ("S2S_BUFFER_SIZE_IN_SECS", 5.6),
75+
"streaming.buffer_size_in_secs":("S2S_BUFFER_SIZE_IN_SECS", 5.6),
6676
}
6777
for cfg_key, (env_var, default) in env_overrides.items():
68-
val = os.environ.get(env_var)
69-
if val is not None:
70-
if default is not None and isinstance(default, bool):
78+
val = os.environ.get(env_var, "")
79+
if val:
80+
if isinstance(default, bool):
7181
val = val.lower() in ("true", "1", "yes")
72-
elif default is not None and isinstance(default, float):
82+
elif isinstance(default, float):
7383
val = float(val)
74-
elif default is not None and isinstance(default, int):
84+
elif isinstance(default, int):
7585
val = int(val)
7686
OmegaConf.update(cfg, cfg_key, val, force_add=True)
7787
elif default is not None:

examples/speechlm2/nemo_inference_pipelines/triton/start_triton.sh

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,26 @@
1919
# Fields marked ??? in the YAML are resolved from environment variables below.
2020
#
2121
# Usage:
22-
# S2S_MODEL_PATH=/path/to/eartts_ckpt \
23-
# S2S_LLM_CHECKPOINT_PATH=/path/to/llm_ckpt \
24-
# S2S_SPEAKER_REFERENCE=/path/to/speaker.wav \
22+
# S2S_MODEL_PATH=/path/to/hf_checkpoint \
23+
# S2S_SPEAKER_NAME=MySpeaker \
2524
# ./start_triton.sh
2625
#
2726
# Environment variables (required):
28-
# S2S_MODEL_PATH - Path to the EarTTS / S2S checkpoint
29-
# S2S_LLM_CHECKPOINT_PATH - Path to the LLM checkpoint
27+
# S2S_MODEL_PATH - Path to the HF-format checkpoint directory
28+
#
29+
# Environment variables (speaker identity — set at least one):
3030
# S2S_SPEAKER_REFERENCE - Path to a speaker reference .wav file
31+
# S2S_SPEAKER_NAME - Registered speaker name from the checkpoint
3132
#
3233
# Environment variables (optional):
3334
# S2S_ENGINE_TYPE - Engine type (default: native)
35+
# S2S_DETERMINISTIC - "true"/"false": deterministic mode (default: false)
36+
# S2S_USE_LLM_CACHE - "true"/"false": LLM KV cache (default: true)
37+
# S2S_USE_TTS_SUBWORD_CACHE - "true"/"false": TTS subword cache (default: false)
3438
# S2S_SYSTEM_PROMPT - LLM system prompt text (default: none)
35-
# S2S_TTS_SYSTEM_PROMPT - TTS system prompt, (default: none)
39+
# S2S_TTS_SYSTEM_PROMPT - TTS system prompt (default: none)
3640
# S2S_CHUNK_SIZE_IN_SECS - Chunk size in seconds, multiple of 0.08 (default: 0.08)
3741
# S2S_BUFFER_SIZE_IN_SECS - Audio buffer size in seconds (default: 5.6)
38-
# S2S_USE_CODEC_CACHE - "true"/"false": incremental codec decode (default: true)
3942
# S2S_TRITON_CONFIG_PATH - Override the YAML config file path
4043
# MODEL_REPO_DIR - Override the Triton model repository path
4144

@@ -45,33 +48,45 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
4548
# backend (infer_streaming.py reads them via os.environ).
4649

4750
# ========================
48-
# Model paths (required)
51+
# Model path (required)
52+
# ========================
53+
export S2S_MODEL_PATH="${S2S_MODEL_PATH:?Please set S2S_MODEL_PATH to the HF-format checkpoint directory}"
54+
55+
# ========================
56+
# Speaker identity (set at least one)
4957
# ========================
50-
export S2S_MODEL_PATH="${S2S_MODEL_PATH:?Please set S2S_MODEL_PATH to the EarTTS / S2S checkpoint path}"
51-
export S2S_LLM_CHECKPOINT_PATH="${S2S_LLM_CHECKPOINT_PATH:?Please set S2S_LLM_CHECKPOINT_PATH to the LLM checkpoint path}"
52-
export S2S_SPEAKER_REFERENCE="${S2S_SPEAKER_REFERENCE:?Please set S2S_SPEAKER_REFERENCE to a speaker reference .wav file}"
58+
export S2S_SPEAKER_REFERENCE="${S2S_SPEAKER_REFERENCE:-}"
59+
export S2S_SPEAKER_NAME="${S2S_SPEAKER_NAME:-}"
60+
if [ -z "${S2S_SPEAKER_REFERENCE}" ] && [ -z "${S2S_SPEAKER_NAME}" ]; then
61+
echo "ERROR: Set at least one of S2S_SPEAKER_REFERENCE or S2S_SPEAKER_NAME"
62+
exit 1
63+
fi
5364

5465
# ========================
5566
# Optional overrides
5667
# ========================
5768
export S2S_ENGINE_TYPE="${S2S_ENGINE_TYPE:-native}"
69+
export S2S_DETERMINISTIC="${S2S_DETERMINISTIC:-}"
70+
export S2S_USE_LLM_CACHE="${S2S_USE_LLM_CACHE:-}"
71+
export S2S_USE_TTS_SUBWORD_CACHE="${S2S_USE_TTS_SUBWORD_CACHE:-}"
5872
export S2S_SYSTEM_PROMPT="${S2S_SYSTEM_PROMPT:-}"
5973
export S2S_TTS_SYSTEM_PROMPT="${S2S_TTS_SYSTEM_PROMPT:-}"
6074
export S2S_CHUNK_SIZE_IN_SECS="${S2S_CHUNK_SIZE_IN_SECS:-0.08}"
6175
export S2S_BUFFER_SIZE_IN_SECS="${S2S_BUFFER_SIZE_IN_SECS:-5.6}"
62-
export S2S_USE_CODEC_CACHE="${S2S_USE_CODEC_CACHE:-true}"
6376
export S2S_TRITON_CONFIG_PATH="${S2S_TRITON_CONFIG_PATH:-${SCRIPT_DIR}/../conf/s2s_streaming.yaml}"
6477
export MODEL_REPO_DIR="${MODEL_REPO_DIR:-${SCRIPT_DIR}/model_repo_s2s}"
6578

6679

6780
echo "=== S2S Triton Server ==="
6881
echo " S2S_MODEL_PATH: ${S2S_MODEL_PATH}"
69-
echo " S2S_LLM_CHECKPOINT_PATH: ${S2S_LLM_CHECKPOINT_PATH}"
70-
echo " S2S_SPEAKER_REFERENCE: ${S2S_SPEAKER_REFERENCE}"
82+
echo " S2S_SPEAKER_REFERENCE: ${S2S_SPEAKER_REFERENCE:-<not set>}"
83+
echo " S2S_SPEAKER_NAME: ${S2S_SPEAKER_NAME:-<not set>}"
7184
echo " S2S_ENGINE_TYPE: ${S2S_ENGINE_TYPE}"
85+
echo " S2S_DETERMINISTIC: ${S2S_DETERMINISTIC:-<default>}"
86+
echo " S2S_USE_LLM_CACHE: ${S2S_USE_LLM_CACHE:-<default>}"
87+
echo " S2S_USE_TTS_SUBWORD_CACHE: ${S2S_USE_TTS_SUBWORD_CACHE:-<default>}"
7288
echo " S2S_CHUNK_SIZE_IN_SECS: ${S2S_CHUNK_SIZE_IN_SECS}"
7389
echo " S2S_BUFFER_SIZE_IN_SECS: ${S2S_BUFFER_SIZE_IN_SECS}"
74-
echo " S2S_USE_CODEC_CACHE: ${S2S_USE_CODEC_CACHE}"
7590
echo " S2S_SYSTEM_PROMPT: ${S2S_SYSTEM_PROMPT:-<not set>}"
7691
echo " S2S_TTS_SYSTEM_PROMPT: ${S2S_TTS_SYSTEM_PROMPT:-<not set>}"
7792
echo " MODEL_REPO_DIR: ${MODEL_REPO_DIR}"

nemo/collections/speechlm2/inference/model_wrappers/model_factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,8 @@ def _sample_text_token(
145145
unique_prev = unique_prev[~torch.isin(unique_prev, ids_t)]
146146

147147
if unique_prev.numel() > 0:
148+
if unique_prev.device != batch_logits.device:
149+
unique_prev = unique_prev.to(batch_logits.device)
148150
prev_logits = batch_logits[unique_prev]
149151
# Positive logits are divided, negative logits are multiplied
150152
# (same as the standard repetition_penalty convention)

nemo/collections/speechlm2/inference/model_wrappers/nemotron_voicechat_inference_wrapper.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,8 @@ def _initialize_model(self):
241241

242242
# Convert some S2S components to the configured dtype
243243
logging.info(f"Converting some S2S components to {self.dtype} (keeping perception & TTS in float32)...")
244-
self.model.stt_model.llm = self.model.stt_model.llm.to(self.dtype)
244+
if self.model.stt_model.llm is not None:
245+
self.model.stt_model.llm = self.model.stt_model.llm.to(self.dtype)
245246
self.model.stt_model.lm_head = self.model.stt_model.lm_head.to(self.dtype)
246247
self.model.stt_model.embed_tokens = self.model.stt_model.embed_tokens.to(self.dtype)
247248
self.model.stt_model.asr_head = self.model.stt_model.asr_head.to(self.dtype)

0 commit comments

Comments
 (0)