Skip to content

Commit 95810d0

Browse files
committed
make style
1 parent c356980 commit 95810d0

6 files changed

Lines changed: 53 additions & 54 deletions

File tree

examples/cosmos/create_prompts_for_gr1_dataset.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from tqdm import tqdm
2020

21+
2122
"""example command
2223
python create_prompts_for_gr1_dataset.py --dataset_path datasets/benchmark_train/gr1
2324
"""

examples/cosmos/eval_cosmos_predict25_lora.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77

88
import numpy as np
99
import torch
10-
from tqdm import tqdm
1110
from torch.utils.data import DataLoader, Dataset
11+
from tqdm import tqdm
1212

1313
from diffusers import Cosmos2_5_PredictBasePipeline
1414
from diffusers.utils import export_to_video, load_image

examples/cosmos/train_cosmos_predict25_lora.py

Lines changed: 27 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
import math
55
import os
66
import random
7-
import shutil
8-
from contextlib import nullcontext
97
from pathlib import Path
108
from typing import Any, Optional
119

@@ -19,28 +17,22 @@
1917
from accelerate.logging import get_logger
2018
from accelerate.utils import ProjectConfiguration, set_seed
2119
from peft import LoraConfig
22-
from peft.utils import get_peft_model_state_dict, set_peft_model_state_dict
20+
from peft.utils import get_peft_model_state_dict
2321
from torch.utils.data import DataLoader, Dataset
2422
from tqdm.auto import tqdm
2523

2624
import diffusers
2725
from diffusers import Cosmos2_5_PredictBasePipeline
28-
from diffusers.optimization import get_linear_schedule_with_warmup, get_scheduler
26+
from diffusers.optimization import get_linear_schedule_with_warmup
2927
from diffusers.training_utils import cast_training_params
30-
from diffusers.utils.torch_utils import is_compiled_module
3128
from diffusers.utils import (
3229
convert_state_dict_to_diffusers,
33-
is_wandb_available,
34-
load_video,
3530
export_to_video,
31+
load_video,
3632
)
3733
from diffusers.video_processor import VideoProcessor
3834

3935

40-
if is_wandb_available():
41-
import wandb
42-
43-
4436
logger = get_logger(__name__, log_level="INFO")
4537

4638

