Allow bucket reshuffling with DreamBooth caches#13712
Allow bucket reshuffling with DreamBooth caches#13712azolotenkov wants to merge 2 commits intohuggingface:mainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
Enables epoch-wise reshuffling of DreamBooth bucketed batches in the Flux2 DreamBooth LoRA example scripts while keeping cached latents and (custom-caption) prompt/text embeddings correctly aligned by switching caches from step-indexing to dataset-sample indexing.
Changes:
- Add per-sample
indexto dataset items and propagate it throughcollate_fnso caches can be keyed by sample index rather than dataloader step. - Rework latent/prompt-embedding caching to precompute via a non-dropping cache dataloader and store per-sample cached tensors.
- Update
BucketBatchSamplerto reshuffle indices/batches on each__iter__()call (epoch reshuffle) while keeping__len__stable.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| examples/dreambooth/train_dreambooth_lora_flux2.py | Switch cache alignment to sample-indexing and re-enable epoch reshuffling for bucket batches. |
| examples/dreambooth/train_dreambooth_lora_flux2_klein.py | Same cache-by-index + epoch reshuffle changes for the Klein variant of the training script. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
6cf74c0 to
b022936
Compare
|
@claude can you do a review here? |
|
Claude finished @sayakpaul's task in 5m 3s —— View job Review: Allow bucket reshuffling with DreamBooth caches
Overall: The approach is sound. Switching from step-indexed to sample-indexed caching correctly decouples cache lookups from batch ordering, enabling epoch reshuffling without misaligned cache reads. The non-dropping cache dataloader ( No blocking issues found. A few observations and suggestions below. Correctness traceI traced the full flow and confirmed the design holds:
Suggestions (non-blocking)
SummaryThe core design (sample-indexed caches + |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
What does this PR do?
Allow DreamBooth bucket batches to reshuffle each epoch while keeping cached latents and custom-caption prompt embeddings aligned.
After #13353, bucket batches with cached latents/custom captions were kept in stable step order because caches were indexed by dataloader step. This fixes the underlying limitation by indexing cached latents and prompt embeddings by dataset sample index instead. The training dataloader can then reshuffle bucket batches each epoch without reading the wrong cached tensors.
The cache precompute pass now uses a non-dropping cache dataloader, so every sample that can appear in a later reshuffled training epoch has a cache entry.
This also avoids mutating static prompt embeddings inside the training loop. Each step now derives repeated prompt/text embeddings from the original static tensors, which keeps prior-preservation runs with multiple steps stable.
Tested:
Kleinsmoke tests withhf-internal-testing/tiny-flux2-klein:--cache_latents--cache_latents--cache_latents--cache_latents--cache_latents, crossing an epoch boundary withmax_train_steps=7Flux2smoke tests withhf-internal-testing/tiny-flux2using the standard tiny-model settings:train_batch_size=1,max_train_steps=2train_batch_size=2,max_train_steps=2--cache_latents,train_batch_size=2,max_train_steps=3Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@sayakpaul