|
| 1 | +# LoRA fine-tuning for Cosmos Predict 2.5 |
| 2 | + |
| 3 | +This example shows how to fine-tune [Cosmos Predict 2.5](https://huggingface.co/nvidia/Cosmos-Predict2.5-2B) using LoRA on a custom video dataset. |
| 4 | + |
| 5 | +## Requirements |
| 6 | + |
| 7 | +Install the library from source and the example-specific dependencies: |
| 8 | + |
| 9 | +```bash |
| 10 | +git clone https://github.com/huggingface/diffusers |
| 11 | +cd diffusers |
| 12 | +pip install -e ".[dev]" |
| 13 | +cd examples/cosmos |
| 14 | +pip install -r requirements.txt |
| 15 | +``` |
| 16 | + |
| 17 | +> [!NOTE] |
| 18 | +> `flash-attn` is required for the default `flash_attention_2` text encoder attention implementation and must be installed separately after PyTorch: |
| 19 | +> ```bash |
| 20 | +> pip install flash-attn --no-build-isolation |
| 21 | +> ``` |
| 22 | +> If your hardware does not support it, pass `--text_encoder_attn_implementation sdpa` to the training and eval scripts instead. |
| 23 | +
|
| 24 | +## Data preparation |
| 25 | +
|
| 26 | +The training script expects a dataset directory with the following layout: |
| 27 | +
|
| 28 | +``` |
| 29 | +<dataset_dir>/ |
| 30 | +├── videos/ # .mp4 files |
| 31 | +└── metas/ # one .txt prompt file per video (same stem) |
| 32 | + ├── 0.txt |
| 33 | + ├── 1.txt |
| 34 | + └── ... |
| 35 | +``` |
| 36 | +
|
| 37 | +### GR1 dataset (quick start) |
| 38 | + |
| 39 | +The `download_and_preprocess_datasets.sh` script downloads the GR1-100 training set and the EVAL-175 test set, then runs the preprocessing script to create the per-video prompt files. |
| 40 | + |
| 41 | +```bash |
| 42 | +bash download_and_preprocess_datasets.sh |
| 43 | +``` |
| 44 | + |
| 45 | +This produces: |
| 46 | +- `gr1_dataset/train/` — training videos + prompts |
| 47 | +- `gr1_dataset/test/` — evaluation images + prompts |
| 48 | + |
| 49 | +## Training |
| 50 | + |
| 51 | +Launch LoRA training with `accelerate`: |
| 52 | + |
| 53 | +```bash |
| 54 | +export MODEL_NAME="nvidia/Cosmos-Predict2.5-2B" |
| 55 | +export DATA_DIR="gr1_dataset/train" |
| 56 | +export OUT_DIR="lora-output" |
| 57 | + |
| 58 | +accelerate launch --mixed_precision="bf16" train_cosmos_predict25_lora.py \ |
| 59 | + --pretrained_model_name_or_path=$MODEL_NAME \ |
| 60 | + --revision diffusers/base/post-trained \ |
| 61 | + --train_data_dir=$DATA_DIR \ |
| 62 | + --output_dir=$OUT_DIR \ |
| 63 | + --train_batch_size=1 \ |
| 64 | + --num_train_epochs=500 \ |
| 65 | + --checkpointing_epochs=100 \ |
| 66 | + --seed=0 \ |
| 67 | + --height 432 --width 768 \ |
| 68 | + --allow_tf32 \ |
| 69 | + --gradient_checkpointing \ |
| 70 | + --lora_rank 32 --lora_alpha 32 \ |
| 71 | + --report_to=wandb |
| 72 | +``` |
| 73 | + |
| 74 | +Or use the provided shell script: |
| 75 | + |
| 76 | +```bash |
| 77 | +bash train_lora.sh |
| 78 | +``` |
| 79 | + |
| 80 | +## Evaluation |
| 81 | + |
| 82 | +Run inference with the trained LoRA adapter: |
| 83 | + |
| 84 | +```bash |
| 85 | +export DATA_DIR="gr1_dataset/test" |
| 86 | +export LORA_DIR="lora-output" |
| 87 | +export OUT_DIR="eval-output" |
| 88 | + |
| 89 | +python eval_cosmos_predict25_lora.py \ |
| 90 | + --data_dir $DATA_DIR \ |
| 91 | + --output_dir $OUT_DIR \ |
| 92 | + --lora_dir $LORA_DIR \ |
| 93 | + --revision diffusers/base/post-trained \ |
| 94 | + --height 432 --width 768 \ |
| 95 | + --num_output_frames 93 \ |
| 96 | + --num_steps 36 \ |
| 97 | + --seed 0 |
| 98 | +``` |
| 99 | + |
| 100 | +Or use the provided shell script: |
| 101 | + |
| 102 | +```bash |
| 103 | +bash eval_lora.sh |
| 104 | +``` |
0 commit comments