Skip to content

Commit be1c322

Browse files
committed
fix(deps): bump diffusers to b9feed8, lock bitsandbytes==0.39.1
merge upstream changes into train_dreambooth.py.
1 parent 7c175e2 commit be1c322

File tree

2 files changed

+43
-14
lines changed

2 files changed

+43
-14
lines changed

Dockerfile

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,8 @@ WORKDIR /api
3737
ADD requirements.txt requirements.txt
3838
RUN pip install -r requirements.txt
3939

40-
# [9965cb5] [Community Pipelines] Update lpw_stable_diffusion pipeline (#3197)
41-
# Above was reverted shortly afterwards for not being backwards compatible.
42-
43-
# [ce55049] Update pipeline_flax_stable_diffusion_controlnet.py (#3306)
44-
ARG DIFFUSERS_VERSION="ce5504934ac484fca39a1a5434ecfae09eabdf41"
40+
# [b9feed8] move to 0.19.0dev (#4048)
41+
ARG DIFFUSERS_VERSION="b9feed87958c27074b0618cc543696c05f58e2c9"
4542
ENV DIFFUSERS_VERSION=${DIFFUSERS_VERSION}
4643

4744
RUN git clone https://github.com/huggingface/diffusers && cd diffusers && git checkout ${DIFFUSERS_VERSION}
@@ -69,7 +66,8 @@ ENV USE_DREAMBOOTH=${USE_DREAMBOOTH}
6966
RUN if [ "$USE_DREAMBOOTH" = "1" ] ; then \
7067
# By specifying the same torch version as conda, it won't download again.
7168
# Without this, it will upgrade torch, break xformers, make bigger image.
72-
pip install -r diffusers/examples/dreambooth/requirements.txt bitsandbytes torch==1.12.1 ; \
69+
# bitsandbytes==0.40.0.post4 had failed cuda detection on dreambooth test.
70+
pip install -r diffusers/examples/dreambooth/requirements.txt bitsandbytes==0.39.1 torch==1.12.1 ; \
7371
fi
7472
RUN if [ "$USE_DREAMBOOTH" = "1" ] ; then apt-get install git-lfs ; fi
7573

api/train_dreambooth.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Based on https://github.com/huggingface/diffusers/commits/main/examples/dreambooth/train_dreambooth.py
2-
# Synced to commit c42f6ee43e0408c5fe8a1d3dc3cdeb9eb3a02fa6 on 2023-06-14
2+
# Synced to commit b9feed87958c27074b0618cc543696c05f58e2c9 on 2023-07-12
33

44
# Reasons for not using that file directly:
55
#
@@ -29,6 +29,7 @@
2929
import logging
3030
import math
3131
import os
32+
import shutil
3233
import warnings
3334
from pathlib import Path
3435

@@ -257,7 +258,7 @@ def TrainDreamBooth(model_id: str, pipeline, model_inputs, call_inputs, send_opt
257258
import wandb
258259

259260
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
260-
check_min_version("0.17.0")
261+
check_min_version("0.19.0.dev0")
261262

262263
logger = get_logger(__name__)
263264

@@ -653,9 +654,7 @@ def main(args, init_pipeline, send_opts):
653654
logging_dir = Path(args.output_dir, args.logging_dir)
654655

655656
accelerator_project_config = ProjectConfiguration(
656-
total_limit=args.checkpoints_total_limit,
657-
project_dir=args.output_dir,
658-
logging_dir=logging_dir,
657+
project_dir=args.output_dir, logging_dir=logging_dir
659658
)
660659

661660
accelerator = Accelerator(
@@ -1055,8 +1054,8 @@ def compute_text_embeddings(prompt):
10551054
unet, optimizer, train_dataloader, lr_scheduler
10561055
)
10571056

1058-
# For mixed precision training we cast the text_encoder and vae weights to half-precision
1059-
# as these models are only used for inference, keeping weights in full precision is not required.
1057+
# For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
1058+
# as these weights are only used for inference, keeping weights in full precision is not required.
10601059
weight_dtype = torch.float32
10611060
if accelerator.mixed_precision == "fp16":
10621061
weight_dtype = torch.float16
@@ -1277,14 +1276,46 @@ def compute_text_embeddings(prompt):
12771276
global_step += 1
12781277

12791278
if accelerator.is_main_process:
1280-
images = []
12811279
if global_step % args.checkpointing_steps == 0:
1280+
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1281+
if args.checkpoints_total_limit is not None:
1282+
checkpoints = os.listdir(args.output_dir)
1283+
checkpoints = [
1284+
d for d in checkpoints if d.startswith("checkpoint")
1285+
]
1286+
checkpoints = sorted(
1287+
checkpoints, key=lambda x: int(x.split("-")[1])
1288+
)
1289+
1290+
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1291+
if len(checkpoints) >= args.checkpoints_total_limit:
1292+
num_to_remove = (
1293+
len(checkpoints) - args.checkpoints_total_limit + 1
1294+
)
1295+
removing_checkpoints = checkpoints[0:num_to_remove]
1296+
1297+
logger.info(
1298+
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1299+
)
1300+
logger.info(
1301+
f"removing checkpoints: {', '.join(removing_checkpoints)}"
1302+
)
1303+
1304+
for removing_checkpoint in removing_checkpoints:
1305+
removing_checkpoint = os.path.join(
1306+
args.output_dir, removing_checkpoint
1307+
)
1308+
shutil.rmtree(removing_checkpoint)
1309+
12821310
save_path = os.path.join(
12831311
args.output_dir, f"checkpoint-{global_step}"
12841312
)
12851313
pipeline.save_pretrained(save_path)
12861314
accelerator.save_state(save_path)
12871315
logger.info(f"Saved state to {save_path}")
1316+
1317+
images = []
1318+
12881319
if (
12891320
args.validation_prompt is not None
12901321
and global_step % args.validation_steps == 0

0 commit comments

Comments
 (0)