@@ -2321,6 +2321,14 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
23212321 prefix = "diffusion_model."
23222322 original_state_dict = {k [len (prefix ) :]: v for k , v in state_dict .items ()}
23232323
2324+ has_lora_down_up = any ("lora_down" in k or "lora_up" in k for k in original_state_dict .keys ())
2325+ if has_lora_down_up :
2326+ temp_state_dict = {}
2327+ for k , v in original_state_dict .items ():
2328+ new_key = k .replace ("lora_down" , "lora_A" ).replace ("lora_up" , "lora_B" )
2329+ temp_state_dict [new_key ] = v
2330+ original_state_dict = temp_state_dict
2331+
23242332 num_double_layers = 0
23252333 num_single_layers = 0
23262334 for key in original_state_dict .keys ():
@@ -2337,13 +2345,15 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
23372345 attn_prefix = f"single_transformer_blocks.{ sl } .attn"
23382346
23392347 for lora_key in lora_keys :
2340- converted_state_dict [f"{ attn_prefix } .to_qkv_mlp_proj.{ lora_key } .weight" ] = original_state_dict .pop (
2341- f"{ single_block_prefix } .linear1.{ lora_key } .weight"
2342- )
2348+ linear1_key = f"{ single_block_prefix } .linear1.{ lora_key } .weight"
2349+ if linear1_key in original_state_dict :
2350+ converted_state_dict [f"{ attn_prefix } .to_qkv_mlp_proj.{ lora_key } .weight" ] = original_state_dict .pop (
2351+ linear1_key
2352+ )
23432353
2344- converted_state_dict [ f"{ attn_prefix } .to_out .{ lora_key } .weight" ] = original_state_dict . pop (
2345- f" { single_block_prefix } .linear2. { lora_key } .weight"
2346- )
2354+ linear2_key = f"{ single_block_prefix } .linear2 .{ lora_key } .weight"
2355+ if linear2_key in original_state_dict :
2356+ converted_state_dict [ f" { attn_prefix } .to_out. { lora_key } .weight" ] = original_state_dict . pop ( linear2_key )
23472357
23482358 for dl in range (num_double_layers ):
23492359 transformer_block_prefix = f"transformer_blocks.{ dl } "
@@ -2352,6 +2362,10 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
23522362 for attn_type in attn_types :
23532363 attn_prefix = f"{ transformer_block_prefix } .attn"
23542364 qkv_key = f"double_blocks.{ dl } .{ attn_type } .qkv.{ lora_key } .weight"
2365+
2366+ if qkv_key not in original_state_dict :
2367+ continue
2368+
23552369 fused_qkv_weight = original_state_dict .pop (qkv_key )
23562370
23572371 if lora_key == "lora_A" :
@@ -2383,8 +2397,9 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
23832397 for org_proj , diff_proj in proj_mappings :
23842398 for lora_key in lora_keys :
23852399 original_key = f"double_blocks.{ dl } .{ org_proj } .{ lora_key } .weight"
2386- diffusers_key = f"{ transformer_block_prefix } .{ diff_proj } .{ lora_key } .weight"
2387- converted_state_dict [diffusers_key ] = original_state_dict .pop (original_key )
2400+ if original_key in original_state_dict :
2401+ diffusers_key = f"{ transformer_block_prefix } .{ diff_proj } .{ lora_key } .weight"
2402+ converted_state_dict [diffusers_key ] = original_state_dict .pop (original_key )
23882403
23892404 mlp_mappings = [
23902405 ("img_mlp.0" , "ff.linear_in" ),
@@ -2395,8 +2410,27 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
23952410 for org_mlp , diff_mlp in mlp_mappings :
23962411 for lora_key in lora_keys :
23972412 original_key = f"double_blocks.{ dl } .{ org_mlp } .{ lora_key } .weight"
2398- diffusers_key = f"{ transformer_block_prefix } .{ diff_mlp } .{ lora_key } .weight"
2399- converted_state_dict [diffusers_key ] = original_state_dict .pop (original_key )
2413+ if original_key in original_state_dict :
2414+ diffusers_key = f"{ transformer_block_prefix } .{ diff_mlp } .{ lora_key } .weight"
2415+ converted_state_dict [diffusers_key ] = original_state_dict .pop (original_key )
2416+
2417+ extra_mappings = {
2418+ "img_in" : "x_embedder" ,
2419+ "txt_in" : "context_embedder" ,
2420+ "time_in.in_layer" : "time_guidance_embed.timestep_embedder.linear_1" ,
2421+ "time_in.out_layer" : "time_guidance_embed.timestep_embedder.linear_2" ,
2422+ "final_layer.linear" : "proj_out" ,
2423+ "final_layer.adaLN_modulation.1" : "norm_out.linear" ,
2424+ "single_stream_modulation.lin" : "single_stream_modulation.linear" ,
2425+ "double_stream_modulation_img.lin" : "double_stream_modulation_img.linear" ,
2426+ "double_stream_modulation_txt.lin" : "double_stream_modulation_txt.linear" ,
2427+ }
2428+
2429+ for org_key , diff_key in extra_mappings .items ():
2430+ for lora_key in lora_keys :
2431+ original_key = f"{ org_key } .{ lora_key } .weight"
2432+ if original_key in original_state_dict :
2433+ converted_state_dict [f"{ diff_key } .{ lora_key } .weight" ] = original_state_dict .pop (original_key )
24002434
24012435 if len (original_state_dict ) > 0 :
24022436 raise ValueError (f"`original_state_dict` should be empty at this point but has { original_state_dict .keys ()= } ." )
0 commit comments