Skip to content

Commit 2e43c80

Browse files
authored
[2/4] Diffusion Quantized ckpt export (#810)
## What does this PR do? **Type of change:** New feature <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** This MR adds HuggingFace checkpoint export support for LTX‑2 by treating TI2VidTwoStagesPipeline as a diffusion-like pipeline, exporting only the stage‑1 transformer (with QKV-fusion-enabled dummy inputs) and falling back to writing model.safetensors when save_pretrained isn’t available. It also preserves the original forward in DynamicModule patching (_forward_pre_dm) so downstream callers can still invoke the pre-patched forward implementation. **Changes** 1. Added the calibration & quantization support of the LTX2, even with FP8 precision. 2. Preserve original forward before `DynamicModule` patching: when patching forward, we now stash the pre-patched implementation in `self._forward_pre_dm` (once) so downstream code can still call the original forward, then re-bind forward to the class implementation. This is needed for the LTX2 FP8 calibration. 3. Added LTX‑2 HF export path: `export_hf_checkpoint()` now also treats ltx_pipelines.ti2vid_two_stages.TI2VidTwoStagesPipeline as a “diffusion-like” object and routes it through _export_diffusers_checkpoint() (import guarded; no hard dependency). 4. Generalized component discovery: introduced get_diffusion_components() (aliasing the old get_diffusers_components) to support non-diffusers pipelines; for LTX‑2 it returns only stage_1_transformer. 5. Enabled QKV fusion for LTX‑2 backbone: added a model-aware dummy forward generator (generate_diffusion_dummy_forward_fn) that builds minimal LTX Modality inputs (including correct timesteps broadcasting) so shared-input hooks can run and fuse QKV when applicable. 6. Export fallback for non-save_pretrained modules: when a component lacks save_pretrained (LTX‑2 transformer), export now writes model.safetensors + minimal config.json instead of pytorch_model.bin. Plans - [x] [1/4] Add the basic functionalities to support limited image models with NVFP4 + FP8, with some refactoring on the previous LLM code and the diffusers example. PIC: @jingyu-ml - [x] [2/4] Add support to more video gen models. PIC: @jingyu-ml - [ ] [3/4] Add test cases, refactor on the doc, and all related README. PIC: @jingyu-ml - [ ] [4/4] Add the final support to ComfyUI. PIC @jingyu-ml ## Usage <!-- You can potentially add a usage example below. --> ```bash python quantize.py --model ltx-2 --format fp4 --batch-size 64 --calib-size 1 --n-steps 40 --extra-param checkpoint_path=/home/scratch.omniml_data_2/jingyux/models/LTX-2/ltx-2-19b-dev-fp8.safetensors --extra-param distilled_lora_path=/home/scratch.omniml_data_2/jingyux/models/LTX-2/ltx-2-19b-distilled-lora-384.safetensors --extra-param spatial_upsampler_path=/home/scratch.omniml_data_2/jingyux/models/LTX-2/ltx-2-spatial-upscaler-x2-1.0.safetensors --extra-param gemma_root=/home/scratch.omniml_data_2/jingyux/models/LTX-2/gemma-3-12b-it-qat-q4_0-unquantized --extra-param fp8transformer=true --hf-ckpt-dir ./ltx2-nvfp4 ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes <!--- If No, explain why. --> - **Did you write any new necessary tests?**:No - **Did you add or update any necessary documentation?**:No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**:No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * Added LTX-2 video model support with complete quantization and export pipeline integration * Introduced `--extra-param` CLI option for flexible model configuration and parameter passing * Enhanced export capabilities with broader diffusion model compatibility * **Chores** * Changed default model data type from Half to BFloat16 for improved numerical stability <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
1 parent 87237e7 commit 2e43c80

13 files changed

Lines changed: 1008 additions & 503 deletions

File tree

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import logging
17+
from pathlib import Path
18+
from typing import Any
19+
20+
from models_utils import MODEL_DEFAULTS, ModelType
21+
from pipeline_manager import PipelineManager
22+
from quantize_config import CalibrationConfig
23+
from tqdm import tqdm
24+
from utils import load_calib_prompts
25+
26+
27+
class Calibrator:
28+
"""Handles model calibration for quantization."""
29+
30+
def __init__(
31+
self,
32+
pipeline_manager: PipelineManager,
33+
config: CalibrationConfig,
34+
model_type: ModelType,
35+
logger: logging.Logger,
36+
):
37+
"""
38+
Initialize calibrator.
39+
40+
Args:
41+
pipeline_manager: Pipeline manager with main and upsampler pipelines
42+
config: Calibration configuration
43+
model_type: Type of model being calibrated
44+
logger: Logger instance
45+
"""
46+
self.pipeline_manager = pipeline_manager
47+
self.pipe = pipeline_manager.pipe
48+
self.pipe_upsample = pipeline_manager.pipe_upsample
49+
self.config = config
50+
self.model_type = model_type
51+
self.logger = logger
52+
53+
def load_and_batch_prompts(self) -> list[list[str]]:
54+
"""
55+
Load calibration prompts from file.
56+
57+
Returns:
58+
List of batched calibration prompts
59+
"""
60+
self.logger.info(f"Loading calibration prompts from {self.config.prompts_dataset}")
61+
if isinstance(self.config.prompts_dataset, Path):
62+
return load_calib_prompts(
63+
self.config.batch_size,
64+
self.config.prompts_dataset,
65+
)
66+
67+
return load_calib_prompts(
68+
self.config.batch_size,
69+
self.config.prompts_dataset["name"],
70+
self.config.prompts_dataset["split"],
71+
self.config.prompts_dataset["column"],
72+
)
73+
74+
def run_calibration(self, batched_prompts: list[list[str]]) -> None:
75+
"""
76+
Run calibration steps on the pipeline.
77+
78+
Args:
79+
batched_prompts: List of batched calibration prompts
80+
"""
81+
self.logger.info(f"Starting calibration with {self.config.num_batches} batches")
82+
extra_args = MODEL_DEFAULTS.get(self.model_type, {}).get("inference_extra_args", {})
83+
84+
with tqdm(total=self.config.num_batches, desc="Calibration", unit="batch") as pbar:
85+
for i, prompt_batch in enumerate(batched_prompts):
86+
if i >= self.config.num_batches:
87+
break
88+
89+
if self.model_type == ModelType.LTX2:
90+
self._run_ltx2_calibration(prompt_batch, extra_args)
91+
elif self.model_type == ModelType.LTX_VIDEO_DEV:
92+
# Special handling for LTX-Video
93+
self._run_ltx_video_calibration(prompt_batch, extra_args)
94+
elif self.model_type in [ModelType.WAN22_T2V_14b, ModelType.WAN22_T2V_5b]:
95+
# Special handling for WAN video models
96+
self._run_wan_video_calibration(prompt_batch, extra_args)
97+
else:
98+
common_args = {
99+
"prompt": prompt_batch,
100+
"num_inference_steps": self.config.n_steps,
101+
}
102+
self.pipe(**common_args, **extra_args).images
103+
pbar.update(1)
104+
self.logger.debug(f"Completed calibration batch {i + 1}/{self.config.num_batches}")
105+
self.logger.info("Calibration completed successfully")
106+
107+
def _run_wan_video_calibration(
108+
self, prompt_batch: list[str], extra_args: dict[str, Any]
109+
) -> None:
110+
kwargs = {}
111+
kwargs["negative_prompt"] = extra_args["negative_prompt"]
112+
kwargs["height"] = extra_args["height"]
113+
kwargs["width"] = extra_args["width"]
114+
kwargs["num_frames"] = extra_args["num_frames"]
115+
kwargs["guidance_scale"] = extra_args["guidance_scale"]
116+
if "guidance_scale_2" in extra_args:
117+
kwargs["guidance_scale_2"] = extra_args["guidance_scale_2"]
118+
kwargs["num_inference_steps"] = self.config.n_steps
119+
120+
self.pipe(prompt=prompt_batch, **kwargs).frames
121+
122+
def _run_ltx2_calibration(self, prompt_batch: list[str], extra_args: dict[str, Any]) -> None:
123+
from ltx_core.model.video_vae import TilingConfig
124+
125+
prompt = prompt_batch[0]
126+
extra_params = self.pipeline_manager.config.extra_params
127+
kwargs = {
128+
"negative_prompt": extra_args.get(
129+
"negative_prompt", "worst quality, inconsistent motion, blurry, jittery, distorted"
130+
),
131+
"seed": extra_params.get("seed", 0),
132+
"height": extra_params.get("height", extra_args.get("height", 1024)),
133+
"width": extra_params.get("width", extra_args.get("width", 1536)),
134+
"num_frames": extra_params.get("num_frames", extra_args.get("num_frames", 121)),
135+
"frame_rate": extra_params.get("frame_rate", extra_args.get("frame_rate", 24.0)),
136+
"num_inference_steps": self.config.n_steps,
137+
"cfg_guidance_scale": extra_params.get(
138+
"cfg_guidance_scale", extra_args.get("cfg_guidance_scale", 4.0)
139+
),
140+
"images": extra_params.get("images", []),
141+
"tiling_config": extra_params.get("tiling_config", TilingConfig.default()),
142+
}
143+
self.pipe(prompt=prompt, **kwargs)
144+
145+
def _run_ltx_video_calibration(
146+
self, prompt_batch: list[str], extra_args: dict[str, Any]
147+
) -> None:
148+
"""
149+
Run calibration for LTX-Video model using the full multi-stage pipeline.
150+
151+
Args:
152+
prompt_batch: Batch of prompts
153+
extra_args: Model-specific arguments
154+
"""
155+
# Extract specific args for LTX-Video
156+
expected_height = extra_args.get("height", 512)
157+
expected_width = extra_args.get("width", 704)
158+
num_frames = extra_args.get("num_frames", 121)
159+
negative_prompt = extra_args.get(
160+
"negative_prompt", "worst quality, inconsistent motion, blurry, jittery, distorted"
161+
)
162+
163+
def round_to_nearest_resolution_acceptable_by_vae(height, width):
164+
height = height - (height % self.pipe.vae_spatial_compression_ratio)
165+
width = width - (width % self.pipe.vae_spatial_compression_ratio)
166+
return height, width
167+
168+
downscale_factor = 2 / 3
169+
# Part 1: Generate video at smaller resolution
170+
downscaled_height, downscaled_width = (
171+
int(expected_height * downscale_factor),
172+
int(expected_width * downscale_factor),
173+
)
174+
downscaled_height, downscaled_width = round_to_nearest_resolution_acceptable_by_vae(
175+
downscaled_height, downscaled_width
176+
)
177+
178+
# Generate initial latents at lower resolution
179+
latents = self.pipe(
180+
conditions=None,
181+
prompt=prompt_batch,
182+
negative_prompt=negative_prompt,
183+
width=downscaled_width,
184+
height=downscaled_height,
185+
num_frames=num_frames,
186+
num_inference_steps=self.config.n_steps,
187+
output_type="latent",
188+
).frames
189+
190+
# Part 2: Upscale generated video using latent upsampler (if available)
191+
if self.pipe_upsample is not None:
192+
_ = self.pipe_upsample(latents=latents, output_type="latent").frames
193+
194+
# Part 3: Denoise the upscaled video with few steps to improve texture
195+
# However, in this example code, we will omit the upscale step since its optional.

examples/diffusers/quantization/models_utils.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import logging
1617
from collections.abc import Callable
1718
from enum import Enum
1819
from typing import Any
@@ -42,6 +43,7 @@ class ModelType(str, Enum):
4243
FLUX_DEV = "flux-dev"
4344
FLUX_SCHNELL = "flux-schnell"
4445
LTX_VIDEO_DEV = "ltx-video-dev"
46+
LTX2 = "ltx-2"
4547
WAN22_T2V_14b = "wan2.2-t2v-14b"
4648
WAN22_T2V_5b = "wan2.2-t2v-5b"
4749

@@ -64,6 +66,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
6466
ModelType.SD3_MEDIUM: filter_func_default,
6567
ModelType.SD35_MEDIUM: filter_func_default,
6668
ModelType.LTX_VIDEO_DEV: filter_func_ltx_video,
69+
ModelType.LTX2: filter_func_ltx_video,
6770
ModelType.WAN22_T2V_14b: filter_func_wan_video,
6871
ModelType.WAN22_T2V_5b: filter_func_wan_video,
6972
}
@@ -80,18 +83,20 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
8083
ModelType.FLUX_DEV: "black-forest-labs/FLUX.1-dev",
8184
ModelType.FLUX_SCHNELL: "black-forest-labs/FLUX.1-schnell",
8285
ModelType.LTX_VIDEO_DEV: "Lightricks/LTX-Video-0.9.7-dev",
86+
ModelType.LTX2: "Lightricks/LTX-2",
8387
ModelType.WAN22_T2V_14b: "Wan-AI/Wan2.2-T2V-A14B-Diffusers",
8488
ModelType.WAN22_T2V_5b: "Wan-AI/Wan2.2-TI2V-5B-Diffusers",
8589
}
8690

