@@ -1188,13 +1188,8 @@ def main(args):
11881188 elif args .prior_generation_precision == "bf16" :
11891189 torch_dtype = torch .bfloat16
11901190
1191- # TODO: change
1192- # transformer = FluxTransformer2DModel.from_pretrained(
1193- # args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant
1194- # )
1195- transformer = FluxTransformer2DModel .from_single_file (
1196- "https://huggingface.co/diffusers/kontext-v2/blob/main/dev-opt-2-a-3.safetensors" ,
1197- torch_dtype = torch .bfloat16 ,
1191+ transformer = FluxTransformer2DModel .from_pretrained (
1192+ args .pretrained_model_name_or_path , subfolder = "transformer" , revision = args .revision , variant = args .variant
11981193 )
11991194 pipeline = FluxKontextPipeline .from_pretrained (
12001195 args .pretrained_model_name_or_path ,
@@ -1270,11 +1265,8 @@ def main(args):
12701265 revision = args .revision ,
12711266 variant = args .variant ,
12721267 )
1273- # transformer = FluxTransformer2DModel.from_pretrained(
1274- # args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant
1275- # )
1276- transformer = FluxTransformer2DModel .from_single_file (
1277- "https://huggingface.co/diffusers/kontext-v2/blob/main/dev-opt-2-a-3.safetensors" , torch_dtype = torch .bfloat16
1268+ transformer = FluxTransformer2DModel .from_pretrained (
1269+ args .pretrained_model_name_or_path , subfolder = "transformer" , revision = args .revision , variant = args .variant
12781270 )
12791271
12801272 # We only train the additional adapter LoRA layers
@@ -2018,13 +2010,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
20182010
20192011 # Final inference
20202012 # Load previous pipeline
2021- # TODO: change
2022- # transformer = FluxTransformer2DModel.from_pretrained(
2023- # args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant
2024- # )
2025- transformer = FluxTransformer2DModel .from_single_file (
2026- "https://huggingface.co/diffusers/kontext-v2/blob/main/dev-opt-2-a-3.safetensors" ,
2027- torch_dtype = torch .bfloat16 ,
2013+ transformer = FluxTransformer2DModel .from_pretrained (
2014+ args .pretrained_model_name_or_path , subfolder = "transformer" , revision = args .revision , variant = args .variant
20282015 )
20292016 pipeline = FluxKontextPipeline .from_pretrained (
20302017 args .pretrained_model_name_or_path ,
0 commit comments