Skip to content

Commit ee162b3

Browse files
committed
addressed the ltx2 issue and the import issue
Signed-off-by: Jingyu Xin <jingyux@nvidia.com>
1 parent f86580c commit ee162b3

2 files changed

Lines changed: 17 additions & 14 deletions

File tree

examples/diffusers/sparsity/wan22_skip_softmax.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
import os
5454

5555
import torch
56+
from datasets import load_dataset
5657
from diffusers import AutoencoderKLWan, WanPipeline
5758
from diffusers.utils import export_to_video
5859

@@ -258,8 +259,6 @@ def build_sparse_config(args: argparse.Namespace, num_blocks: int) -> dict:
258259

259260
def load_calib_prompts(calib_size: int) -> list[str]:
260261
"""Load calibration prompts from OpenVid-1M dataset."""
261-
from datasets import load_dataset
262-
263262
dataset = load_dataset("nkp37/OpenVid-1M", split="train")
264263
prompts = list(dataset["caption"][:calib_size])
265264
print(f"Loaded {len(prompts)} calibration prompts from OpenVid-1M")

modelopt/torch/sparsity/attention_sparsity/kernels/ltx_triton_attention.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import warnings
2626

2727
import torch
28-
from ltx_core.model.transformer.attention import Attention
2928

3029
from modelopt.torch.kernels import attention, attention_calibrate
3130

@@ -171,19 +170,24 @@ def __call__(self, q, k, v, heads, mask=None):
171170

172171
def register_ltx_triton_attention(model: torch.nn.Module) -> None:
173172
"""Patch all ``ltx_core.Attention`` modules for Triton dispatch."""
174-
warnings.warn(
175-
"LTX-2 packages (ltx-core, ltx-pipelines, ltx-trainer) are provided by Lightricks "
176-
"and are NOT covered by the Apache 2.0 license governing NVIDIA Model Optimizer. "
177-
"You MUST comply with the LTX Community License Agreement when installing and using "
178-
"LTX-2 with NVIDIA Model Optimizer. Any derivative models or fine-tuned weights from "
179-
"LTX-2 (including quantized or distilled checkpoints) remain subject to the LTX "
180-
"Community License Agreement, not Apache 2.0. "
181-
"See: https://github.com/Lightricks/LTX-2/blob/main/LICENSE",
182-
UserWarning,
183-
stacklevel=2,
184-
)
173+
from ltx_core.model.transformer.attention import Attention
174+
175+
_warned = False
185176
for module in model.modules():
186177
if isinstance(module, Attention):
178+
if not _warned:
179+
warnings.warn(
180+
"LTX-2 packages (ltx-core, ltx-pipelines, ltx-trainer) are provided by "
181+
"Lightricks and are NOT covered by the Apache 2.0 license governing NVIDIA "
182+
"Model Optimizer. You MUST comply with the LTX Community License Agreement "
183+
"when installing and using LTX-2 with NVIDIA Model Optimizer. Any derivative "
184+
"models or fine-tuned weights from LTX-2 (including quantized or distilled "
185+
"checkpoints) remain subject to the LTX Community License Agreement, not "
186+
"Apache 2.0. See: https://github.com/Lightricks/LTX-2/blob/main/LICENSE",
187+
UserWarning,
188+
stacklevel=2,
189+
)
190+
_warned = True
187191
fn = module.attention_function
188192
if not isinstance(fn, _TritonLTXAttentionWrapper):
189193
module.attention_function = _TritonLTXAttentionWrapper(fn)

0 commit comments

Comments
 (0)