@@ -287,7 +279,7 @@ def __init__(
287279
caption_format: str = "auto", # "text", "json", or "auto"
288280
video_paths: Optional[list[str]] = None,
289281
) -> None:
290-
282+
291283
super().__init__()
292284
self.dataset_dir = dataset_dir
293285
self.num_frames = num_frames
@@ -307,7 +299,7 @@ def __init__(
307299
logger.info(f"{len(self.video_paths)} videos in total", main_process_only=True)
308300

309301
self.video_size = video_size
310-
self.video_processor = VideoProcessor(vae_scale_factor=8, resample='bilinear')
302+
self.video_processor = VideoProcessor(vae_scale_factor=8, resample="bilinear")
311303
self.num_failed_loads = 0
312304

313305
def __str__(self) -> str:
@@ -326,7 +318,7 @@ def _load_video(self, video_path: str) -> list:
326318

327319
# randomly sample a consecutive window of frames
328320
max_start_idx = total_frames - self.num_frames
329-
start_frame = np.random.randint(0, max_start_idx+1)
321+
start_frame = np.random.randint(0, max_start_idx + 1)
330322
return frames[start_frame : start_frame + self.num_frames]
331323

332324
def _setup_caption_format(self) -> None:
@@ -401,7 +393,7 @@ def _get_frames(self, video_path: str) -> torch.Tensor:
401393

402394
def __getitem__(self, index: int) -> dict | Any:
403395
try:
404-
data = dict()
396+
data = {}
405397
video = self._get_frames(self.video_paths[index]) # [C, T, H, W]
406398

407399
# Load caption based on format
@@ -463,7 +455,7 @@ def sample_train_sigma_t(batch_size, distribution, device, dtype=torch.float32,
463455
t = torch.sigmoid(torch.randn((batch_size,))).to(device=device, dtype=dtype)
464456
else:
465457
raise NotImplementedError(f"Time distribution {distribution} is not implemented.")
466-
sigma_t = shift * t / (1 + (shift - 1) * t) # 0.0 <= sigma_t <= 1.0
458+
sigma_t = shift * t / (1 + (shift - 1) * t) # 0.0 <= sigma_t <= 1.0
467459
return sigma_t.view(batch_size, 1, 1, 1, 1)
468460

469461

@@ -516,9 +508,9 @@ def main():
516508
if args.output_dir is not None:
517509
os.makedirs(args.output_dir, exist_ok=True)
518510

519-
print('-'*100)
511+
print("-" * 100)
520512
print(args)
521-
print('-'*100)
513+
print("-" * 100)
522514

523515
# Initialize models
524516
pipe = Cosmos2_5_PredictBasePipeline.from_pretrained(
@@ -538,7 +530,7 @@ def main():
538530
vae.requires_grad_(False)
539531
text_encoder.requires_grad_(False)
540532

541-
target_modules_list = ['to_q', 'to_k', 'to_v', 'to_out.0', 'ff.net.0.proj', 'ff.net.2']
533+
target_modules_list = ["to_q", "to_k", "to_v", "to_out.0", "ff.net.0.proj", "ff.net.2"]
542534
dit_lora_config = LoraConfig(
543535
r=args.lora_rank,
544536
lora_alpha=args.lora_alpha,
@@ -600,7 +592,7 @@ def save_model_hook(models, weights, output_dir):
600592
transformer_lora_layers=dit_lora_state_dict,
601593
safe_serialization=True,
602594
)
603-
595+
604596
accelerator.register_save_state_pre_hook(save_model_hook)
605597

606598
if accelerator.is_main_process:
@@ -634,7 +626,7 @@ def save_model_hook(models, weights, output_dir):
634626
padding_mask = torch.zeros(1, 1, args.height, args.width, dtype=dit_dtype, device=device)
635627
latent_shape = pipe.get_latent_shape_cthw(args.height, args.width, args.num_frames)
636628
latents_mean = pipe.latents_mean.float().to(device)
637-
latents_std = pipe.latents_std.float().to(device) # 1/σ
629+
latents_std = pipe.latents_std.float().to(device) # 1/σ
638630
# Start training
639631
torch.set_grad_enabled(True) # re-enable grad disabled by Cosmos2_5_PredictBasePipeline
640632
for epoch in range(first_epoch, args.num_train_epochs):
@@ -647,15 +639,15 @@ def save_model_hook(models, weights, output_dir):
647639
raw_state = batch["video"].to(device=device, dtype=vae.dtype)
648640
mu = vae.encode(raw_state).latent_dist.mean # deterministic
649641
clean_latent = ((mu - latents_mean) * latents_std).contiguous().float()
650-
assert clean_latent.requires_grad == False
642+
assert not clean_latent.requires_grad
651643
torch.cuda.empty_cache()
652644

653645
# Encode text to text embeddings
654646
prompt_embeds = pipe._get_prompt_embeds(
655647
prompt=batch["caption"],
656648
device=device,
657649
)
658-
assert prompt_embeds.requires_grad == False
650+
assert not prompt_embeds.requires_grad
659651

660652
# CFG dropout: independently zero out text conditioning per sample
661653
bsz = clean_latent.shape[0]
@@ -667,18 +659,21 @@ def save_model_hook(models, weights, output_dir):
667659
weights = list(args.conditional_frames_probs.values())
668660
num_conditional_frames = random.choices(frames_options, weights=weights, k=bsz)
669661
cond_indicator, cond_mask = pipe.create_condition_mask(
670-
(bsz, *latent_shape), device=device, dtype=torch.float32, num_cond_latent_frames=num_conditional_frames
662+
(bsz, *latent_shape),
663+
device=device,
664+
dtype=torch.float32,
665+
num_cond_latent_frames=num_conditional_frames,
671666
)
672667

673668
# Sample a random timestep
674-
sigma_t = sample_train_sigma_t(bsz, distribution='logitnormal', device=device)
669+
sigma_t = sample_train_sigma_t(bsz, distribution="logitnormal", device=device)
675670
# 1. Sample noise 2. Get the target velocity 3. Get xt by interpolation between noise and clean
676671
xt_B_C_T_H_W, target_velocity = get_flow_xt_and_target_v(clean_latent, sigma_t, cond_mask)
677-
672+
678673
# Denoise
679674
if args.conditional_frame_timestep >= 0:
680675
in_timestep = cond_indicator * args.conditional_frame_timestep + (1 - cond_indicator) * sigma_t
681-
676+
682677
pred_velocity = dit(
683678
hidden_states=xt_B_C_T_H_W,
684679
condition_mask=cond_mask,
@@ -717,7 +712,7 @@ def save_model_hook(models, weights, output_dir):
717712
if global_step >= max_train_steps:
718713
break
719714

720-
if (epoch+1) % args.checkpointing_epochs == 0 and (epoch+1) < args.num_train_epochs:
715+
if (epoch + 1) % args.checkpointing_epochs == 0 and (epoch + 1) < args.num_train_epochs:
721716
if accelerator.is_main_process:
722717
save_path = os.path.join(args.output_dir, f"checkpoint-{epoch}")
723718
accelerator.save_state(save_path)
@@ -738,7 +733,7 @@ def save_model_hook(models, weights, output_dir):
738733
if args.do_final_eval:
739734
noises = arch_invariant_rand((1, *latent_shape), dtype=torch.float32, device=device, seed=args.seed)
740735
inputs = train_dataloader.dataset[0]
741-
736+
742737
pipe.transformer.eval()
743738
with torch.inference_mode():
744739
frames = pipe(
@@ -747,14 +742,15 @@ def save_model_hook(models, weights, output_dir):
747742
prompt=inputs["caption"],
748743
num_frames=args.num_frames,
749744
num_inference_steps=args.num_inference_steps,
750-
latents=noises, # ensure architecture invariant generation
745+
latents=noises, # ensure architecture invariant generation
751746
height=args.height,
752747
width=args.width,
753748
).frames[0]
754-
749+
755750
export_to_video(frames, os.path.join(args.output_dir, "eval_output.mp4"), fps=16)
756751

757752
accelerator.end_training()
758753

754+
759755
if __name__ == "__main__":
760756
main()

src/diffusers/loaders/lora_pipeline.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2263,16 +2263,14 @@ def _prepare_outputs(state_dict, metadata, alphas=None, return_alphas=False, ret
22632263

22642264
class CosmosLoraLoaderMixin(FluxLoraLoaderMixin):
22652265
r"""
2266-
Load LoRA layers into [`CosmosTransformer3DModel`],
2267-
Specific to [`Cosmos2_5_PredictBasePipeline`].
2266+
Load LoRA layers into [`CosmosTransformer3DModel`], Specific to [`Cosmos2_5_PredictBasePipeline`].
22682267
"""
22692268

22702269
_lora_loadable_modules = ["transformer"]
22712270
transformer_name = TRANSFORMER_NAME
22722271
text_encoder_name = TEXT_ENCODER_NAME
22732272
_control_lora_supported_norm_keys = ["norm_q", "norm_k", "norm_added_q", "norm_added_k"]
22742273

2275-
22762274
def load_lora_weights(
22772275
self,
22782276
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
@@ -2312,11 +2310,6 @@ def load_lora_weights(
23122310
if not (has_lora_keys or has_norm_keys):
23132311
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
23142312

2315-
transformer_lora_state_dict = {
2316-
k: state_dict.get(k)
2317-
for k in list(state_dict.keys())
2318-
if k.startswith(f"{self.transformer_name}.") and "lora" in k
2319-
}
23202313
transformer_norm_state_dict = {
23212314
k: state_dict.pop(k)
23222315
for k in list(state_dict.keys())

src/diffusers/models/transformers/transformer_cosmos.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,9 @@ def __call__(
194194
original_dtype = query.dtype
195195
with torch.amp.autocast("cuda", enabled=self.autocast_fp32, dtype=torch.float32):
196196
target_dtype = torch.float32 if self.autocast_fp32 else original_dtype
197-
query = apply_rotary_emb(query.to(target_dtype), image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
197+
query = apply_rotary_emb(
198+
query.to(target_dtype), image_rotary_emb, use_real=True, use_real_unbind_dim=-2
199+
)
198200
key = apply_rotary_emb(key.to(target_dtype), image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
199201
query = query.to(original_dtype)
200202
key = key.to(original_dtype)
@@ -267,7 +269,9 @@ def __call__(
267269
original_dtype = query.dtype
268270
with torch.amp.autocast("cuda", enabled=self.autocast_fp32, dtype=torch.float32):
269271
target_dtype = torch.float32 if self.autocast_fp32 else original_dtype
270-
query = apply_rotary_emb(query.to(target_dtype), image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
272+
query = apply_rotary_emb(
273+
query.to(target_dtype), image_rotary_emb, use_real=True, use_real_unbind_dim=-2
274+
)
271275
key = apply_rotary_emb(key.to(target_dtype), image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
272276
query = query.to(original_dtype)
273277
key = key.to(original_dtype)

src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2323
from ...image_processor import PipelineImageInput
24+
from ...loaders import CosmosLoraLoaderMixin
2425
from ...models import AutoencoderKLWan, CosmosTransformer3DModel
2526
from ...schedulers import UniPCMultistepScheduler
2627
from ...utils import (
@@ -33,7 +34,6 @@
3334
from ...utils.torch_utils import randn_tensor
3435
from ...video_processor import VideoProcessor
3536
from ..pipeline_utils import DiffusionPipeline
36-
from ...loaders import CosmosLoraLoaderMixin
3737
from .pipeline_output import CosmosPipelineOutput
3838

3939

@@ -239,11 +239,11 @@ def __init__(
239239

240240
self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
241241
self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
242-
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial, resample='bilinear')
242+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial, resample="bilinear")
243243

244244
assert getattr(self.vae.config, "latents_mean", None), "VAE configuration must define `latents_mean`."
245245
assert getattr(self.vae.config, "latents_std", None), "VAE configuration must define `latents_std`."
246-
246+
247247
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).float()
248248
latents_std = torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).float()
249249
self.latents_mean = latents_mean
@@ -259,7 +259,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
259259
"torch_dtype": kwargs.get("torch_dtype", None),
260260
"attn_implementation": text_encoder_attn_implementation,
261261
}
262-
262+
263263
if os.path.isdir(pretrained_model_name_or_path):
264264
text_encoder_path = os.path.join(pretrained_model_name_or_path, "text_encoder")
265265
else:
@@ -270,21 +270,21 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
270270
)
271271

272272
return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
273-
273+
274274
def get_latent_shape_cthw(self, height: int, width: int, num_frames: int):
275275
C = self.vae.config.z_dim
276276
T = (num_frames - 1) // self.vae_scale_factor_temporal + 1
277277
H = height // self.vae_scale_factor_spatial
278278
W = width // self.vae_scale_factor_spatial
279279
return (C, T, H, W)
280-
280+
281281
def create_condition_mask(self, latent_shape, device, dtype, num_cond_latent_frames):
282282
bsz, C, T, H, W = latent_shape
283283
cond_indicator = torch.zeros(bsz, 1, T, 1, 1, dtype=dtype, device=device)
284284
if isinstance(num_cond_latent_frames, int):
285285
num_cond_latent_frames = [num_cond_latent_frames] * bsz
286286
for idx in range(bsz):
287-
cond_indicator[idx, :, :num_cond_latent_frames[idx], :, :] = 1.0
287+
cond_indicator[idx, :, : num_cond_latent_frames[idx], :, :] = 1.0
288288
cond_mask = cond_indicator.expand(-1, -1, -1, H, W)
289289
return cond_indicator, cond_mask
290290

@@ -493,11 +493,16 @@ def prepare_latents(
493493

494494
if isinstance(generator, list):
495495
cond_latents = [
496-
retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator=generator[i], sample_mode="argmax")
496+
retrieve_latents(
497+
self.vae.encode(video[i].unsqueeze(0)), generator=generator[i], sample_mode="argmax"
498+
)
497499
for i in range(batch_size)
498500
]
499501
else:
500-
cond_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator, sample_mode="argmax") for vid in video]
502+
cond_latents = [
503+
retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator, sample_mode="argmax")
504+
for vid in video
505+
]
501506

502507
cond_latents = torch.cat(cond_latents, dim=0).to(dtype)
503508

@@ -760,8 +765,8 @@ def __call__(
760765
raise ValueError(
761766
f"Input video has only {total_input_frames} frames but Video2World requires at least "
762767
f"{frames_to_extract} frames for conditioning."
763-
)
764-
768+
)
769+
765770
video = video[:, :, -frames_to_extract:, :, :]
766771
if video.shape[2] < num_frames:
767772
n_pad_frames = num_frames - video.shape[2]
@@ -807,7 +812,7 @@ def __call__(
807812
continue
808813

809814
self._current_timestep = t.cpu().item()
810-
815+
811816
# NOTE: assumes sigma(t) \in [0, 1]
812817
sigma_t = self.scheduler.sigmas[i].expand(batch_size).to(device=device, dtype=torch.float32)
813818
if conditional_frame_timestep >= 0:

0 commit comments

Comments
 (0)