Skip to content

Commit c4c616a

Browse files
authored
Merge branch 'main' into torchao-dequantize
2 parents db6ec03 + 7b107d3 commit c4c616a

23 files changed

Lines changed: 2543 additions & 78 deletions

.github/workflows/pr_labeler.yml

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ jobs:
2020
runs-on: ubuntu-latest
2121
steps:
2222
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
23+
with:
24+
ref: ${{ github.event.pull_request.base.sha }}
2325
- name: Check for missing tests
2426
id: check
2527
env:
@@ -34,11 +36,17 @@ jobs:
3436
env:
3537
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
3638
PR_NUMBER: ${{ github.event.pull_request.number }}
39+
REPO: ${{ github.repository }}
3740
run: |
41+
HAS_LABEL=$(gh api "repos/${REPO}/issues/${PR_NUMBER}/labels" --jq 'any(.[]; .name == "missing-tests")')
3842
if [ "${{ steps.check.outcome }}" = "failure" ]; then
39-
gh pr edit "$PR_NUMBER" --add-label "missing-tests"
43+
if [ "$HAS_LABEL" != "true" ]; then
44+
gh pr edit "$PR_NUMBER" --add-label "missing-tests"
45+
fi
4046
else
41-
gh pr edit "$PR_NUMBER" --remove-label "missing-tests" 2>/dev/null || true
47+
if [ "$HAS_LABEL" = "true" ]; then
48+
gh pr edit "$PR_NUMBER" --remove-label "missing-tests" 2>/dev/null || true
49+
fi
4250
fi
4351
4452
fixes-issue:
@@ -65,10 +73,15 @@ jobs:
6573
}
6674
}' \
6775
--jq '.data.repository.pullRequest.closingIssuesReferences.totalCount')
76+
HAS_LABEL=$(gh api "repos/${REPO}/issues/${PR_NUMBER}/labels" --jq 'any(.[]; .name == "fixes-issue")')
6877
if [ "${COUNT:-0}" -gt 0 ]; then
69-
gh pr edit "$PR_NUMBER" --repo "$REPO" --add-label "fixes-issue"
78+
if [ "$HAS_LABEL" != "true" ]; then
79+
gh pr edit "$PR_NUMBER" --repo "$REPO" --add-label "fixes-issue"
80+
fi
7081
else
71-
gh pr edit "$PR_NUMBER" --repo "$REPO" --remove-label "fixes-issue" 2>/dev/null || true
82+
if [ "$HAS_LABEL" = "true" ]; then
83+
gh pr edit "$PR_NUMBER" --repo "$REPO" --remove-label "fixes-issue" 2>/dev/null || true
84+
fi
7285
fi
7386
7487
size-label:
@@ -81,13 +94,19 @@ jobs:
8194
REPO: ${{ github.repository }}
8295
run: |
8396
DIFF_SIZE=$(gh api "repos/${REPO}/pulls/${PR_NUMBER}" --jq '.additions + .deletions')
84-
for label in size/S size/M size/L; do
85-
gh pr edit "$PR_NUMBER" --repo "$REPO" --remove-label "$label" 2>/dev/null || true
86-
done
8797
if [ "$DIFF_SIZE" -lt 50 ]; then
88-
gh pr edit "$PR_NUMBER" --repo "$REPO" --add-label "size/S"
98+
CANDIDATE_LABEL="size/S"
8999
elif [ "$DIFF_SIZE" -lt 200 ]; then
90-
gh pr edit "$PR_NUMBER" --repo "$REPO" --add-label "size/M"
100+
CANDIDATE_LABEL="size/M"
91101
else
92-
gh pr edit "$PR_NUMBER" --repo "$REPO" --add-label "size/L"
102+
CANDIDATE_LABEL="size/L"
103+
fi
104+
CURRENT_LABELS=$(gh api "repos/${REPO}/issues/${PR_NUMBER}/labels" --jq '.[].name')
105+
for label in size/S size/M size/L; do
106+
if [ "$label" != "$CANDIDATE_LABEL" ] && echo "$CURRENT_LABELS" | grep -qx "$label"; then
107+
gh pr edit "$PR_NUMBER" --repo "$REPO" --remove-label "$label" 2>/dev/null || true
108+
fi
109+
done
110+
if ! echo "$CURRENT_LABELS" | grep -qx "$CANDIDATE_LABEL"; then
111+
gh pr edit "$PR_NUMBER" --repo "$REPO" --add-label "$CANDIDATE_LABEL"
93112
fi

