|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +"""LTX-2 inference with skip-softmax sparse attention. |
| 17 | +
|
| 18 | +This example applies skip-softmax sparse attention to the LTX-2 video |
| 19 | +generation model using exponential model calibration |
| 20 | +(``scale_factor = a * exp(b * target_sparsity)``). |
| 21 | +
|
| 22 | +During calibration, ``flash_skip_softmax`` with the eager attention backend |
| 23 | +collects sparsity statistics across multiple threshold trials. The fitted |
| 24 | +exponential model then allows runtime control of the target sparsity ratio |
| 25 | +without recalibration. |
| 26 | +
|
| 27 | +Only the stage-1 backbone is sparsified. Stage 2 (spatial upsampler + |
| 28 | +distilled LoRA) runs unmodified. |
| 29 | +
|
| 30 | +Usage:: |
| 31 | +
|
| 32 | + # With calibration (recommended) |
| 33 | + python ltx2_skip_softmax.py --prompt "A cat playing piano" --output out.mp4 \\ |
| 34 | + --calibrate --target-sparsity 0.25 |
| 35 | +
|
| 36 | + # Disable sparsity on first/last 2 layers (higher quality, less speedup) |
| 37 | + python ltx2_skip_softmax.py --prompt "A cat playing piano" --output out.mp4 \\ |
| 38 | + --calibrate --target-sparsity 0.25 --skip-first-last 2 |
| 39 | +""" |
| 40 | + |
| 41 | +import argparse |
| 42 | +import functools |
| 43 | +import os |
| 44 | + |
| 45 | +import torch |
| 46 | +from ltx_core.loader import LTXV_LORA_COMFY_RENAMING_MAP, LoraPathStrengthAndSDOps |
| 47 | +from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number |
| 48 | +from ltx_pipelines.ti2vid_two_stages import TI2VidTwoStagesPipeline |
| 49 | +from ltx_pipelines.utils.constants import ( |
| 50 | + AUDIO_SAMPLE_RATE, |
| 51 | + DEFAULT_2_STAGE_HEIGHT, |
| 52 | + DEFAULT_2_STAGE_WIDTH, |
| 53 | + DEFAULT_AUDIO_GUIDER_PARAMS, |
| 54 | + DEFAULT_FRAME_RATE, |
| 55 | + DEFAULT_NEGATIVE_PROMPT, |
| 56 | + DEFAULT_NUM_INFERENCE_STEPS, |
| 57 | + DEFAULT_SEED, |
| 58 | + DEFAULT_VIDEO_GUIDER_PARAMS, |
| 59 | +) |
| 60 | +from ltx_pipelines.utils.media_io import encode_video |
| 61 | + |
| 62 | +import modelopt.torch.sparsity.attention_sparsity as mtsa |
| 63 | +from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule |
| 64 | + |
| 65 | +# ---- Model paths (edit these or override via environment variables) ---- |
| 66 | +CHECKPOINT_PATH = os.environ.get( |
| 67 | + "LTX2_CHECKPOINT", |
| 68 | + "/home/scratch.omniml_data_2/jingyux/models/LTX-2/ltx-2-19b-dev.safetensors", |
| 69 | +) |
| 70 | +DISTILLED_LORA_PATH = os.environ.get( |
| 71 | + "LTX2_DISTILLED_LORA", |
| 72 | + "/home/scratch.omniml_data_2/jingyux/models/LTX-2/ltx-2-19b-distilled-lora-384.safetensors", |
| 73 | +) |
| 74 | +SPATIAL_UPSAMPLER_PATH = os.environ.get( |
| 75 | + "LTX2_SPATIAL_UPSAMPLER", |
| 76 | + "/home/scratch.omniml_data_2/jingyux/models/LTX-2/ltx-2-spatial-upscaler-x2-1.0.safetensors", |
| 77 | +) |
| 78 | +GEMMA_ROOT = os.environ.get( |
| 79 | + "LTX2_GEMMA_ROOT", |
| 80 | + "/home/scratch.omniml_data_2/jingyux/models/LTX-2/gemma-3-12b-it-qat-q4_0-unquantized", |
| 81 | +) |
| 82 | + |
| 83 | +DEFAULT_NUM_FRAMES = 121 |
| 84 | +NUM_TRANSFORMER_BLOCKS = 48 |
| 85 | + |
| 86 | +# Default threshold trials for calibration |
| 87 | +DEFAULT_THRESHOLD_TRIALS = [ |
| 88 | + 1e-6, |
| 89 | + 5e-6, |
| 90 | + 1e-5, |
| 91 | + 5e-5, |
| 92 | + 1e-4, |
| 93 | + 5e-4, |
| 94 | + 1e-3, |
| 95 | + 5e-3, |
| 96 | + 1e-2, |
| 97 | + 2e-2, |
| 98 | + 5e-2, |
| 99 | + 1e-1, |
| 100 | + 2e-1, |
| 101 | + 3e-1, |
| 102 | + 5e-1, |
| 103 | + 7e-1, |
| 104 | +] |
| 105 | + |
| 106 | + |
| 107 | +def parse_args() -> argparse.Namespace: |
| 108 | + parser = argparse.ArgumentParser( |
| 109 | + description="LTX-2 video generation with skip-softmax sparse attention" |
| 110 | + ) |
| 111 | + parser.add_argument("--prompt", type=str, default=None, help="Text prompt for generation") |
| 112 | + parser.add_argument( |
| 113 | + "--prompt-dir", |
| 114 | + type=str, |
| 115 | + default=None, |
| 116 | + help="Directory of .txt prompt files (one prompt per file). Overrides --prompt.", |
| 117 | + ) |
| 118 | + parser.add_argument("--output", type=str, default="output.mp4", help="Output video path") |
| 119 | + parser.add_argument( |
| 120 | + "--output-dir", |
| 121 | + type=str, |
| 122 | + default=None, |
| 123 | + help="Directory to save videos when using --prompt-dir", |
| 124 | + ) |
| 125 | + parser.add_argument( |
| 126 | + "--num-frames", type=int, default=DEFAULT_NUM_FRAMES, help="Number of frames" |
| 127 | + ) |
| 128 | + parser.add_argument("--height", type=int, default=DEFAULT_2_STAGE_HEIGHT, help="Video height") |
| 129 | + parser.add_argument("--width", type=int, default=DEFAULT_2_STAGE_WIDTH, help="Video width") |
| 130 | + parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="Random seed") |
| 131 | + |
| 132 | + # Sparse attention options |
| 133 | + parser.add_argument( |
| 134 | + "--skip-first-last", |
| 135 | + type=int, |
| 136 | + default=0, |
| 137 | + help="Number of first/last transformer layers to keep dense (default: 0)", |
| 138 | + ) |
| 139 | + |
| 140 | + # Calibration options |
| 141 | + parser.add_argument( |
| 142 | + "--calibrate", |
| 143 | + action="store_true", |
| 144 | + help="Calibrate threshold via exponential model (recommended)", |
| 145 | + ) |
| 146 | + parser.add_argument( |
| 147 | + "--target-sparsity", |
| 148 | + type=float, |
| 149 | + default=0.25, |
| 150 | + help="Target sparsity ratio for calibration (0.0-1.0)", |
| 151 | + ) |
| 152 | + parser.add_argument( |
| 153 | + "--calib-steps", |
| 154 | + type=int, |
| 155 | + default=10, |
| 156 | + help="Inference steps per calibration sample", |
| 157 | + ) |
| 158 | + parser.add_argument( |
| 159 | + "--calib-frames", |
| 160 | + type=int, |
| 161 | + default=81, |
| 162 | + help="Number of frames per calibration sample", |
| 163 | + ) |
| 164 | + parser.add_argument( |
| 165 | + "--calib-size", |
| 166 | + type=int, |
| 167 | + default=1, |
| 168 | + help="Number of prompts to use for calibration", |
| 169 | + ) |
| 170 | + return parser.parse_args() |
| 171 | + |
| 172 | + |
| 173 | +def _patch_vae_requires_grad(pipeline: TI2VidTwoStagesPipeline): |
| 174 | + """Ensure VAE decoder weights have requires_grad=False to avoid autograd issues.""" |
| 175 | + for ledger_attr in ("stage_1_model_ledger", "stage_2_model_ledger"): |
| 176 | + ledger = getattr(pipeline, ledger_attr, None) |
| 177 | + if ledger is None: |
| 178 | + continue |
| 179 | + for loader_name in ("video_decoder", "audio_decoder"): |
| 180 | + orig_loader = getattr(ledger, loader_name, None) |
| 181 | + if orig_loader is None: |
| 182 | + continue |
| 183 | + |
| 184 | + def _make_patched(fn): |
| 185 | + @functools.wraps(fn) |
| 186 | + def patched(): |
| 187 | + model = fn() |
| 188 | + model.requires_grad_(False) |
| 189 | + return model |
| 190 | + |
| 191 | + return patched |
| 192 | + |
| 193 | + setattr(ledger, loader_name, _make_patched(orig_loader)) |
| 194 | + |
| 195 | + |
| 196 | +def build_pipeline() -> TI2VidTwoStagesPipeline: |
| 197 | + """Build the LTX-2 two-stage video generation pipeline.""" |
| 198 | + pipeline = TI2VidTwoStagesPipeline( |
| 199 | + checkpoint_path=CHECKPOINT_PATH, |
| 200 | + distilled_lora=[ |
| 201 | + LoraPathStrengthAndSDOps(DISTILLED_LORA_PATH, 0.8, LTXV_LORA_COMFY_RENAMING_MAP) |
| 202 | + ], |
| 203 | + spatial_upsampler_path=SPATIAL_UPSAMPLER_PATH, |
| 204 | + gemma_root=GEMMA_ROOT, |
| 205 | + loras=[], |
| 206 | + ) |
| 207 | + _patch_vae_requires_grad(pipeline) |
| 208 | + return pipeline |
| 209 | + |
| 210 | + |
| 211 | +def build_sparse_config(args: argparse.Namespace) -> dict: |
| 212 | + """Build sparse attention config from CLI args. |
| 213 | +
|
| 214 | + Uses flash_skip_softmax which supports both calibration (eager attention |
| 215 | + with F.softmax patching) and inference. Calibration fits an exponential |
| 216 | + model: scale_factor = a * exp(b * sparsity). |
| 217 | + """ |
| 218 | + attn_cfg: dict = { |
| 219 | + "method": "flash_skip_softmax", |
| 220 | + "thresholds": {"prefill": [1e-3]}, |
| 221 | + "br": 128, |
| 222 | + "bc": 128, |
| 223 | + "backend": "pytorch", |
| 224 | + "is_causal": False, # Diffusion = bidirectional attention |
| 225 | + "collect_stats": True, |
| 226 | + "enable": True, |
| 227 | + } |
| 228 | + |
| 229 | + sparse_cfg: dict = { |
| 230 | + "*.attn1": attn_cfg, # Self-attention only |
| 231 | + # Disable on all cross-attention and cross-modal attention |
| 232 | + "*.attn2": {"enable": False}, |
| 233 | + "*audio_attn1*": {"enable": False}, |
| 234 | + "*audio_attn2*": {"enable": False}, |
| 235 | + "*audio_to_video_attn*": {"enable": False}, |
| 236 | + "*video_to_audio_attn*": {"enable": False}, |
| 237 | + "default": {"enable": False}, |
| 238 | + } |
| 239 | + |
| 240 | + # Keep first/last N layers dense for quality |
| 241 | + for i in range(args.skip_first_last): |
| 242 | + sparse_cfg[f"*transformer_blocks.{i}.attn*"] = {"enable": False} |
| 243 | + sparse_cfg[f"*transformer_blocks.{NUM_TRANSFORMER_BLOCKS - 1 - i}.attn*"] = { |
| 244 | + "enable": False |
| 245 | + } |
| 246 | + |
| 247 | + config: dict = {"sparse_cfg": sparse_cfg} |
| 248 | + |
| 249 | + # Add calibration config with threshold trials |
| 250 | + if args.calibrate: |
| 251 | + sparse_cfg["calibration"] = { |
| 252 | + "target_sparse_ratio": {"prefill": args.target_sparsity}, |
| 253 | + "samples": args.calib_size, |
| 254 | + "threshold_trials": DEFAULT_THRESHOLD_TRIALS, |
| 255 | + } |
| 256 | + |
| 257 | + return config |
| 258 | + |
| 259 | + |
| 260 | +def load_calib_prompts(calib_size: int) -> list[str]: |
| 261 | + """Load calibration prompts from OpenVid-1M dataset.""" |
| 262 | + from datasets import load_dataset |
| 263 | + |
| 264 | + dataset = load_dataset("nkp37/OpenVid-1M") |
| 265 | + prompts = list(dataset["train"]["caption"][:calib_size]) |
| 266 | + print(f"Loaded {len(prompts)} calibration prompts from OpenVid-1M") |
| 267 | + return prompts |
| 268 | + |
| 269 | + |
| 270 | +def build_calibration_forward_loop( |
| 271 | + pipeline: TI2VidTwoStagesPipeline, |
| 272 | + num_steps: int = 10, |
| 273 | + num_frames: int = 81, |
| 274 | + calib_size: int = 1, |
| 275 | +): |
| 276 | + """Build a forward loop for exponential model calibration. |
| 277 | +
|
| 278 | + Generates short videos to exercise the attention mechanism at various |
| 279 | + threshold trials, collecting sparsity statistics for the exponential fit. |
| 280 | + """ |
| 281 | + calib_prompts = load_calib_prompts(calib_size) |
| 282 | + tiling_config = TilingConfig.default() |
| 283 | + |
| 284 | + def forward_loop(model): |
| 285 | + for i, prompt in enumerate(calib_prompts): |
| 286 | + print(f"Calibration [{i + 1}/{len(calib_prompts)}]: {prompt[:60]}...") |
| 287 | + pipeline( |
| 288 | + prompt=prompt, |
| 289 | + negative_prompt=DEFAULT_NEGATIVE_PROMPT, |
| 290 | + seed=DEFAULT_SEED, |
| 291 | + height=DEFAULT_2_STAGE_HEIGHT, |
| 292 | + width=DEFAULT_2_STAGE_WIDTH, |
| 293 | + num_frames=num_frames, |
| 294 | + frame_rate=DEFAULT_FRAME_RATE, |
| 295 | + num_inference_steps=num_steps, |
| 296 | + video_guider_params=DEFAULT_VIDEO_GUIDER_PARAMS, |
| 297 | + audio_guider_params=DEFAULT_AUDIO_GUIDER_PARAMS, |
| 298 | + images=[], |
| 299 | + tiling_config=tiling_config, |
| 300 | + ) |
| 301 | + |
| 302 | + return forward_loop |
| 303 | + |
| 304 | + |
| 305 | +def print_sparsity_summary(transformer: torch.nn.Module) -> None: |
| 306 | + """Print per-module sparsity statistics.""" |
| 307 | + enabled, disabled = [], [] |
| 308 | + for name, module in transformer.named_modules(): |
| 309 | + if isinstance(module, SparseAttentionModule): |
| 310 | + if module.is_enabled: |
| 311 | + enabled.append((name, module)) |
| 312 | + else: |
| 313 | + disabled.append(name) |
| 314 | + |
| 315 | + print(f"\nSparse attention: {len(enabled)} enabled, {len(disabled)} disabled") |
| 316 | + for name, module in enabled: |
| 317 | + info = module.get_threshold_info() |
| 318 | + print(f" {name}: {info}") |
| 319 | + |
| 320 | + |
| 321 | +def main() -> None: |
| 322 | + args = parse_args() |
| 323 | + |
| 324 | + # ---- Build pipeline ---- |
| 325 | + print("Building LTX-2 pipeline...") |
| 326 | + pipeline = build_pipeline() |
| 327 | + |
| 328 | + # ---- Get and sparsify the stage-1 transformer ---- |
| 329 | + transformer = pipeline.stage_1_model_ledger.transformer() |
| 330 | + # Pin transformer in memory so pipeline reuses the sparsified version |
| 331 | + pipeline.stage_1_model_ledger.transformer = lambda: transformer |
| 332 | + |
| 333 | + config = build_sparse_config(args) |
| 334 | + forward_loop = None |
| 335 | + if args.calibrate: |
| 336 | + forward_loop = build_calibration_forward_loop( |
| 337 | + pipeline, |
| 338 | + num_steps=args.calib_steps, |
| 339 | + num_frames=args.calib_frames, |
| 340 | + calib_size=args.calib_size, |
| 341 | + ) |
| 342 | + |
| 343 | + print("Applying skip-softmax sparse attention...") |
| 344 | + mtsa.sparsify(transformer, config, forward_loop=forward_loop) |
| 345 | + |
| 346 | + # ---- Build prompt list ---- |
| 347 | + prompts_and_outputs: list[tuple[str, str]] = [] |
| 348 | + if args.prompt_dir: |
| 349 | + output_dir = args.output_dir or "output_videos" |
| 350 | + os.makedirs(output_dir, exist_ok=True) |
| 351 | + prompt_files = sorted(f for f in os.listdir(args.prompt_dir) if f.endswith(".txt")) |
| 352 | + for pf in prompt_files: |
| 353 | + with open(os.path.join(args.prompt_dir, pf)) as f: |
| 354 | + prompt = f.read().strip() |
| 355 | + stem = os.path.splitext(pf)[0] |
| 356 | + prompts_and_outputs.append((prompt, os.path.join(output_dir, f"{stem}.mp4"))) |
| 357 | + elif args.prompt: |
| 358 | + prompts_and_outputs.append((args.prompt, args.output)) |
| 359 | + else: |
| 360 | + raise ValueError("Either --prompt or --prompt-dir must be provided") |
| 361 | + |
| 362 | + # ---- Generate ---- |
| 363 | + tiling_config = TilingConfig.default() |
| 364 | + for i, (prompt, output_path) in enumerate(prompts_and_outputs): |
| 365 | + print(f"\nGenerating [{i + 1}/{len(prompts_and_outputs)}]: {prompt[:80]}...") |
| 366 | + |
| 367 | + video, audio = pipeline( |
| 368 | + prompt=prompt, |
| 369 | + negative_prompt=DEFAULT_NEGATIVE_PROMPT, |
| 370 | + seed=args.seed, |
| 371 | + height=args.height, |
| 372 | + width=args.width, |
| 373 | + num_frames=args.num_frames, |
| 374 | + frame_rate=DEFAULT_FRAME_RATE, |
| 375 | + num_inference_steps=DEFAULT_NUM_INFERENCE_STEPS, |
| 376 | + video_guider_params=DEFAULT_VIDEO_GUIDER_PARAMS, |
| 377 | + audio_guider_params=DEFAULT_AUDIO_GUIDER_PARAMS, |
| 378 | + images=[], |
| 379 | + tiling_config=tiling_config, |
| 380 | + ) |
| 381 | + |
| 382 | + encode_video( |
| 383 | + video=video, |
| 384 | + fps=DEFAULT_FRAME_RATE, |
| 385 | + audio=audio, |
| 386 | + audio_sample_rate=AUDIO_SAMPLE_RATE, |
| 387 | + output_path=output_path, |
| 388 | + video_chunks_number=get_video_chunks_number(args.num_frames, tiling_config), |
| 389 | + ) |
| 390 | + print(f"Saved to {output_path}") |
| 391 | + |
| 392 | + # ---- Print stats ---- |
| 393 | + print_sparsity_summary(transformer) |
| 394 | + |
| 395 | + |
| 396 | +if __name__ == "__main__": |
| 397 | + main() |
0 commit comments