44import math
55import os
66import random
7- import shutil
8- from contextlib import nullcontext
97from pathlib import Path
108from typing import Any , Optional
119
1917from accelerate .logging import get_logger
2018from accelerate .utils import ProjectConfiguration , set_seed
2119from peft import LoraConfig
22- from peft .utils import get_peft_model_state_dict , set_peft_model_state_dict
20+ from peft .utils import get_peft_model_state_dict
2321from torch .utils .data import DataLoader , Dataset
2422from tqdm .auto import tqdm
2523
2624import diffusers
2725from diffusers import Cosmos2_5_PredictBasePipeline
28- from diffusers .optimization import get_linear_schedule_with_warmup , get_scheduler
26+ from diffusers .optimization import get_linear_schedule_with_warmup
2927from diffusers .training_utils import cast_training_params
30- from diffusers .utils .torch_utils import is_compiled_module
3128from diffusers .utils import (
3229 convert_state_dict_to_diffusers ,
33- is_wandb_available ,
34- load_video ,
3530 export_to_video ,
31+ load_video ,
3632)
3733from diffusers .video_processor import VideoProcessor
3834
3935
40- if is_wandb_available ():
41- import wandb
42-
43-
4436logger = get_logger (__name__ , log_level = "INFO" )
4537
4638
@@ -287,7 +279,7 @@ def __init__(
287279 caption_format : str = "auto" , # "text", "json", or "auto"
288280 video_paths : Optional [list [str ]] = None ,
289281 ) -> None :
290-
282+
291283 super ().__init__ ()
292284 self .dataset_dir = dataset_dir
293285 self .num_frames = num_frames
@@ -307,7 +299,7 @@ def __init__(
307299 logger .info (f"{ len (self .video_paths )} videos in total" , main_process_only = True )
308300
309301 self .video_size = video_size
310- self .video_processor = VideoProcessor (vae_scale_factor = 8 , resample = ' bilinear' )
302+ self .video_processor = VideoProcessor (vae_scale_factor = 8 , resample = " bilinear" )
311303 self .num_failed_loads = 0
312304
313305 def __str__ (self ) -> str :
@@ -326,7 +318,7 @@ def _load_video(self, video_path: str) -> list:
326318
327319 # randomly sample a consecutive window of frames
328320 max_start_idx = total_frames - self .num_frames
329- start_frame = np .random .randint (0 , max_start_idx + 1 )
321+ start_frame = np .random .randint (0 , max_start_idx + 1 )
330322 return frames [start_frame : start_frame + self .num_frames ]
331323
332324 def _setup_caption_format (self ) -> None :
@@ -401,7 +393,7 @@ def _get_frames(self, video_path: str) -> torch.Tensor:
401393
402394 def __getitem__ (self , index : int ) -> dict | Any :
403395 try :
404- data = dict ()
396+ data = {}
405397 video = self ._get_frames (self .video_paths [index ]) # [C, T, H, W]
406398
407399 # Load caption based on format
@@ -463,7 +455,7 @@ def sample_train_sigma_t(batch_size, distribution, device, dtype=torch.float32,
463455 t = torch .sigmoid (torch .randn ((batch_size ,))).to (device = device , dtype = dtype )
464456 else :
465457 raise NotImplementedError (f"Time distribution { distribution } is not implemented." )
466- sigma_t = shift * t / (1 + (shift - 1 ) * t ) # 0.0 <= sigma_t <= 1.0
458+ sigma_t = shift * t / (1 + (shift - 1 ) * t ) # 0.0 <= sigma_t <= 1.0
467459 return sigma_t .view (batch_size , 1 , 1 , 1 , 1 )
468460
469461
@@ -516,9 +508,9 @@ def main():
516508 if args .output_dir is not None :
517509 os .makedirs (args .output_dir , exist_ok = True )
518510
519- print ('-' * 100 )
511+ print ("-" * 100 )
520512 print (args )
521- print ('-' * 100 )
513+ print ("-" * 100 )
522514
523515 # Initialize models
524516 pipe = Cosmos2_5_PredictBasePipeline .from_pretrained (
@@ -538,7 +530,7 @@ def main():
538530 vae .requires_grad_ (False )
539531 text_encoder .requires_grad_ (False )
540532
541- target_modules_list = [' to_q' , ' to_k' , ' to_v' , ' to_out.0' , ' ff.net.0.proj' , ' ff.net.2' ]
533+ target_modules_list = [" to_q" , " to_k" , " to_v" , " to_out.0" , " ff.net.0.proj" , " ff.net.2" ]
542534 dit_lora_config = LoraConfig (
543535 r = args .lora_rank ,
544536 lora_alpha = args .lora_alpha ,
@@ -600,7 +592,7 @@ def save_model_hook(models, weights, output_dir):
600592 transformer_lora_layers = dit_lora_state_dict ,
601593 safe_serialization = True ,
602594 )
603-
595+
604596 accelerator .register_save_state_pre_hook (save_model_hook )
605597
606598 if accelerator .is_main_process :
@@ -634,7 +626,7 @@ def save_model_hook(models, weights, output_dir):
634626 padding_mask = torch .zeros (1 , 1 , args .height , args .width , dtype = dit_dtype , device = device )
635627 latent_shape = pipe .get_latent_shape_cthw (args .height , args .width , args .num_frames )
636628 latents_mean = pipe .latents_mean .float ().to (device )
637- latents_std = pipe .latents_std .float ().to (device ) # 1/σ
629+ latents_std = pipe .latents_std .float ().to (device ) # 1/σ
638630 # Start training
639631 torch .set_grad_enabled (True ) # re-enable grad disabled by Cosmos2_5_PredictBasePipeline
640632 for epoch in range (first_epoch , args .num_train_epochs ):
@@ -647,15 +639,15 @@ def save_model_hook(models, weights, output_dir):
647639 raw_state = batch ["video" ].to (device = device , dtype = vae .dtype )
648640 mu = vae .encode (raw_state ).latent_dist .mean # deterministic
649641 clean_latent = ((mu - latents_mean ) * latents_std ).contiguous ().float ()
650- assert clean_latent .requires_grad == False
642+ assert not clean_latent .requires_grad
651643 torch .cuda .empty_cache ()
652644
653645 # Encode text to text embeddings
654646 prompt_embeds = pipe ._get_prompt_embeds (
655647 prompt = batch ["caption" ],
656648 device = device ,
657649 )
658- assert prompt_embeds .requires_grad == False
650+ assert not prompt_embeds .requires_grad
659651
660652 # CFG dropout: independently zero out text conditioning per sample
661653 bsz = clean_latent .shape [0 ]
@@ -667,18 +659,21 @@ def save_model_hook(models, weights, output_dir):
667659 weights = list (args .conditional_frames_probs .values ())
668660 num_conditional_frames = random .choices (frames_options , weights = weights , k = bsz )
669661 cond_indicator , cond_mask = pipe .create_condition_mask (
670- (bsz , * latent_shape ), device = device , dtype = torch .float32 , num_cond_latent_frames = num_conditional_frames
662+ (bsz , * latent_shape ),
663+ device = device ,
664+ dtype = torch .float32 ,
665+ num_cond_latent_frames = num_conditional_frames ,
671666 )
672667
673668 # Sample a random timestep
674- sigma_t = sample_train_sigma_t (bsz , distribution = ' logitnormal' , device = device )
669+ sigma_t = sample_train_sigma_t (bsz , distribution = " logitnormal" , device = device )
675670 # 1. Sample noise 2. Get the target velocity 3. Get xt by interpolation between noise and clean
676671 xt_B_C_T_H_W , target_velocity = get_flow_xt_and_target_v (clean_latent , sigma_t , cond_mask )
677-
672+
678673 # Denoise
679674 if args .conditional_frame_timestep >= 0 :
680675 in_timestep = cond_indicator * args .conditional_frame_timestep + (1 - cond_indicator ) * sigma_t
681-
676+
682677 pred_velocity = dit (
683678 hidden_states = xt_B_C_T_H_W ,
684679 condition_mask = cond_mask ,
@@ -717,7 +712,7 @@ def save_model_hook(models, weights, output_dir):
717712 if global_step >= max_train_steps :
718713 break
719714
720- if (epoch + 1 ) % args .checkpointing_epochs == 0 and (epoch + 1 ) < args .num_train_epochs :
715+ if (epoch + 1 ) % args .checkpointing_epochs == 0 and (epoch + 1 ) < args .num_train_epochs :
721716 if accelerator .is_main_process :
722717 save_path = os .path .join (args .output_dir , f"checkpoint-{ epoch } " )
723718 accelerator .save_state (save_path )
@@ -738,7 +733,7 @@ def save_model_hook(models, weights, output_dir):
738733 if args .do_final_eval :
739734 noises = arch_invariant_rand ((1 , * latent_shape ), dtype = torch .float32 , device = device , seed = args .seed )
740735 inputs = train_dataloader .dataset [0 ]
741-
736+
742737 pipe .transformer .eval ()
743738 with torch .inference_mode ():
744739 frames = pipe (
@@ -747,14 +742,15 @@ def save_model_hook(models, weights, output_dir):
747742 prompt = inputs ["caption" ],
748743 num_frames = args .num_frames ,
749744 num_inference_steps = args .num_inference_steps ,
750- latents = noises , # ensure architecture invariant generation
745+ latents = noises , # ensure architecture invariant generation
751746 height = args .height ,
752747 width = args .width ,
753748 ).frames [0 ]
754-
749+
755750 export_to_video (frames , os .path .join (args .output_dir , "eval_output.mp4" ), fps = 16 )
756751
757752 accelerator .end_training ()
758753
754+
759755if __name__ == "__main__" :
760756 main ()
0 commit comments