Skip to content

Commit e377c0a

Browse files
Fix fp16 LoRA unscale crash after validation in train_dreambooth_lora.py (#13895)
When training with `--mixed_precision="fp16"` and `--validation_prompt`, the first optimizer step after a validation run fails with `ValueError: Attempting to unscale FP16 gradients`. Under fp16, `cast_training_params` keeps the trainable LoRA params in fp32. The in-loop validation pipeline is built with the same live `unet` object, and `log_validation` then calls `pipeline.to(device, dtype=torch_dtype)`, which downcasts those fp32 LoRA params back to fp16. The next backward therefore produces fp16 grads and `GradScaler.unscale_` raises. Drop the dtype cast from that `.to(...)` so the shared `unet` keeps its fp32 LoRA params. This matches train_dreambooth_lora_sdxl.py, which moves the validation pipeline with `.to(accelerator.device)` only. Fixes #13124 Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent 3759fab commit e377c0a

1 file changed

Lines changed: 5 additions & 1 deletion

File tree

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,11 @@ def log_validation(
147147

148148
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
149149

150-
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
150+
# Don't pass `dtype` here: under fp16 the trainable LoRA params are kept in fp32 (see
151+
# `cast_training_params` above) and the validation pipeline shares the training `unet`, so casting it
152+
# to fp16 would break the next optimizer step ("Attempting to unscale FP16 gradients"). Matches the
153+
# SDXL script.
154+
pipeline = pipeline.to(accelerator.device)
151155
pipeline.set_progress_bar_config(disable=True)
152156

153157
# run inference

0 commit comments

Comments
 (0)