Skip to content

Commit 8879824

Browse files
cogvideo example: Distribute VAE video encoding across processes in CogVideoX LoRA training (#13207)
* Distribute VAE video encoding across processes in CogVideoX LoRA training Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * Apply style fixes --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 4a2833c commit 8879824

1 file changed

Lines changed: 34 additions & 7 deletions

File tree

examples/cogvideo/train_cogvideox_lora.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,22 +1232,49 @@ def load_model_hook(models, input_dir):
12321232
id_token=args.id_token,
12331233
)
12341234

1235-
def encode_video(video, bar):
1236-
bar.update(1)
1235+
def encode_video(video):
12371236
video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0)
12381237
video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
12391238
latent_dist = vae.encode(video).latent_dist
12401239
return latent_dist
12411240

1241+
# Distribute video encoding across processes: each process only encodes its own shard
1242+
num_videos = len(train_dataset.instance_videos)
1243+
num_procs = accelerator.num_processes
1244+
local_rank = accelerator.process_index
1245+
local_count = len(range(local_rank, num_videos, num_procs))
1246+
12421247
progress_encode_bar = tqdm(
1243-
range(0, len(train_dataset.instance_videos)),
1244-
desc="Loading Encode videos",
1248+
range(local_count),
1249+
desc="Encoding videos",
1250+
disable=not accelerator.is_local_main_process,
12451251
)
1246-
train_dataset.instance_videos = [
1247-
encode_video(video, progress_encode_bar) for video in train_dataset.instance_videos
1248-
]
1252+
1253+
encoded_videos = [None] * num_videos
1254+
for i, video in enumerate(train_dataset.instance_videos):
1255+
if i % num_procs == local_rank:
1256+
encoded_videos[i] = encode_video(video)
1257+
progress_encode_bar.update(1)
12491258
progress_encode_bar.close()
12501259

1260+
# Broadcast encoded latent distributions so every process has the full set
1261+
if num_procs > 1:
1262+
import torch.distributed as dist
1263+
1264+
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
1265+
1266+
ref_params = next(v for v in encoded_videos if v is not None).parameters
1267+
for i in range(num_videos):
1268+
src = i % num_procs
1269+
if encoded_videos[i] is not None:
1270+
params = encoded_videos[i].parameters.contiguous()
1271+
else:
1272+
params = torch.empty_like(ref_params)
1273+
dist.broadcast(params, src=src)
1274+
encoded_videos[i] = DiagonalGaussianDistribution(params)
1275+
1276+
train_dataset.instance_videos = encoded_videos
1277+
12511278
def collate_fn(examples):
12521279
videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples]
12531280
prompts = [example["instance_prompt"] for example in examples]

0 commit comments

Comments
 (0)