Skip to content

Commit d4b27f6

Browse files
author
Ting-Yun Chang
committed
simplifiy the eval script to make it more user-friendly
1 parent 6acdc1c commit d4b27f6

1 file changed

Lines changed: 0 additions & 17 deletions

File tree

examples/cosmos/eval_cosmos_predict25_lora.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import argparse
66
import os
77

8-
import numpy as np
98
import torch
109
from torch.utils.data import DataLoader, Dataset
1110
from tqdm import tqdm
@@ -67,12 +66,6 @@ def collate_fn(batch):
6766
}
6867

6968

70-
def arch_invariant_rand(shape, dtype, device, seed=None):
71-
rng = np.random.RandomState(seed)
72-
random_array = rng.standard_normal(shape).astype(np.float32)
73-
return torch.from_numpy(random_array).to(dtype=dtype, device=device)
74-
75-
7669
def parse_args():
7770
parser = argparse.ArgumentParser(description="Eval Cosmos Predict 2.5 with optional LoRA weights.")
7871

@@ -143,15 +136,6 @@ def check_video_safety(self, video):
143136
pipe.fuse_lora(lora_scale=1.0)
144137
print(f"Loaded LoRA weights from {args.lora_dir}")
145138

146-
latent_shape = (
147-
pipe.vae.config.z_dim,
148-
(args.num_output_frames - 1) // pipe.vae_scale_factor_temporal + 1,
149-
args.height // pipe.vae_scale_factor_spatial,
150-
args.width // pipe.vae_scale_factor_spatial,
151-
)
152-
noises = arch_invariant_rand(
153-
(args.batch_size, *latent_shape), dtype=torch.float32, device=args.device, seed=args.seed
154-
)
155139
progress = tqdm(total=len(dataset), desc="Generating")
156140
for batch in dataloader:
157141
images = batch["images"]
@@ -167,7 +151,6 @@ def check_video_safety(self, video):
167151
num_inference_steps=args.num_steps,
168152
height=args.height,
169153
width=args.width,
170-
latents=noises,
171154
).frames[0] # NOTE: batch_size == 1
172155

173156
out_path = os.path.join(args.output_dir, f"{stem}.mp4")

0 commit comments

Comments
 (0)