Skip to content

Commit c3f6d98

Browse files
authored
[TRTLLM-13028][doc] Add VisualGen API walkthrough example and docs page (#14685)
Signed-off-by: Zhenhua Wang <zhenhuaw@nvidia.com>
1 parent 4ba59c0 commit c3f6d98

8 files changed

Lines changed: 129 additions & 2 deletions

File tree

docs/source/helper.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,18 @@ def extract_meta_info(filename: str) -> Optional[DocMeta]:
6464
def generate_examples():
6565
root_dir = Path(__file__).parent.parent.parent.resolve()
6666
ignore_list = {
67-
'__init__.py', 'quickstart_example.py', 'quickstart_advanced.py',
68-
'quickstart_multimodal.py', 'star_attention.py'
67+
'__init__.py',
68+
'quickstart_example.py',
69+
'quickstart_advanced.py',
70+
'quickstart_multimodal.py',
71+
'star_attention.py',
72+
# Older VisualGen example scripts without ### :title metadata; opt
73+
# in by adding the metadata block and removing the entry below.
74+
'visual_gen_flux.py',
75+
'visual_gen_ltx2.py',
76+
'visual_gen_wan_i2v.py',
77+
'visual_gen_wan_t2v.py',
78+
'visual_gen_mgmn_distributed.sh'
6979
}
7080
doc_dir = root_dir / "docs/source/examples"
7181

@@ -95,6 +105,13 @@ def collect_script_paths(examples_subdir: str) -> list[Path]:
95105
]
96106
serve_script_base_url = f"https://github.com/NVIDIA/TensorRT-LLM/blob/{commit_hash}/examples/serve"
97107

