Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 9 additions & 96 deletions fastvideo/train/models/wan/wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from __future__ import annotations

import copy
import gc
from typing import Any, Literal, TYPE_CHECKING

import torch
Expand All @@ -19,10 +18,6 @@
from fastvideo.models.schedulers.scheduling_flow_match_euler_discrete import (
FlowMatchEulerDiscreteScheduler, )
from fastvideo.pipelines import TrainingBatch
from fastvideo.pipelines.basic.wan.wan_pipeline import (
WanPipeline, )
from fastvideo.pipelines.pipeline_batch_info import (
ForwardBatch, )
from fastvideo.training.activation_checkpoint import (
apply_activation_checkpointing, )
from fastvideo.training.training_utils import (
Expand All @@ -41,6 +36,7 @@
apply_trainable, )
from fastvideo.train.utils.moduleloader import (
load_module_from_path, )
from fastvideo.train.utils.negative_prompt import encode_negative_prompt

if TYPE_CHECKING:
from fastvideo.train.utils.training_config import (
Expand Down Expand Up @@ -341,98 +337,15 @@ def ensure_negative_conditioning(self) -> None:

assert self.training_config is not None
tc = self.training_config
world_group = self.world_group
device = self.device
dtype = self._get_training_dtype()

from fastvideo.train.utils.moduleloader import (
make_inference_args, )

neg_embeds: torch.Tensor | None = None
neg_mask: torch.Tensor | None = None

if world_group.rank_in_group == 0:
sampling_param = SamplingParam.from_pretrained(tc.model_path)
negative_prompt = sampling_param.negative_prompt

inference_args = make_inference_args(tc, model_path=tc.model_path)

prompt_pipeline = WanPipeline.from_pretrained(
tc.model_path,
args=inference_args,
inference_mode=True,
loaded_modules={"transformer": self.transformer},
tp_size=tc.distributed.tp_size,
sp_size=tc.distributed.sp_size,
num_gpus=tc.distributed.num_gpus,
pin_cpu_memory=(tc.distributed.pin_cpu_memory),
dit_cpu_offload=True,
)

batch_negative = ForwardBatch(
data_type="video",
prompt=negative_prompt,
prompt_embeds=[],
prompt_attention_mask=[],
)
result_batch = prompt_pipeline.prompt_encoding_stage( # type: ignore[attr-defined]
batch_negative,
inference_args,
)

neg_embeds = result_batch.prompt_embeds[0].to(device=device, dtype=dtype)
neg_mask = (result_batch.prompt_attention_mask[0].to(device=device, dtype=dtype))

del prompt_pipeline
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()

meta = torch.zeros((2, ), device=device, dtype=torch.int64)
if world_group.rank_in_group == 0:
assert neg_embeds is not None
assert neg_mask is not None
meta[0] = neg_embeds.ndim
meta[1] = neg_mask.ndim
world_group.broadcast(meta, src=0)
embed_ndim, mask_ndim = (
int(meta[0].item()),
int(meta[1].item()),
sampling_param = SamplingParam.from_pretrained(tc.model_path)
embeds, mask = encode_negative_prompt(
tc,
prompt=sampling_param.negative_prompt,
device=self.device,
dtype=self._get_training_dtype(),
)

max_ndim = 8
embed_shape = torch.full((max_ndim, ), -1, device=device, dtype=torch.int64)
mask_shape = torch.full((max_ndim, ), -1, device=device, dtype=torch.int64)
if world_group.rank_in_group == 0:
assert neg_embeds is not None
assert neg_mask is not None
embed_shape[:embed_ndim] = torch.tensor(
list(neg_embeds.shape),
device=device,
dtype=torch.int64,
)
mask_shape[:mask_ndim] = torch.tensor(
list(neg_mask.shape),
device=device,
dtype=torch.int64,
)
world_group.broadcast(embed_shape, src=0)
world_group.broadcast(mask_shape, src=0)

embed_sizes = tuple(int(x) for x in embed_shape[:embed_ndim].tolist())
mask_sizes = tuple(int(x) for x in mask_shape[:mask_ndim].tolist())

if world_group.rank_in_group != 0:
neg_embeds = torch.empty(embed_sizes, device=device, dtype=dtype)
neg_mask = torch.empty(mask_sizes, device=device, dtype=dtype)
assert neg_embeds is not None
assert neg_mask is not None

world_group.broadcast(neg_embeds, src=0)
world_group.broadcast(neg_mask, src=0)

self.negative_prompt_embeds = neg_embeds
self.negative_prompt_attention_mask = neg_mask
self.negative_prompt_embeds = embeds
self.negative_prompt_attention_mask = mask

def _sample_timesteps(
self,
Expand Down
99 changes: 99 additions & 0 deletions fastvideo/train/utils/negative_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# SPDX-License-Identifier: Apache-2.0
"""Per-rank negative-prompt encoding shared by training model plugins.

Encoding the negative prompt only on rank 0 and broadcasting (the
previous Wan path) ran ``Pipeline.from_pretrained`` asymmetrically across
ranks, which deadlocked on any collective fired during text-encoder load
(FSDP device-mesh init, weight broadcast, etc.). The text encoder is
small and only loaded once at startup, so loading it on every rank
sidesteps the deadlock entirely.
"""

from __future__ import annotations

import os
from typing import TYPE_CHECKING

import torch
from transformers import AutoTokenizer

from fastvideo.forward_context import set_forward_context
from fastvideo.models.loader.component_loader import TextEncoderLoader
from fastvideo.train.utils.moduleloader import make_inference_args
from fastvideo.utils import maybe_download_model

if TYPE_CHECKING:
from fastvideo.train.utils.training_config import TrainingConfig


def encode_negative_prompt(
training_config: TrainingConfig,
*,
prompt: str,
device: torch.device,
dtype: torch.dtype,
encoder_index: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Per-rank encode of ``prompt`` using encoder ``encoder_index``.

Reads ``pipeline_config.text_encoder_configs[encoder_index]`` so the
encoder class (e.g. UMT5 for Wan) and tokenizer kwargs match the
inference path, and applies the matching ``postprocess_text_funcs``
entry. Returns ``(embeds, mask)`` on ``device`` cast to ``dtype``.
"""
tc = training_config
pipeline_config = tc.pipeline_config
if pipeline_config is None:
raise ValueError("training_config.pipeline_config is required for negative "
"prompt encoding")

encoder_configs = pipeline_config.text_encoder_configs
postprocess_funcs = pipeline_config.postprocess_text_funcs
preprocess_funcs = getattr(pipeline_config, "preprocess_text_funcs", None)

if encoder_index < 0 or encoder_index >= len(encoder_configs):
raise IndexError(f"encoder_index {encoder_index} out of range for "
f"text_encoder_configs (len={len(encoder_configs)})")
encoder_config = encoder_configs[encoder_index]
postprocess_text = postprocess_funcs[encoder_index]
preprocess_text = (preprocess_funcs[encoder_index] if preprocess_funcs is not None else None)

# HF convention: text_encoder / tokenizer for index 0,
# text_encoder_2 / tokenizer_2 for index 1, etc.
suffix = "" if encoder_index == 0 else f"_{encoder_index + 1}"
encoder_subdir = f"text_encoder{suffix}"
tokenizer_subdir = f"tokenizer{suffix}"

model_path = maybe_download_model(tc.model_path)
inference_args = make_inference_args(tc, model_path=model_path)
# Keep the encoder on-device; CPU offload would init an FSDP device
# mesh and reintroduce the collective at load time.
inference_args.text_encoder_cpu_offload = False

loader = TextEncoderLoader()
text_encoder = loader.load(
os.path.join(model_path, encoder_subdir),
inference_args,
).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(os.path.join(model_path, tokenizer_subdir))

tok_kwargs = dict(encoder_config.tokenizer_kwargs)
text = preprocess_text(prompt) if preprocess_text is not None else prompt

with torch.no_grad(), set_forward_context(
current_timestep=0,
attn_metadata=None,
):
text_inputs = tokenizer(text, **tok_kwargs).to(device)
outputs = text_encoder(
input_ids=text_inputs.input_ids,
attention_mask=text_inputs.attention_mask,
)
# Mirror TextEncodingStage: postprocess reads outputs.attention_mask.
outputs.attention_mask = text_inputs["attention_mask"]
embeds = postprocess_text(outputs).to(device=device, dtype=dtype)
mask = text_inputs["attention_mask"].to(device=device, dtype=dtype)

del text_encoder, tokenizer

return embeds, mask
Loading