55import argparse
66import os
77
8- import numpy as np
98import torch
109from torch .utils .data import DataLoader , Dataset
1110from tqdm import tqdm
@@ -67,12 +66,6 @@ def collate_fn(batch):
6766 }
6867
6968
70- def arch_invariant_rand (shape , dtype , device , seed = None ):
71- rng = np .random .RandomState (seed )
72- random_array = rng .standard_normal (shape ).astype (np .float32 )
73- return torch .from_numpy (random_array ).to (dtype = dtype , device = device )
74-
75-
7669def parse_args ():
7770 parser = argparse .ArgumentParser (description = "Eval Cosmos Predict 2.5 with optional LoRA weights." )
7871
@@ -143,15 +136,6 @@ def check_video_safety(self, video):
143136 pipe .fuse_lora (lora_scale = 1.0 )
144137 print (f"Loaded LoRA weights from { args .lora_dir } " )
145138
146- latent_shape = (
147- pipe .vae .config .z_dim ,
148- (args .num_output_frames - 1 ) // pipe .vae_scale_factor_temporal + 1 ,
149- args .height // pipe .vae_scale_factor_spatial ,
150- args .width // pipe .vae_scale_factor_spatial ,
151- )
152- noises = arch_invariant_rand (
153- (args .batch_size , * latent_shape ), dtype = torch .float32 , device = args .device , seed = args .seed
154- )
155139 progress = tqdm (total = len (dataset ), desc = "Generating" )
156140 for batch in dataloader :
157141 images = batch ["images" ]
@@ -167,7 +151,6 @@ def check_video_safety(self, video):
167151 num_inference_steps = args .num_steps ,
168152 height = args .height ,
169153 width = args .width ,
170- latents = noises ,
171154 ).frames [0 ] # NOTE: batch_size == 1
172155
173156 out_path = os .path .join (args .output_dir , f"{ stem } .mp4" )
0 commit comments