|
4 | 4 | from __future__ import annotations |
5 | 5 |
|
6 | 6 | import copy |
7 | | -import gc |
8 | 7 | from typing import Any, Literal, TYPE_CHECKING |
9 | 8 |
|
10 | 9 | import torch |
|
19 | 18 | from fastvideo.models.schedulers.scheduling_flow_match_euler_discrete import ( |
20 | 19 | FlowMatchEulerDiscreteScheduler, ) |
21 | 20 | from fastvideo.pipelines import TrainingBatch |
22 | | -from fastvideo.pipelines.basic.wan.wan_pipeline import ( |
23 | | - WanPipeline, ) |
24 | | -from fastvideo.pipelines.pipeline_batch_info import ( |
25 | | - ForwardBatch, ) |
26 | 21 | from fastvideo.training.activation_checkpoint import ( |
27 | 22 | apply_activation_checkpointing, ) |
28 | 23 | from fastvideo.training.training_utils import ( |
|
41 | 36 | apply_trainable, ) |
42 | 37 | from fastvideo.train.utils.moduleloader import ( |
43 | 38 | load_module_from_path, ) |
| 39 | +from fastvideo.train.utils.negative_prompt import encode_negative_prompt |
44 | 40 |
|
45 | 41 | if TYPE_CHECKING: |
46 | 42 | from fastvideo.train.utils.training_config import ( |
@@ -341,98 +337,15 @@ def ensure_negative_conditioning(self) -> None: |
341 | 337 |
|
342 | 338 | assert self.training_config is not None |
343 | 339 | tc = self.training_config |
344 | | - world_group = self.world_group |
345 | | - device = self.device |
346 | | - dtype = self._get_training_dtype() |
347 | | - |
348 | | - from fastvideo.train.utils.moduleloader import ( |
349 | | - make_inference_args, ) |
350 | | - |
351 | | - neg_embeds: torch.Tensor | None = None |
352 | | - neg_mask: torch.Tensor | None = None |
353 | | - |
354 | | - if world_group.rank_in_group == 0: |
355 | | - sampling_param = SamplingParam.from_pretrained(tc.model_path) |
356 | | - negative_prompt = sampling_param.negative_prompt |
357 | | - |
358 | | - inference_args = make_inference_args(tc, model_path=tc.model_path) |
359 | | - |
360 | | - prompt_pipeline = WanPipeline.from_pretrained( |
361 | | - tc.model_path, |
362 | | - args=inference_args, |
363 | | - inference_mode=True, |
364 | | - loaded_modules={"transformer": self.transformer}, |
365 | | - tp_size=tc.distributed.tp_size, |
366 | | - sp_size=tc.distributed.sp_size, |
367 | | - num_gpus=tc.distributed.num_gpus, |
368 | | - pin_cpu_memory=(tc.distributed.pin_cpu_memory), |
369 | | - dit_cpu_offload=True, |
370 | | - ) |
371 | | - |
372 | | - batch_negative = ForwardBatch( |
373 | | - data_type="video", |
374 | | - prompt=negative_prompt, |
375 | | - prompt_embeds=[], |
376 | | - prompt_attention_mask=[], |
377 | | - ) |
378 | | - result_batch = prompt_pipeline.prompt_encoding_stage( # type: ignore[attr-defined] |
379 | | - batch_negative, |
380 | | - inference_args, |
381 | | - ) |
382 | | - |
383 | | - neg_embeds = result_batch.prompt_embeds[0].to(device=device, dtype=dtype) |
384 | | - neg_mask = (result_batch.prompt_attention_mask[0].to(device=device, dtype=dtype)) |
385 | | - |
386 | | - del prompt_pipeline |
387 | | - gc.collect() |
388 | | - if torch.cuda.is_available(): |
389 | | - torch.cuda.empty_cache() |
390 | | - |
391 | | - meta = torch.zeros((2, ), device=device, dtype=torch.int64) |
392 | | - if world_group.rank_in_group == 0: |
393 | | - assert neg_embeds is not None |
394 | | - assert neg_mask is not None |
395 | | - meta[0] = neg_embeds.ndim |
396 | | - meta[1] = neg_mask.ndim |
397 | | - world_group.broadcast(meta, src=0) |
398 | | - embed_ndim, mask_ndim = ( |
399 | | - int(meta[0].item()), |
400 | | - int(meta[1].item()), |
| 340 | + sampling_param = SamplingParam.from_pretrained(tc.model_path) |
| 341 | + embeds, mask = encode_negative_prompt( |
| 342 | + tc, |
| 343 | + prompt=sampling_param.negative_prompt, |
| 344 | + device=self.device, |
| 345 | + dtype=self._get_training_dtype(), |
401 | 346 | ) |
402 | | - |
403 | | - max_ndim = 8 |
404 | | - embed_shape = torch.full((max_ndim, ), -1, device=device, dtype=torch.int64) |
405 | | - mask_shape = torch.full((max_ndim, ), -1, device=device, dtype=torch.int64) |
406 | | - if world_group.rank_in_group == 0: |
407 | | - assert neg_embeds is not None |
408 | | - assert neg_mask is not None |
409 | | - embed_shape[:embed_ndim] = torch.tensor( |
410 | | - list(neg_embeds.shape), |
411 | | - device=device, |
412 | | - dtype=torch.int64, |
413 | | - ) |
414 | | - mask_shape[:mask_ndim] = torch.tensor( |
415 | | - list(neg_mask.shape), |
416 | | - device=device, |
417 | | - dtype=torch.int64, |
418 | | - ) |
419 | | - world_group.broadcast(embed_shape, src=0) |
420 | | - world_group.broadcast(mask_shape, src=0) |
421 | | - |
422 | | - embed_sizes = tuple(int(x) for x in embed_shape[:embed_ndim].tolist()) |
423 | | - mask_sizes = tuple(int(x) for x in mask_shape[:mask_ndim].tolist()) |
424 | | - |
425 | | - if world_group.rank_in_group != 0: |
426 | | - neg_embeds = torch.empty(embed_sizes, device=device, dtype=dtype) |
427 | | - neg_mask = torch.empty(mask_sizes, device=device, dtype=dtype) |
428 | | - assert neg_embeds is not None |
429 | | - assert neg_mask is not None |
430 | | - |
431 | | - world_group.broadcast(neg_embeds, src=0) |
432 | | - world_group.broadcast(neg_mask, src=0) |
433 | | - |
434 | | - self.negative_prompt_embeds = neg_embeds |
435 | | - self.negative_prompt_attention_mask = neg_mask |
| 347 | + self.negative_prompt_embeds = embeds |
| 348 | + self.negative_prompt_attention_mask = mask |
436 | 349 |
|
437 | 350 | def _sample_timesteps( |
438 | 351 | self, |
|
0 commit comments