Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 80 additions & 51 deletions examples/dreambooth/train_dreambooth_lora_flux2.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,7 @@ def __len__(self):
def __getitem__(self, index):
example = {}
instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images]
example["index"] = index
example["instance_images"] = instance_image
example["bucket_idx"] = bucket_idx
if self.custom_instance_prompts:
Expand Down Expand Up @@ -957,7 +958,9 @@ def train_transform(self, image, size=(224, 224), center_crop=False, random_flip


def collate_fn(examples, with_prior_preservation=False):
indices = [example["index"] for example in examples]
pixel_values = [example["instance_images"] for example in examples]
instance_prompts = [example["instance_prompt"] for example in examples]
prompts = [example["instance_prompt"] for example in examples]

# Concat class and instance examples for prior preservation.
Expand All @@ -969,18 +972,17 @@ def collate_fn(examples, with_prior_preservation=False):
pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()

batch = {"pixel_values": pixel_values, "prompts": prompts}
batch = {
"indices": indices,
"pixel_values": pixel_values,
"instance_prompts": instance_prompts,
"prompts": prompts,
}
return batch


class BucketBatchSampler(BatchSampler):
def __init__(
self,
dataset: DreamBoothDataset,
batch_size: int,
drop_last: bool = False,
shuffle_batches_each_epoch: bool = True,
):
def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False):
if not isinstance(batch_size, int) or batch_size <= 0:
raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size))
if not isinstance(drop_last, bool):
Expand All @@ -989,37 +991,32 @@ def __init__(
self.dataset = dataset
self.batch_size = batch_size
self.drop_last = drop_last
self.shuffle_batches_each_epoch = shuffle_batches_each_epoch

# Group indices by bucket
self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))]
for idx, (_, bucket_idx) in enumerate(self.dataset.pixel_values):
self.bucket_indices[bucket_idx].append(idx)

self.sampler_len = 0
self.batches = []
for indices_in_bucket in self.bucket_indices:
num_batches, remainder = divmod(len(indices_in_bucket), self.batch_size)
self.sampler_len += num_batches
if remainder > 0 and not self.drop_last:
self.sampler_len += 1

# Pre-generate batches for each bucket
def __iter__(self):
batches = []
for indices_in_bucket in self.bucket_indices:
# Shuffle indices within the bucket
random.shuffle(indices_in_bucket)
# Create batches
for i in range(0, len(indices_in_bucket), self.batch_size):
batch = indices_in_bucket[i : i + self.batch_size]
shuffled_indices = indices_in_bucket.copy()
random.shuffle(shuffled_indices)
for i in range(0, len(shuffled_indices), self.batch_size):
batch = shuffled_indices[i : i + self.batch_size]
if len(batch) < self.batch_size and self.drop_last:
continue # Skip partial batch if drop_last is True
self.batches.append(batch)
self.sampler_len += 1 # Count the number of batches
continue
batches.append(batch)

if not self.shuffle_batches_each_epoch:
# Shuffle the precomputed batches once to mix buckets while keeping
# the order stable across epochs for step-indexed caches.
random.shuffle(self.batches)

def __iter__(self):
if self.shuffle_batches_each_epoch:
random.shuffle(self.batches)
for batch in self.batches:
random.shuffle(batches)
for batch in batches:
yield batch

def __len__(self):
Expand Down Expand Up @@ -1480,13 +1477,8 @@ def load_model_hook(models, input_dir):
center_crop=args.center_crop,
buckets=buckets,
)
has_step_indexed_caches = precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
batch_sampler = BucketBatchSampler(
train_dataset,
batch_size=args.train_batch_size,
drop_last=True,
shuffle_batches_each_epoch=not has_step_indexed_caches,
)
precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_sampler=batch_sampler,
Expand Down Expand Up @@ -1599,32 +1591,58 @@ def _encode_single(prompt: str):
if args.with_prior_preservation:
prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
text_ids = torch.cat([text_ids, class_text_ids], dim=0)
static_prompt_embeds = prompt_embeds
static_text_ids = text_ids

