Skip to content

Commit 6cf74c0

Browse files
committed
Allow bucket reshuffling with DreamBooth caches
1 parent 48f39c2 commit 6cf74c0

2 files changed

Lines changed: 157 additions & 101 deletions

File tree

examples/dreambooth/train_dreambooth_lora_flux2.py

Lines changed: 79 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,7 @@ def __len__(self):
905905
def __getitem__(self, index):
906906
example = {}
907907
instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images]
908+
example["index"] = index
908909
example["instance_images"] = instance_image
909910
example["bucket_idx"] = bucket_idx
910911
if self.custom_instance_prompts:
@@ -957,7 +958,9 @@ def train_transform(self, image, size=(224, 224), center_crop=False, random_flip
957958

958959

959960
def collate_fn(examples, with_prior_preservation=False):
961+
indices = [example["index"] for example in examples]
960962
pixel_values = [example["instance_images"] for example in examples]
963+
instance_prompts = [example["instance_prompt"] for example in examples]
961964
prompts = [example["instance_prompt"] for example in examples]
962965

963966
# Concat class and instance examples for prior preservation.
@@ -969,18 +972,17 @@ def collate_fn(examples, with_prior_preservation=False):
969972
pixel_values = torch.stack(pixel_values)
970973
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
971974

972-
batch = {"pixel_values": pixel_values, "prompts": prompts}
975+
batch = {
976+
"indices": indices,
977+
"pixel_values": pixel_values,
978+
"instance_prompts": instance_prompts,
979+
"prompts": prompts,
980+
}
973981
return batch
974982

975983

976984
class BucketBatchSampler(BatchSampler):
977-
def __init__(
978-
self,
979-
dataset: DreamBoothDataset,
980-
batch_size: int,
981-
drop_last: bool = False,
982-
shuffle_batches_each_epoch: bool = True,
983-
):
985+
def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False):
984986
if not isinstance(batch_size, int) or batch_size <= 0:
985987
raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size))
986988
if not isinstance(drop_last, bool):
@@ -989,37 +991,32 @@ def __init__(
989991
self.dataset = dataset
990992
self.batch_size = batch_size
991993
self.drop_last = drop_last
992-
self.shuffle_batches_each_epoch = shuffle_batches_each_epoch
993994

994995
# Group indices by bucket
995996
self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))]
996997
for idx, (_, bucket_idx) in enumerate(self.dataset.pixel_values):
997998
self.bucket_indices[bucket_idx].append(idx)
998999

9991000
self.sampler_len = 0
1000-
self.batches = []
1001+
for indices_in_bucket in self.bucket_indices:
1002+
num_batches, remainder = divmod(len(indices_in_bucket), self.batch_size)
1003+
self.sampler_len += num_batches
1004+
if remainder > 0 and not self.drop_last:
1005+
self.sampler_len += 1
10011006

1002-
# Pre-generate batches for each bucket
1007+
def __iter__(self):
1008+
batches = []
10031009
for indices_in_bucket in self.bucket_indices:
1004-
# Shuffle indices within the bucket
1005-
random.shuffle(indices_in_bucket)
1006-
# Create batches
1007-
for i in range(0, len(indices_in_bucket), self.batch_size):
1008-
batch = indices_in_bucket[i : i + self.batch_size]
1010+
shuffled_indices = indices_in_bucket.copy()
1011+
random.shuffle(shuffled_indices)
1012+
for i in range(0, len(shuffled_indices), self.batch_size):
1013+
batch = shuffled_indices[i : i + self.batch_size]
10091014
if len(batch) < self.batch_size and self.drop_last:
1010-
continue # Skip partial batch if drop_last is True
1011-
self.batches.append(batch)
1012-
self.sampler_len += 1 # Count the number of batches
1015+
continue
1016+
batches.append(batch)
10131017

1014-
if not self.shuffle_batches_each_epoch:
1015-
# Shuffle the precomputed batches once to mix buckets while keeping
1016-
# the order stable across epochs for step-indexed caches.
1017-
random.shuffle(self.batches)
1018-
1019-
def __iter__(self):
1020-
if self.shuffle_batches_each_epoch:
1021-
random.shuffle(self.batches)
1022-
for batch in self.batches:
1018+
random.shuffle(batches)
1019+
for batch in batches:
10231020
yield batch
10241021

