Skip to content

Commit 57f10d5

Browse files
committed
Address bucket sampler cache variable naming review
1 parent fff6c8c commit 57f10d5

5 files changed

Lines changed: 10 additions & 10 deletions

examples/dreambooth/train_dreambooth_lora_flux2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1480,7 +1480,7 @@ def load_model_hook(models, input_dir):
14801480
center_crop=args.center_crop,
14811481
buckets=buckets,
14821482
)
1483-
has_step_indexed_caches = args.cache_latents or train_dataset.custom_instance_prompts
1483+
has_step_indexed_caches = precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
14841484
batch_sampler = BucketBatchSampler(
14851485
train_dataset,
14861486
batch_size=args.train_batch_size,
@@ -1603,7 +1603,7 @@ def _encode_single(prompt: str):
16031603
# if cache_latents is set to True, we encode images to latents and store them.
16041604
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
16051605
# we encode them in advance as well.
1606-
if has_step_indexed_caches:
1606+
if precompute_latents:
16071607
prompt_embeds_cache = []
16081608
text_ids_cache = []
16091609
latents_cache = []

examples/dreambooth/train_dreambooth_lora_flux2_img2img.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1427,7 +1427,7 @@ def load_model_hook(models, input_dir):
14271427
center_crop=args.center_crop,
14281428
buckets=buckets,
14291429
)
1430-
has_step_indexed_caches = args.cache_latents or train_dataset.custom_instance_prompts
1430+
has_step_indexed_caches = precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
14311431
batch_sampler = BucketBatchSampler(
14321432
train_dataset,
14331433
batch_size=args.train_batch_size,
@@ -1536,7 +1536,7 @@ def _encode_single(prompt: str):
15361536
# if cache_latents is set to True, we encode images to latents and store them.
15371537
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
15381538
# we encode them in advance as well.
1539-
if has_step_indexed_caches:
1539+
if precompute_latents:
15401540
prompt_embeds_cache = []
15411541
text_ids_cache = []
15421542
latents_cache = []

examples/dreambooth/train_dreambooth_lora_flux2_klein.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1473,7 +1473,7 @@ def load_model_hook(models, input_dir):
14731473
center_crop=args.center_crop,
14741474
buckets=buckets,
14751475
)
1476-
has_step_indexed_caches = args.cache_latents or train_dataset.custom_instance_prompts
1476+
has_step_indexed_caches = precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
14771477
batch_sampler = BucketBatchSampler(
14781478
train_dataset,
14791479
batch_size=args.train_batch_size,
@@ -1546,7 +1546,7 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
15461546
# if cache_latents is set to True, we encode images to latents and store them.
15471547
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
15481548
# we encode them in advance as well.
1549-
if has_step_indexed_caches:
1549+
if precompute_latents:
15501550
prompt_embeds_cache = []
15511551
text_ids_cache = []
15521552
latents_cache = []

examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1421,7 +1421,7 @@ def load_model_hook(models, input_dir):
14211421
center_crop=args.center_crop,
14221422
buckets=buckets,
14231423
)
1424-
has_step_indexed_caches = args.cache_latents or train_dataset.custom_instance_prompts
1424+
has_step_indexed_caches = precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
14251425
batch_sampler = BucketBatchSampler(
14261426
train_dataset,
14271427
batch_size=args.train_batch_size,
@@ -1487,7 +1487,7 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
14871487
# if cache_latents is set to True, we encode images to latents and store them.
14881488
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
14891489
# we encode them in advance as well.
1490-
if has_step_indexed_caches:
1490+
if precompute_latents:
14911491
prompt_embeds_cache = []
14921492
text_ids_cache = []
14931493
latents_cache = []

examples/dreambooth/train_dreambooth_lora_z_image.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1461,7 +1461,7 @@ def load_model_hook(models, input_dir):
14611461
center_crop=args.center_crop,
14621462
buckets=buckets,
14631463
)
1464-
has_step_indexed_caches = args.cache_latents or train_dataset.custom_instance_prompts
1464+
has_step_indexed_caches = precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
14651465
batch_sampler = BucketBatchSampler(
14661466
train_dataset,
14671467
batch_size=args.train_batch_size,
@@ -1527,7 +1527,7 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
15271527
# if cache_latents is set to True, we encode images to latents and store them.
15281528
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
15291529
# we encode them in advance as well.
1530-
if has_step_indexed_caches:
1530+
if precompute_latents:
15311531
prompt_embeds_cache = []
15321532
latents_cache = []
15331533
for batch in tqdm(train_dataloader, desc="Caching latents"):

0 commit comments

Comments
 (0)