Skip to content

Commit 04c6304

Browse files
committed
Scope stable bucket ordering to cached DreamBooth batches
1 parent e31704f commit 04c6304

5 files changed

Lines changed: 90 additions & 40 deletions

examples/dreambooth/train_dreambooth_lora_flux2.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -974,7 +974,9 @@ def collate_fn(examples, with_prior_preservation=False):
974974

975975

976976
class BucketBatchSampler(BatchSampler):
977-
def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False):
977+
def __init__(
978+
self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False, shuffle_batches_each_epoch: bool = True
979+
):
978980
if not isinstance(batch_size, int) or batch_size <= 0:
979981
raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size))
980982
if not isinstance(drop_last, bool):
@@ -983,6 +985,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
983985
self.dataset = dataset
984986
self.batch_size = batch_size
985987
self.drop_last = drop_last
988+
self.shuffle_batches_each_epoch = shuffle_batches_each_epoch
986989

987990
# Group indices by bucket
988991
self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))]
@@ -1004,12 +1007,14 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
10041007
self.batches.append(batch)
10051008
self.sampler_len += 1 # Count the number of batches
10061009

1007-
# Shuffle the precomputed batches once to mix buckets while keeping
1008-
# the order stable across epochs for step-indexed caches.
1009-
random.shuffle(self.batches)
1010+
if not self.shuffle_batches_each_epoch:
1011+
# Shuffle the precomputed batches once to mix buckets while keeping
1012+
# the order stable across epochs for step-indexed caches.
1013+
random.shuffle(self.batches)
10101014

10111015
def __iter__(self):
1012-
# Keep the precomputed batch order stable so step-indexed caches stay aligned.
1016+
if self.shuffle_batches_each_epoch:
1017+
random.shuffle(self.batches)
10131018
for batch in self.batches:
10141019
yield batch
10151020

@@ -1465,7 +1470,13 @@ def load_model_hook(models, input_dir):
14651470
center_crop=args.center_crop,
14661471
buckets=buckets,
14671472
)
1468-
batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True)
1473+
has_step_indexed_caches = args.cache_latents or train_dataset.custom_instance_prompts
1474+
batch_sampler = BucketBatchSampler(
1475+
train_dataset,
1476+
batch_size=args.train_batch_size,
1477+
drop_last=True,
1478+
shuffle_batches_each_epoch=not has_step_indexed_caches,
1479+
)
14691480
train_dataloader = torch.utils.data.DataLoader(
14701481
train_dataset,
14711482
batch_sampler=batch_sampler,
@@ -1582,8 +1593,7 @@ def _encode_single(prompt: str):
15821593
# if cache_latents is set to True, we encode images to latents and store them.
15831594
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
15841595
# we encode them in advance as well.
1585-
precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
1586-
if precompute_latents:
1596+
if has_step_indexed_caches:
15871597
prompt_embeds_cache = []
15881598
text_ids_cache = []
15891599
latents_cache = []

examples/dreambooth/train_dreambooth_lora_flux2_img2img.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -972,7 +972,9 @@ def collate_fn(examples):
972972

973973

974974
class BucketBatchSampler(BatchSampler):
975-
def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False):
975+
def __init__(
976+
self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False, shuffle_batches_each_epoch: bool = True
977+
):
976978
if not isinstance(batch_size, int) or batch_size <= 0:
977979
raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size))
978980
if not isinstance(drop_last, bool):
@@ -981,6 +983,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
981983
self.dataset = dataset
982984
self.batch_size = batch_size
983985
self.drop_last = drop_last
986+
self.shuffle_batches_each_epoch = shuffle_batches_each_epoch
984987

985988
# Group indices by bucket
986989
self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))]
@@ -1002,12 +1005,14 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
10021005
self.batches.append(batch)
10031006
self.sampler_len += 1 # Count the number of batches
10041007

1005-
# Shuffle the precomputed batches once to mix buckets while keeping
1006-
# the order stable across epochs for step-indexed caches.
1007-
random.shuffle(self.batches)
1008+
if not self.shuffle_batches_each_epoch:
1009+
# Shuffle the precomputed batches once to mix buckets while keeping
1010+
# the order stable across epochs for step-indexed caches.
1011+
random.shuffle(self.batches)
10081012

10091013
def __iter__(self):
1010-
# Keep the precomputed batch order stable so step-indexed caches stay aligned.
1014+
if self.shuffle_batches_each_epoch:
1015+
random.shuffle(self.batches)
10111016
for batch in self.batches:
10121017
yield batch
10131018

