|
1 | | -# Based on https://github.com/huggingface/diffusers/blob/180841bbde4b200be43350164eef80c93a68983a/examples/dreambooth/train_dreambooth.py |
| 1 | +# Based on https://github.com/huggingface/diffusers/blob/946d1cb200a875f694818be37c9c9f7547e9db45/examples/dreambooth/train_dreambooth.py |
2 | 2 |
|
3 | 3 | # Reasons for not using that file directly: |
4 | 4 | # |
@@ -127,7 +127,7 @@ def TrainDreamBooth(model_id: str, pipeline, model_inputs, call_inputs): |
127 | 127 | # The integration to report the results and logs to. Supported platforms are `"tensorboard"` |
128 | 128 | # (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations. |
129 | 129 | "report_to": "tensorboard", |
130 | | - "mixed_precision": None, # DDA, was: None XXX fp16 |
| 130 | + "mixed_precision": "fp16", # DDA, was: None |
131 | 131 | # Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= |
132 | 132 | # 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32. |
133 | 133 | "prior_generation_precision": None, # "no", "fp32", "fp16", "bf16" |
@@ -582,6 +582,26 @@ def main(args, init_pipeline): |
582 | 582 | if args.train_text_encoder: |
583 | 583 | text_encoder.gradient_checkpointing_enable() |
584 | 584 |
|
| 585 | + # Check that all trainable models are in full precision |
| 586 | + low_precision_error_string = ( |
| 587 | + "Please make sure to always have all model weights in full float32 precision when starting training - even if" |
| 588 | + " doing mixed precision training. copy of the weights should still be float32." |
| 589 | + ) |
| 590 | + |
| 591 | + if accelerator.unwrap_model(unet).dtype != torch.float32: |
| 592 | + raise ValueError( |
| 593 | + f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}" |
| 594 | + ) |
| 595 | + |
| 596 | + if ( |
| 597 | + args.train_text_encoder |
| 598 | + and accelerator.unwrap_model(text_encoder).dtype != torch.float32 |
| 599 | + ): |
| 600 | + raise ValueError( |
| 601 | + f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}." |
| 602 | + f" {low_precision_error_string}" |
| 603 | + ) |
| 604 | + |
585 | 605 | # Enable TF32 for faster training on Ampere GPUs, |
586 | 606 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices |
587 | 607 | if args.allow_tf32: |
@@ -688,24 +708,6 @@ def main(args, init_pipeline): |
688 | 708 | if not args.train_text_encoder: |
689 | 709 | text_encoder.to(accelerator.device, dtype=weight_dtype) |
690 | 710 |
|
691 | | - low_precision_error_string = ( |
692 | | - "Please make sure to always have all model weights in full float32 precision when starting training - even if" |
693 | | - " doing mixed precision training. copy of the weights should still be float32." |
694 | | - ) |
695 | | - |
696 | | - if accelerator.unwrap_model(unet).dtype != torch.float32: |
697 | | - raise ValueError( |
698 | | - f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}" |
699 | | - ) |
700 | | - if ( |
701 | | - args.train_text_encoder |
702 | | - and accelerator.unwrap_model(text_encoder).dtype != torch.float32 |
703 | | - ): |
704 | | - raise ValueError( |
705 | | - f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}." |
706 | | - f" {low_precision_error_string}" |
707 | | - ) |
708 | | - |
709 | 711 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. |
710 | 712 | num_update_steps_per_epoch = math.ceil( |
711 | 713 | len(train_dataloader) / args.gradient_accumulation_steps |
|
0 commit comments