@@ -1232,22 +1232,49 @@ def load_model_hook(models, input_dir):
12321232 id_token = args .id_token ,
12331233 )
12341234
1235- def encode_video (video , bar ):
1236- bar .update (1 )
1235+ def encode_video (video ):
12371236 video = video .to (accelerator .device , dtype = vae .dtype ).unsqueeze (0 )
12381237 video = video .permute (0 , 2 , 1 , 3 , 4 ) # [B, C, F, H, W]
12391238 latent_dist = vae .encode (video ).latent_dist
12401239 return latent_dist
12411240
1241+ # Distribute video encoding across processes: each process only encodes its own shard
1242+ num_videos = len (train_dataset .instance_videos )
1243+ num_procs = accelerator .num_processes
1244+ local_rank = accelerator .process_index
1245+ local_count = len (range (local_rank , num_videos , num_procs ))
1246+
12421247 progress_encode_bar = tqdm (
1243- range (0 , len (train_dataset .instance_videos )),
1244- desc = "Loading Encode videos" ,
1248+ range (local_count ),
1249+ desc = "Encoding videos" ,
1250+ disable = not accelerator .is_local_main_process ,
12451251 )
1246- train_dataset .instance_videos = [
1247- encode_video (video , progress_encode_bar ) for video in train_dataset .instance_videos
1248- ]
1252+
1253+ encoded_videos = [None ] * num_videos
1254+ for i , video in enumerate (train_dataset .instance_videos ):
1255+ if i % num_procs == local_rank :
1256+ encoded_videos [i ] = encode_video (video )
1257+ progress_encode_bar .update (1 )
12491258 progress_encode_bar .close ()
12501259
1260+ # Broadcast encoded latent distributions so every process has the full set
1261+ if num_procs > 1 :
1262+ import torch .distributed as dist
1263+
1264+ from diffusers .models .autoencoders .vae import DiagonalGaussianDistribution
1265+
1266+ ref_params = next (v for v in encoded_videos if v is not None ).parameters
1267+ for i in range (num_videos ):
1268+ src = i % num_procs
1269+ if encoded_videos [i ] is not None :
1270+ params = encoded_videos [i ].parameters .contiguous ()
1271+ else :
1272+ params = torch .empty_like (ref_params )
1273+ dist .broadcast (params , src = src )
1274+ encoded_videos [i ] = DiagonalGaussianDistribution (params )
1275+
1276+ train_dataset .instance_videos = encoded_videos
1277+
12511278 def collate_fn (examples ):
12521279 videos = [example ["instance_video" ].sample () * vae .config .scaling_factor for example in examples ]
12531280 prompts = [example ["instance_prompt" ] for example in examples ]
0 commit comments