examples/dreambooth/train_dreambooth_lora_flux2.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -974,7 +974,13 @@ def collate_fn(examples, with_prior_preservation=False):
974974

975975

976976
class BucketBatchSampler(BatchSampler):
977-
def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False):
977+
def __init__(
978+
self,
979+
dataset: DreamBoothDataset,
980+
batch_size: int,
981+
drop_last: bool = False,
982+
shuffle_batches_each_epoch: bool = True,
983+
):
978984
if not isinstance(batch_size, int) or batch_size <= 0:
979985
raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size))
980986
if not isinstance(drop_last, bool):
@@ -983,6 +989,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
983989
self.dataset = dataset
984990
self.batch_size = batch_size
985991
self.drop_last = drop_last
992+
self.shuffle_batches_each_epoch = shuffle_batches_each_epoch
986993

987994
# Group indices by bucket
988995
self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))]
@@ -1004,9 +1011,14 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
10041011
self.batches.append(batch)
10051012
self.sampler_len += 1 # Count the number of batches
10061013

1014+
if not self.shuffle_batches_each_epoch:
1015+
# Shuffle the precomputed batches once to mix buckets while keeping
1016+
# the order stable across epochs for step-indexed caches.
1017+
random.shuffle(self.batches)
1018+
10071019
def __iter__(self):
1008-
# Shuffle the order of the batches each epoch
1009-
random.shuffle(self.batches)
1020+
if self.shuffle_batches_each_epoch:
1021+
random.shuffle(self.batches)
10101022
for batch in self.batches:
10111023
yield batch
10121024

@@ -1468,7 +1480,13 @@ def load_model_hook(models, input_dir):
14681480
center_crop=args.center_crop,
14691481
buckets=buckets,
14701482
)
1471-
batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True)
1483+
has_step_indexed_caches = precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
1484+
batch_sampler = BucketBatchSampler(
1485+
train_dataset,
1486+
batch_size=args.train_batch_size,
1487+
drop_last=True,
1488+
shuffle_batches_each_epoch=not has_step_indexed_caches,
1489+
)
14721490
train_dataloader = torch.utils.data.DataLoader(
14731491
train_dataset,
14741492
batch_sampler=batch_sampler,
@@ -1585,7 +1603,6 @@ def _encode_single(prompt: str):
15851603
# if cache_latents is set to True, we encode images to latents and store them.
15861604
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
15871605
# we encode them in advance as well.
1588-
precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
15891606
if precompute_latents:
15901607
prompt_embeds_cache = []
15911608
text_ids_cache = []

examples/dreambooth/train_dreambooth_lora_flux2_img2img.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -972,7 +972,13 @@ def collate_fn(examples):
972972

973973

974974
class BucketBatchSampler(BatchSampler):
975-
def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False):
975+
def __init__(
976+
self,
977+
dataset: DreamBoothDataset,
978+
batch_size: int,
979+
drop_last: bool = False,
980+
shuffle_batches_each_epoch: bool = True,
981+
):
976982
if not isinstance(batch_size, int) or batch_size <= 0:
977983
raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size))
978984
if not isinstance(drop_last, bool):
@@ -981,6 +987,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
981987
self.dataset = dataset
982988
self.batch_size = batch_size
983989
self.drop_last = drop_last
990+
self.shuffle_batches_each_epoch = shuffle_batches_each_epoch
984991

985992
# Group indices by bucket
986993
self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))]
@@ -1002,9 +1009,14 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
10021009
self.batches.append(batch)
10031010
self.sampler_len += 1 # Count the number of batches
10041011

1012+
if not self.shuffle_batches_each_epoch:
1013+
# Shuffle the precomputed batches once to mix buckets while keeping
1014+
# the order stable across epochs for step-indexed caches.
1015+
random.shuffle(self.batches)
1016+
10051017
def __iter__(self):
1006-
# Shuffle the order of the batches each epoch
1007-
random.shuffle(self.batches)
1018+
if self.shuffle_batches_each_epoch:
1019+
random.shuffle(self.batches)
10081020
for batch in self.batches:
10091021
yield batch
10101022

