|
1 | 1 | # Based on https://github.com/huggingface/diffusers/commits/main/examples/dreambooth/train_dreambooth.py |
2 | | -# Synced to commit c42f6ee43e0408c5fe8a1d3dc3cdeb9eb3a02fa6 on 2023-06-14 |
| 2 | +# Synced to commit b9feed87958c27074b0618cc543696c05f58e2c9 on 2023-07-12 |
3 | 3 |
|
4 | 4 | # Reasons for not using that file directly: |
5 | 5 | # |
|
29 | 29 | import logging |
30 | 30 | import math |
31 | 31 | import os |
| 32 | +import shutil |
32 | 33 | import warnings |
33 | 34 | from pathlib import Path |
34 | 35 |
|
@@ -257,7 +258,7 @@ def TrainDreamBooth(model_id: str, pipeline, model_inputs, call_inputs, send_opt |
257 | 258 | import wandb |
258 | 259 |
|
259 | 260 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. |
260 | | -check_min_version("0.17.0") |
| 261 | +check_min_version("0.19.0.dev0") |
261 | 262 |
|
262 | 263 | logger = get_logger(__name__) |
263 | 264 |
|
@@ -653,9 +654,7 @@ def main(args, init_pipeline, send_opts): |
653 | 654 | logging_dir = Path(args.output_dir, args.logging_dir) |
654 | 655 |
|
655 | 656 | accelerator_project_config = ProjectConfiguration( |
656 | | - total_limit=args.checkpoints_total_limit, |
657 | | - project_dir=args.output_dir, |
658 | | - logging_dir=logging_dir, |
| 657 | + project_dir=args.output_dir, logging_dir=logging_dir |
659 | 658 | ) |
660 | 659 |
|
661 | 660 | accelerator = Accelerator( |
@@ -1055,8 +1054,8 @@ def compute_text_embeddings(prompt): |
1055 | 1054 | unet, optimizer, train_dataloader, lr_scheduler |
1056 | 1055 | ) |
1057 | 1056 |
|
1058 | | - # For mixed precision training we cast the text_encoder and vae weights to half-precision |
1059 | | - # as these models are only used for inference, keeping weights in full precision is not required. |
| 1057 | + # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision |
| 1058 | + # as these weights are only used for inference, keeping weights in full precision is not required. |
1060 | 1059 | weight_dtype = torch.float32 |
1061 | 1060 | if accelerator.mixed_precision == "fp16": |
1062 | 1061 | weight_dtype = torch.float16 |
@@ -1277,14 +1276,46 @@ def compute_text_embeddings(prompt): |
1277 | 1276 | global_step += 1 |
1278 | 1277 |
|
1279 | 1278 | if accelerator.is_main_process: |
1280 | | - images = [] |
1281 | 1279 | if global_step % args.checkpointing_steps == 0: |
| 1280 | + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` |
| 1281 | + if args.checkpoints_total_limit is not None: |
| 1282 | + checkpoints = os.listdir(args.output_dir) |
| 1283 | + checkpoints = [ |
| 1284 | + d for d in checkpoints if d.startswith("checkpoint") |
| 1285 | + ] |
| 1286 | + checkpoints = sorted( |
| 1287 | + checkpoints, key=lambda x: int(x.split("-")[1]) |
| 1288 | + ) |
| 1289 | + |
| 1290 | + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints |
| 1291 | + if len(checkpoints) >= args.checkpoints_total_limit: |
| 1292 | + num_to_remove = ( |
| 1293 | + len(checkpoints) - args.checkpoints_total_limit + 1 |
| 1294 | + ) |
| 1295 | + removing_checkpoints = checkpoints[0:num_to_remove] |
| 1296 | + |
| 1297 | + logger.info( |
| 1298 | + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" |
| 1299 | + ) |
| 1300 | + logger.info( |
| 1301 | + f"removing checkpoints: {', '.join(removing_checkpoints)}" |
| 1302 | + ) |
| 1303 | + |
| 1304 | + for removing_checkpoint in removing_checkpoints: |
| 1305 | + removing_checkpoint = os.path.join( |
| 1306 | + args.output_dir, removing_checkpoint |
| 1307 | + ) |
| 1308 | + shutil.rmtree(removing_checkpoint) |
| 1309 | + |
1282 | 1310 | save_path = os.path.join( |
1283 | 1311 | args.output_dir, f"checkpoint-{global_step}" |
1284 | 1312 | ) |
1285 | 1313 | pipeline.save_pretrained(save_path) |
1286 | 1314 | accelerator.save_state(save_path) |
1287 | 1315 | logger.info(f"Saved state to {save_path}") |
| 1316 | + |
| 1317 | + images = [] |
| 1318 | + |
1288 | 1319 | if ( |
1289 | 1320 | args.validation_prompt is not None |
1290 | 1321 | and global_step % args.validation_steps == 0 |
|
0 commit comments