Skip to content

Commit 9a8bbe1

Browse files
authored
[bugfix]: fix SP deadlock in negative prompt encoding during training (#1178)
1 parent ea25441 commit 9a8bbe1

2 files changed

Lines changed: 108 additions & 96 deletions

File tree

fastvideo/train/models/wan/wan.py

Lines changed: 9 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from __future__ import annotations
55

66
import copy
7-
import gc
87
from typing import Any, Literal, TYPE_CHECKING
98

109
import torch
@@ -19,10 +18,6 @@
1918
from fastvideo.models.schedulers.scheduling_flow_match_euler_discrete import (
2019
FlowMatchEulerDiscreteScheduler, )
2120
from fastvideo.pipelines import TrainingBatch
22-
from fastvideo.pipelines.basic.wan.wan_pipeline import (
23-
WanPipeline, )
24-
from fastvideo.pipelines.pipeline_batch_info import (
25-
ForwardBatch, )
2621
from fastvideo.training.activation_checkpoint import (
2722
apply_activation_checkpointing, )
2823
from fastvideo.training.training_utils import (
@@ -41,6 +36,7 @@
4136
apply_trainable, )
4237
from fastvideo.train.utils.moduleloader import (
4338
load_module_from_path, )
39+
from fastvideo.train.utils.negative_prompt import encode_negative_prompt
4440

4541
if TYPE_CHECKING:
4642
from fastvideo.train.utils.training_config import (
@@ -341,98 +337,15 @@ def ensure_negative_conditioning(self) -> None:
341337

342338
assert self.training_config is not None
343339
tc = self.training_config
344-
world_group = self.world_group
345-
device = self.device
346-
dtype = self._get_training_dtype()
347-
348-
from fastvideo.train.utils.moduleloader import (
349-
make_inference_args, )
350-
351-
neg_embeds: torch.Tensor | None = None
352-
neg_mask: torch.Tensor | None = None
353-
354-
if world_group.rank_in_group == 0:
355-
sampling_param = SamplingParam.from_pretrained(tc.model_path)
356-
negative_prompt = sampling_param.negative_prompt
357-
358-
inference_args = make_inference_args(tc, model_path=tc.model_path)
359-
360-
prompt_pipeline = WanPipeline.from_pretrained(
361-
tc.model_path,
362-
args=inference_args,
363-
inference_mode=True,
364-
loaded_modules={"transformer": self.transformer},
365-
tp_size=tc.distributed.tp_size,
366-
sp_size=tc.distributed.sp_size,
367-
num_gpus=tc.distributed.num_gpus,
368-
pin_cpu_memory=(tc.distributed.pin_cpu_memory),
369-
dit_cpu_offload=True,
370-
)
371-
372-
batch_negative = ForwardBatch(
373-
data_type="video",
374-
prompt=negative_prompt,
375-
prompt_embeds=[],
376-
prompt_attention_mask=[],
377-
)
378-
result_batch = prompt_pipeline.prompt_encoding_stage( # type: ignore[attr-defined]
379-
batch_negative,
380-
inference_args,
381-
)
382-
383-
neg_embeds = result_batch.prompt_embeds[0].to(device=device, dtype=dtype)
384-
neg_mask = (result_batch.prompt_attention_mask[0].to(device=device, dtype=dtype))
385-
386-
del prompt_pipeline
387-
gc.collect()
388-
if torch.cuda.is_available():
389-
torch.cuda.empty_cache()
390-
391-
meta = torch.zeros((2, ), device=device, dtype=torch.int64)
392-
if world_group.rank_in_group == 0:
393-
assert neg_embeds is not None
394-
assert neg_mask is not None
395-
meta[0] = neg_embeds.ndim
396-
meta[1] = neg_mask.ndim
397-
world_group.broadcast(meta, src=0)
398-
embed_ndim, mask_ndim = (
399-
int(meta[0].item()),
400-
int(meta[1].item()),
340+
sampling_param = SamplingParam.from_pretrained(tc.model_path)
341+
embeds, mask = encode_negative_prompt(
342+
tc,
343+
prompt=sampling_param.negative_prompt,
344+
device=self.device,
345+
dtype=self._get_training_dtype(),
401346
)
402-
403-
max_ndim = 8
404-
embed_shape = torch.full((max_ndim, ), -1, device=device, dtype=torch.int64)
405-
mask_shape = torch.full((max_ndim, ), -1, device=device, dtype=torch.int64)
406-
if world_group.rank_in_group == 0:
407-
assert neg_embeds is not None
408-
assert neg_mask is not None
409-
embed_shape[:embed_ndim] = torch.tensor(
410-
list(neg_embeds.shape),
411-
device=device,
412-
dtype=torch.int64,
413-
)
414-
mask_shape[:mask_ndim] = torch.tensor(
415-
list(neg_mask.shape),
416-
device=device,
417-
dtype=torch.int64,
418-
)
419-
world_group.broadcast(embed_shape, src=0)
420-
world_group.broadcast(mask_shape, src=0)
421-
422-
embed_sizes = tuple(int(x) for x in embed_shape[:embed_ndim].tolist())
423-
mask_sizes = tuple(int(x) for x in mask_shape[:mask_ndim].tolist())
424-
425-
if world_group.rank_in_group != 0:
426-
neg_embeds = torch.empty(embed_sizes, device=device, dtype=dtype)
427-
neg_mask = torch.empty(mask_sizes, device=device, dtype=dtype)
428-
assert neg_embeds is not None
429-
assert neg_mask is not None
430-
431-
world_group.broadcast(neg_embeds, src=0)
432-
world_group.broadcast(neg_mask, src=0)
433-
434-
self.negative_prompt_embeds = neg_embeds
435-
self.negative_prompt_attention_mask = neg_mask
347+
self.negative_prompt_embeds = embeds
348+
self.negative_prompt_attention_mask = mask
436349

437350
def _sample_timesteps(
438351
self,
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Per-rank negative-prompt encoding shared by training model plugins.
3+
4+
Encoding the negative prompt only on rank 0 and broadcasting (the
5+
previous Wan path) ran ``Pipeline.from_pretrained`` asymmetrically across
6+
ranks, which deadlocked on any collective fired during text-encoder load
7+
(FSDP device-mesh init, weight broadcast, etc.). The text encoder is
8+
small and only loaded once at startup, so loading it on every rank
9+
sidesteps the deadlock entirely.
10+
"""
11+
12+
from __future__ import annotations
13+
14+
import os
15+
from typing import TYPE_CHECKING
16+
17+
import torch
18+
from transformers import AutoTokenizer
19+
20+
from fastvideo.forward_context import set_forward_context
21+
from fastvideo.models.loader.component_loader import TextEncoderLoader
22+
from fastvideo.train.utils.moduleloader import make_inference_args
23+
from fastvideo.utils import maybe_download_model
24+
25+
if TYPE_CHECKING:
26+
from fastvideo.train.utils.training_config import TrainingConfig
27+
28+
29+
def encode_negative_prompt(
30+
training_config: TrainingConfig,
31+
*,
32+
prompt: str,
33+
device: torch.device,
34+
dtype: torch.dtype,
35+
encoder_index: int = 0,
36+
) -> tuple[torch.Tensor, torch.Tensor]:
37+
"""Per-rank encode of ``prompt`` using encoder ``encoder_index``.
38+
39+
Reads ``pipeline_config.text_encoder_configs[encoder_index]`` so the
40+
encoder class (e.g. UMT5 for Wan) and tokenizer kwargs match the
41+
inference path, and applies the matching ``postprocess_text_funcs``
42+
entry. Returns ``(embeds, mask)`` on ``device`` cast to ``dtype``.
43+
"""
44+
tc = training_config
45+
pipeline_config = tc.pipeline_config
46+
if pipeline_config is None:
47+
raise ValueError("training_config.pipeline_config is required for negative "
48+
"prompt encoding")
49+
50+
encoder_configs = pipeline_config.text_encoder_configs
51+
postprocess_funcs = pipeline_config.postprocess_text_funcs
52+
preprocess_funcs = getattr(pipeline_config, "preprocess_text_funcs", None)
53+
54+
if encoder_index < 0 or encoder_index >= len(encoder_configs):
55+
raise IndexError(f"encoder_index {encoder_index} out of range for "
56+
f"text_encoder_configs (len={len(encoder_configs)})")
57+
encoder_config = encoder_configs[encoder_index]
58+
postprocess_text = postprocess_funcs[encoder_index]
59+
preprocess_text = (preprocess_funcs[encoder_index] if preprocess_funcs is not None else None)
60+
61+
# HF convention: text_encoder / tokenizer for index 0,
62+
# text_encoder_2 / tokenizer_2 for index 1, etc.
63+
suffix = "" if encoder_index == 0 else f"_{encoder_index + 1}"
64+
encoder_subdir = f"text_encoder{suffix}"
65+
tokenizer_subdir = f"tokenizer{suffix}"
66+
67+
model_path = maybe_download_model(tc.model_path)
68+
inference_args = make_inference_args(tc, model_path=model_path)
69+
# Keep the encoder on-device; CPU offload would init an FSDP device
70+
# mesh and reintroduce the collective at load time.
71+
inference_args.text_encoder_cpu_offload = False
72+
73+
loader = TextEncoderLoader()
74+
text_encoder = loader.load(
75+
os.path.join(model_path, encoder_subdir),
76+
inference_args,
77+
).to(device).eval()
78+
tokenizer = AutoTokenizer.from_pretrained(os.path.join(model_path, tokenizer_subdir))
79+
80+
tok_kwargs = dict(encoder_config.tokenizer_kwargs)
81+
text = preprocess_text(prompt) if preprocess_text is not None else prompt
82+
83+
with torch.no_grad(), set_forward_context(
84+
current_timestep=0,
85+
attn_metadata=None,
86+
):
87+
text_inputs = tokenizer(text, **tok_kwargs).to(device)
88+
outputs = text_encoder(
89+
input_ids=text_inputs.input_ids,
90+
attention_mask=text_inputs.attention_mask,
91+
)
92+
# Mirror TextEncodingStage: postprocess reads outputs.attention_mask.
93+
outputs.attention_mask = text_inputs["attention_mask"]
94+
embeds = postprocess_text(outputs).to(device=device, dtype=dtype)
95+
mask = text_inputs["attention_mask"].to(device=device, dtype=dtype)
96+
97+
del text_encoder, tokenizer
98+
99+
return embeds, mask

0 commit comments

Comments
 (0)