Skip to content

Commit f50ca53

Browse files
authored
[TRTLLM-13120][feat] Cosmos3 Audio Output Support (#14827)
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
1 parent 2ffab8d commit f50ca53

21 files changed

Lines changed: 2333 additions & 274 deletions

File tree

examples/visual_gen/configs/cosmos3-nano-1gpu.yaml

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,14 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
# 1-GPU Cosmos3 (Nano / Super) with FP8 dynamic quantization.
16+
# 1-GPU Cosmos3 (Nano / Super).
1717
# Model: nvidia/Cosmos3-Nano or nvidia/Cosmos3-Super
1818
# Shared by offline examples (--visual_gen_args) and trtllm-serve.
1919
#
2020
# Cosmos3 constraints: VANILLA attention only;
21-
# no Attention2D / Ring. Use CFG + Ulysses for multi-GPU (see cosmos3-super-4gpu.yaml).
22-
quant_config:
23-
quant_algo: FP8
24-
dynamic: true
25-
ignore: ["language_model.*", "vae2llm", "llm2vae", "time_embedder.*"]
21+
# Use CFG + Ulysses for multi-GPU (see cosmos3-super-4gpu.yaml).
2622
attention_config:
2723
backend: VANILLA
2824
parallel_config:
2925
cfg_size: 1
3026
ulysses_size: 1
31-
cuda_graph_config:
32-
enable: false

examples/visual_gen/configs/cosmos3-super-4gpu.yaml

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,12 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
# 4-GPU Cosmos3-Super with FP8 dynamic quantization (CFG + Ulysses + parallel VAE).
16+
# 4-GPU Cosmos3-Super with (CFG + Ulysses + parallel VAE).
1717
# Launch with 4 processes, e.g. torchrun --nproc_per_node=4 ...
1818
# Model: nvidia/Cosmos3-Super
1919
# Shared by offline examples (--visual_gen_args) and trtllm-serve.
2020
#
2121
# GPU layout: cfg_size=2 (positive | negative) x ulysses_size=2 (sequence split).
22-
quant_config:
23-
quant_algo: FP8
24-
dynamic: true
25-
ignore: ["language_model.*", "vae2llm", "llm2vae", "time_embedder.*"]
2622
attention_config:
2723
backend: VANILLA
2824
parallel_config:
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Cosmos3 Text(+Image)-to-Video(+Audio) generation
2+
3+
Cosmos3 supports four generation modes from a single checkpoint:
4+
5+
- **T2V** — text-to-video (`prompts/t2v.json`).
6+
- **T2I** — text-to-image (`prompts/t2i.json`); emits a still frame (use `--output_type image` / a non-video `--output_path`).
7+
- **I2V / TI2V** — image-conditioned video (`prompts/i2v.json`). Condition on a reference frame via the prompt file's `vision_path` or `--image_path`. The image may be a local path, a `file://` / `http(s)://` URL, or a `data:` URI.
8+
- **T2AV** — text-to-video with synchronized audio (`prompts/t2av.json` with `enable_audio: true`, or pass `--enable_audio`). Combine with a `vision_path` for image-conditioned audio-video (TI2AV).
9+
10+
## Checkpoints
11+
12+
Pass the Hub ID or local path via `--model`:
13+
14+
- [`nvidia/Cosmos3-Nano`](https://huggingface.co/nvidia/Cosmos3-Nano)
15+
- [`nvidia/Cosmos3-Super`](https://huggingface.co/nvidia/Cosmos3-Super)
16+
17+
## Guardrails
18+
19+
Guardrails are enabled by default (required by the [NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license)). Install and authenticate as follows:
20+
21+
```bash
22+
pip install cosmos_guardrail==0.3.0 && pip uninstall opencv-python
23+
```
24+
25+
Accept the terms for the guardrail checkpoint at https://huggingface.co/nvidia/Cosmos-1.0-Guardrail and set a valid `HF_TOKEN` (the checkpoint is downloaded automatically on first run).
26+
27+
To run without guardrails (you are responsible for safe deployment):
28+
29+
```bash
30+
export TRTLLM_DISABLE_COSMOS3_GUARDRAILS=1
31+
```
32+
33+
## Deployment configs
34+
35+
See `examples/visual_gen/configs/`:
36+
37+
- `cosmos3-nano-1gpu.yaml` — 1 GPU
38+
- `cosmos3-super-4gpu.yaml` — 4 GPU, CFG + Ulysses + parallel VAE
39+
40+
Example prompts live under `prompts/` (mirroring `cosmos3-internal/inputs/omni`).
41+
42+
## Usage
43+
44+
```bash
45+
# T2V: text-to-video
46+
python cosmos3.py --model nvidia/Cosmos3-Nano \
47+
--prompt_file prompts/t2v.json \
48+
--visual_gen_args ../configs/cosmos3-nano-1gpu.yaml
49+
50+
# I2V/TI2V: image-conditioned video (vision_path is read from the prompt file;
51+
# local path, file://, http(s):// URL, or data: URI are all accepted)
52+
python cosmos3.py --model nvidia/Cosmos3-Nano \
53+
--prompt_file prompts/i2v.json \
54+
--visual_gen_args ../configs/cosmos3-nano-1gpu.yaml
55+
56+
# I2V with an explicit conditioning image (overrides the prompt file)
57+
python cosmos3.py --model nvidia/Cosmos3-Nano \
58+
--prompt_file prompts/i2v.json \
59+
--image_path https://example.com/frame.jpg \
60+
--visual_gen_args ../configs/cosmos3-nano-1gpu.yaml
61+
62+
# T2AV: text-to-video with synchronized audio
63+
python cosmos3.py --model nvidia/Cosmos3-Nano \
64+
--prompt_file prompts/t2av.json \
65+
--visual_gen_args ../configs/cosmos3-nano-1gpu.yaml
66+
67+
# T2I: text-to-image
68+
python cosmos3.py --model nvidia/Cosmos3-Nano \
69+
--prompt_file prompts/t2i.json \
70+
--visual_gen_args ../configs/cosmos3-nano-1gpu.yaml \
71+
--output_path output.png
72+
73+
# Inline prompt (--prompt or a JSON file path)
74+
python cosmos3.py --model nvidia/Cosmos3-Nano \
75+
--prompt "A cute puppy playing with a ball in a park" \
76+
--visual_gen_args ../configs/cosmos3-nano-1gpu.yaml
77+
```
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
#!/usr/bin/env python3
2+
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import argparse
18+
import json
19+
import os
20+
from pathlib import Path
21+
from typing import Any, Dict, Optional
22+
23+
from tensorrt_llm import VisualGen, VisualGenArgs
24+
25+
_SCRIPT_DIR = Path(__file__).resolve().parent
26+
27+
28+
def _resolve_path(path: str) -> str:
29+
candidate = Path(path)
30+
if candidate.is_file():
31+
return str(candidate.resolve())
32+
relative_to_script = _SCRIPT_DIR / path
33+
if relative_to_script.is_file():
34+
return str(relative_to_script.resolve())
35+
return path
36+
37+
38+
def load_prompt_file(path: str) -> Dict[str, Any]:
39+
"""Load a Cosmos3 omni prompt JSON (``prompt``, optional ``vision_path``, etc.)."""
40+
resolved = _resolve_path(path)
41+
with open(resolved, encoding="utf-8") as f:
42+
data = json.load(f)
43+
if not isinstance(data, dict):
44+
raise ValueError(f"Prompt file must be a JSON object, got {type(data)!r}.")
45+
if not data.get("prompt"):
46+
raise ValueError(f"Prompt file {resolved!r} is missing a non-empty 'prompt' field.")
47+
return data
48+
49+
50+
def resolve_prompt_and_options(
51+
*,
52+
prompt: Optional[str],
53+
prompt_file: Optional[str],
54+
image_path: Optional[str],
55+
enable_audio: bool,
56+
output_type: str,
57+
) -> tuple[str, Optional[str], bool, str]:
58+
"""Merge CLI args with optional prompt-file defaults."""
59+
prompt_data: Dict[str, Any] = {}
60+
if prompt_file is not None:
61+
prompt_data = load_prompt_file(prompt_file)
62+
63+
resolved_prompt = prompt
64+
if resolved_prompt is None:
65+
resolved_prompt = prompt_data.get("prompt")
66+
if not resolved_prompt:
67+
raise ValueError("Provide --prompt or --prompt_file with a 'prompt' field.")
68+
69+
resolved_image = image_path
70+
if resolved_image is None:
71+
resolved_image = prompt_data.get("vision_path") or prompt_data.get("image_path")
72+
73+
resolved_enable_audio = enable_audio or bool(prompt_data.get("enable_audio", False))
74+
75+
resolved_output_type = output_type
76+
model_mode = str(prompt_data.get("model_mode", "")).lower()
77+
if model_mode == "text2image" and output_type == "video":
78+
resolved_output_type = "image"
79+
80+
return resolved_prompt, resolved_image, resolved_enable_audio, resolved_output_type
81+
82+
83+
def main():
84+
parser = argparse.ArgumentParser(description="Cosmos3 Text(+Image)-to-Video(+Audio) example")
85+
parser.add_argument(
86+
"--model",
87+
type=str,
88+
default="nvidia/Cosmos3-Nano",
89+
help="Model path or HuggingFace Hub ID (nvidia/Cosmos3-Nano, nvidia/Cosmos3-Super)",
90+
)
91+
parser.add_argument(
92+
"--visual_gen_args",
93+
dest="visual_gen_args",
94+
type=str,
95+
default=None,
96+
help="Path to YAML config (same as trtllm-serve --visual_gen_args)",
97+
)
98+
parser.add_argument(
99+
"--prompt",
100+
type=str,
101+
default=None,
102+
help="Text prompt for generation (overrides --prompt_file when both are set)",
103+
)
104+
parser.add_argument(
105+
"--prompt_file",
106+
type=str,
107+
default="prompts/t2v.json",
108+
help="Path to a JSON prompt file (default: prompts/t2v.json)",
109+
)
110+
parser.add_argument(
111+
"--negative_prompt",
112+
type=str,
113+
default="cosmos3_negative_prompt.json",
114+
help="Text prompt or path to JSON file for negative prompt",
115+
)
116+
parser.add_argument(
117+
"--image_path",
118+
type=str,
119+
default=None,
120+
help="Optional conditioning image path or URL for I2V/TI2V",
121+
)
122+
parser.add_argument(
123+
"--output_path",
124+
type=str,
125+
default="cosmos3_output.mp4",
126+
help="Path to save the output video",
127+
)
128+
parser.add_argument(
129+
"--disable_duration_template",
130+
action="store_true",
131+
help="Disable duration metadata template (enabled by default, matching cosmos-framework CLI)",
132+
)
133+
parser.add_argument(
134+
"--disable_resolution_template",
135+
action="store_true",
136+
help="Disable resolution metadata template (enabled by default, matching cosmos-framework CLI)",
137+
)
138+
parser.add_argument(
139+
"--use_system_prompt", action="store_true", help="Use system prompt in prompt"
140+
)
141+
parser.add_argument("--enable_audio", action="store_true", help="Enable audio generation")
142+
parser.add_argument(
143+
"--output_type", type=str, default="video", help="Output type (video, image)"
144+
)
145+
146+
# Guardrails
147+
parser.add_argument(
148+
"--disable_guardrails", action="store_true", help="NOT RECOMMENDED: Disable guardrails"
149+
)
150+
args = parser.parse_args()
151+
152+
prompt, image_path, enable_audio, output_type = resolve_prompt_and_options(
153+
prompt=args.prompt,
154+
prompt_file=args.prompt_file,
155+
image_path=args.image_path,
156+
enable_audio=args.enable_audio,
157+
output_type=args.output_type,
158+
)
159+
160+
# Engine config from shared YAML (optional); model-specific defaults apply otherwise.
161+
extra_args = VisualGenArgs.from_yaml(args.visual_gen_args) if args.visual_gen_args else None
162+
visual_gen = VisualGen(model=args.model, args=extra_args)
163+
164+
# --- Model-specific: T2V / TI2V request construction ---
165+
# Query per-model defaults (resolution, steps, guidance, seed, etc.).
166+
params = visual_gen.default_params
167+
if image_path is not None:
168+
params.image = image_path
169+
170+
negative_prompt_path = _resolve_path(args.negative_prompt)
171+
if args.negative_prompt is not None:
172+
if os.path.isfile(negative_prompt_path) and negative_prompt_path.endswith(".json"):
173+
with open(negative_prompt_path, encoding="utf-8") as f:
174+
negative_prompt = json.load(f)
175+
else:
176+
negative_prompt = args.negative_prompt
177+
else:
178+
negative_prompt = None
179+
180+
if args.disable_duration_template:
181+
params.extra_params["use_duration_template"] = False
182+
if args.disable_resolution_template:
183+
params.extra_params["use_resolution_template"] = False
184+
params.extra_params["use_system_prompt"] = args.use_system_prompt
185+
params.extra_params["enable_audio"] = enable_audio
186+
params.extra_params["use_guardrails"] = not args.disable_guardrails
187+
params.extra_params["output_type"] = output_type
188+
189+
if negative_prompt is None:
190+
params.negative_prompt = None
191+
elif isinstance(negative_prompt, str):
192+
params.negative_prompt = negative_prompt
193+
else:
194+
params.negative_prompt = json.dumps(negative_prompt)
195+
196+
output = visual_gen.generate(
197+
inputs=prompt,
198+
params=params,
199+
)
200+
201+
output.save(args.output_path)
202+
print(f"Saved: {args.output_path}")
203+
print(output.metrics)
204+
205+
206+
if __name__ == "__main__":
207+
main()

0 commit comments

Comments
 (0)