|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +import logging |
| 17 | +from pathlib import Path |
| 18 | +from typing import Any |
| 19 | + |
| 20 | +from models_utils import MODEL_DEFAULTS, ModelType |
| 21 | +from pipeline_manager import PipelineManager |
| 22 | +from quantize_config import CalibrationConfig |
| 23 | +from tqdm import tqdm |
| 24 | +from utils import load_calib_prompts |
| 25 | + |
| 26 | + |
| 27 | +class Calibrator: |
| 28 | + """Handles model calibration for quantization.""" |
| 29 | + |
| 30 | + def __init__( |
| 31 | + self, |
| 32 | + pipeline_manager: PipelineManager, |
| 33 | + config: CalibrationConfig, |
| 34 | + model_type: ModelType, |
| 35 | + logger: logging.Logger, |
| 36 | + ): |
| 37 | + """ |
| 38 | + Initialize calibrator. |
| 39 | +
|
| 40 | + Args: |
| 41 | + pipeline_manager: Pipeline manager with main and upsampler pipelines |
| 42 | + config: Calibration configuration |
| 43 | + model_type: Type of model being calibrated |
| 44 | + logger: Logger instance |
| 45 | + """ |
| 46 | + self.pipeline_manager = pipeline_manager |
| 47 | + self.pipe = pipeline_manager.pipe |
| 48 | + self.pipe_upsample = pipeline_manager.pipe_upsample |
| 49 | + self.config = config |
| 50 | + self.model_type = model_type |
| 51 | + self.logger = logger |
| 52 | + |
| 53 | + def load_and_batch_prompts(self) -> list[list[str]]: |
| 54 | + """ |
| 55 | + Load calibration prompts from file. |
| 56 | +
|
| 57 | + Returns: |
| 58 | + List of batched calibration prompts |
| 59 | + """ |
| 60 | + self.logger.info(f"Loading calibration prompts from {self.config.prompts_dataset}") |
| 61 | + if isinstance(self.config.prompts_dataset, Path): |
| 62 | + return load_calib_prompts( |
| 63 | + self.config.batch_size, |
| 64 | + self.config.prompts_dataset, |
| 65 | + ) |
| 66 | + |
| 67 | + return load_calib_prompts( |
| 68 | + self.config.batch_size, |
| 69 | + self.config.prompts_dataset["name"], |
| 70 | + self.config.prompts_dataset["split"], |
| 71 | + self.config.prompts_dataset["column"], |
| 72 | + ) |
| 73 | + |
| 74 | + def run_calibration(self, batched_prompts: list[list[str]]) -> None: |
| 75 | + """ |
| 76 | + Run calibration steps on the pipeline. |
| 77 | +
|
| 78 | + Args: |
| 79 | + batched_prompts: List of batched calibration prompts |
| 80 | + """ |
| 81 | + self.logger.info(f"Starting calibration with {self.config.num_batches} batches") |
| 82 | + extra_args = MODEL_DEFAULTS.get(self.model_type, {}).get("inference_extra_args", {}) |
| 83 | + |
| 84 | + with tqdm(total=self.config.num_batches, desc="Calibration", unit="batch") as pbar: |
| 85 | + for i, prompt_batch in enumerate(batched_prompts): |
| 86 | + if i >= self.config.num_batches: |
| 87 | + break |
| 88 | + |
| 89 | + if self.model_type == ModelType.LTX2: |
| 90 | + self._run_ltx2_calibration(prompt_batch, extra_args) |
| 91 | + elif self.model_type == ModelType.LTX_VIDEO_DEV: |
| 92 | + # Special handling for LTX-Video |
| 93 | + self._run_ltx_video_calibration(prompt_batch, extra_args) |
| 94 | + elif self.model_type in [ModelType.WAN22_T2V_14b, ModelType.WAN22_T2V_5b]: |
| 95 | + # Special handling for WAN video models |
| 96 | + self._run_wan_video_calibration(prompt_batch, extra_args) |
| 97 | + else: |
| 98 | + common_args = { |
| 99 | + "prompt": prompt_batch, |
| 100 | + "num_inference_steps": self.config.n_steps, |
| 101 | + } |
| 102 | + self.pipe(**common_args, **extra_args).images |
| 103 | + pbar.update(1) |
| 104 | + self.logger.debug(f"Completed calibration batch {i + 1}/{self.config.num_batches}") |
| 105 | + self.logger.info("Calibration completed successfully") |
| 106 | + |
| 107 | + def _run_wan_video_calibration( |
| 108 | + self, prompt_batch: list[str], extra_args: dict[str, Any] |
| 109 | + ) -> None: |
| 110 | + kwargs = {} |
| 111 | + kwargs["negative_prompt"] = extra_args["negative_prompt"] |
| 112 | + kwargs["height"] = extra_args["height"] |
| 113 | + kwargs["width"] = extra_args["width"] |
| 114 | + kwargs["num_frames"] = extra_args["num_frames"] |
| 115 | + kwargs["guidance_scale"] = extra_args["guidance_scale"] |
| 116 | + if "guidance_scale_2" in extra_args: |
| 117 | + kwargs["guidance_scale_2"] = extra_args["guidance_scale_2"] |
| 118 | + kwargs["num_inference_steps"] = self.config.n_steps |
| 119 | + |
| 120 | + self.pipe(prompt=prompt_batch, **kwargs).frames |
| 121 | + |
| 122 | + def _run_ltx2_calibration(self, prompt_batch: list[str], extra_args: dict[str, Any]) -> None: |
| 123 | + from ltx_core.model.video_vae import TilingConfig |
| 124 | + |
| 125 | + prompt = prompt_batch[0] |
| 126 | + extra_params = self.pipeline_manager.config.extra_params |
| 127 | + kwargs = { |
| 128 | + "negative_prompt": extra_args.get( |
| 129 | + "negative_prompt", "worst quality, inconsistent motion, blurry, jittery, distorted" |
| 130 | + ), |
| 131 | + "seed": extra_params.get("seed", 0), |
| 132 | + "height": extra_params.get("height", extra_args.get("height", 1024)), |
| 133 | + "width": extra_params.get("width", extra_args.get("width", 1536)), |
| 134 | + "num_frames": extra_params.get("num_frames", extra_args.get("num_frames", 121)), |
| 135 | + "frame_rate": extra_params.get("frame_rate", extra_args.get("frame_rate", 24.0)), |
| 136 | + "num_inference_steps": self.config.n_steps, |
| 137 | + "cfg_guidance_scale": extra_params.get( |
| 138 | + "cfg_guidance_scale", extra_args.get("cfg_guidance_scale", 4.0) |
| 139 | + ), |
| 140 | + "images": extra_params.get("images", []), |
| 141 | + "tiling_config": extra_params.get("tiling_config", TilingConfig.default()), |
| 142 | + } |
| 143 | + self.pipe(prompt=prompt, **kwargs) |
| 144 | + |
| 145 | + def _run_ltx_video_calibration( |
| 146 | + self, prompt_batch: list[str], extra_args: dict[str, Any] |
| 147 | + ) -> None: |
| 148 | + """ |
| 149 | + Run calibration for LTX-Video model using the full multi-stage pipeline. |
| 150 | +
|
| 151 | + Args: |
| 152 | + prompt_batch: Batch of prompts |
| 153 | + extra_args: Model-specific arguments |
| 154 | + """ |
| 155 | + # Extract specific args for LTX-Video |
| 156 | + expected_height = extra_args.get("height", 512) |
| 157 | + expected_width = extra_args.get("width", 704) |
| 158 | + num_frames = extra_args.get("num_frames", 121) |
| 159 | + negative_prompt = extra_args.get( |
| 160 | + "negative_prompt", "worst quality, inconsistent motion, blurry, jittery, distorted" |
| 161 | + ) |
| 162 | + |
| 163 | + def round_to_nearest_resolution_acceptable_by_vae(height, width): |
| 164 | + height = height - (height % self.pipe.vae_spatial_compression_ratio) |
| 165 | + width = width - (width % self.pipe.vae_spatial_compression_ratio) |
| 166 | + return height, width |
| 167 | + |
| 168 | + downscale_factor = 2 / 3 |
| 169 | + # Part 1: Generate video at smaller resolution |
| 170 | + downscaled_height, downscaled_width = ( |
| 171 | + int(expected_height * downscale_factor), |
| 172 | + int(expected_width * downscale_factor), |
| 173 | + ) |
| 174 | + downscaled_height, downscaled_width = round_to_nearest_resolution_acceptable_by_vae( |
| 175 | + downscaled_height, downscaled_width |
| 176 | + ) |
| 177 | + |
| 178 | + # Generate initial latents at lower resolution |
| 179 | + latents = self.pipe( |
| 180 | + conditions=None, |
| 181 | + prompt=prompt_batch, |
| 182 | + negative_prompt=negative_prompt, |
| 183 | + width=downscaled_width, |
| 184 | + height=downscaled_height, |
| 185 | + num_frames=num_frames, |
| 186 | + num_inference_steps=self.config.n_steps, |
| 187 | + output_type="latent", |
| 188 | + ).frames |
| 189 | + |
| 190 | + # Part 2: Upscale generated video using latent upsampler (if available) |
| 191 | + if self.pipe_upsample is not None: |
| 192 | + _ = self.pipe_upsample(latents=latents, output_type="latent").frames |
| 193 | + |
| 194 | + # Part 3: Denoise the upscaled video with few steps to improve texture |
| 195 | + # However, in this example code, we will omit the upscale step since its optional. |
0 commit comments