Skip to content

Commit 0f5d5ff

Browse files
committed
fix(dreambooth): bump diffusers, fixes fp16 mixed precision training
1 parent aebcf65 commit 0f5d5ff

2 files changed

Lines changed: 24 additions & 22 deletions

File tree

Dockerfile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ WORKDIR /api
3737
ADD requirements.txt requirements.txt
3838
RUN pip install -r requirements.txt
3939

40-
# 2023-01-25 Release: v0.12.0
41-
RUN git clone https://github.com/huggingface/diffusers && cd diffusers && git checkout 180841bbde4b200be43350164eef80c93a68983a
40+
# [dreambooth] check the low-precision guard before preparing model (#2102)
41+
RUN git clone https://github.com/huggingface/diffusers && cd diffusers && git checkout 946d1cb200a875f694818be37c9c9f7547e9db45
4242
WORKDIR /api
4343
RUN pip install -e diffusers
4444

api/train_dreambooth.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Based on https://github.com/huggingface/diffusers/blob/180841bbde4b200be43350164eef80c93a68983a/examples/dreambooth/train_dreambooth.py
1+
# Based on https://github.com/huggingface/diffusers/blob/946d1cb200a875f694818be37c9c9f7547e9db45/examples/dreambooth/train_dreambooth.py
22

33
# Reasons for not using that file directly:
44
#
@@ -127,7 +127,7 @@ def TrainDreamBooth(model_id: str, pipeline, model_inputs, call_inputs):
127127
# The integration to report the results and logs to. Supported platforms are `"tensorboard"`
128128
# (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.
129129
"report_to": "tensorboard",
130-
"mixed_precision": None, # DDA, was: None XXX fp16
130+
"mixed_precision": "fp16", # DDA, was: None
131131
# Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=
132132
# 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32.
133133
"prior_generation_precision": None, # "no", "fp32", "fp16", "bf16"
@@ -582,6 +582,26 @@ def main(args, init_pipeline):
582582
if args.train_text_encoder:
583583
text_encoder.gradient_checkpointing_enable()
584584

585+
# Check that all trainable models are in full precision
586+
low_precision_error_string = (
587+
"Please make sure to always have all model weights in full float32 precision when starting training - even if"
588+
" doing mixed precision training. copy of the weights should still be float32."
589+
)
590+
591+
if accelerator.unwrap_model(unet).dtype != torch.float32:
592+
raise ValueError(
593+
f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
594+
)
595+
596+
if (
597+
args.train_text_encoder
598+
and accelerator.unwrap_model(text_encoder).dtype != torch.float32
599+
):
600+
raise ValueError(
601+
f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}."
602+
f" {low_precision_error_string}"
603+
)
604+
585605
# Enable TF32 for faster training on Ampere GPUs,
586606
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
587607
if args.allow_tf32:
@@ -688,24 +708,6 @@ def main(args, init_pipeline):
688708
if not args.train_text_encoder:
689709
text_encoder.to(accelerator.device, dtype=weight_dtype)
690710

691-
low_precision_error_string = (
692-
"Please make sure to always have all model weights in full float32 precision when starting training - even if"
693-
" doing mixed precision training. copy of the weights should still be float32."
694-
)
695-
696-
if accelerator.unwrap_model(unet).dtype != torch.float32:
697-
raise ValueError(
698-
f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
699-
)
700-
if (
701-
args.train_text_encoder
702-
and accelerator.unwrap_model(text_encoder).dtype != torch.float32
703-
):
704-
raise ValueError(
705-
f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}."
706-
f" {low_precision_error_string}"
707-
)
708-
709711
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
710712
num_update_steps_per_epoch = math.ceil(
711713
len(train_dataloader) / args.gradient_accumulation_steps

0 commit comments

Comments
 (0)