@@ -1415,7 +1427,13 @@ def load_model_hook(models, input_dir):
14151427
center_crop=args.center_crop,
14161428
buckets=buckets,
14171429
)
1418-
batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True)
1430+
has_step_indexed_caches = precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
1431+
batch_sampler = BucketBatchSampler(
1432+
train_dataset,
1433+
batch_size=args.train_batch_size,
1434+
drop_last=True,
1435+
shuffle_batches_each_epoch=not has_step_indexed_caches,
1436+
)
14191437
train_dataloader = torch.utils.data.DataLoader(
14201438
train_dataset,
14211439
batch_sampler=batch_sampler,
@@ -1518,7 +1536,6 @@ def _encode_single(prompt: str):
15181536
# if cache_latents is set to True, we encode images to latents and store them.
15191537
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
15201538
# we encode them in advance as well.
1521-
precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
15221539
if precompute_latents:
15231540
prompt_embeds_cache = []
15241541
text_ids_cache = []

examples/dreambooth/train_dreambooth_lora_flux2_klein.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -969,7 +969,13 @@ def collate_fn(examples, with_prior_preservation=False):
969969

970970

971971
class BucketBatchSampler(BatchSampler):
972-
def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False):
972+
def __init__(
973+
self,
974+
dataset: DreamBoothDataset,
975+
batch_size: int,
976+
drop_last: bool = False,
977+
shuffle_batches_each_epoch: bool = True,
978+
):
973979
if not isinstance(batch_size, int) or batch_size <= 0:
974980
raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size))
975981
if not isinstance(drop_last, bool):
@@ -978,6 +984,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
978984
self.dataset = dataset
979985
self.batch_size = batch_size
980986
self.drop_last = drop_last
987+
self.shuffle_batches_each_epoch = shuffle_batches_each_epoch
981988

982989
# Group indices by bucket
983990
self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))]
@@ -999,9 +1006,14 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
9991006
self.batches.append(batch)
10001007
self.sampler_len += 1 # Count the number of batches
10011008

1009+
if not self.shuffle_batches_each_epoch:
1010+
# Shuffle the precomputed batches once to mix buckets while keeping
1011+
# the order stable across epochs for step-indexed caches.
1012+
random.shuffle(self.batches)
1013+
10021014
def __iter__(self):
1003-
# Shuffle the order of the batches each epoch
1004-
random.shuffle(self.batches)
1015+
if self.shuffle_batches_each_epoch:
1016+
random.shuffle(self.batches)
10051017
for batch in self.batches:
10061018
yield batch
10071019

@@ -1461,7 +1473,13 @@ def load_model_hook(models, input_dir):
14611473
center_crop=args.center_crop,
14621474
buckets=buckets,
14631475
)
1464-
batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True)
1476+
has_step_indexed_caches = precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
1477+
batch_sampler = BucketBatchSampler(
1478+
train_dataset,
1479+
batch_size=args.train_batch_size,
1480+
drop_last=True,
1481+
shuffle_batches_each_epoch=not has_step_indexed_caches,
1482+
)
14651483
train_dataloader = torch.utils.data.DataLoader(
14661484
train_dataset,
14671485
batch_sampler=batch_sampler,
@@ -1528,7 +1546,6 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
15281546
# if cache_latents is set to True, we encode images to latents and store them.
15291547
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
15301548
# we encode them in advance as well.
1531-
precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
15321549
if precompute_latents:
15331550
prompt_embeds_cache = []
15341551
text_ids_cache = []

examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -968,7 +968,13 @@ def collate_fn(examples):
968968

969969

970970
class BucketBatchSampler(BatchSampler):
971-
def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False):
971+
def __init__(
972+
self,
973+
dataset: DreamBoothDataset,
974+
batch_size: int,
975+
drop_last: bool = False,
976+
shuffle_batches_each_epoch: bool = True,
977+
):
972978
if not isinstance(batch_size, int) or batch_size <= 0:
973979
raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size))
974980
if not isinstance(drop_last, bool):
@@ -977,6 +983,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
977983
self.dataset = dataset
978984
self.batch_size = batch_size
979985
self.drop_last = drop_last
986+
self.shuffle_batches_each_epoch = shuffle_batches_each_epoch
980987

