Skip to content

Commit 6acdc1c

Browse files
author
Ting-Yun Chang
committed
remove the get_latent_shape_cthw method and fix style
1 parent 49f5b35 commit 6acdc1c

3 files changed

Lines changed: 16 additions & 14 deletions

File tree

examples/cosmos/eval_cosmos_predict25_lora.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,12 @@ def check_video_safety(self, video):
143143
pipe.fuse_lora(lora_scale=1.0)
144144
print(f"Loaded LoRA weights from {args.lora_dir}")
145145

146-
latent_shape = pipe.get_latent_shape_cthw(args.height, args.width, args.num_output_frames)
146+
latent_shape = (
147+
pipe.vae.config.z_dim,
148+
(args.num_output_frames - 1) // pipe.vae_scale_factor_temporal + 1,
149+
args.height // pipe.vae_scale_factor_spatial,
150+
args.width // pipe.vae_scale_factor_spatial,
151+
)
147152
noises = arch_invariant_rand(
148153
(args.batch_size, *latent_shape), dtype=torch.float32, device=args.device, seed=args.seed
149154
)

examples/cosmos/train_cosmos_predict25_lora.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,12 @@ def save_model_hook(models, weights, output_dir):
614614
)
615615

616616
padding_mask = torch.zeros(1, 1, args.height, args.width, dtype=dit_dtype, device=device)
617-
latent_shape = pipe.get_latent_shape_cthw(args.height, args.width, args.num_frames)
617+
latent_shape = (
618+
pipe.vae.config.z_dim,
619+
(args.num_frames - 1) // pipe.vae_scale_factor_temporal + 1,
620+
args.height // pipe.vae_scale_factor_spatial,
621+
args.width // pipe.vae_scale_factor_spatial,
622+
)
618623
latents_mean = pipe.latents_mean.float().to(device)
619624
latents_std = pipe.latents_std.float().to(device) # 1/σ
620625
# Start training

src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -236,21 +236,11 @@ def __init__(
236236
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
237237
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial, resample="bilinear")
238238

239-
assert getattr(self.vae.config, "latents_mean", None), "VAE configuration must define `latents_mean`."
240-
assert getattr(self.vae.config, "latents_std", None), "VAE configuration must define `latents_std`."
241-
242239
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).float()
243240
latents_std = torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).float()
244241
self.latents_mean = latents_mean
245242
self.latents_std = 1.0 / latents_std
246243

247-
def get_latent_shape_cthw(self, height: int, width: int, num_frames: int):
248-
C = self.vae.config.z_dim
249-
T = (num_frames - 1) // self.vae_scale_factor_temporal + 1
250-
H = height // self.vae_scale_factor_spatial
251-
W = width // self.vae_scale_factor_spatial
252-
return (C, T, H, W)
253-
254244
def create_condition_mask(self, latent_shape, device, dtype, num_cond_latent_frames):
255245
bsz, C, T, H, W = latent_shape
256246
cond_indicator = torch.zeros(bsz, 1, T, 1, 1, dtype=dtype, device=device)
@@ -438,9 +428,11 @@ def prepare_latents(
438428
)
439429

440430
B = batch_size
441-
C, T, H, W = self.get_latent_shape_cthw(height, width, num_frames_out)
431+
C = num_channels_latents
432+
T = (num_frames_out - 1) // self.vae_scale_factor_temporal + 1
433+
H = height // self.vae_scale_factor_spatial
434+
W = width // self.vae_scale_factor_spatial
442435
shape = (B, C, T, H, W)
443-
assert C == num_channels_latents, f"Expected number of channels to be {num_channels_latents}, but got {C}."
444436

445437
if num_frames_in == 0:
446438
if latents is None:

0 commit comments

Comments
 (0)