|
49 | 49 |
|
50 | 50 | import torch |
51 | 51 | import torch.distributed as dist |
52 | | - |
53 | 52 | # LTX imports |
54 | | -from ltx_trainer.model_loader import load_transformer |
55 | 53 | from ltx_trainer.config import LtxTrainerConfig |
56 | 54 | from ltx_trainer.datasets import PrecomputedDataset |
| 55 | +from ltx_trainer.model_loader import load_transformer |
57 | 56 | from ltx_trainer.timestep_samplers import SAMPLERS |
58 | 57 | from ltx_trainer.trainer import LtxvTrainer |
59 | 58 | from ltx_trainer.training_strategies import get_training_strategy |
@@ -311,9 +310,7 @@ def forward(self, student_output, teacher_output): |
311 | 310 | audio_student.float(), audio_teacher.float() |
312 | 311 | ) |
313 | 312 | return loss |
314 | | - return torch.nn.functional.mse_loss( |
315 | | - student_output.float(), teacher_output.float() |
316 | | - ) |
| 313 | + return torch.nn.functional.mse_loss(student_output.float(), teacher_output.float()) |
317 | 314 |
|
318 | 315 |
|
319 | 316 | # ─── QAD Trainer ────────────────────────────────────────────────────────────── |
@@ -423,9 +420,7 @@ def calibration_forward_loop(model): |
423 | 420 | if text_encoder is not None and "conditions" in batch: |
424 | 421 | apply_connectors(batch, text_encoder) |
425 | 422 |
|
426 | | - model_inputs = strategy.prepare_training_inputs( |
427 | | - batch, timestep_sampler |
428 | | - ) |
| 423 | + model_inputs = strategy.prepare_training_inputs(batch, timestep_sampler) |
429 | 424 | model( |
430 | 425 | video=model_inputs.video, |
431 | 426 | audio=model_inputs.audio, |
@@ -507,9 +502,7 @@ def _training_step(self, batch): |
507 | 502 | audio=model_inputs.audio, |
508 | 503 | perturbations=None, |
509 | 504 | ) |
510 | | - hard_loss = self._training_strategy.compute_loss( |
511 | | - video_pred, audio_pred, model_inputs |
512 | | - ) |
| 505 | + hard_loss = self._training_strategy.compute_loss(video_pred, audio_pred, model_inputs) |
513 | 506 |
|
514 | 507 | unwrapped = self._accelerator.unwrap_model(self._transformer) |
515 | 508 | if isinstance(unwrapped, DistillationModel) and unwrapped.training: |
|
0 commit comments