|
| 1 | +### :title API walkthrough |
| 2 | +### :order 0 |
| 3 | +from tensorrt_llm import VisualGen, VisualGenArgs |
| 4 | +from tensorrt_llm.visual_gen.args import CompilationConfig |
| 5 | + |
| 6 | + |
| 7 | +def main(): |
| 8 | + # 1. List supported models registered with the pipeline registry. |
| 9 | + print("\n=== Supported models ===") |
| 10 | + for hf_id in VisualGen.supported_models(): |
| 11 | + print(f" - {hf_id}") |
| 12 | + |
| 13 | + # 2. Inspect default pipeline_config knobs for the chosen model. These |
| 14 | + # are per-architecture runtime knobs (e.g. Lightricks/LTX-2's |
| 15 | + # ``text_encoder_path``); Wan-AI/Wan2.1-T2V-1.3B-Diffusers registers |
| 16 | + # none, so the dict is empty. |
| 17 | + pipeline_defaults = VisualGen.pipeline_config("Wan-AI/Wan2.1-T2V-1.3B-Diffusers") |
| 18 | + print("\n=== Pipeline config defaults for Wan-AI/Wan2.1-T2V-1.3B-Diffusers ===") |
| 19 | + print(f" {pipeline_defaults or '(none)'}") |
| 20 | + |
| 21 | + # 3. Build VisualGenArgs. ``pipeline_config`` carries the per-architecture |
| 22 | + # knobs from step 2 (here we just forward the registered defaults; |
| 23 | + # real callers would override entries like ``text_encoder_path``). |
| 24 | + # ``compilation_config.skip_warmup`` skips the post-load warmup pass. |
| 25 | + visual_gen = VisualGen( |
| 26 | + model="Wan-AI/Wan2.1-T2V-1.3B-Diffusers", |
| 27 | + args=VisualGenArgs( |
| 28 | + pipeline_config=pipeline_defaults, |
| 29 | + compilation_config=CompilationConfig(skip_warmup=True), |
| 30 | + ), |
| 31 | + ) |
| 32 | + |
| 33 | + # 4. Discover model-specific ``extra_params`` accepted by the loaded |
| 34 | + # pipeline. Wan-AI/Wan2.1-T2V-1.3B-Diffusers declares none; |
| 35 | + # Wan-AI/Wan2.2-T2V-A14B-Diffusers surfaces ``guidance_scale_2`` and |
| 36 | + # ``boundary_ratio`` here. |
| 37 | + specs = visual_gen.extra_param_specs |
| 38 | + print("\n=== Extra param specs (extra_params keys) ===") |
| 39 | + for name, spec in specs.items(): |
| 40 | + print(f" - {name}: {spec}") |
| 41 | + if not specs: |
| 42 | + print(" (none for this model)") |
| 43 | + |
| 44 | + # 5. Take the pipeline's resolved defaults (height/width/steps/etc.) |
| 45 | + # and override fields. ``default_params`` already pre-populates |
| 46 | + # ``params.extra_params`` with each declared spec's default, so the |
| 47 | + # override below shows how a caller would set a model-specific knob |
| 48 | + # -- no-op on Wan-AI/Wan2.1-T2V-1.3B-Diffusers, but the wiring is |
| 49 | + # the same on Wan-AI/Wan2.2-T2V-A14B-Diffusers where |
| 50 | + # ``extra_params["guidance_scale_2"]`` is honored. |
| 51 | + params = visual_gen.default_params |
| 52 | + # Wan requires num_frames of the form 4k+1; 1.25x the model default (81) |
| 53 | + # is 101.25, so we round to the nearest valid value, 101 (= 4*25 + 1). |
| 54 | + params.num_frames = 101 |
| 55 | + for name, spec in specs.items(): |
| 56 | + params.extra_params[name] = spec.default |
| 57 | + |
| 58 | + print("\n=== Request params ===") |
| 59 | + print(params.model_dump_json(indent=2)) |
| 60 | + |
| 61 | + output = visual_gen.generate(inputs="A cute cat playing piano in a sunny room", params=params) |
| 62 | + |
| 63 | + # 6. Persist to disk. ``save`` infers the container from the file |
| 64 | + # extension (.avi/.mp4) and uses the frame_rate carried on the |
| 65 | + # output. |
| 66 | + saved = output.save("api_walkthrough_output.avi") |
| 67 | + print(f"\nSaved: {saved}") |
| 68 | + |
| 69 | + |
| 70 | +if __name__ == "__main__": |
| 71 | + main() |
0 commit comments