10251022
def __len__(self):
@@ -1480,13 +1477,8 @@ def load_model_hook(models, input_dir):
14801477
center_crop=args.center_crop,
14811478
buckets=buckets,
14821479
)
1483-
has_step_indexed_caches = precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
1484-
batch_sampler = BucketBatchSampler(
1485-
train_dataset,
1486-
batch_size=args.train_batch_size,
1487-
drop_last=True,
1488-
shuffle_batches_each_epoch=not has_step_indexed_caches,
1489-
)
1480+
precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
1481+
batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True)
14901482
train_dataloader = torch.utils.data.DataLoader(
14911483
train_dataset,
14921484
batch_sampler=batch_sampler,
@@ -1599,32 +1591,58 @@ def _encode_single(prompt: str):
15991591
if args.with_prior_preservation:
16001592
prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
16011593
text_ids = torch.cat([text_ids, class_text_ids], dim=0)
1594+
static_prompt_embeds = prompt_embeds
1595+
static_text_ids = text_ids
16021596

16031597
# if cache_latents is set to True, we encode images to latents and store them.
16041598
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
16051599
# we encode them in advance as well.
1600+
if args.cache_latents:
1601+
instance_latents_cache = [None] * train_dataset.num_instance_images
1602+
class_latents_cache = [None] * train_dataset.num_instance_images if args.with_prior_preservation else None
1603+
if train_dataset.custom_instance_prompts:
1604+
prompt_embeds_cache = [None] * train_dataset.num_instance_images
1605+
text_ids_cache = [None] * train_dataset.num_instance_images
16061606
if precompute_latents:
1607-
prompt_embeds_cache = []
1608-
text_ids_cache = []
1609-
latents_cache = []
1610-
for batch in tqdm(train_dataloader, desc="Caching latents"):
1607+
cache_batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=False)
1608+
cache_dataloader = torch.utils.data.DataLoader(
1609+
train_dataset,
1610+
batch_sampler=cache_batch_sampler,
1611+
collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
1612+
num_workers=args.dataloader_num_workers,
1613+
)
1614+
for batch in tqdm(cache_dataloader, desc="Caching latents"):
16111615
with torch.no_grad():
1616+
indices = batch["indices"]
16121617
if args.cache_latents:
16131618
with offload_models(vae, device=accelerator.device, offload=args.offload):
16141619
batch["pixel_values"] = batch["pixel_values"].to(
16151620
accelerator.device, non_blocking=True, dtype=vae.dtype
16161621
)
1617-
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
1622+
latents = vae.encode(batch["pixel_values"]).latent_dist.mode()
1623+
if args.with_prior_preservation:
1624+
instance_latents, class_latents = torch.chunk(latents, 2, dim=0)
1625+
else:
1626+
instance_latents = latents
1627+
for i, idx in enumerate(indices):
1628+
instance_latents_cache[idx] = instance_latents[i : i + 1]
1629+
if args.with_prior_preservation:
1630+
class_latents_cache[idx] = class_latents[i : i + 1]
16181631
if train_dataset.custom_instance_prompts:
16191632
if args.remote_text_encoder:
1620-
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"])
1633+
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["instance_prompts"])
16211634
elif args.fsdp_text_encoder:
1622-
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
1635+
prompt_embeds, text_ids = compute_text_embeddings(
1636+
batch["instance_prompts"], text_encoding_pipeline
1637+
)
16231638
else:
16241639
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
1625-
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
1626-
prompt_embeds_cache.append(prompt_embeds)
1627-
text_ids_cache.append(text_ids)
1640+
prompt_embeds, text_ids = compute_text_embeddings(
1641+
batch["instance_prompts"], text_encoding_pipeline
1642+
)
1643+
for i, idx in enumerate(indices):
1644+
prompt_embeds_cache[idx] = prompt_embeds[i : i + 1]
1645+
text_ids_cache[idx] = text_ids[i : i + 1]
16281646

16291647
# move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624
16301648
if args.cache_latents:
@@ -1748,25 +1766,35 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17481766
for epoch in range(first_epoch, args.num_train_epochs):
17491767
transformer.train()
17501768

1751-
for step, batch in enumerate(train_dataloader):
1769+
for batch in train_dataloader:
17521770
models_to_accumulate = [transformer]
1771+
indices = batch["indices"]
17531772
prompts = batch["prompts"]
17541773

17551774
with accelerator.accumulate(models_to_accumulate):
17561775
if train_dataset.custom_instance_prompts:
1757-
prompt_embeds = prompt_embeds_cache[step]
1758-
text_ids = text_ids_cache[step]
1776+
prompt_embeds = torch.cat([prompt_embeds_cache[idx] for idx in indices], dim=0)
1777+
text_ids = torch.cat([text_ids_cache[idx] for idx in indices], dim=0)
1778+
if args.with_prior_preservation:
1779+
prompt_embeds = torch.cat(
1780+
[prompt_embeds, class_prompt_hidden_states.repeat(len(indices), 1, 1)], dim=0
1781+
)
1782+
text_ids = torch.cat([text_ids, class_text_ids.repeat(len(indices), 1, 1)], dim=0)
17591783
else:
17601784
# With prior preservation, prompt_embeds/text_ids already contain [instance, class] entries,
17611785
# while collate_fn orders batches as [inst1..instB, class1..classB]. Repeat each entry along
17621786
# dim 0 to preserve that grouping instead of interleaving [inst, class, inst, class, ...].
17631787
num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts)
1764-
prompt_embeds = prompt_embeds.repeat_interleave(num_repeat_elements, dim=0)
1765-
text_ids = text_ids.repeat_interleave(num_repeat_elements, dim=0)
1788+
prompt_embeds = static_prompt_embeds.repeat_interleave(num_repeat_elements, dim=0)
1789+
text_ids = static_text_ids.repeat_interleave(num_repeat_elements, dim=0)
17661790

17671791
# Convert images to latent space
17681792
if args.cache_latents:
1769-
model_input = latents_cache[step].mode()
1793+
model_input = torch.cat([instance_latents_cache[idx] for idx in indices], dim=0)
1794+
if args.with_prior_preservation:
1795+
model_input = torch.cat(
1796+
[model_input, torch.cat([class_latents_cache[idx] for idx in indices], dim=0)], dim=0
1797+
)
17701798
else:
17711799
with offload_models(vae, device=accelerator.device, offload=args.offload):
17721800
pixel_values = batch["pixel_values"].to(device=accelerator.device, dtype=vae.dtype)

0 commit comments

Comments
 (0)