Commit e377c0a
Fix fp16 LoRA unscale crash after validation in train_dreambooth_lora.py (#13895)
When training with `--mixed_precision="fp16"` and `--validation_prompt`,
the first optimizer step after a validation run fails with
`ValueError: Attempting to unscale FP16 gradients`.
Under fp16, `cast_training_params` keeps the trainable LoRA params in
fp32. The in-loop validation pipeline is built with the same live `unet`
object, and `log_validation` then calls `pipeline.to(device, dtype=torch_dtype)`,
which downcasts those fp32 LoRA params back to fp16. The next backward
therefore produces fp16 grads and `GradScaler.unscale_` raises.
Drop the dtype cast from that `.to(...)` so the shared `unet` keeps its
fp32 LoRA params. This matches train_dreambooth_lora_sdxl.py, which moves
the validation pipeline with `.to(accelerator.device)` only.
Fixes #13124
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>1 parent 3759fab commit e377c0a
1 file changed
Lines changed: 5 additions & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
147 | 147 | | |
148 | 148 | | |
149 | 149 | | |
150 | | - | |
| 150 | + | |
| 151 | + | |
| 152 | + | |
| 153 | + | |
| 154 | + | |
151 | 155 | | |
152 | 156 | | |
153 | 157 | | |
| |||
0 commit comments