Skip to content

Commit 4ca8633

Browse files
terarachangTing-Yun Chang
andauthored
Add LoRA support for Cosmos Predict 2.5 and fix pipeline to match official Cosmos repo (#13664)
* support lora for cosmos 2.5 * Fix inconsistencies with cosmos official repo in VAE encoding, text encoder attention implementation, and timestep scaling * Support f_min and f_max in linear_scheduler warmup * Add requirements and dataset preprocessing scripts to run examples * Add LoRA training scripts * Add LoRA eval scripts * add assets for blogpost * Fix(scheduler): device mismatch from upstream b114620 - move rk and b to device before torch.stack * Always upcast to fp32 * Directly inhrit from LoraBaseMixin * remove flash-attn2 * Use _keep_in_fp32_modules instead of autocast * remove the get_latent_shape_cthw method and fix style * simplifiy the eval script to make it more user-friendly * overwrite scheduling_unipc_multistep.py with main's version * remove network_alphas and add # Copied from * remove figures and assets * revert scheduler * revert fp32 upcast and support bs > 1 --------- Co-authored-by: Ting-Yun Chang <tingyunc@nvidia.com>
1 parent 1030249 commit 4ca8633

16 files changed

Lines changed: 1495 additions & 87 deletions

docs/source/en/api/loaders/lora.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,10 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
132132

133133
[[autodoc]] loaders.lora_pipeline.ZImageLoraLoaderMixin
134134

135+
## CosmosLoraLoaderMixin
136+
137+
[[autodoc]] loaders.lora_pipeline.CosmosLoraLoaderMixin
138+
135139
## KandinskyLoraLoaderMixin
136140
[[autodoc]] loaders.lora_pipeline.KandinskyLoraLoaderMixin
137141

examples/cosmos/README.md

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# LoRA fine-tuning for Cosmos Predict 2.5
2+
3+
This example shows how to fine-tune [Cosmos Predict 2.5](https://huggingface.co/nvidia/Cosmos-Predict2.5-2B) using LoRA on a custom video dataset.
4+
5+
## Requirements
6+
7+
Install the library from source and the example-specific dependencies:
8+
9+
```bash
10+
git clone https://github.com/huggingface/diffusers
11+
cd diffusers
12+
pip install -e ".[dev]"
13+
cd examples/cosmos
14+
pip install -r requirements.txt
15+
```
16+
17+
## Data preparation
18+
19+
The training script expects a dataset directory with the following layout:
20+
21+
```
22+
<dataset_dir>/
23+
├── videos/ # .mp4 files
24+
└── metas/ # one .txt prompt file per video (same stem)
25+
├── 0.txt
26+
├── 1.txt
27+
└── ...
28+
```
29+
30+
### GR1 dataset (quick start)
31+
32+
The `download_and_preprocess_datasets.sh` script downloads the GR1-100 training set and the EVAL-175 test set, then runs the preprocessing script to create the per-video prompt files.
33+
34+
```bash
35+
bash download_and_preprocess_datasets.sh
36+
```
37+
38+
This produces:
39+
- `gr1_dataset/train/` — training videos + prompts
40+
- `gr1_dataset/test/` — evaluation images + prompts
41+
42+
## Training
43+
44+
Launch LoRA training with `accelerate`:
45+
46+
```bash
47+
export MODEL_NAME="nvidia/Cosmos-Predict2.5-2B"
48+
export DATA_DIR="gr1_dataset/train"
49+
export OUT_DIR="lora-output"
50+
51+
accelerate launch --mixed_precision="bf16" train_cosmos_predict25_lora.py \
52+
--pretrained_model_name_or_path=$MODEL_NAME \
53+
--revision diffusers/base/post-trained \
54+
--train_data_dir=$DATA_DIR \
55+
--output_dir=$OUT_DIR \
56+
--train_batch_size=1 \
57+
--num_train_epochs=500 \
58+
--checkpointing_epochs=100 \
59+
--seed=0 \
60+
--height 432 --width 768 \
61+
--allow_tf32 \
62+
--gradient_checkpointing \
63+
--lora_rank 32 --lora_alpha 32 \
64+
--report_to=wandb
65+
```
66+
67+
Or use the provided shell script:
68+
69+
```bash
70+
bash train_lora.sh
71+
```
72+
73+
## Evaluation
74+
75+
Run inference with the trained LoRA adapter:
76+
77+
```bash
78+
export DATA_DIR="gr1_dataset/test"
79+
export LORA_DIR="lora-output"
80+
export OUT_DIR="eval-output"
81+
82+
python eval_cosmos_predict25_lora.py \
83+
--data_dir $DATA_DIR \
84+
--output_dir $OUT_DIR \
85+
--lora_dir $LORA_DIR \
86+
--revision diffusers/base/post-trained \
87+
--height 432 --width 768 \
88+
--num_output_frames 93 \
89+
--num_steps 36 \
90+
--seed 0
91+
```
92+
93+
Or use the provided shell script:
94+
95+
```bash
96+
bash eval_lora.sh
97+
```
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import argparse
17+
import os
18+
19+
from tqdm import tqdm
20+
21+
22+
"""example command
23+
python create_prompts_for_gr1_dataset.py --dataset_path datasets/benchmark_train/gr1
24+
"""
25+
26+
27+
def parse_args() -> argparse.ArgumentParser:
28+
parser = argparse.ArgumentParser(description="Create text prompts for GR1 dataset")
29+
parser.add_argument(
30+
"--dataset_path", type=str, default="datasets/benchmark_train/gr1", help="Root path to the dataset"
31+
)
32+
parser.add_argument(
33+
"--prompt_prefix", type=str, default="The robot arm is performing a task. ", help="Prefix of the prompt"
34+
)
35+
parser.add_argument(
36+
"--meta_csv", type=str, default=None, help="Metadata csv file (defaults to <dataset_path>/metadata.csv)"
37+
)
38+
return parser.parse_args()
39+
40+
41+
def main(args) -> None:
42+
meta_csv = args.meta_csv or os.path.join(args.dataset_path, "metadata.csv")
43+
meta_lines = open(meta_csv).readlines()[1:]
44+
meta_txt_dir = os.path.join(args.dataset_path, "metas")
45+
os.makedirs(meta_txt_dir, exist_ok=True)
46+
47+
for meta_line in tqdm(meta_lines):
48+
video_filename, prompt = meta_line.split(",", 1)
49+
prompt = prompt.strip("\n")
50+
if prompt.startswith('"') and prompt.endswith('"'):
51+
# Remove the quotes
52+
prompt = prompt[1:-1]
53+
prompt = args.prompt_prefix + prompt
54+
meta_txt_filename = os.path.join(meta_txt_dir, os.path.basename(video_filename).replace(".mp4", ".txt"))
55+
with open(meta_txt_filename, "w") as fp:
56+
fp.write(prompt)
57+
58+
print(f"encoding prompt: {prompt}")
59+
60+
61+
if __name__ == "__main__":
62+
args = parse_args()
63+
main(args)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
dataset_dir='gr1_dataset'
2+
train_dir=$dataset_dir/train
3+
test_dir=$dataset_dir/test
4+
5+
# Download and Preprocess Training Dataset
6+
hf download nvidia/GR1-100 --repo-type dataset --local-dir datasets/benchmark_train/hf_gr1/ && \
7+
mkdir -p datasets/benchmark_train/gr1/videos && \
8+
mv datasets/benchmark_train/hf_gr1/gr1/*mp4 datasets/benchmark_train/gr1/videos && \
9+
mv datasets/benchmark_train/hf_gr1/metadata.csv datasets/benchmark_train/gr1/
10+
11+
python create_prompts_for_gr1_dataset.py --dataset_path datasets/benchmark_train/gr1
12+
13+
# Download Eval Dataset
14+
hf download nvidia/EVAL-175 --repo-type dataset --local-dir dream_gen_benchmark
15+
16+
17+
# Rename dataset directory
18+
mkdir $dataset_dir
19+
mv datasets/benchmark_train/gr1 $train_dir
20+
mv dream_gen_benchmark/gr1_object $test_dir
21+
echo Download training data to $train_dir
22+
echo Download test data to $test_dir
23+
24+
# Clean up staging directories
25+
rm -rf datasets/ dream_gen_benchmark/
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
#!/usr/bin/env python3
2+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import argparse
6+
import os
7+
8+
import torch
9+
from torch.utils.data import DataLoader, Dataset
10+
from tqdm import tqdm
11+
12+
from diffusers import Cosmos2_5_PredictBasePipeline
13+
from diffusers.utils import export_to_video, load_image
14+
15+
16+
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png"}
17+
18+
19+
class ImageDataset(Dataset):
20+
"""Dataset that loads images and their corresponding text prompts.
21+
22+
Expects a directory with:
23+
<filename>.jpg / .jpeg / .png — the conditioning image
24+
<filename>.txt — the prompt text
25+
"""
26+
27+
def __init__(self, data_dir: str):
28+
self.data_dir = data_dir
29+
self.samples = []
30+
31+
for filename in sorted(os.listdir(data_dir)):
32+
stem, ext = os.path.splitext(filename)
33+
if ext.lower() not in IMAGE_EXTENSIONS:
34+
continue
35+
img_path = os.path.join(data_dir, filename)
36+
txt_path = os.path.join(data_dir, stem + ".txt")
37+
if not os.path.exists(txt_path):
38+
print(f"WARNING: no prompt file found for {img_path}, skipping.")
39+
continue
40+
self.samples.append((img_path, txt_path, stem))
41+
42+
if len(self.samples) == 0:
43+
raise ValueError(f"No valid image/prompt pairs found in {data_dir}")
44+
45+
def __len__(self):
46+
return len(self.samples)
47+
48+
def __getitem__(self, idx):
49+
img_path, txt_path, stem = self.samples[idx]
50+
image = load_image(img_path)
51+
with open(txt_path) as f:
52+
prompt = f.read().strip()
53+
return {
54+
"image": image,
55+
"prompt": prompt,
56+
"stem": stem,
57+
}
58+
59+
60+
def collate_fn(batch):
61+
"""Keep images as a list (PIL images can't be stacked into a tensor)."""
62+
return {
63+
"images": [item["image"] for item in batch],
64+
"prompts": [item["prompt"] for item in batch],
65+
"stems": [item["stem"] for item in batch],
66+
}
67+
68+
69+
def parse_args():
70+
parser = argparse.ArgumentParser(description="Eval Cosmos Predict 2.5 with optional LoRA weights.")
71+
72+
parser.add_argument("--data_dir", type=str, required=True, help="Directory with image/prompt pairs.")
73+
parser.add_argument("--output_dir", type=str, required=True, help="Directory to save generated outputs.")
74+
parser.add_argument(
75+
"--model_id", type=str, default="nvidia/Cosmos-Predict2.5-2B", help="HuggingFace model repository."
76+
)
77+
parser.add_argument(
78+
"--revision",
79+
type=str,
80+
default="diffusers/base/post-trained",
81+
choices=["diffusers/base/post-trained", "diffusers/base/pre-trained"],
82+
)
83+
parser.add_argument("--lora_dir", type=str, default=None, help="Path to LoRA weights directory.")
84+
parser.add_argument("--num_output_frames", type=int, default=93, help="1 for image output, 93 for video output.")
85+
parser.add_argument("--num_steps", type=int, default=36, help="Number of inference steps.")
86+
parser.add_argument("--height", type=int, default=704, help="Output height in pixels (must be divisible by 16).")
87+
parser.add_argument("--width", type=int, default=1280, help="Output width in pixels (must be divisible by 16).")
88+
parser.add_argument("--seed", type=int, default=0, help="Random seed.")
89+
parser.add_argument("--device", type=str, default="cuda", help="Device to use.")
90+
parser.add_argument("--batch_size", type=int, default=1, help="Number of samples per batch.")
91+
parser.add_argument("--num_workers", type=int, default=4, help="DataLoader worker processes.")
92+
parser.add_argument(
93+
"--negative_prompt",
94+
type=str,
95+
default=None,
96+
help="Negative prompt. Defaults to the pipeline's built-in negative prompt.",
97+
)
98+
return parser.parse_args()
99+
100+
101+
def main():
102+
args = parse_args()
103+
os.makedirs(args.output_dir, exist_ok=True)
104+
105+
dataset = ImageDataset(args.data_dir)
106+
dataloader = DataLoader(
107+
dataset,
108+
batch_size=args.batch_size,
109+
shuffle=False,
110+
num_workers=args.num_workers,
111+
collate_fn=collate_fn,
112+
)
113+
114+
print(f"Found {len(dataset)} examples.")
115+
116+
class MockSafetyChecker:
117+
def to(self, *args, **kwargs):
118+
return self
119+
120+
def check_text_safety(self, *args, **kwargs):
121+
return True
122+
123+
def check_video_safety(self, video):
124+
return video
125+
126+
pipe = Cosmos2_5_PredictBasePipeline.from_pretrained(
127+
args.model_id,
128+
revision=args.revision,
129+
device_map=args.device,
130+
torch_dtype=torch.bfloat16,
131+
safety_checker=MockSafetyChecker(),
132+
)
133+
134+
if args.lora_dir is not None:
135+
pipe.load_lora_weights(args.lora_dir)
136+
pipe.fuse_lora(lora_scale=1.0)
137+
print(f"Loaded LoRA weights from {args.lora_dir}")
138+
139+
progress = tqdm(total=len(dataset), desc="Generating")
140+
for batch in dataloader:
141+
images = batch["images"]
142+
prompts = batch["prompts"]
143+
stems = batch["stems"]
144+
145+
for image, prompt, stem in zip(images, prompts, stems):
146+
frames = pipe(
147+
image=image,
148+
prompt=prompt,
149+
negative_prompt=args.negative_prompt,
150+
num_frames=args.num_output_frames,
151+
num_inference_steps=args.num_steps,
152+
height=args.height,
153+
width=args.width,
154+
).frames[0] # NOTE: batch_size == 1
155+
156+
out_path = os.path.join(args.output_dir, f"{stem}.mp4")
157+
export_to_video(frames, out_path, fps=16)
158+
159+
tqdm.write(f" Saved to: {out_path}")
160+
progress.update(1)
161+
162+
163+
if __name__ == "__main__":
164+
main()

examples/cosmos/eval_lora.sh

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
export DATA_DIR="gr1_dataset/test"
2+
export LORA_DIR=YOUR_ADAPTER_DIR
3+
export OUT_DIR=YOUR_EVAL_OUTPUT_DIR
4+
revision="post-trained"
5+
6+
export TOKENIZERS_PARALLELISM=false
7+
python eval_cosmos_predict25_lora.py \
8+
--data_dir $DATA_DIR \
9+
--output_dir $OUT_DIR \
10+
--lora_dir $LORA_DIR \
11+
--revision diffusers/base/$revision \
12+
--height 432 --width 768 \
13+
--num_output_frames 93 \
14+
--num_steps 36 \
15+
--seed 0

0 commit comments

Comments
 (0)