Skip to content

Commit 8093507

Browse files
committed
rebase to latest training code and handle review commnets
Signed-off-by: ynankani <ynankani@nvidia.com>
1 parent f117065 commit 8093507

1 file changed

Lines changed: 4 additions & 11 deletions

File tree

examples/windows/torch_onnx/diffusers/qad_example/sample_example_qad_diffusers.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,10 @@
4949

5050
import torch
5151
import torch.distributed as dist
52-
5352
# LTX imports
54-
from ltx_trainer.model_loader import load_transformer
5553
from ltx_trainer.config import LtxTrainerConfig
5654
from ltx_trainer.datasets import PrecomputedDataset
55+
from ltx_trainer.model_loader import load_transformer
5756
from ltx_trainer.timestep_samplers import SAMPLERS
5857
from ltx_trainer.trainer import LtxvTrainer
5958
from ltx_trainer.training_strategies import get_training_strategy
@@ -311,9 +310,7 @@ def forward(self, student_output, teacher_output):
311310
audio_student.float(), audio_teacher.float()
312311
)
313312
return loss
314-
return torch.nn.functional.mse_loss(
315-
student_output.float(), teacher_output.float()
316-
)
313+
return torch.nn.functional.mse_loss(student_output.float(), teacher_output.float())
317314

318315

319316
# ─── QAD Trainer ──────────────────────────────────────────────────────────────
@@ -423,9 +420,7 @@ def calibration_forward_loop(model):
423420
if text_encoder is not None and "conditions" in batch:
424421
apply_connectors(batch, text_encoder)
425422

426-
model_inputs = strategy.prepare_training_inputs(
427-
batch, timestep_sampler
428-
)
423+
model_inputs = strategy.prepare_training_inputs(batch, timestep_sampler)
429424
model(
430425
video=model_inputs.video,
431426
audio=model_inputs.audio,
@@ -507,9 +502,7 @@ def _training_step(self, batch):
507502
audio=model_inputs.audio,
508503
perturbations=None,
509504
)
510-
hard_loss = self._training_strategy.compute_loss(
511-
video_pred, audio_pred, model_inputs
512-
)
505+
hard_loss = self._training_strategy.compute_loss(video_pred, audio_pred, model_inputs)
513506

514507
unwrapped = self._accelerator.unwrap_model(self._transformer)
515508
if isinstance(unwrapped, DistillationModel) and unwrapped.training:

0 commit comments

Comments
 (0)