|
25 | 25 | import warnings |
26 | 26 |
|
27 | 27 | import torch |
28 | | -from ltx_core.model.transformer.attention import Attention |
29 | 28 |
|
30 | 29 | from modelopt.torch.kernels import attention, attention_calibrate |
31 | 30 |
|
@@ -171,19 +170,24 @@ def __call__(self, q, k, v, heads, mask=None): |
171 | 170 |
|
172 | 171 | def register_ltx_triton_attention(model: torch.nn.Module) -> None: |
173 | 172 | """Patch all ``ltx_core.Attention`` modules for Triton dispatch.""" |
174 | | - warnings.warn( |
175 | | - "LTX-2 packages (ltx-core, ltx-pipelines, ltx-trainer) are provided by Lightricks " |
176 | | - "and are NOT covered by the Apache 2.0 license governing NVIDIA Model Optimizer. " |
177 | | - "You MUST comply with the LTX Community License Agreement when installing and using " |
178 | | - "LTX-2 with NVIDIA Model Optimizer. Any derivative models or fine-tuned weights from " |
179 | | - "LTX-2 (including quantized or distilled checkpoints) remain subject to the LTX " |
180 | | - "Community License Agreement, not Apache 2.0. " |
181 | | - "See: https://github.com/Lightricks/LTX-2/blob/main/LICENSE", |
182 | | - UserWarning, |
183 | | - stacklevel=2, |
184 | | - ) |
| 173 | + from ltx_core.model.transformer.attention import Attention |
| 174 | + |
| 175 | + _warned = False |
185 | 176 | for module in model.modules(): |
186 | 177 | if isinstance(module, Attention): |
| 178 | + if not _warned: |
| 179 | + warnings.warn( |
| 180 | + "LTX-2 packages (ltx-core, ltx-pipelines, ltx-trainer) are provided by " |
| 181 | + "Lightricks and are NOT covered by the Apache 2.0 license governing NVIDIA " |
| 182 | + "Model Optimizer. You MUST comply with the LTX Community License Agreement " |
| 183 | + "when installing and using LTX-2 with NVIDIA Model Optimizer. Any derivative " |
| 184 | + "models or fine-tuned weights from LTX-2 (including quantized or distilled " |
| 185 | + "checkpoints) remain subject to the LTX Community License Agreement, not " |
| 186 | + "Apache 2.0. See: https://github.com/Lightricks/LTX-2/blob/main/LICENSE", |
| 187 | + UserWarning, |
| 188 | + stacklevel=2, |
| 189 | + ) |
| 190 | + _warned = True |
187 | 191 | fn = module.attention_function |
188 | 192 | if not isinstance(fn, _TritonLTXAttentionWrapper): |
189 | 193 | module.attention_function = _TritonLTXAttentionWrapper(fn) |
0 commit comments