Skip to content

Commit 215056b

Browse files
committed
minor fix
Signed-off-by: Meng Xin <mxin@nvidia.com>
1 parent acaa7f9 commit 215056b

4 files changed

Lines changed: 23 additions & 33 deletions

File tree

examples/diffusers/distillation/src/models/ltx2/pipeline.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,10 @@ def unload_text_encoder(self) -> None:
184184
del te.feature_extractor_linear
185185
te.feature_extractor_linear = None
186186
te.tokenizer = None
187-
# Keep connectors on GPU (they're small)
188-
te.embeddings_connector.to("cuda")
189-
te.audio_embeddings_connector.to("cuda")
187+
# Keep connectors on current GPU (they're small)
188+
device = torch.device("cuda", torch.cuda.current_device())
189+
te.embeddings_connector.to(device)
190+
te.audio_embeddings_connector.to(device)
190191
free_gpu_memory()
191192
logger.info("Text encoder unloaded (connectors kept for training/inference)")
192193

examples/diffusers/distillation/src/models/wan/adapter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ def prepare_inputs(
8181
)
8282

8383
def forward_model(self, model: nn.Module, inputs: BackboneInputs) -> Tensor:
84-
# WanModel's internal norms promote to float32; autocast keeps
85-
# linear ops in bf16 to match the model weights.
84+
# WanModel norms promote to float32 internally; autocast keeps
85+
# matmuls in bf16 to match the original Wan inference code.
8686
with torch.amp.autocast("cuda", dtype=torch.bfloat16):
8787
output_list = model(**inputs.forward_kwargs)
8888
return torch.stack(output_list)

examples/diffusers/distillation/src/models/wan/pipeline.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,18 @@ def load_components(self, model_config, device: str, dtype: torch.dtype) -> None
6666
self._config = self._var["config"]()
6767

6868
t5_path = os.path.join(path, self._config.t5_checkpoint)
69+
# Prefer local tokenizer dir (avoids HuggingFace network calls).
70+
# Wan ships tokenizer files under <model_root>/google/umt5-xxl/.
71+
tokenizer_path = self._config.t5_tokenizer
72+
local_tokenizer = os.path.join(path, tokenizer_path)
73+
if os.path.isdir(local_tokenizer):
74+
tokenizer_path = local_tokenizer
6975
self._text_encoder = T5EncoderModel(
7076
text_len=self._config.text_len,
7177
dtype=dtype,
7278
device=torch.device("cpu"),
7379
checkpoint_path=t5_path,
74-
tokenizer_path=self._config.t5_tokenizer,
80+
tokenizer_path=tokenizer_path,
7581
)
7682

7783
vae_mod = importlib.import_module(self._var["vae_module"])

examples/diffusers/distillation/src/trainer.py

Lines changed: 10 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def __init__(
102102
self._inference_pipeline = inference_pipeline
103103

104104
self._global_step = 0
105+
self._data_epoch = 0
105106
self._wandb_run = None
106107

107108
set_seed(config.seed)
@@ -671,30 +672,6 @@ def _training_step(self, batch: dict[str, Tensor]) -> Tensor:
671672

672673
return total_loss
673674

674-
def _compute_distillation_loss(
675-
self, student_pred: Tensor, teacher_pred: Tensor, loss_mask: Tensor
676-
) -> Tensor:
677-
loss_type = self._config.distillation.distillation_loss_type
678-
679-
if loss_type == "mse":
680-
loss = torch.nn.functional.mse_loss(student_pred, teacher_pred, reduction="none")
681-
elif loss_type == "cosine":
682-
s_flat = student_pred.flatten(start_dim=2)
683-
t_flat = teacher_pred.flatten(start_dim=2)
684-
cos_sim = torch.nn.functional.cosine_similarity(s_flat, t_flat, dim=-1)
685-
loss = 1.0 - cos_sim # [B, T]
686-
else:
687-
raise ValueError(f"Unknown distillation loss type: {loss_type}")
688-
689-
if loss_mask is not None and loss_mask.numel() > 0:
690-
# Expand mask to match loss dimensions
691-
while loss_mask.dim() < loss.dim():
692-
loss_mask = loss_mask.unsqueeze(-1)
693-
mask = loss_mask.float()
694-
loss = loss.mul(mask).div(mask.mean())
695-
696-
return loss.mean()
697-
698675
def _compute_layer_distillation_loss(self) -> Tensor:
699676
"""Compute distillation loss across hooked intermediate layers."""
700677
assert self._student_extractor is not None
@@ -1078,10 +1055,12 @@ def train(self) -> dict:
10781055
f"batch_size={cfg.optimization.batch_size}"
10791056
)
10801057

1058+
start_micro = self._global_step * grad_accum
1059+
total_micro = total_steps * grad_accum
10811060
pbar = tqdm(
1082-
range(self._global_step, total_steps * grad_accum),
1083-
initial=self._global_step * grad_accum,
1084-
total=total_steps * grad_accum,
1061+
range(start_micro, total_micro),
1062+
initial=start_micro,
1063+
total=total_micro,
10851064
desc="Training",
10861065
disable=not _is_global_rank0(),
10871066
)
@@ -1091,6 +1070,10 @@ def train(self) -> dict:
10911070
try:
10921071
batch = next(data_iter)
10931072
except StopIteration:
1073+
self._data_epoch += 1
1074+
sampler = getattr(self._dataloader, "sampler", None)
1075+
if sampler is not None and hasattr(sampler, "set_epoch"):
1076+
sampler.set_epoch(self._data_epoch)
10941077
data_iter = iter(self._dataloader)
10951078
batch = next(data_iter)
10961079

0 commit comments

Comments
 (0)