Skip to content

Commit d0de091

Browse files
committed
fixes
1 parent 902574c commit d0de091

1 file changed

Lines changed: 6 additions & 19 deletions

File tree

examples/dreambooth/train_dreambooth_lora_flux_kontext.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1188,13 +1188,8 @@ def main(args):
11881188
elif args.prior_generation_precision == "bf16":
11891189
torch_dtype = torch.bfloat16
11901190

1191-
# TODO: change
1192-
# transformer = FluxTransformer2DModel.from_pretrained(
1193-
# args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant
1194-
# )
1195-
transformer = FluxTransformer2DModel.from_single_file(
1196-
"https://huggingface.co/diffusers/kontext-v2/blob/main/dev-opt-2-a-3.safetensors",
1197-
torch_dtype=torch.bfloat16,
1191+
transformer = FluxTransformer2DModel.from_pretrained(
1192+
args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant
11981193
)
11991194
pipeline = FluxKontextPipeline.from_pretrained(
12001195
args.pretrained_model_name_or_path,
@@ -1270,11 +1265,8 @@ def main(args):
12701265
revision=args.revision,
12711266
variant=args.variant,
12721267
)
1273-
# transformer = FluxTransformer2DModel.from_pretrained(
1274-
# args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant
1275-
# )
1276-
transformer = FluxTransformer2DModel.from_single_file(
1277-
"https://huggingface.co/diffusers/kontext-v2/blob/main/dev-opt-2-a-3.safetensors", torch_dtype=torch.bfloat16
1268+
transformer = FluxTransformer2DModel.from_pretrained(
1269+
args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant
12781270
)
12791271

12801272
# We only train the additional adapter LoRA layers
@@ -2018,13 +2010,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
20182010

20192011
# Final inference
20202012
# Load previous pipeline
2021-
# TODO: change
2022-
# transformer = FluxTransformer2DModel.from_pretrained(
2023-
# args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant
2024-
# )
2025-
transformer = FluxTransformer2DModel.from_single_file(
2026-
"https://huggingface.co/diffusers/kontext-v2/blob/main/dev-opt-2-a-3.safetensors",
2027-
torch_dtype=torch.bfloat16,
2013+
transformer = FluxTransformer2DModel.from_pretrained(
2014+
args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant
20282015
)
20292016
pipeline = FluxKontextPipeline.from_pretrained(
20302017
args.pretrained_model_name_or_path,

0 commit comments

Comments
 (0)