Skip to content

Commit 431aecd

Browse files
committed
add readme, requirements, and make flash attn optional
1 parent 3049d45 commit 431aecd

6 files changed

Lines changed: 131 additions & 10 deletions

examples/cosmos/README.md

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# LoRA fine-tuning for Cosmos Predict 2.5
2+
3+
This example shows how to fine-tune [Cosmos Predict 2.5](https://huggingface.co/nvidia/Cosmos-Predict2.5-2B) using LoRA on a custom video dataset.
4+
5+
## Requirements
6+
7+
Install the library from source and the example-specific dependencies:
8+
9+
```bash
10+
git clone https://github.com/huggingface/diffusers
11+
cd diffusers
12+
pip install -e ".[dev]"
13+
cd examples/cosmos
14+
pip install -r requirements.txt
15+
```
16+
17+
> [!NOTE]
18+
> `flash-attn` is required for the default `flash_attention_2` text encoder attention implementation and must be installed separately after PyTorch:
19+
> ```bash
20+
> pip install flash-attn --no-build-isolation
21+
> ```
22+
> If your hardware does not support it, pass `--text_encoder_attn_implementation sdpa` to the training and eval scripts instead.
23+
24+
## Data preparation
25+
26+
The training script expects a dataset directory with the following layout:
27+
28+
```
29+
<dataset_dir>/
30+
├── videos/ # .mp4 files
31+
└── metas/ # one .txt prompt file per video (same stem)
32+
├── 0.txt
33+
├── 1.txt
34+
└── ...
35+
```
36+
37+
### GR1 dataset (quick start)
38+
39+
The `download_and_preprocess_datasets.sh` script downloads the GR1-100 training set and the EVAL-175 test set, then runs the preprocessing script to create the per-video prompt files.
40+
41+
```bash
42+
bash download_and_preprocess_datasets.sh
43+
```
44+
45+
This produces:
46+
- `gr1_dataset/train/` — training videos + prompts
47+
- `gr1_dataset/test/` — evaluation images + prompts
48+
49+
## Training
50+
51+
Launch LoRA training with `accelerate`:
52+
53+
```bash
54+
export MODEL_NAME="nvidia/Cosmos-Predict2.5-2B"
55+
export DATA_DIR="gr1_dataset/train"
56+
export OUT_DIR="lora-output"
57+
58+
accelerate launch --mixed_precision="bf16" train_cosmos_predict25_lora.py \
59+
--pretrained_model_name_or_path=$MODEL_NAME \
60+
--revision diffusers/base/post-trained \
61+
--train_data_dir=$DATA_DIR \
62+
--output_dir=$OUT_DIR \
63+
--train_batch_size=1 \
64+
--num_train_epochs=500 \
65+
--checkpointing_epochs=100 \
66+
--seed=0 \
67+
--height 432 --width 768 \
68+
--allow_tf32 \
69+
--gradient_checkpointing \
70+
--lora_rank 32 --lora_alpha 32 \
71+
--report_to=wandb
72+
```
73+
74+
Or use the provided shell script:
75+
76+
```bash
77+
bash train_lora.sh
78+
```
79+
80+
## Evaluation
81+
82+
Run inference with the trained LoRA adapter:
83+
84+
```bash
85+
export DATA_DIR="gr1_dataset/test"
86+
export LORA_DIR="lora-output"
87+
export OUT_DIR="eval-output"
88+
89+
python eval_cosmos_predict25_lora.py \
90+
--data_dir $DATA_DIR \
91+
--output_dir $OUT_DIR \
92+
--lora_dir $LORA_DIR \
93+
--revision diffusers/base/post-trained \
94+
--height 432 --width 768 \
95+
--num_output_frames 93 \
96+
--num_steps 36 \
97+
--seed 0
98+
```
99+
100+
Or use the provided shell script:
101+
102+
```bash
103+
bash eval_lora.sh
104+
```

examples/cosmos/create_prompts_for_gr1_dataset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from tqdm import tqdm
2020

2121
"""example command
22-
python -m scripts.create_prompts_for_gr1_dataset --dataset_path datasets/benchmark_train/gr1
22+
python create_prompts_for_gr1_dataset.py --dataset_path datasets/benchmark_train/gr1
2323
"""
2424

2525

@@ -32,13 +32,13 @@ def parse_args() -> argparse.ArgumentParser:
3232
"--prompt_prefix", type=str, default="The robot arm is performing a task. ", help="Prefix of the prompt"
3333
)
3434
parser.add_argument(
35-
"--meta_csv", type=str, default="datasets/benchmark_train/gr1/metadata.csv", help="Metadata csv file"
35+
"--meta_csv", type=str, default=None, help="Metadata csv file (defaults to <dataset_path>/metadata.csv)"
3636
)
3737
return parser.parse_args()
3838

3939

4040
def main(args) -> None:
41-
meta_csv = args.meta_csv
41+
meta_csv = args.meta_csv or os.path.join(args.dataset_path, "metadata.csv")
4242
meta_lines = open(meta_csv).readlines()[1:]
4343
meta_txt_dir = os.path.join(args.dataset_path, "metas")
4444
os.makedirs(meta_txt_dir, exist_ok=True)

examples/cosmos/download_and_preprocess_datasets.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,6 @@ mv datasets/benchmark_train/gr1 $train_dir
2020
mv dream_gen_benchmark/gr1_object $test_dir
2121
echo Download training data to $train_dir
2222
echo Download test data to $test_dir
23+
24+
# Clean up staging directories
25+
rm -rf datasets/ dream_gen_benchmark/

examples/cosmos/eval_cosmos_predict25_lora.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,13 @@ def parse_args():
102102
default=None,
103103
help="Negative prompt. Defaults to the pipeline's built-in negative prompt.",
104104
)
105+
parser.add_argument(
106+
"--text_encoder_attn_implementation",
107+
type=str,
108+
default="flash_attention_2",
109+
choices=["eager", "sdpa", "flash_attention_2"],
110+
help="The attention implementation to use for the text encoder (Qwen2.5 VL).",
111+
)
105112

106113
return parser.parse_args()
107114

@@ -137,6 +144,7 @@ def check_video_safety(self, video):
137144
device_map=args.device,
138145
torch_dtype=torch.bfloat16,
139146
safety_checker=MockSafetyChecker(),
147+
text_encoder_attn_implementation=args.text_encoder_attn_implementation,
140148
)
141149

142150
if args.lora_dir is not None:

examples/cosmos/requirements.txt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
accelerate>=0.31.0
2+
huggingface_hub
3+
imageio
4+
imageio-ffmpeg
5+
transformers>=4.41.2
6+
peft>=0.11.1
7+
datasets
8+
numpy
9+
tqdm
10+
sentencepiece
11+
tensorboard
12+
wandb

examples/cosmos/train_cosmos_predict25_lora.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -380,13 +380,7 @@ def _load_json_caption(self, json_path: Path) -> str:
380380
"""Load caption from JSON file with prompt type selection."""
381381
try:
382382
with open(json_path, "r") as f:
383-
content = f.read()
384-
# Handle JSON that might not have top-level object
385-
if not content.strip().startswith("{"):
386-
# Wrap in object if needed
387-
data = json.loads("{" + content + "}")
388-
else:
389-
data = json.loads(content)
383+
data = json.load(f)
390384

391385
# Get the first model's captions (e.g., "qwen3_vl_30b_a3b")
392386
model_key = next(iter(data.keys()))

0 commit comments

Comments
 (0)