Skip to content

Commit eb5d549

Browse files
jingyu-mlkevalmorabia97
authored andcommitted
Update the LTX2 API calls during the calibration (#926)
## What does this PR do? **Type of change:** Bug fix <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** Update LTX-2 integration to match latest upstream API 1. The LTX-2 codebase removed/replaced several APIs. This MR updates all affected files: 2. Replace cfg_guidance_scale with MultiModalGuiderParams: The pipeline __call__ no longer accepts a single cfg_guidance_scale float. It now requires two MultiModalGuiderParams objects (video_guider_params and audio_guider_params) that control CFG, STG, rescale, cross-modality guidance, and skip-step settings. Updated in ltx-2.py, ltx-2-fp8.py, ltx-2-onestage.py, calibration.py, and models_utils.py. 3. Replace fp8transformer with QuantizationPolicy: The TI2VidTwoStagesPipeline constructor no longer accepts the fp8transformer boolean flag. FP8 quantization is now configured via quantization=QuantizationPolicy.fp8_cast(). Updated in ltx-2-fp8.py and pipeline_manager.py (with backwards-compatible support for the old --extra-param fp8transformer=true CLI flag). 4. Remove DEFAULT_CFG_GUIDANCE_SCALE constant: Replaced by DEFAULT_VIDEO_GUIDER_PARAMS and DEFAULT_AUDIO_GUIDER_PARAMS in all import sites. ## Usage <!-- You can potentially add a usage example below. --> ```bash python quantize.py --model ltx-2 --format fp4 --batch-size 1 --calib-size 1 --n-steps 40 --extra-param checkpoint_path=./ltx-2-19b-dev-fp8.safetensors --extra-param distilled_lora_path=./ltx-2-19b-distilled-lora-384.safetensors --extra-param spatial_upsampler_path=./ltx-2-spatial-upscaler-x2-1.0.safetensors --extra-param gemma_root=./gemma-3-12b-it-qat-q4_0-unquantized --extra-param fp8transformer=true --hf-ckpt-dir ./ltx2-nvfp4 ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **Updates** * Default resolution for LTX2 models adjusted to 768x1280 * Guidance parameter configuration updated for video and audio pipelines * FP8 quantization parameter handling refined <!-- end of auto-generated comment: release notes by coderabbit.ai --> Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
1 parent a30b2e8 commit eb5d549

3 files changed

Lines changed: 13 additions & 8 deletions

File tree

examples/diffusers/quantization/calibration.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,10 @@ def _run_wan_video_calibration(
121121

122122
def _run_ltx2_calibration(self, prompt_batch: list[str], extra_args: dict[str, Any]) -> None:
123123
from ltx_core.model.video_vae import TilingConfig
124+
from ltx_pipelines.utils.constants import (
125+
DEFAULT_AUDIO_GUIDER_PARAMS,
126+
DEFAULT_VIDEO_GUIDER_PARAMS,
127+
)
124128

125129
prompt = prompt_batch[0]
126130
extra_params = self.pipeline_manager.config.extra_params
@@ -134,9 +138,8 @@ def _run_ltx2_calibration(self, prompt_batch: list[str], extra_args: dict[str, A
134138
"num_frames": extra_params.get("num_frames", extra_args.get("num_frames", 121)),
135139
"frame_rate": extra_params.get("frame_rate", extra_args.get("frame_rate", 24.0)),
136140
"num_inference_steps": self.config.n_steps,
137-
"cfg_guidance_scale": extra_params.get(
138-
"cfg_guidance_scale", extra_args.get("cfg_guidance_scale", 4.0)
139-
),
141+
"video_guider_params": DEFAULT_VIDEO_GUIDER_PARAMS,
142+
"audio_guider_params": DEFAULT_AUDIO_GUIDER_PARAMS,
140143
"images": extra_params.get("images", []),
141144
"tiling_config": extra_params.get("tiling_config", TilingConfig.default()),
142145
}

examples/diffusers/quantization/models_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,11 +163,10 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
163163
"backbone": "transformer",
164164
"dataset": _SD_PROMPTS_DATASET,
165165
"inference_extra_args": {
166-
"height": 1024,
167-
"width": 1536,
166+
"height": 768,
167+
"width": 1280,
168168
"num_frames": 121,
169169
"frame_rate": 24.0,
170-
"cfg_guidance_scale": 4.0,
171170
"negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted",
172171
},
173172
},

examples/diffusers/quantization/pipeline_manager.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,9 @@ def _create_ltx2_pipeline(self) -> Any:
213213
distilled_lora_strength = params.pop("distilled_lora_strength", 0.8)
214214
spatial_upsampler_path = params.pop("spatial_upsampler_path", None)
215215
gemma_root = params.pop("gemma_root", None)
216-
fp8transformer = params.pop("fp8transformer", False)
216+
fp8_quantization = params.pop("fp8_quantization", None) or params.pop(
217+
"fp8transformer", False
218+
)
217219

218220
if not checkpoint_path:
219221
raise ValueError("Missing required extra_param: checkpoint_path.")
@@ -225,6 +227,7 @@ def _create_ltx2_pipeline(self) -> Any:
225227
raise ValueError("Missing required extra_param: gemma_root.")
226228

227229
from ltx_core.loader import LTXV_LORA_COMFY_RENAMING_MAP, LoraPathStrengthAndSDOps
230+
from ltx_core.quantization import QuantizationPolicy
228231
from ltx_pipelines.ti2vid_two_stages import TI2VidTwoStagesPipeline
229232

230233
distilled_lora = [
@@ -240,7 +243,7 @@ def _create_ltx2_pipeline(self) -> Any:
240243
"spatial_upsampler_path": str(spatial_upsampler_path),
241244
"gemma_root": str(gemma_root),
242245
"loras": [],
243-
"fp8transformer": bool(fp8transformer),
246+
"quantization": QuantizationPolicy.fp8_cast() if fp8_quantization else None,
244247
}
245248
pipeline_kwargs.update(params)
246249
return TI2VidTwoStagesPipeline(**pipeline_kwargs)

0 commit comments

Comments
 (0)