Skip to content

Commit 3cabf56

Browse files
committed
Format DreamBooth bucket sampler updates
1 parent 04c6304 commit 3cabf56

5 files changed

Lines changed: 25 additions & 5 deletions

examples/dreambooth/train_dreambooth_lora_flux2.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -975,7 +975,11 @@ def collate_fn(examples, with_prior_preservation=False):
975975

976976
class BucketBatchSampler(BatchSampler):
977977
def __init__(
978-
self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False, shuffle_batches_each_epoch: bool = True
978+
self,
979+
dataset: DreamBoothDataset,
980+
batch_size: int,
981+
drop_last: bool = False,
982+
shuffle_batches_each_epoch: bool = True,
979983
):
980984
if not isinstance(batch_size, int) or batch_size <= 0:
981985
raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size))

examples/dreambooth/train_dreambooth_lora_flux2_img2img.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -973,7 +973,11 @@ def collate_fn(examples):
973973

974974
class BucketBatchSampler(BatchSampler):
975975
def __init__(
976-
self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False, shuffle_batches_each_epoch: bool = True
976+
self,
977+
dataset: DreamBoothDataset,
978+
batch_size: int,
979+
drop_last: bool = False,
980+
shuffle_batches_each_epoch: bool = True,
977981
):
978982
if not isinstance(batch_size, int) or batch_size <= 0:
979983
raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size))

examples/dreambooth/train_dreambooth_lora_flux2_klein.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -970,7 +970,11 @@ def collate_fn(examples, with_prior_preservation=False):
970970

971971
class BucketBatchSampler(BatchSampler):
972972
def __init__(
973-
self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False, shuffle_batches_each_epoch: bool = True
973+
self,
974+
dataset: DreamBoothDataset,
975+
batch_size: int,
976+
drop_last: bool = False,
977+
shuffle_batches_each_epoch: bool = True,
974978
):
975979
if not isinstance(batch_size, int) or batch_size <= 0:
976980
raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size))

examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -969,7 +969,11 @@ def collate_fn(examples):
969969

970970
class BucketBatchSampler(BatchSampler):
971971
def __init__(
972-
self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False, shuffle_batches_each_epoch: bool = True
972+
self,
973+
dataset: DreamBoothDataset,
974+
batch_size: int,
975+
drop_last: bool = False,
976+
shuffle_batches_each_epoch: bool = True,
973977
):
974978
if not isinstance(batch_size, int) or batch_size <= 0:
975979
raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size))

examples/dreambooth/train_dreambooth_lora_z_image.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -964,7 +964,11 @@ def collate_fn(examples, with_prior_preservation=False):
964964

965965
class BucketBatchSampler(BatchSampler):
966966
def __init__(
967-
self, dataset: DreamBoothDataset, batch_size: int, drop_last: bool = False, shuffle_batches_each_epoch: bool = True
967+
self,
968+
dataset: DreamBoothDataset,
969+
batch_size: int,
970+
drop_last: bool = False,
971+
shuffle_batches_each_epoch: bool = True,
968972
):
969973
if not isinstance(batch_size, int) or batch_size <= 0:
970974
raise ValueError("batch_size should be a positive integer value, but got batch_size={}".format(batch_size))

0 commit comments

Comments
 (0)