# if cache_latents is set to True, we encode images to latents and store them.
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
# we encode them in advance as well.
if args.cache_latents:
instance_latents_cache = [None] * train_dataset.num_instance_images
class_latents_cache = [None] * train_dataset.num_instance_images if args.with_prior_preservation else None
if train_dataset.custom_instance_prompts:
prompt_embeds_cache = [None] * train_dataset.num_instance_images
text_ids_cache = [None] * train_dataset.num_instance_images
if precompute_latents:
prompt_embeds_cache = []
text_ids_cache = []
latents_cache = []
for batch in tqdm(train_dataloader, desc="Caching latents"):
cache_batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=False)
cache_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_sampler=cache_batch_sampler,
collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
num_workers=args.dataloader_num_workers,
)
for batch in tqdm(cache_dataloader, desc="Caching latents"):
with torch.no_grad():
sample_indices = batch["indices"]
if args.cache_latents:
with offload_models(vae, device=accelerator.device, offload=args.offload):
batch["pixel_values"] = batch["pixel_values"].to(
accelerator.device, non_blocking=True, dtype=vae.dtype
)
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
latents = vae.encode(batch["pixel_values"]).latent_dist.mode()
if args.with_prior_preservation:
instance_latents, class_latents = torch.chunk(latents, 2, dim=0)
else:
instance_latents = latents
for i, idx in enumerate(sample_indices):
instance_latents_cache[idx] = instance_latents[i : i + 1]
if args.with_prior_preservation:
class_latents_cache[idx] = class_latents[i : i + 1]
if train_dataset.custom_instance_prompts:
if args.remote_text_encoder:
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"])
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["instance_prompts"])
elif args.fsdp_text_encoder:
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
prompt_embeds, text_ids = compute_text_embeddings(
batch["instance_prompts"], text_encoding_pipeline
)
else:
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
prompt_embeds_cache.append(prompt_embeds)
text_ids_cache.append(text_ids)
prompt_embeds, text_ids = compute_text_embeddings(
batch["instance_prompts"], text_encoding_pipeline
)
for i, idx in enumerate(sample_indices):
prompt_embeds_cache[idx] = prompt_embeds[i : i + 1]
text_ids_cache[idx] = text_ids[i : i + 1]

# move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624
if args.cache_latents:
Expand Down Expand Up @@ -1748,25 +1766,36 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
for epoch in range(first_epoch, args.num_train_epochs):
transformer.train()

for step, batch in enumerate(train_dataloader):
for batch in train_dataloader:
models_to_accumulate = [transformer]
sample_indices = batch["indices"]
prompts = batch["prompts"]
Comment thread
azolotenkov marked this conversation as resolved.

with accelerator.accumulate(models_to_accumulate):
if train_dataset.custom_instance_prompts:
prompt_embeds = prompt_embeds_cache[step]
text_ids = text_ids_cache[step]
prompt_embeds = torch.cat([prompt_embeds_cache[idx] for idx in sample_indices], dim=0)
text_ids = torch.cat([text_ids_cache[idx] for idx in sample_indices], dim=0)
if args.with_prior_preservation:
prompt_embeds = torch.cat(
[prompt_embeds, class_prompt_hidden_states.repeat(len(sample_indices), 1, 1)], dim=0
)
text_ids = torch.cat([text_ids, class_text_ids.repeat(len(sample_indices), 1, 1)], dim=0)
else:
# With prior preservation, prompt_embeds/text_ids already contain [instance, class] entries,
# while collate_fn orders batches as [inst1..instB, class1..classB]. Repeat each entry along
# dim 0 to preserve that grouping instead of interleaving [inst, class, inst, class, ...].
num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts)
prompt_embeds = prompt_embeds.repeat_interleave(num_repeat_elements, dim=0)
text_ids = text_ids.repeat_interleave(num_repeat_elements, dim=0)
prompt_embeds = static_prompt_embeds.repeat_interleave(num_repeat_elements, dim=0)
text_ids = static_text_ids.repeat_interleave(num_repeat_elements, dim=0)

# Convert images to latent space
if args.cache_latents:
model_input = latents_cache[step].mode()
model_input = torch.cat([instance_latents_cache[idx] for idx in sample_indices], dim=0)
if args.with_prior_preservation:
model_input = torch.cat(
[model_input, torch.cat([class_latents_cache[idx] for idx in sample_indices], dim=0)],
dim=0,
)
else:
with offload_models(vae, device=accelerator.device, offload=args.offload):
pixel_values = batch["pixel_values"].to(device=accelerator.device, dtype=vae.dtype)
Expand Down
Loading
Loading