Skip to content

Commit c9cf85e

Browse files
committed
[bugfix]: route Wan negative prompt encoding through TextEncoderLoader
Addresses review feedback on #1178: the previous fix loaded the negative prompt encoder via transformers' T5EncoderModel, but Wan's text_encoder is UMT5 (per-layer relative position bias, not shared). Loading UMT5 weights into a T5 architecture silently produces wrong embeddings for the negative prompt and diverges the training-time CFG from inference. Switch to TextEncoderLoader so the encoder class is resolved from pipeline_config (UMT5EncoderModel for Wan) and the postprocess_text function is reused instead of imported by name. This keeps the fix to the original SP deadlock (every rank encodes independently; no full WanPipeline construction, no NCCL collectives) while staying inside the existing prompt-encoding abstraction. text_encoder_cpu_offload is forced off for this short-lived load to avoid initializing an FSDP device mesh, which would re-introduce collectives.
1 parent 9780ea6 commit c9cf85e

1 file changed

Lines changed: 30 additions & 14 deletions

File tree

  • fastvideo/train/models/wan

fastvideo/train/models/wan/wan.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -339,37 +339,53 @@ def ensure_negative_conditioning(self) -> None:
339339
device = self.device
340340
dtype = self._get_training_dtype()
341341

342-
# Every rank encodes the negative prompt independently.
343-
# This avoids NCCL collectives that would deadlock when
344-
# only a subset of ranks creates an inference pipeline.
342+
# Every rank encodes the negative prompt independently. This avoids
343+
# the NCCL deadlock that occurred when only rank 0 constructed the
344+
# full inference pipeline. We go through TextEncoderLoader so the
345+
# encoder class is resolved from pipeline_config (i.e. UMT5 for Wan,
346+
# not vanilla T5) and the same tokenizer / postprocess_text used at
347+
# inference time are reused.
345348
import os
346349

347-
from transformers import AutoTokenizer, T5EncoderModel
350+
from transformers import AutoTokenizer
348351

349-
from fastvideo.configs.pipelines.wan import (
350-
t5_postprocess_text, )
351-
from fastvideo.utils import PRECISION_TO_TYPE, maybe_download_model
352+
from fastvideo.models.loader.component_loader import TextEncoderLoader
353+
from fastvideo.train.utils.moduleloader import make_inference_args
354+
from fastvideo.utils import maybe_download_model
352355

353356
model_path = maybe_download_model(tc.model_path)
354357

355358
sampling_param = SamplingParam.from_pretrained(model_path)
356359
negative_prompt = sampling_param.negative_prompt
357360

358361
encoder_config = tc.pipeline_config.text_encoder_configs[0]
362+
postprocess_text = tc.pipeline_config.postprocess_text_funcs[0]
359363
tok_kwargs = dict(encoder_config.tokenizer_kwargs)
360364

361-
text_enc_dtype = PRECISION_TO_TYPE[tc.pipeline_config.text_encoder_precisions[0]]
362-
tokenizer = AutoTokenizer.from_pretrained(os.path.join(model_path, "tokenizer"))
363-
text_encoder = T5EncoderModel.from_pretrained(
365+
inference_args = make_inference_args(tc, model_path=model_path)
366+
# The negative-prompt encoder is small and only used once at startup;
367+
# keep it on-device and skip CPU offload to avoid initializing FSDP
368+
# device meshes (which would re-introduce collective communication).
369+
inference_args.text_encoder_cpu_offload = False
370+
371+
loader = TextEncoderLoader()
372+
text_encoder = loader.load(
364373
os.path.join(model_path, "text_encoder"),
365-
torch_dtype=text_enc_dtype,
374+
inference_args,
366375
).to(device).eval()
376+
tokenizer = AutoTokenizer.from_pretrained(os.path.join(model_path, "tokenizer"))
367377

368-
with torch.no_grad():
378+
with torch.no_grad(), set_forward_context(current_timestep=0, attn_metadata=None):
369379
text_inputs = tokenizer(negative_prompt, **tok_kwargs).to(device)
370-
outputs = text_encoder(**text_inputs)
380+
outputs = text_encoder(
381+
input_ids=text_inputs.input_ids,
382+
attention_mask=text_inputs.attention_mask,
383+
)
384+
# postprocess_text reads outputs.attention_mask; the FastVideo
385+
# encoders already set it, but be explicit to match the inference
386+
# path (where TextEncodingStage assigns it).
371387
outputs.attention_mask = text_inputs["attention_mask"]
372-
neg_embeds = t5_postprocess_text(outputs).to(device=device, dtype=dtype)
388+
neg_embeds = postprocess_text(outputs).to(device=device, dtype=dtype)
373389
neg_mask = text_inputs["attention_mask"].to(device=device, dtype=dtype)
374390

375391
del text_encoder, tokenizer

0 commit comments

Comments
 (0)