Skip to content

Commit 1d8ac33

Browse files
committed
Add the Skip softmax diffusion
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
1 parent de55e8a commit 1d8ac33

File tree

13 files changed

+1429
-32
lines changed

13 files changed

+1429
-32
lines changed
Lines changed: 397 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,397 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""LTX-2 inference with skip-softmax sparse attention.
17+
18+
This example applies skip-softmax sparse attention to the LTX-2 video
19+
generation model using exponential model calibration
20+
(``scale_factor = a * exp(b * target_sparsity)``).
21+
22+
During calibration, ``flash_skip_softmax`` with the eager attention backend
23+
collects sparsity statistics across multiple threshold trials. The fitted
24+
exponential model then allows runtime control of the target sparsity ratio
25+
without recalibration.
26+
27+
Only the stage-1 backbone is sparsified. Stage 2 (spatial upsampler +
28+
distilled LoRA) runs unmodified.
29+
30+
Usage::
31+
32+
# With calibration (recommended)
33+
python ltx2_skip_softmax.py --prompt "A cat playing piano" --output out.mp4 \\
34+
--calibrate --target-sparsity 0.25
35+
36+
# Disable sparsity on first/last 2 layers (higher quality, less speedup)
37+
python ltx2_skip_softmax.py --prompt "A cat playing piano" --output out.mp4 \\
38+
--calibrate --target-sparsity 0.25 --skip-first-last 2
39+
"""
40+
41+
import argparse
42+
import functools
43+
import os
44+
45+
import torch
46+
from ltx_core.loader import LTXV_LORA_COMFY_RENAMING_MAP, LoraPathStrengthAndSDOps
47+
from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number
48+
from ltx_pipelines.ti2vid_two_stages import TI2VidTwoStagesPipeline
49+
from ltx_pipelines.utils.constants import (
50+
AUDIO_SAMPLE_RATE,
51+
DEFAULT_2_STAGE_HEIGHT,
52+
DEFAULT_2_STAGE_WIDTH,
53+
DEFAULT_AUDIO_GUIDER_PARAMS,
54+
DEFAULT_FRAME_RATE,
55+
DEFAULT_NEGATIVE_PROMPT,
56+
DEFAULT_NUM_INFERENCE_STEPS,
57+
DEFAULT_SEED,
58+
DEFAULT_VIDEO_GUIDER_PARAMS,
59+
)
60+
from ltx_pipelines.utils.media_io import encode_video
61+
62+
import modelopt.torch.sparsity.attention_sparsity as mtsa
63+
from modelopt.torch.sparsity.attention_sparsity.sparse_attention import SparseAttentionModule
64+
65+
# ---- Model paths (edit these or override via environment variables) ----
66+
CHECKPOINT_PATH = os.environ.get(
67+
"LTX2_CHECKPOINT",
68+
"/home/scratch.omniml_data_2/jingyux/models/LTX-2/ltx-2-19b-dev.safetensors",
69+
)
70+
DISTILLED_LORA_PATH = os.environ.get(
71+
"LTX2_DISTILLED_LORA",
72+
"/home/scratch.omniml_data_2/jingyux/models/LTX-2/ltx-2-19b-distilled-lora-384.safetensors",
73+
)
74+
SPATIAL_UPSAMPLER_PATH = os.environ.get(
75+
"LTX2_SPATIAL_UPSAMPLER",
76+
"/home/scratch.omniml_data_2/jingyux/models/LTX-2/ltx-2-spatial-upscaler-x2-1.0.safetensors",
77+
)
78+
GEMMA_ROOT = os.environ.get(
79+
"LTX2_GEMMA_ROOT",
80+
"/home/scratch.omniml_data_2/jingyux/models/LTX-2/gemma-3-12b-it-qat-q4_0-unquantized",
81+
)
82+
83+
DEFAULT_NUM_FRAMES = 121
84+
NUM_TRANSFORMER_BLOCKS = 48
85+
86+
# Default threshold trials for calibration
87+
DEFAULT_THRESHOLD_TRIALS = [
88+
1e-6,
89+
5e-6,
90+
1e-5,
91+
5e-5,
92+
1e-4,
93+
5e-4,
94+
1e-3,
95+
5e-3,
96+
1e-2,
97+
2e-2,
98+
5e-2,
99+
1e-1,
100+
2e-1,
101+
3e-1,
102+
5e-1,
103+
7e-1,
104+
]
105+
106+
107+
def parse_args() -> argparse.Namespace:
108+
parser = argparse.ArgumentParser(
109+
description="LTX-2 video generation with skip-softmax sparse attention"
110+
)
111+
parser.add_argument("--prompt", type=str, default=None, help="Text prompt for generation")
112+
parser.add_argument(
113+
"--prompt-dir",
114+
type=str,
115+
default=None,
116+
help="Directory of .txt prompt files (one prompt per file). Overrides --prompt.",
117+
)
118+
parser.add_argument("--output", type=str, default="output.mp4", help="Output video path")
119+
parser.add_argument(
120+
"--output-dir",
121+
type=str,
122+
default=None,
123+
help="Directory to save videos when using --prompt-dir",
124+
)
125+
parser.add_argument(
126+
"--num-frames", type=int, default=DEFAULT_NUM_FRAMES, help="Number of frames"
127+
)
128+
parser.add_argument("--height", type=int, default=DEFAULT_2_STAGE_HEIGHT, help="Video height")
129+
parser.add_argument("--width", type=int, default=DEFAULT_2_STAGE_WIDTH, help="Video width")
130+
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="Random seed")
131+
132+
# Sparse attention options
133+
parser.add_argument(
134+
"--skip-first-last",
135+
type=int,
136+
default=0,
137+
help="Number of first/last transformer layers to keep dense (default: 0)",
138+
)
139+
140+
# Calibration options
141+
parser.add_argument(
142+
"--calibrate",
143+
action="store_true",
144+
help="Calibrate threshold via exponential model (recommended)",
145+
)
146+
parser.add_argument(
147+
"--target-sparsity",
148+
type=float,
149+
default=0.25,
150+
help="Target sparsity ratio for calibration (0.0-1.0)",
151+
)
152+
parser.add_argument(
153+
"--calib-steps",
154+
type=int,
155+
default=10,
156+
help="Inference steps per calibration sample",
157+
)
158+
parser.add_argument(
159+
"--calib-frames",
160+
type=int,
161+
default=81,
162+
help="Number of frames per calibration sample",
163+
)
164+
parser.add_argument(
165+
"--calib-size",
166+
type=int,
167+
default=1,
168+
help="Number of prompts to use for calibration",
169+
)
170+
return parser.parse_args()
171+
172+
173+
def _patch_vae_requires_grad(pipeline: TI2VidTwoStagesPipeline):
174+
"""Ensure VAE decoder weights have requires_grad=False to avoid autograd issues."""
175+
for ledger_attr in ("stage_1_model_ledger", "stage_2_model_ledger"):
176+
ledger = getattr(pipeline, ledger_attr, None)
177+
if ledger is None:
178+
continue
179+
for loader_name in ("video_decoder", "audio_decoder"):
180+
orig_loader = getattr(ledger, loader_name, None)
181+
if orig_loader is None:
182+
continue
183+
184+
def _make_patched(fn):
185+
@functools.wraps(fn)
186+
def patched():
187+
model = fn()
188+
model.requires_grad_(False)
189+
return model
190+
191+
return patched
192+
193+
setattr(ledger, loader_name, _make_patched(orig_loader))
194+
195+
196+
def build_pipeline() -> TI2VidTwoStagesPipeline:
197+
"""Build the LTX-2 two-stage video generation pipeline."""
198+
pipeline = TI2VidTwoStagesPipeline(
199+
checkpoint_path=CHECKPOINT_PATH,
200+
distilled_lora=[
201+
LoraPathStrengthAndSDOps(DISTILLED_LORA_PATH, 0.8, LTXV_LORA_COMFY_RENAMING_MAP)
202+
],
203+
spatial_upsampler_path=SPATIAL_UPSAMPLER_PATH,
204+
gemma_root=GEMMA_ROOT,
205+
loras=[],
206+
)
207+
_patch_vae_requires_grad(pipeline)
208+
return pipeline
209+
210+
211+
def build_sparse_config(args: argparse.Namespace) -> dict:
212+
"""Build sparse attention config from CLI args.
213+
214+
Uses flash_skip_softmax which supports both calibration (eager attention
215+
with F.softmax patching) and inference. Calibration fits an exponential
216+
model: scale_factor = a * exp(b * sparsity).
217+
"""
218+
attn_cfg: dict = {
219+
"method": "flash_skip_softmax",
220+
"thresholds": {"prefill": [1e-3]},
221+
"br": 128,
222+
"bc": 128,
223+
"backend": "pytorch",
224+
"is_causal": False, # Diffusion = bidirectional attention
225+
"collect_stats": True,
226+
"enable": True,
227+
}
228+
229+
sparse_cfg: dict = {
230+
"*.attn1": attn_cfg, # Self-attention only
231+
# Disable on all cross-attention and cross-modal attention
232+
"*.attn2": {"enable": False},
233+
"*audio_attn1*": {"enable": False},
234+
"*audio_attn2*": {"enable": False},
235+
"*audio_to_video_attn*": {"enable": False},
236+
"*video_to_audio_attn*": {"enable": False},
237+
"default": {"enable": False},
238+
}
239+
240+
# Keep first/last N layers dense for quality
241+
for i in range(args.skip_first_last):
242+
sparse_cfg[f"*transformer_blocks.{i}.attn*"] = {"enable": False}
243+
sparse_cfg[f"*transformer_blocks.{NUM_TRANSFORMER_BLOCKS - 1 - i}.attn*"] = {
244+
"enable": False
245+
}
246+
247+
config: dict = {"sparse_cfg": sparse_cfg}
248+
249+
# Add calibration config with threshold trials
250+
if args.calibrate:
251+
sparse_cfg["calibration"] = {
252+
"target_sparse_ratio": {"prefill": args.target_sparsity},
253+
"samples": args.calib_size,
254+
"threshold_trials": DEFAULT_THRESHOLD_TRIALS,
255+
}
256+
257+
return config
258+
259+
260+
def load_calib_prompts(calib_size: int) -> list[str]:
261+
"""Load calibration prompts from OpenVid-1M dataset."""
262+
from datasets import load_dataset
263+
264+
dataset = load_dataset("nkp37/OpenVid-1M")
265+
prompts = list(dataset["train"]["caption"][:calib_size])
266+
print(f"Loaded {len(prompts)} calibration prompts from OpenVid-1M")
267+
return prompts
268+
269+
270+
def build_calibration_forward_loop(
271+
pipeline: TI2VidTwoStagesPipeline,
272+
num_steps: int = 10,
273+
num_frames: int = 81,
274+
calib_size: int = 1,
275+
):
276+
"""Build a forward loop for exponential model calibration.
277+
278+
Generates short videos to exercise the attention mechanism at various
279+
threshold trials, collecting sparsity statistics for the exponential fit.
280+
"""
281+
calib_prompts = load_calib_prompts(calib_size)
282+
tiling_config = TilingConfig.default()
283+
284+
def forward_loop(model):
285+
for i, prompt in enumerate(calib_prompts):
286+
print(f"Calibration [{i + 1}/{len(calib_prompts)}]: {prompt[:60]}...")
287+
pipeline(
288+
prompt=prompt,
289+
negative_prompt=DEFAULT_NEGATIVE_PROMPT,
290+
seed=DEFAULT_SEED,
291+
height=DEFAULT_2_STAGE_HEIGHT,
292+
width=DEFAULT_2_STAGE_WIDTH,
293+
num_frames=num_frames,
294+
frame_rate=DEFAULT_FRAME_RATE,
295+
num_inference_steps=num_steps,
296+
video_guider_params=DEFAULT_VIDEO_GUIDER_PARAMS,
297+
audio_guider_params=DEFAULT_AUDIO_GUIDER_PARAMS,
298+
images=[],
299+
tiling_config=tiling_config,
300+
)
301+
302+
return forward_loop
303+
304+
305+
def print_sparsity_summary(transformer: torch.nn.Module) -> None:
306+
"""Print per-module sparsity statistics."""
307+
enabled, disabled = [], []
308+
for name, module in transformer.named_modules():
309+
if isinstance(module, SparseAttentionModule):
310+
if module.is_enabled:
311+
enabled.append((name, module))
312+
else:
313+
disabled.append(name)
314+
315+
print(f"\nSparse attention: {len(enabled)} enabled, {len(disabled)} disabled")
316+
for name, module in enabled:
317+
info = module.get_threshold_info()
318+
print(f" {name}: {info}")
319+
320+
321+
def main() -> None:
322+
args = parse_args()
323+
324+
# ---- Build pipeline ----
325+
print("Building LTX-2 pipeline...")
326+
pipeline = build_pipeline()
327+
328+
# ---- Get and sparsify the stage-1 transformer ----
329+
transformer = pipeline.stage_1_model_ledger.transformer()
330+
# Pin transformer in memory so pipeline reuses the sparsified version
331+
pipeline.stage_1_model_ledger.transformer = lambda: transformer
332+
333+
config = build_sparse_config(args)
334+
forward_loop = None
335+
if args.calibrate:
336+
forward_loop = build_calibration_forward_loop(
337+
pipeline,
338+
num_steps=args.calib_steps,
339+
num_frames=args.calib_frames,
340+
calib_size=args.calib_size,
341+
)
342+
343+
print("Applying skip-softmax sparse attention...")
344+
mtsa.sparsify(transformer, config, forward_loop=forward_loop)
345+
346+
# ---- Build prompt list ----
347+
prompts_and_outputs: list[tuple[str, str]] = []
348+
if args.prompt_dir:
349+
output_dir = args.output_dir or "output_videos"
350+
os.makedirs(output_dir, exist_ok=True)
351+
prompt_files = sorted(f for f in os.listdir(args.prompt_dir) if f.endswith(".txt"))
352+
for pf in prompt_files:
353+
with open(os.path.join(args.prompt_dir, pf)) as f:
354+
prompt = f.read().strip()
355+
stem = os.path.splitext(pf)[0]
356+
prompts_and_outputs.append((prompt, os.path.join(output_dir, f"{stem}.mp4")))
357+
elif args.prompt:
358+
prompts_and_outputs.append((args.prompt, args.output))
359+
else:
360+
raise ValueError("Either --prompt or --prompt-dir must be provided")
361+
362+
# ---- Generate ----
363+
tiling_config = TilingConfig.default()
364+
for i, (prompt, output_path) in enumerate(prompts_and_outputs):
365+
print(f"\nGenerating [{i + 1}/{len(prompts_and_outputs)}]: {prompt[:80]}...")
366+
367+
video, audio = pipeline(
368+
prompt=prompt,
369+
negative_prompt=DEFAULT_NEGATIVE_PROMPT,
370+
seed=args.seed,
371+
height=args.height,
372+
width=args.width,
373+
num_frames=args.num_frames,
374+
frame_rate=DEFAULT_FRAME_RATE,
375+
num_inference_steps=DEFAULT_NUM_INFERENCE_STEPS,
376+
video_guider_params=DEFAULT_VIDEO_GUIDER_PARAMS,
377+
audio_guider_params=DEFAULT_AUDIO_GUIDER_PARAMS,
378+
images=[],
379+
tiling_config=tiling_config,
380+
)
381+
382+
encode_video(
383+
video=video,
384+
fps=DEFAULT_FRAME_RATE,
385+
audio=audio,
386+
audio_sample_rate=AUDIO_SAMPLE_RATE,
387+
output_path=output_path,
388+
video_chunks_number=get_video_chunks_number(args.num_frames, tiling_config),
389+
)
390+
print(f"Saved to {output_path}")
391+
392+
# ---- Print stats ----
393+
print_sparsity_summary(transformer)
394+
395+
396+
if __name__ == "__main__":
397+
main()

0 commit comments

Comments
 (0)