@@ -905,6 +905,7 @@ def __len__(self):
905905 def __getitem__ (self , index ):
906906 example = {}
907907 instance_image , bucket_idx = self .pixel_values [index % self .num_instance_images ]
908+ example ["index" ] = index
908909 example ["instance_images" ] = instance_image
909910 example ["bucket_idx" ] = bucket_idx
910911 if self .custom_instance_prompts :
@@ -957,7 +958,9 @@ def train_transform(self, image, size=(224, 224), center_crop=False, random_flip
957958
958959
959960def collate_fn (examples , with_prior_preservation = False ):
961+ indices = [example ["index" ] for example in examples ]
960962 pixel_values = [example ["instance_images" ] for example in examples ]
963+ instance_prompts = [example ["instance_prompt" ] for example in examples ]
961964 prompts = [example ["instance_prompt" ] for example in examples ]
962965
963966 # Concat class and instance examples for prior preservation.
@@ -969,18 +972,17 @@ def collate_fn(examples, with_prior_preservation=False):
969972 pixel_values = torch .stack (pixel_values )
970973 pixel_values = pixel_values .to (memory_format = torch .contiguous_format ).float ()
971974
972- batch = {"pixel_values" : pixel_values , "prompts" : prompts }
975+ batch = {
976+ "indices" : indices ,
977+ "pixel_values" : pixel_values ,
978+ "instance_prompts" : instance_prompts ,
979+ "prompts" : prompts ,
980+ }
973981 return batch
974982
975983
976984class BucketBatchSampler (BatchSampler ):
977- def __init__ (
978- self ,
979- dataset : DreamBoothDataset ,
980- batch_size : int ,
981- drop_last : bool = False ,
982- shuffle_batches_each_epoch : bool = True ,
983- ):
985+ def __init__ (self , dataset : DreamBoothDataset , batch_size : int , drop_last : bool = False ):
984986 if not isinstance (batch_size , int ) or batch_size <= 0 :
985987 raise ValueError ("batch_size should be a positive integer value, but got batch_size={}" .format (batch_size ))
986988 if not isinstance (drop_last , bool ):
@@ -989,37 +991,32 @@ def __init__(
989991 self .dataset = dataset
990992 self .batch_size = batch_size
991993 self .drop_last = drop_last
992- self .shuffle_batches_each_epoch = shuffle_batches_each_epoch
993994
994995 # Group indices by bucket
995996 self .bucket_indices = [[] for _ in range (len (self .dataset .buckets ))]
996997 for idx , (_ , bucket_idx ) in enumerate (self .dataset .pixel_values ):
997998 self .bucket_indices [bucket_idx ].append (idx )
998999
9991000 self .sampler_len = 0
1000- self .batches = []
1001+ for indices_in_bucket in self .bucket_indices :
1002+ num_batches , remainder = divmod (len (indices_in_bucket ), self .batch_size )
1003+ self .sampler_len += num_batches
1004+ if remainder > 0 and not self .drop_last :
1005+ self .sampler_len += 1
10011006
1002- # Pre-generate batches for each bucket
1007+ def __iter__ (self ):
1008+ batches = []
10031009 for indices_in_bucket in self .bucket_indices :
1004- # Shuffle indices within the bucket
1005- random .shuffle (indices_in_bucket )
1006- # Create batches
1007- for i in range (0 , len (indices_in_bucket ), self .batch_size ):
1008- batch = indices_in_bucket [i : i + self .batch_size ]
1010+ shuffled_indices = indices_in_bucket .copy ()
1011+ random .shuffle (shuffled_indices )
1012+ for i in range (0 , len (shuffled_indices ), self .batch_size ):
1013+ batch = shuffled_indices [i : i + self .batch_size ]
10091014 if len (batch ) < self .batch_size and self .drop_last :
1010- continue # Skip partial batch if drop_last is True
1011- self .batches .append (batch )
1012- self .sampler_len += 1 # Count the number of batches
1015+ continue
1016+ batches .append (batch )
10131017
1014- if not self .shuffle_batches_each_epoch :
1015- # Shuffle the precomputed batches once to mix buckets while keeping
1016- # the order stable across epochs for step-indexed caches.
1017- random .shuffle (self .batches )
1018-
1019- def __iter__ (self ):
1020- if self .shuffle_batches_each_epoch :
1021- random .shuffle (self .batches )
1022- for batch in self .batches :
1018+ random .shuffle (batches )
1019+ for batch in batches :
10231020 yield batch
10241021
10251022 def __len__ (self ):
@@ -1480,13 +1477,8 @@ def load_model_hook(models, input_dir):
14801477 center_crop = args .center_crop ,
14811478 buckets = buckets ,
14821479 )
1483- has_step_indexed_caches = precompute_latents = args .cache_latents or train_dataset .custom_instance_prompts
1484- batch_sampler = BucketBatchSampler (
1485- train_dataset ,
1486- batch_size = args .train_batch_size ,
1487- drop_last = True ,
1488- shuffle_batches_each_epoch = not has_step_indexed_caches ,
1489- )
1480+ precompute_latents = args .cache_latents or train_dataset .custom_instance_prompts
1481+ batch_sampler = BucketBatchSampler (train_dataset , batch_size = args .train_batch_size , drop_last = True )
14901482 train_dataloader = torch .utils .data .DataLoader (
14911483 train_dataset ,
14921484 batch_sampler = batch_sampler ,
@@ -1599,32 +1591,58 @@ def _encode_single(prompt: str):
15991591 if args .with_prior_preservation :
16001592 prompt_embeds = torch .cat ([prompt_embeds , class_prompt_hidden_states ], dim = 0 )
16011593 text_ids = torch .cat ([text_ids , class_text_ids ], dim = 0 )
1594+ static_prompt_embeds = prompt_embeds
1595+ static_text_ids = text_ids
16021596
16031597 # if cache_latents is set to True, we encode images to latents and store them.
16041598 # Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
16051599 # we encode them in advance as well.
1600+ if args .cache_latents :
1601+ instance_latents_cache = [None ] * train_dataset .num_instance_images
1602+ class_latents_cache = [None ] * train_dataset .num_instance_images if args .with_prior_preservation else None
1603+ if train_dataset .custom_instance_prompts :
1604+ prompt_embeds_cache = [None ] * train_dataset .num_instance_images
1605+ text_ids_cache = [None ] * train_dataset .num_instance_images
16061606 if precompute_latents :
1607- prompt_embeds_cache = []
1608- text_ids_cache = []
1609- latents_cache = []
1610- for batch in tqdm (train_dataloader , desc = "Caching latents" ):
1607+ cache_batch_sampler = BucketBatchSampler (train_dataset , batch_size = args .train_batch_size , drop_last = False )
1608+ cache_dataloader = torch .utils .data .DataLoader (
1609+ train_dataset ,
1610+ batch_sampler = cache_batch_sampler ,
1611+ collate_fn = lambda examples : collate_fn (examples , args .with_prior_preservation ),
1612+ num_workers = args .dataloader_num_workers ,
1613+ )
1614+ for batch in tqdm (cache_dataloader , desc = "Caching latents" ):
16111615 with torch .no_grad ():
1616+ indices = batch ["indices" ]
16121617 if args .cache_latents :
16131618 with offload_models (vae , device = accelerator .device , offload = args .offload ):
16141619 batch ["pixel_values" ] = batch ["pixel_values" ].to (
16151620 accelerator .device , non_blocking = True , dtype = vae .dtype
16161621 )
1617- latents_cache .append (vae .encode (batch ["pixel_values" ]).latent_dist )
1622+ latents = vae .encode (batch ["pixel_values" ]).latent_dist .mode ()
1623+ if args .with_prior_preservation :
1624+ instance_latents , class_latents = torch .chunk (latents , 2 , dim = 0 )
1625+ else :
1626+ instance_latents = latents
1627+ for i , idx in enumerate (indices ):
1628+ instance_latents_cache [idx ] = instance_latents [i : i + 1 ]
1629+ if args .with_prior_preservation :
1630+ class_latents_cache [idx ] = class_latents [i : i + 1 ]
16181631 if train_dataset .custom_instance_prompts :
16191632 if args .remote_text_encoder :
1620- prompt_embeds , text_ids = compute_remote_text_embeddings (batch ["prompts " ])
1633+ prompt_embeds , text_ids = compute_remote_text_embeddings (batch ["instance_prompts " ])
16211634 elif args .fsdp_text_encoder :
1622- prompt_embeds , text_ids = compute_text_embeddings (batch ["prompts" ], text_encoding_pipeline )
1635+ prompt_embeds , text_ids = compute_text_embeddings (
1636+ batch ["instance_prompts" ], text_encoding_pipeline
1637+ )
16231638 else :
16241639 with offload_models (text_encoding_pipeline , device = accelerator .device , offload = args .offload ):
1625- prompt_embeds , text_ids = compute_text_embeddings (batch ["prompts" ], text_encoding_pipeline )
1626- prompt_embeds_cache .append (prompt_embeds )
1627- text_ids_cache .append (text_ids )
1640+ prompt_embeds , text_ids = compute_text_embeddings (
1641+ batch ["instance_prompts" ], text_encoding_pipeline
1642+ )
1643+ for i , idx in enumerate (indices ):
1644+ prompt_embeds_cache [idx ] = prompt_embeds [i : i + 1 ]
1645+ text_ids_cache [idx ] = text_ids [i : i + 1 ]
16281646
16291647 # move back to cpu before deleting to ensure memory is freed see: https://github.com/huggingface/diffusers/issues/11376#issue-3008144624
16301648 if args .cache_latents :
@@ -1748,25 +1766,35 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17481766 for epoch in range (first_epoch , args .num_train_epochs ):
17491767 transformer .train ()
17501768
1751- for step , batch in enumerate ( train_dataloader ) :
1769+ for batch in train_dataloader :
17521770 models_to_accumulate = [transformer ]
1771+ indices = batch ["indices" ]
17531772 prompts = batch ["prompts" ]
17541773
17551774 with accelerator .accumulate (models_to_accumulate ):
17561775 if train_dataset .custom_instance_prompts :
1757- prompt_embeds = prompt_embeds_cache [step ]
1758- text_ids = text_ids_cache [step ]
1776+ prompt_embeds = torch .cat ([prompt_embeds_cache [idx ] for idx in indices ], dim = 0 )
1777+ text_ids = torch .cat ([text_ids_cache [idx ] for idx in indices ], dim = 0 )
1778+ if args .with_prior_preservation :
1779+ prompt_embeds = torch .cat (
1780+ [prompt_embeds , class_prompt_hidden_states .repeat (len (indices ), 1 , 1 )], dim = 0
1781+ )
1782+ text_ids = torch .cat ([text_ids , class_text_ids .repeat (len (indices ), 1 , 1 )], dim = 0 )
17591783 else :
17601784 # With prior preservation, prompt_embeds/text_ids already contain [instance, class] entries,
17611785 # while collate_fn orders batches as [inst1..instB, class1..classB]. Repeat each entry along
17621786 # dim 0 to preserve that grouping instead of interleaving [inst, class, inst, class, ...].
17631787 num_repeat_elements = len (prompts ) // 2 if args .with_prior_preservation else len (prompts )
1764- prompt_embeds = prompt_embeds .repeat_interleave (num_repeat_elements , dim = 0 )
1765- text_ids = text_ids .repeat_interleave (num_repeat_elements , dim = 0 )
1788+ prompt_embeds = static_prompt_embeds .repeat_interleave (num_repeat_elements , dim = 0 )
1789+ text_ids = static_text_ids .repeat_interleave (num_repeat_elements , dim = 0 )
17661790
17671791 # Convert images to latent space
17681792 if args .cache_latents :
1769- model_input = latents_cache [step ].mode ()
1793+ model_input = torch .cat ([instance_latents_cache [idx ] for idx in indices ], dim = 0 )
1794+ if args .with_prior_preservation :
1795+ model_input = torch .cat (
1796+ [model_input , torch .cat ([class_latents_cache [idx ] for idx in indices ], dim = 0 )], dim = 0
1797+ )
17701798 else :
17711799 with offload_models (vae , device = accelerator .device , offload = args .offload ):
17721800 pixel_values = batch ["pixel_values" ].to (device = accelerator .device , dtype = vae .dtype )
0 commit comments