Skip to content

Commit ccc0a50

Browse files
committed
[bugfix] Use legacy LTX2 behavior for SSIM compatibility
1 parent 02a601f commit ccc0a50

4 files changed

Lines changed: 32 additions & 8 deletions

File tree

fastvideo/fastvideo_args.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,8 @@ class FastVideoArgs:
206206
ltx2_refine_add_noise: bool = True
207207
ltx2_refine_noise_path: str | None = None
208208
ltx2_refine_audio_noise_path: str | None = None
209+
ltx2_legacy_native_noise_order: bool = False
210+
ltx2_use_distilled_sigmas: bool = True
209211

210212
# model paths for correct deallocation
211213
model_paths: dict[str, str] = field(default_factory=dict)

fastvideo/pipelines/basic/ltx2/stages/ltx2_denoising.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,8 @@ def forward(
243243
logger.info("[LTX2] Using override sigma schedule, %s", self.sigmas_override)
244244
else:
245245
# Use distilled hardcoded schedule (or subsets) when enabled.
246-
use_distilled_sigmas = os.getenv("LTX2_USE_DISTILLED_SIGMAS", "1") == "1"
246+
use_distilled_sigmas = (fastvideo_args.ltx2_use_distilled_sigmas
247+
and os.getenv("LTX2_USE_DISTILLED_SIGMAS", "1") == "1")
247248
max_distilled_steps = len(DISTILLED_SIGMA_VALUES) - 1
248249
if use_distilled_sigmas and num_inference_steps <= max_distilled_steps:
249250
sigmas, distilled_indices = _distilled_subset_sigmas(

fastvideo/pipelines/basic/ltx2/stages/ltx2_latent_preparation.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,14 @@ def forward(
160160
loaded_latents = self._load_initial_latent(latent_path, device, dtype)
161161
if loaded_latents is not None:
162162
latents = loaded_latents
163+
elif fastvideo_args.ltx2_legacy_native_noise_order:
164+
latents = randn_tensor(
165+
shape,
166+
generator=generator,
167+
device=device,
168+
dtype=dtype,
169+
)
170+
self._save_initial_latent(latent_path, latents)
163171
else:
164172
latents = _randn_ltx2_video_latents(
165173
shape=shape,
@@ -170,13 +178,21 @@ def forward(
170178
)
171179
self._save_initial_latent(latent_path, latents)
172180
else:
173-
latents = _randn_ltx2_video_latents(
174-
shape=shape,
175-
transformer=self.transformer,
176-
generator=generator,
177-
device=device,
178-
dtype=dtype,
179-
)
181+
if fastvideo_args.ltx2_legacy_native_noise_order:
182+
latents = randn_tensor(
183+
shape,
184+
generator=generator,
185+
device=device,
186+
dtype=dtype,
187+
)
188+
else:
189+
latents = _randn_ltx2_video_latents(
190+
shape=shape,
191+
transformer=self.transformer,
192+
generator=generator,
193+
device=device,
194+
dtype=dtype,
195+
)
180196
else:
181197
latents = latents.to(device)
182198

fastvideo/tests/ssim/test_ltx2_similarity.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,5 +123,10 @@ def test_ltx2_distilled_inference_similarity(
123123
full_quality_params_map=FULL_QUALITY_LTX2_DISTILLED_MODEL_TO_PARAMS,
124124
slice_cosine_threshold=SLICE_COSINE_DISTANCE_THRESHOLD,
125125
full_cosine_threshold=FULL_COSINE_DISTANCE_THRESHOLD,
126+
init_kwargs_override={
127+
"dit_cpu_offload": True,
128+
"ltx2_legacy_native_noise_order": True,
129+
"ltx2_use_distilled_sigmas": False,
130+
},
126131
generation_kwargs_override=LTX2_DISTILLED_REFERENCE_GUIDANCE_OVERRIDES,
127132
)

0 commit comments

Comments
 (0)