87-
MODEL_PIPELINE: dict[ModelType, type[DiffusionPipeline]] = {
91+
MODEL_PIPELINE: dict[ModelType, type[DiffusionPipeline] | None] = {
8892
ModelType.SDXL_BASE: DiffusionPipeline,
8993
ModelType.SDXL_TURBO: DiffusionPipeline,
9094
ModelType.SD3_MEDIUM: StableDiffusion3Pipeline,
9195
ModelType.SD35_MEDIUM: StableDiffusion3Pipeline,
9296
ModelType.FLUX_DEV: FluxPipeline,
9397
ModelType.FLUX_SCHNELL: FluxPipeline,
9498
ModelType.LTX_VIDEO_DEV: LTXConditionPipeline,
99+
ModelType.LTX2: None,
95100
ModelType.WAN22_T2V_14b: WanPipeline,
96101
ModelType.WAN22_T2V_5b: WanPipeline,
97102
}
@@ -154,6 +159,18 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
154159
"negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted",
155160
},
156161
},
162+
ModelType.LTX2: {
163+
"backbone": "transformer",
164+
"dataset": _SD_PROMPTS_DATASET,
165+
"inference_extra_args": {
166+
"height": 1024,
167+
"width": 1536,
168+
"num_frames": 121,
169+
"frame_rate": 24.0,
170+
"cfg_guidance_scale": 4.0,
171+
"negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted",
172+
},
173+
},
157174
ModelType.WAN22_T2V_14b: {
158175
**_WAN_BASE_CONFIG,
159176
"from_pretrained_extra_args": {
@@ -192,3 +209,48 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
192209
},
193210
},
194211
}
212+
213+
214+
def _coerce_extra_param_value(value: str) -> Any:
215+
lowered = value.lower()
216+
if lowered in {"true", "false"}:
217+
return lowered == "true"
218+
try:
219+
return int(value)
220+
except ValueError:
221+
pass
222+
try:
223+
return float(value)
224+
except ValueError:
225+
return value
226+
227+
228+
def parse_extra_params(
229+
kv_args: list[str], unknown_args: list[str], logger: logging.Logger
230+
) -> dict[str, Any]:
231+
extra_params: dict[str, Any] = {}
232+
for item in kv_args:
233+
if "=" not in item:
234+
raise ValueError(f"Invalid --extra-param value: '{item}'. Expected KEY=VALUE.")
235+
key, value = item.split("=", 1)
236+
extra_params[key] = _coerce_extra_param_value(value)
237+
238+
i = 0
239+
while i < len(unknown_args):
240+
token = unknown_args[i]
241+
if token.startswith("--extra_param."):
242+
key = token[len("--extra_param.") :]
243+
value = "true"
244+
if i + 1 < len(unknown_args) and not unknown_args[i + 1].startswith("--"):
245+
value = unknown_args[i + 1]
246+
i += 1
247+
extra_params[key] = _coerce_extra_param_value(value)
248+
elif token.startswith("--extra_param"):
249+
raise ValueError(
250+
"Use --extra_param.KEY VALUE or --extra-param KEY=VALUE for extra parameters."
251+
)
252+
else:
253+
logger.warning("Ignoring unknown argument: %s", token)
254+
i += 1
255+
256+
return extra_params

0 commit comments

Comments
 (0)