981988
# Group indices by bucket
982989
self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))]
@@ -998,9 +1005,14 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
9981005
self.batches.append(batch)
9991006
self.sampler_len += 1 # Count the number of batches
10001007

1008+
if not self.shuffle_batches_each_epoch:
1009+
# Shuffle the precomputed batches once to mix buckets while keeping
1010+
# the order stable across epochs for step-indexed caches.
1011+
random.shuffle(self.batches)
1012+
10011013
def __iter__(self):
1002-
# Shuffle the order of the batches each epoch
1003-
random.shuffle(self.batches)
1014+
if self.shuffle_batches_each_epoch:
1015+
random.shuffle(self.batches)
10041016
for batch in self.batches:
10051017
yield batch
10061018

@@ -1409,7 +1421,13 @@ def load_model_hook(models, input_dir):
14091421
center_crop=args.center_crop,
14101422
buckets=buckets,
14111423
)
1412-
batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True)
1424+
has_step_indexed_caches = precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
1425+
batch_sampler = BucketBatchSampler(
1426+
train_dataset,
1427+
batch_size=args.train_batch_size,
1428+
drop_last=True,
1429+
shuffle_batches_each_epoch=not has_step_indexed_caches,
1430+
)
14131431
train_dataloader = torch.utils.data.DataLoader(
14141432
train_dataset,
14151433
batch_sampler=batch_sampler,
@@ -1469,7 +1487,6 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
14691487
# if cache_latents is set to True, we encode images to latents and store them.
14701488
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
14711489
# we encode them in advance as well.
1472-
precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
14731490
if precompute_latents:
14741491
prompt_embeds_cache = []
14751492
text_ids_cache = []

examples/dreambooth/train_dreambooth_lora_z_image.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -963,7 +963,13 @@ def collate_fn(examples, with_prior_preservation=False):
963963

964964

965965
class BucketBatchSampler(BatchSampler):
966-
def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False):
966+
def __init__(
967+
self,
968+
dataset: DreamBoothDataset,
969+
batch_size: int,
970+
drop_last: bool = False,
971+
shuffle_batches_each_epoch: bool = True,
972+
):
967973
if not isinstance(batch_size, int) or batch_size <= 0:
968974
raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size))
969975
if not isinstance(drop_last, bool):
@@ -972,6 +978,7 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
972978
self.dataset = dataset
973979
self.batch_size = batch_size
974980
self.drop_last = drop_last
981+
self.shuffle_batches_each_epoch = shuffle_batches_each_epoch
975982

976983
# Group indices by bucket
977984
self.bucket_indices = [[] for _ in range(len(self.dataset.buckets))]
@@ -993,9 +1000,14 @@ def __init__(self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool
9931000
self.batches.append(batch)
9941001
self.sampler_len += 1 # Count the number of batches
9951002

1003+
if not self.shuffle_batches_each_epoch:
1004+
# Shuffle the precomputed batches once to mix buckets while keeping
1005+
# the order stable across epochs for step-indexed caches.
1006+
random.shuffle(self.batches)
1007+
9961008
def __iter__(self):
997-
# Shuffle the order of the batches each epoch
998-
random.shuffle(self.batches)
1009+
if self.shuffle_batches_each_epoch:
1010+
random.shuffle(self.batches)
9991011
for batch in self.batches:
10001012
yield batch
10011013

@@ -1449,7 +1461,13 @@ def load_model_hook(models, input_dir):
14491461
center_crop=args.center_crop,
14501462
buckets=buckets,
14511463
)
1452-
batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=True)
1464+
has_step_indexed_caches = precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
1465+
batch_sampler = BucketBatchSampler(
1466+
train_dataset,
1467+
batch_size=args.train_batch_size,
1468+
drop_last=True,
1469+
shuffle_batches_each_epoch=not has_step_indexed_caches,
1470+
)
14531471
train_dataloader = torch.utils.data.DataLoader(
14541472
train_dataset,
14551473
batch_sampler=batch_sampler,
@@ -1509,7 +1527,6 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
15091527
# if cache_latents is set to True, we encode images to latents and store them.
15101528
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
15111529
# we encode them in advance as well.
1512-
precompute_latents = args.cache_latents or train_dataset.custom_instance_prompts
15131530
if precompute_latents:
15141531
prompt_embeds_cache = []
15151532
latents_cache = []

0 commit comments

Comments
 (0)