@@ -2331,6 +2331,18 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
23312331 temp_state_dict [new_key ] = v
23322332 original_state_dict = temp_state_dict
23332333
2334+ # Bake alpha/rank scaling into lora_A weights (matches Flux1 kohya alpha handling).
2335+ # Without this, .alpha keys remain unconsumed and trigger a validation error.
2336+ alpha_keys = [k for k in original_state_dict if k .endswith (".alpha" )]
2337+ for alpha_key in alpha_keys :
2338+ alpha = original_state_dict .pop (alpha_key ).item ()
2339+ module_path = alpha_key [: - len (".alpha" )]
2340+ lora_a_key = f"{ module_path } .lora_A.weight"
2341+ if lora_a_key in original_state_dict :
2342+ rank = original_state_dict [lora_a_key ].shape [0 ]
2343+ scale = alpha / rank
2344+ original_state_dict [lora_a_key ] = original_state_dict [lora_a_key ] * scale
2345+
23342346 num_double_layers = 0
23352347 num_single_layers = 0
23362348 for key in original_state_dict .keys ():
@@ -2443,6 +2455,166 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
24432455 return converted_state_dict
24442456
24452457
2458+ def _convert_non_diffusers_flux2_lokr_to_diffusers (state_dict ):
2459+ """Convert non-diffusers Flux2 LoKR state dict (kohya/LyCORIS format) to peft-compatible diffusers format.
2460+
2461+ Uses fuse-first QKV mapping: BFL fused `img_attn.qkv` maps to diffusers `attn.to_qkv` (created by
2462+ `fuse_projections()`), avoiding lossy Kronecker factor splitting. The caller must fuse the model's
2463+ QKV projections before injecting the adapter.
2464+ """
2465+ converted_state_dict = {}
2466+
2467+ prefix = "diffusion_model."
2468+ original_state_dict = {k [len (prefix ) :] if k .startswith (prefix ) else k : v for k , v in state_dict .items ()}
2469+
2470+ # Log key patterns for debugging
2471+ suffixes = set ()
2472+ for k in original_state_dict :
2473+ for s in ("lokr_w1_a" , "lokr_w1_b" , "lokr_w2_a" , "lokr_w2_b" , "lokr_w1" , "lokr_w2" , "lokr_t2" , "alpha" ):
2474+ if k .endswith (f".{ s } " ):
2475+ suffixes .add (s )
2476+ break
2477+ logger .warning (
2478+ f"[LoKR DEBUG] LoKR converter: { len (original_state_dict )} keys, weight types: { sorted (suffixes )} "
2479+ )
2480+
2481+ num_double_layers = 0
2482+ num_single_layers = 0
2483+ for key in original_state_dict :
2484+ if key .startswith ("single_blocks." ):
2485+ num_single_layers = max (num_single_layers , int (key .split ("." )[1 ]) + 1 )
2486+ elif key .startswith ("double_blocks." ):
2487+ num_double_layers = max (num_double_layers , int (key .split ("." )[1 ]) + 1 )
2488+
2489+ lokr_suffixes = ("lokr_w1" , "lokr_w1_a" , "lokr_w1_b" , "lokr_w2" , "lokr_w2_a" , "lokr_w2_b" , "lokr_t2" )
2490+
2491+ def _remap_lokr_module (bfl_path , diff_path ):
2492+ """Pop all lokr keys for a BFL module, bake alpha scaling, and store under diffusers path."""
2493+ alpha_key = f"{ bfl_path } .alpha"
2494+ alpha = original_state_dict .pop (alpha_key ).item () if alpha_key in original_state_dict else None
2495+
2496+ for suffix in lokr_suffixes :
2497+ src_key = f"{ bfl_path } .{ suffix } "
2498+ if src_key not in original_state_dict :
2499+ continue
2500+
2501+ weight = original_state_dict .pop (src_key )
2502+
2503+ # Bake alpha/rank scaling into the first w1 tensor encountered for this module.
2504+ # After baking, peft's config uses alpha=r so its runtime scaling is 1.0.
2505+ if alpha is not None and suffix in ("lokr_w1" , "lokr_w1_a" ):
2506+ w2a_key = f"{ bfl_path } .lokr_w2_a"
2507+ w1a_key = f"{ bfl_path } .lokr_w1_a"
2508+ if w2a_key in original_state_dict :
2509+ r_eff = original_state_dict [w2a_key ].shape [1 ]
2510+ elif w1a_key in original_state_dict :
2511+ r_eff = original_state_dict [w1a_key ].shape [1 ]
2512+ else :
2513+ r_eff = alpha
2514+ scale = alpha / r_eff
2515+ logger .warning (
2516+ f"[LoKR DEBUG] Alpha bake: { bfl_path } alpha={ alpha :.1f} r={ r_eff } scale={ scale :.4f} "
2517+ )
2518+ weight = weight * scale
2519+ alpha = None # only bake once per module
2520+
2521+ converted_state_dict [f"{ diff_path } .{ suffix } " ] = weight
2522+
2523+ # --- Single blocks ---
2524+ for sl in range (num_single_layers ):
2525+ _remap_lokr_module (f"single_blocks.{ sl } .linear1" , f"single_transformer_blocks.{ sl } .attn.to_qkv_mlp_proj" )
2526+ _remap_lokr_module (f"single_blocks.{ sl } .linear2" , f"single_transformer_blocks.{ sl } .attn.to_out" )
2527+
2528+ # --- Double blocks ---
2529+ for dl in range (num_double_layers ):
2530+ tb = f"transformer_blocks.{ dl } "
2531+ db = f"double_blocks.{ dl } "
2532+
2533+ # QKV → fused to_qkv / to_added_qkv (model must be fused before injection)
2534+ _remap_lokr_module (f"{ db } .img_attn.qkv" , f"{ tb } .attn.to_qkv" )
2535+ _remap_lokr_module (f"{ db } .txt_attn.qkv" , f"{ tb } .attn.to_added_qkv" )
2536+
2537+ # Projections
2538+ _remap_lokr_module (f"{ db } .img_attn.proj" , f"{ tb } .attn.to_out.0" )
2539+ _remap_lokr_module (f"{ db } .txt_attn.proj" , f"{ tb } .attn.to_add_out" )
2540+
2541+ # MLPs
2542+ _remap_lokr_module (f"{ db } .img_mlp.0" , f"{ tb } .ff.linear_in" )
2543+ _remap_lokr_module (f"{ db } .img_mlp.2" , f"{ tb } .ff.linear_out" )
2544+ _remap_lokr_module (f"{ db } .txt_mlp.0" , f"{ tb } .ff_context.linear_in" )
2545+ _remap_lokr_module (f"{ db } .txt_mlp.2" , f"{ tb } .ff_context.linear_out" )
2546+
2547+ # --- Extra mappings (embedders, modulation, final layer) ---
2548+ extra_mappings = {
2549+ "img_in" : "x_embedder" ,
2550+ "txt_in" : "context_embedder" ,
2551+ "time_in.in_layer" : "time_guidance_embed.timestep_embedder.linear_1" ,
2552+ "time_in.out_layer" : "time_guidance_embed.timestep_embedder.linear_2" ,
2553+ "final_layer.linear" : "proj_out" ,
2554+ "final_layer.adaLN_modulation.1" : "norm_out.linear" ,
2555+ "single_stream_modulation.lin" : "single_stream_modulation.linear" ,
2556+ "double_stream_modulation_img.lin" : "double_stream_modulation_img.linear" ,
2557+ "double_stream_modulation_txt.lin" : "double_stream_modulation_txt.linear" ,
2558+ }
2559+ for bfl_key , diff_key in extra_mappings .items ():
2560+ _remap_lokr_module (bfl_key , diff_key )
2561+
2562+ if len (original_state_dict ) > 0 :
2563+ raise ValueError (f"`original_state_dict` should be empty at this point but has { original_state_dict .keys ()= } ." )
2564+
2565+ for key in list (converted_state_dict .keys ()):
2566+ converted_state_dict [f"transformer.{ key } " ] = converted_state_dict .pop (key )
2567+
2568+ return converted_state_dict
2569+
2570+
2571+ def _refuse_flux2_lora_state_dict (state_dict ):
2572+ """Re-fuse separate Q/K/V LoRA keys into fused to_qkv/to_added_qkv keys.
2573+
2574+ When the model's QKV projections are fused (e.g., from a prior LoKR load), incoming LoRA keys
2575+ that target separate to_q/to_k/to_v must be re-fused to match. This is the exact inverse of the
2576+ QKV split in ``_convert_non_diffusers_flux2_lora_to_diffusers``.
2577+ """
2578+ converted = {}
2579+ remaining = dict (state_dict )
2580+
2581+ # Detect double block indices from keys
2582+ num_double_layers = 0
2583+ for key in remaining :
2584+ if ".transformer_blocks." in key :
2585+ parts = key .split ("." )
2586+ idx = parts .index ("transformer_blocks" ) + 1
2587+ if idx < len (parts ) and parts [idx ].isdigit ():
2588+ num_double_layers = max (num_double_layers , int (parts [idx ]) + 1 )
2589+
2590+ # Fuse Q/K/V for image stream and text stream
2591+ qkv_groups = [
2592+ (["to_q" , "to_k" , "to_v" ], "to_qkv" ),
2593+ (["add_q_proj" , "add_k_proj" , "add_v_proj" ], "to_added_qkv" ),
2594+ ]
2595+
2596+ for dl in range (num_double_layers ):
2597+ attn_prefix = f"transformer.transformer_blocks.{ dl } .attn"
2598+ for separate_keys , fused_name in qkv_groups :
2599+ for lora_key in ("lora_A" , "lora_B" ):
2600+ src_keys = [f"{ attn_prefix } .{ sk } .{ lora_key } .weight" for sk in separate_keys ]
2601+ if not all (k in remaining for k in src_keys ):
2602+ continue
2603+
2604+ weights = [remaining .pop (k ) for k in src_keys ]
2605+ dst_key = f"{ attn_prefix } .{ fused_name } .{ lora_key } .weight"
2606+ if lora_key == "lora_A" :
2607+ # lora_A was replicated during split — all three are identical, take the first
2608+ converted [dst_key ] = weights [0 ]
2609+ else :
2610+ # lora_B was chunked along dim=0 — concatenate back
2611+ converted [dst_key ] = torch .cat (weights , dim = 0 )
2612+
2613+ # Pass through all non-QKV keys unchanged
2614+ converted .update (remaining )
2615+ return converted
2616+
2617+
24462618def _convert_non_diffusers_z_image_lora_to_diffusers (state_dict ):
24472619 """
24482620 Convert non-diffusers ZImage LoRA state dict to diffusers format.
0 commit comments