108+
# Collect source paths for VisualGen examples
109+
visual_gen_script_paths = collect_script_paths("visual_gen")
110+
visual_gen_doc_paths = [
111+
doc_dir / f"{path.stem}.rst" for path in visual_gen_script_paths
112+
]
113+
visual_gen_script_base_url = f"https://github.com/NVIDIA/TensorRT-LLM/blob/{commit_hash}/examples/visual_gen"
114+
98115
def _get_lines_without_metadata(filename: str) -> str:
99116
"""Get line ranges that exclude metadata lines.
100117
Returns a string like "5-10,15-20" for use in :lines: directive.
@@ -267,6 +284,18 @@ def write_index(metas: list[DocMeta], doc_template_path: Path,
267284
example_name="Online Serving Examples",
268285
section_order=[])
269286

287+
# Generate the toctree for VisualGen example scripts. No section_order
288+
# while the example set is small; add one alongside ### :section
289+
# metadata on the scripts once we have enough examples to group.
290+
visual_gen_metas = write_scripts(visual_gen_script_base_url,
291+
visual_gen_script_paths,
292+
visual_gen_doc_paths)
293+
write_index(metas=visual_gen_metas,
294+
doc_template_path=doc_dir / "llm_examples_index.template.rst_",
295+
doc_path=doc_dir / "visual_gen_examples.rst",
296+
example_name="VisualGen Examples",
297+
section_order=[])
298+
270299

271300
def extract_all_and_eval(file_path):
272301
''' Extract the __all__ variable from a Python file.

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ Welcome to TensorRT LLM's Documentation!
2323
:name: Deployment Guide
2424

2525
examples/llm_api_examples.rst
26+
examples/visual_gen_examples.rst
2627
examples/trtllm_serve_examples
2728
examples/dynamo_k8s_example.rst
2829
deployment-guide/index.rst
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
### :title API walkthrough
2+
### :order 0
3+
from tensorrt_llm import VisualGen, VisualGenArgs
4+
from tensorrt_llm.visual_gen.args import CompilationConfig
5+
6+
7+
def main():
8+
# 1. List supported models registered with the pipeline registry.
9+
print("\n=== Supported models ===")
10+
for hf_id in VisualGen.supported_models():
11+
print(f" - {hf_id}")
12+
13+
# 2. Inspect default pipeline_config knobs for the chosen model. These
14+
# are per-architecture runtime knobs (e.g. Lightricks/LTX-2's
15+
# ``text_encoder_path``); Wan-AI/Wan2.1-T2V-1.3B-Diffusers registers
16+
# none, so the dict is empty.
17+
pipeline_defaults = VisualGen.pipeline_config("Wan-AI/Wan2.1-T2V-1.3B-Diffusers")
18+
print("\n=== Pipeline config defaults for Wan-AI/Wan2.1-T2V-1.3B-Diffusers ===")
19+
print(f" {pipeline_defaults or '(none)'}")
20+
21+
# 3. Build VisualGenArgs. ``pipeline_config`` carries the per-architecture
22+
# knobs from step 2 (here we just forward the registered defaults;
23+
# real callers would override entries like ``text_encoder_path``).
24+
# ``compilation_config.skip_warmup`` skips the post-load warmup pass.
25+
visual_gen = VisualGen(
26+
model="Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
27+
args=VisualGenArgs(
28+
pipeline_config=pipeline_defaults,
29+
compilation_config=CompilationConfig(skip_warmup=True),
30+
),
31+
)
32+
33+
# 4. Discover model-specific ``extra_params`` accepted by the loaded
34+
# pipeline. Wan-AI/Wan2.1-T2V-1.3B-Diffusers declares none;
35+
# Wan-AI/Wan2.2-T2V-A14B-Diffusers surfaces ``guidance_scale_2`` and
36+
# ``boundary_ratio`` here.
37+
specs = visual_gen.extra_param_specs
38+
print("\n=== Extra param specs (extra_params keys) ===")
39+
for name, spec in specs.items():
40+
print(f" - {name}: {spec}")
41+
if not specs:
42+
print(" (none for this model)")
43+
44+
# 5. Take the pipeline's resolved defaults (height/width/steps/etc.)
45+
# and override fields. ``default_params`` already pre-populates
46+
# ``params.extra_params`` with each declared spec's default, so the
47+
# override below shows how a caller would set a model-specific knob
48+
# -- no-op on Wan-AI/Wan2.1-T2V-1.3B-Diffusers, but the wiring is
49+
# the same on Wan-AI/Wan2.2-T2V-A14B-Diffusers where
50+
# ``extra_params["guidance_scale_2"]`` is honored.
51+
params = visual_gen.default_params
52+
# Wan requires num_frames of the form 4k+1; 1.25x the model default (81)
53+
# is 101.25, so we round to the nearest valid value, 101 (= 4*25 + 1).
54+
params.num_frames = 101
55+
for name, spec in specs.items():
56+
params.extra_params[name] = spec.default
57+
58+
print("\n=== Request params ===")
59+
print(params.model_dump_json(indent=2))
60+
61+
output = visual_gen.generate(inputs="A cute cat playing piano in a sunny room", params=params)
62+
63+
# 6. Persist to disk. ``save`` infers the container from the file
64+
# extension (.avi/.mp4) and uses the frame_rate carried on the
65+
# output.
66+
saved = output.save("api_walkthrough_output.avi")
67+
print(f"\nSaved: {saved}")
68+
69+
70+
if __name__ == "__main__":
71+
main()

tests/integration/defs/examples/test_visual_gen.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1188,6 +1188,28 @@ def test_visual_gen_quickstart(_visual_gen_deps, llm_root, llm_venv):
11881188
assert os.path.isfile(output_path), f"Quickstart did not produce output.avi at {output_path}"
11891189

11901190

1191+
def test_visual_gen_api_walkthrough(_visual_gen_deps, llm_root, llm_venv):
1192+
"""Run examples/visual_gen/api_walkthrough.py end-to-end."""
1193+
scratch_space = conftest.llm_models_root()
1194+
model_src = os.path.join(scratch_space, WAN_T2V_MODEL_SUBPATH)
1195+
if not os.path.isdir(model_src):
1196+
pytest.skip(
1197+
f"Model not found: {model_src} "
1198+
f"(set LLM_MODELS_ROOT or place {WAN_T2V_MODEL_SUBPATH} under scratch)"
1199+
)
1200+
1201+
model_dst = os.path.join(llm_venv.get_working_directory(), "Wan-AI", WAN_T2V_MODEL_SUBPATH)
1202+
if not os.path.islink(model_dst):
1203+
os.makedirs(os.path.dirname(model_dst), exist_ok=True)
1204+
os.symlink(model_src, model_dst, target_is_directory=True)
1205+
1206+
script_path = os.path.join(llm_root, "examples", "visual_gen", "api_walkthrough.py")
1207+
venv_check_call(llm_venv, [script_path])
1208+
1209+
output_path = os.path.join(llm_venv.get_working_directory(), "api_walkthrough_output.avi")
1210+
assert os.path.isfile(output_path), f"API walkthrough did not produce {output_path}"
1211+
1212+
11911213
# =============================================================================
11921214
# Core example tests — run per-model scripts from examples/visual_gen/models/
11931215
# with shared YAML configs from examples/visual_gen/configs/.

tests/integration/test_lists/test-db/l0_dgx_b200.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,7 @@ l0_dgx_b200:
304304
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTEDSL-mtp_nextn=2-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-low_precision_combine=True-torch_compile=False]
305305
- accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2[torch_compile=False-enable_gemm_allreduce_fusion=False]
306306
- examples/test_visual_gen.py::test_visual_gen_quickstart
307+
- examples/test_visual_gen.py::test_visual_gen_api_walkthrough
307308
- examples/test_visual_gen.py::test_wan_t2v_example
308309
- examples/test_visual_gen.py::test_flux1_lpips_against_golden
309310
- examples/test_visual_gen.py::test_flux2_lpips_against_golden

tests/integration/test_lists/test-db/l0_gh200.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ l0_gh200:
2424
- unittest/llmapi/test_llm_quant.py
2525
- llmapi/test_llm_examples.py::test_llmapi_quickstart_atexit
2626
- examples/test_visual_gen.py::test_visual_gen_quickstart
27+
- examples/test_visual_gen.py::test_visual_gen_api_walkthrough
2728
- unittest/test_model_runner_cpp.py
2829
- accuracy/test_cli_flow.py::TestGptNext::test_auto_dtype
2930
- examples/test_medusa.py::test_llm_medusa_with_qaunt_base_model_1gpu[fp8-use_py_session-medusa-vicuna-7b-v1.3-4-heads-float16-bs1] TIMEOUT (90)

tests/integration/test_lists/test-db/l0_h100.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ l0_h100:
272272
- test_e2e.py::test_mistral_large_hidden_vocab_size
273273
- llmapi/test_llm_examples.py::test_llmapi_quickstart_atexit
274274
- examples/test_visual_gen.py::test_visual_gen_quickstart
275+
- examples/test_visual_gen.py::test_visual_gen_api_walkthrough
275276
- unittest/trt/attention/test_gpt_attention_IFB.py
276277
- accuracy/test_cli_flow.py::TestLlama3_1_8BInstruct::test_fp8_prequantized
277278
- accuracy/test_cli_flow.py::TestLlama2_7B::test_fp8

tests/integration/test_lists/test-db/l0_l40s.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ l0_l40s:
6464
- examples/test_nemotron_nas.py::test_nemotron_nas_summary_1gpu[DeciLM-7B]
6565
- llmapi/test_llm_examples.py::test_llmapi_quickstart
6666
- examples/test_visual_gen.py::test_visual_gen_quickstart
67+
- examples/test_visual_gen.py::test_visual_gen_api_walkthrough
6768
- llmapi/test_llm_examples.py::test_llmapi_example_inference
6869
- llmapi/test_llm_examples.py::test_llmapi_example_inference_async
6970
- llmapi/test_llm_examples.py::test_llmapi_example_inference_async_streaming

0 commit comments

Comments
 (0)