@@ -2331,6 +2331,18 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
23312331 temp_state_dict [new_key ] = v
23322332 original_state_dict = temp_state_dict
23332333
2334+ # Bake alpha/rank scaling into lora_A weights so .alpha keys are consumed.
2335+ # Matches the pattern used by _convert_kohya_flux_lora_to_diffusers for Flux1.
2336+ alpha_keys = [k for k in original_state_dict if k .endswith (".alpha" )]
2337+ for alpha_key in alpha_keys :
2338+ alpha = original_state_dict .pop (alpha_key ).item ()
2339+ module_path = alpha_key [: - len (".alpha" )]
2340+ lora_a_key = f"{ module_path } .lora_A.weight"
2341+ if lora_a_key in original_state_dict :
2342+ rank = original_state_dict [lora_a_key ].shape [0 ]
2343+ scale = alpha / rank
2344+ original_state_dict [lora_a_key ] = original_state_dict [lora_a_key ] * scale
2345+
23342346 num_double_layers = 0
23352347 num_single_layers = 0
23362348 for key in original_state_dict .keys ():
@@ -2443,6 +2455,152 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
24432455 return converted_state_dict
24442456
24452457
2458+ def _convert_non_diffusers_flux2_lokr_to_diffusers (state_dict ):
2459+ """Convert non-diffusers Flux2 LoKR state dict (kohya/LyCORIS format) to peft-compatible diffusers format.
2460+
2461+ Uses fuse-first QKV mapping: BFL fused `img_attn.qkv` maps to diffusers `attn.to_qkv` (created by
2462+ `fuse_projections()`), avoiding lossy Kronecker factor splitting. The caller must fuse the model's
2463+ QKV projections before injecting the adapter.
2464+ """
2465+ converted_state_dict = {}
2466+
2467+ prefix = "diffusion_model."
2468+ original_state_dict = {k [len (prefix ) :] if k .startswith (prefix ) else k : v for k , v in state_dict .items ()}
2469+
2470+ num_double_layers = 0
2471+ num_single_layers = 0
2472+ for key in original_state_dict :
2473+ if key .startswith ("single_blocks." ):
2474+ num_single_layers = max (num_single_layers , int (key .split ("." )[1 ]) + 1 )
2475+ elif key .startswith ("double_blocks." ):
2476+ num_double_layers = max (num_double_layers , int (key .split ("." )[1 ]) + 1 )
2477+
2478+ lokr_suffixes = ("lokr_w1" , "lokr_w1_a" , "lokr_w1_b" , "lokr_w2" , "lokr_w2_a" , "lokr_w2_b" , "lokr_t2" )
2479+
2480+ def _remap_lokr_module (bfl_path , diff_path ):
2481+ """Pop all lokr keys for a BFL module, bake alpha scaling, and store under diffusers path."""
2482+ alpha_key = f"{ bfl_path } .alpha"
2483+ alpha = original_state_dict .pop (alpha_key ).item () if alpha_key in original_state_dict else None
2484+
2485+ for suffix in lokr_suffixes :
2486+ src_key = f"{ bfl_path } .{ suffix } "
2487+ if src_key not in original_state_dict :
2488+ continue
2489+
2490+ weight = original_state_dict .pop (src_key )
2491+
2492+ # Bake alpha/rank scaling into the first w1 tensor encountered for this module.
2493+ # After baking, peft's config uses alpha=r so its runtime scaling is 1.0.
2494+ if alpha is not None and suffix in ("lokr_w1" , "lokr_w1_a" ):
2495+ w2a_key = f"{ bfl_path } .lokr_w2_a"
2496+ w1a_key = f"{ bfl_path } .lokr_w1_a"
2497+ if w2a_key in original_state_dict :
2498+ r_eff = original_state_dict [w2a_key ].shape [1 ]
2499+ elif w1a_key in original_state_dict :
2500+ r_eff = original_state_dict [w1a_key ].shape [1 ]
2501+ else :
2502+ r_eff = alpha
2503+ scale = alpha / r_eff
2504+ weight = weight * scale
2505+ alpha = None # only bake once per module
2506+
2507+ converted_state_dict [f"{ diff_path } .{ suffix } " ] = weight
2508+
2509+ # --- Single blocks ---
2510+ for sl in range (num_single_layers ):
2511+ _remap_lokr_module (f"single_blocks.{ sl } .linear1" , f"single_transformer_blocks.{ sl } .attn.to_qkv_mlp_proj" )
2512+ _remap_lokr_module (f"single_blocks.{ sl } .linear2" , f"single_transformer_blocks.{ sl } .attn.to_out" )
2513+
2514+ # --- Double blocks ---
2515+ for dl in range (num_double_layers ):
2516+ tb = f"transformer_blocks.{ dl } "
2517+ db = f"double_blocks.{ dl } "
2518+
2519+ # QKV → fused to_qkv / to_added_qkv (model must be fused before injection)
2520+ _remap_lokr_module (f"{ db } .img_attn.qkv" , f"{ tb } .attn.to_qkv" )
2521+ _remap_lokr_module (f"{ db } .txt_attn.qkv" , f"{ tb } .attn.to_added_qkv" )
2522+
2523+ # Projections
2524+ _remap_lokr_module (f"{ db } .img_attn.proj" , f"{ tb } .attn.to_out.0" )
2525+ _remap_lokr_module (f"{ db } .txt_attn.proj" , f"{ tb } .attn.to_add_out" )
2526+
2527+ # MLPs
2528+ _remap_lokr_module (f"{ db } .img_mlp.0" , f"{ tb } .ff.linear_in" )
2529+ _remap_lokr_module (f"{ db } .img_mlp.2" , f"{ tb } .ff.linear_out" )
2530+ _remap_lokr_module (f"{ db } .txt_mlp.0" , f"{ tb } .ff_context.linear_in" )
2531+ _remap_lokr_module (f"{ db } .txt_mlp.2" , f"{ tb } .ff_context.linear_out" )
2532+
2533+ # --- Extra mappings (embedders, modulation, final layer) ---
2534+ extra_mappings = {
2535+ "img_in" : "x_embedder" ,
2536+ "txt_in" : "context_embedder" ,
2537+ "time_in.in_layer" : "time_guidance_embed.timestep_embedder.linear_1" ,
2538+ "time_in.out_layer" : "time_guidance_embed.timestep_embedder.linear_2" ,
2539+ "final_layer.linear" : "proj_out" ,
2540+ "final_layer.adaLN_modulation.1" : "norm_out.linear" ,
2541+ "single_stream_modulation.lin" : "single_stream_modulation.linear" ,
2542+ "double_stream_modulation_img.lin" : "double_stream_modulation_img.linear" ,
2543+ "double_stream_modulation_txt.lin" : "double_stream_modulation_txt.linear" ,
2544+ }
2545+ for bfl_key , diff_key in extra_mappings .items ():
2546+ _remap_lokr_module (bfl_key , diff_key )
2547+
2548+ if len (original_state_dict ) > 0 :
2549+ raise ValueError (f"`original_state_dict` should be empty at this point but has { original_state_dict .keys ()= } ." )
2550+
2551+ for key in list (converted_state_dict .keys ()):
2552+ converted_state_dict [f"transformer.{ key } " ] = converted_state_dict .pop (key )
2553+
2554+ return converted_state_dict
2555+
2556+
2557+ def _refuse_flux2_lora_state_dict (state_dict ):
2558+ """Re-fuse separate Q/K/V LoRA keys into fused to_qkv/to_added_qkv keys.
2559+
2560+ When the model's QKV projections are fused, incoming LoRA keys targeting separate
2561+ to_q/to_k/to_v must be re-fused to match. Inverse of the QKV split performed by
2562+ ``_convert_non_diffusers_flux2_lora_to_diffusers``.
2563+ """
2564+ converted = {}
2565+ remaining = dict (state_dict )
2566+
2567+ # Detect double block indices from keys
2568+ num_double_layers = 0
2569+ for key in remaining :
2570+ if ".transformer_blocks." in key :
2571+ parts = key .split ("." )
2572+ idx = parts .index ("transformer_blocks" ) + 1
2573+ if idx < len (parts ) and parts [idx ].isdigit ():
2574+ num_double_layers = max (num_double_layers , int (parts [idx ]) + 1 )
2575+
2576+ # Fuse Q/K/V for image stream and text stream
2577+ qkv_groups = [
2578+ (["to_q" , "to_k" , "to_v" ], "to_qkv" ),
2579+ (["add_q_proj" , "add_k_proj" , "add_v_proj" ], "to_added_qkv" ),
2580+ ]
2581+
2582+ for dl in range (num_double_layers ):
2583+ attn_prefix = f"transformer.transformer_blocks.{ dl } .attn"
2584+ for separate_keys , fused_name in qkv_groups :
2585+ for lora_key in ("lora_A" , "lora_B" ):
2586+ src_keys = [f"{ attn_prefix } .{ sk } .{ lora_key } .weight" for sk in separate_keys ]
2587+ if not all (k in remaining for k in src_keys ):
2588+ continue
2589+
2590+ weights = [remaining .pop (k ) for k in src_keys ]
2591+ dst_key = f"{ attn_prefix } .{ fused_name } .{ lora_key } .weight"
2592+ if lora_key == "lora_A" :
2593+ # lora_A was replicated during split - all three are identical, take the first
2594+ converted [dst_key ] = weights [0 ]
2595+ else :
2596+ # lora_B was chunked along dim=0 - concatenate back
2597+ converted [dst_key ] = torch .cat (weights , dim = 0 )
2598+
2599+ # Pass through all non-QKV keys unchanged
2600+ converted .update (remaining )
2601+ return converted
2602+
2603+
24462604def _convert_non_diffusers_z_image_lora_to_diffusers (state_dict ):
24472605 """
24482606 Convert non-diffusers ZImage LoRA state dict to diffusers format.
@@ -2600,14 +2758,14 @@ def get_alpha_scales(down_weight, alpha_key):
26002758
26012759 base = k [: - len (lora_dot_down_key )]
26022760
2603- # Skip combined "qkv" projection — individual to.q/k/v keys are also present.
2761+ # Skip combined "qkv" projection - individual to.q/k/v keys are also present.
26042762 if base .endswith (".qkv" ):
26052763 state_dict .pop (k )
26062764 state_dict .pop (k .replace (lora_dot_down_key , lora_dot_up_key ), None )
26072765 state_dict .pop (base + ".alpha" , None )
26082766 continue
26092767
2610- # Skip bare "out.lora.*" — "to_out.0.lora.*" covers the same projection.
2768+ # Skip bare "out.lora.*" - "to_out.0.lora.*" covers the same projection.
26112769 if re .search (r"\.out$" , base ) and ".to_out" not in base :
26122770 state_dict .pop (k )
26132771 state_dict .pop (k .replace (lora_dot_down_key , lora_dot_up_key ), None )
0 commit comments