@@ -1412,7 +1417,13 @@ def load_model_hook(models, input_dir):
14121417
center_crop=args.center_crop,
14131418
buckets=buckets,
14141419
)
1415-
batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True)
1420+
has_step_indexed_caches = args.cache_latents or train_dataset.custom_instance_prompts
1421+
batch_sampler = BucketBatchSampler(
1422+
train_dataset,
1423+
batch_size=args.train_batch_size,
1424+
drop_last=True,
1425+
shuffle_batches_each_epoch=not has_step_indexed_caches,
1426+
)
14161427
train_dataloader = torch.utils.data.DataLoader(
14171428
train_dataset,
14181429
batch_sampler=batch_sampler,
@@ -1515,8 +1526,7 @@ def _encode_single(prompt: str):
15151526
# if cache_latents is set to True, we encode images to latents and store them.
15161527
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
15171528
# we encode them in advance as well.
1518-
precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
1519-
if precompute_latents:
1529+
if has_step_indexed_caches:
15201530
prompt_embeds_cache = []
15211531
text_ids_cache = []
15221532
latents_cache = []

examples/dreambooth/train_dreambooth_lora_flux2_klein.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -969,7 +969,9 @@ def collate_fn(examples, with_prior_preservation=False):
969969

970970

971971
class BucketBatchSampler(BatchSampler):
972-
def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False):
972+
def __init__(
973+
self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False, shuffle_batches_each_epoch: bool = True
974+
):
973975
if not isinstance(batch_size, int) or batch_size <= 0:
974976
raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size))
975977
if not isinstance(drop_last, bool):
@@ -978,6 +980,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
978980
self.dataset = dataset
979981
self.batch_size = batch_size
980982
self.drop_last = drop_last
983+
self.shuffle_batches_each_epoch = shuffle_batches_each_epoch
981984

982985
# Group indices by bucket
983986
self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))]
@@ -999,12 +1002,14 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
9991002
self.batches.append(batch)
10001003
self.sampler_len += 1 # Count the number of batches
10011004

1002-
# Shuffle the precomputed batches once to mix buckets while keeping
1003-
# the order stable across epochs for step-indexed caches.
1004-
random.shuffle(self.batches)
1005+
if not self.shuffle_batches_each_epoch:
1006+
# Shuffle the precomputed batches once to mix buckets while keeping
1007+
# the order stable across epochs for step-indexed caches.
1008+
random.shuffle(self.batches)
10051009

10061010
def __iter__(self):
1007-
# Keep the precomputed batch order stable so step-indexed caches stay aligned.
1011+
if self.shuffle_batches_each_epoch:
1012+
random.shuffle(self.batches)
10081013
for batch in self.batches:
10091014
yield batch
10101015

@@ -1458,7 +1463,13 @@ def load_model_hook(models, input_dir):
14581463
center_crop=args.center_crop,
14591464
buckets=buckets,
14601465
)
1461-
batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True)
1466+
has_step_indexed_caches = args.cache_latents or train_dataset.custom_instance_prompts
1467+
batch_sampler = BucketBatchSampler(
1468+
train_dataset,
1469+
batch_size=args.train_batch_size,
1470+
drop_last=True,
1471+
shuffle_batches_each_epoch=not has_step_indexed_caches,
1472+
)
14621473
train_dataloader = torch.utils.data.DataLoader(
14631474
train_dataset,
14641475
batch_sampler=batch_sampler,
@@ -1525,8 +1536,7 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
15251536
# if cache_latents is set to True, we encode images to latents and store them.
15261537
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
15271538
# we encode them in advance as well.
1528-
precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
1529-
if precompute_latents:
1539+
if has_step_indexed_caches:
15301540
prompt_embeds_cache = []
15311541
text_ids_cache = []
15321542
latents_cache = []

examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -968,7 +968,9 @@ def collate_fn(examples):
968968

969969

970970
class BucketBatchSampler(BatchSampler):
971-
def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False):
971+
def __init__(
972+
self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False, shuffle_batches_each_epoch: bool = True
973+
):
972974
if not isinstance(batch_size, int) or batch_size <= 0:
973975
raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size))
974976
if not isinstance(drop_last, bool):
@@ -977,6 +979,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
977979
self.dataset = dataset
978980
self.batch_size = batch_size
979981
self.drop_last = drop_last
982+
self.shuffle_batches_each_epoch = shuffle_batches_each_epoch
980983

981984
# Group indices by bucket
982985
self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))]
@@ -998,12 +1001,14 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
9981001
self.batches.append(batch)
9991002
self.sampler_len += 1 # Count the number of batches
10001003

