@@ -969,7 +969,13 @@ def collate_fn(examples, with_prior_preservation=False):
969969
970970
971971class BucketBatchSampler (BatchSampler ):
972- def __init__ (self , dataset : DreamBoothDataset , batch_size : int , drop_last : bool = False ):
972+ def __init__ (
973+ self ,
974+ dataset : DreamBoothDataset ,
975+ batch_size : int ,
976+ drop_last : bool = False ,
977+ shuffle_batches_each_epoch : bool = True ,
978+ ):
973979 if not isinstance (batch_size , int ) or batch_size <= 0 :
974980 raise ValueError ("batch_size should be a positive integer value, but got batch_size={}" .format (batch_size ))
975981 if not isinstance (drop_last , bool ):
@@ -978,6 +984,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
978984 self .dataset = dataset
979985 self .batch_size = batch_size
980986 self .drop_last = drop_last
987+ self .shuffle_batches_each_epoch = shuffle_batches_each_epoch
981988
982989 # Group indices by bucket
983990 self .bucket_indices = [[] for _ in range (len (self .dataset .buckets ))]
@@ -999,9 +1006,14 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
9991006 self .batches .append (batch )
10001007 self .sampler_len += 1 # Count the number of batches
10011008
1009+ if not self .shuffle_batches_each_epoch :
1010+ # Shuffle the precomputed batches once to mix buckets while keeping
1011+ # the order stable across epochs for step-indexed caches.
1012+ random .shuffle (self .batches )
1013+
10021014 def __iter__ (self ):
1003- # Shuffle the order of the batches each epoch
1004- random .shuffle (self .batches )
1015+ if self . shuffle_batches_each_epoch :
1016+ random .shuffle (self .batches )
10051017 for batch in self .batches :
10061018 yield batch
10071019
@@ -1461,7 +1473,13 @@ def load_model_hook(models, input_dir):
14611473 center_crop = args .center_crop ,
14621474 buckets = buckets ,
14631475 )
1464- batch_sampler = BucketBatchSampler (train_dataset , batch_size = args .train_batch_size , drop_last = True )
1476+ has_step_indexed_caches = precompute_latents = args .cache_latents or train_dataset .custom_instance_prompts
1477+ batch_sampler = BucketBatchSampler (
1478+ train_dataset ,
1479+ batch_size = args .train_batch_size ,
1480+ drop_last = True ,
1481+ shuffle_batches_each_epoch = not has_step_indexed_caches ,
1482+ )
14651483 train_dataloader = torch .utils .data .DataLoader (
14661484 train_dataset ,
14671485 batch_sampler = batch_sampler ,
@@ -1528,7 +1546,6 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
15281546 # if cache_latents is set to True, we encode images to latents and store them.
15291547 # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
15301548 # we encode them in advance as well.
1531- precompute_latents = args .cache_latents or train_dataset .custom_instance_prompts
15321549 if precompute_latents :
15331550 prompt_embeds_cache = []
15341551 text_ids_cache = []
0 commit comments