1001-
# Shuffle the precomputed batches once to mix buckets while keeping
1002-
# the order stable across epochs for step-indexed caches.
1003-
random.shuffle(self.batches)
1004+
if not self.shuffle_batches_each_epoch:
1005+
# Shuffle the precomputed batches once to mix buckets while keeping
1006+
# the order stable across epochs for step-indexed caches.
1007+
random.shuffle(self.batches)
10041008

10051009
def __iter__(self):
1006-
# Keep the precomputed batch order stable so step-indexed caches stay aligned.
1010+
if self.shuffle_batches_each_epoch:
1011+
random.shuffle(self.batches)
10071012
for batch in self.batches:
10081013
yield batch
10091014

@@ -1406,7 +1411,13 @@ def load_model_hook(models, input_dir):
14061411
center_crop=args.center_crop,
14071412
buckets=buckets,
14081413
)
1409-
batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True)
1414+
has_step_indexed_caches = args.cache_latents or train_dataset.custom_instance_prompts
1415+
batch_sampler = BucketBatchSampler(
1416+
train_dataset,
1417+
batch_size=args.train_batch_size,
1418+
drop_last=True,
1419+
shuffle_batches_each_epoch=not has_step_indexed_caches,
1420+
)
14101421
train_dataloader = torch.utils.data.DataLoader(
14111422
train_dataset,
14121423
batch_sampler=batch_sampler,
@@ -1466,8 +1477,7 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
14661477
# if cache_latents is set to True, we encode images to latents and store them.
14671478
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
14681479
# we encode them in advance as well.
1469-
precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
1470-
if precompute_latents:
1480+
if has_step_indexed_caches:
14711481
prompt_embeds_cache = []
14721482
text_ids_cache = []
14731483
latents_cache = []

examples/dreambooth/train_dreambooth_lora_z_image.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -963,7 +963,9 @@ def collate_fn(examples, with_prior_preservation=False):
963963

964964

965965
class BucketBatchSampler(BatchSampler):
966-
def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False):
966+
def __init__(
967+
self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False, shuffle_batches_each_epoch: bool = True
968+
):
967969
if not isinstance(batch_size, int) or batch_size <= 0:
968970
raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size))
969971
if not isinstance(drop_last, bool):
@@ -972,6 +974,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
972974
self.dataset = dataset
973975
self.batch_size = batch_size
974976
self.drop_last = drop_last
977+
self.shuffle_batches_each_epoch = shuffle_batches_each_epoch
975978

976979
# Group indices by bucket
977980
self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))]
@@ -993,12 +996,14 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
993996
self.batches.append(batch)
994997
self.sampler_len += 1 # Count the number of batches
995998

996-
# Shuffle the precomputed batches once to mix buckets while keeping
997-
# the order stable across epochs for step-indexed caches.
998-
random.shuffle(self.batches)
999+
if not self.shuffle_batches_each_epoch:
1000+
# Shuffle the precomputed batches once to mix buckets while keeping
1001+
# the order stable across epochs for step-indexed caches.
1002+
random.shuffle(self.batches)
9991003

10001004
def __iter__(self):
1001-
# Keep the precomputed batch order stable so step-indexed caches stay aligned.
1005+
if self.shuffle_batches_each_epoch:
1006+
random.shuffle(self.batches)
10021007
for batch in self.batches:
10031008
yield batch
10041009

@@ -1452,7 +1457,13 @@ def load_model_hook(models, input_dir):
14521457
center_crop=args.center_crop,
14531458
buckets=buckets,
14541459
)
1455-
batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True)
1460+
has_step_indexed_caches = args.cache_latents or train_dataset.custom_instance_prompts
1461+
batch_sampler = BucketBatchSampler(
1462+
train_dataset,
1463+
batch_size=args.train_batch_size,
1464+
drop_last=True,
1465+
shuffle_batches_each_epoch=not has_step_indexed_caches,
1466+
)
14561467
train_dataloader = torch.utils.data.DataLoader(
14571468
train_dataset,
14581469
batch_sampler=batch_sampler,
@@ -1512,8 +1523,7 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
15121523
# if cache_latents is set to True, we encode images to latents and store them.
15131524
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
15141525
# we encode them in advance as well.
1515-
precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
1516-
if precompute_latents:
1526+
if has_step_indexed_caches:
15171527
prompt_embeds_cache = []
15181528
latents_cache = []
15191529
for batch in tqdm(train_dataloader, desc="Caching latents"):

0 commit comments

Comments
 (0)