@@ -2331,6 +2331,18 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
23312331 temp_state_dict [new_key ] = v
23322332 original_state_dict = temp_state_dict
23332333
2334+ # Bake alpha/rank scaling into lora_A weights so .alpha keys are consumed.
2335+ # Matches the pattern used by _convert_kohya_flux_lora_to_diffusers for Flux1.
2336+ alpha_keys = [k for k in original_state_dict if k .endswith (".alpha" )]
2337+ for alpha_key in alpha_keys :
2338+ alpha = original_state_dict .pop (alpha_key ).item ()
2339+ module_path = alpha_key [: - len (".alpha" )]
2340+ lora_a_key = f"{ module_path } .lora_A.weight"
2341+ if lora_a_key in original_state_dict :
2342+ rank = original_state_dict [lora_a_key ].shape [0 ]
2343+ scale = alpha / rank
2344+ original_state_dict [lora_a_key ] = original_state_dict [lora_a_key ] * scale
2345+
23342346 num_double_layers = 0
23352347 num_single_layers = 0
23362348 for key in original_state_dict .keys ():
@@ -2628,6 +2640,105 @@ def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
26282640 return ait_sd
26292641
26302642
2643+ def _convert_non_diffusers_flux2_lokr_to_diffusers (state_dict ):
2644+ """Convert non-diffusers Flux2 LoKR state dict (kohya/LyCORIS format) to peft-compatible diffusers format.
2645+
2646+ Uses fuse-first QKV mapping: BFL fused `img_attn.qkv` maps to diffusers `attn.to_qkv` (created by
2647+ `fuse_projections()`), avoiding lossy Kronecker factor splitting. The caller must fuse the model's
2648+ QKV projections before injecting the adapter.
2649+ """
2650+ converted_state_dict = {}
2651+
2652+ prefix = "diffusion_model."
2653+ original_state_dict = {k [len (prefix ) :] if k .startswith (prefix ) else k : v for k , v in state_dict .items ()}
2654+
2655+ num_double_layers = 0
2656+ num_single_layers = 0
2657+ for key in original_state_dict :
2658+ if key .startswith ("single_blocks." ):
2659+ num_single_layers = max (num_single_layers , int (key .split ("." )[1 ]) + 1 )
2660+ elif key .startswith ("double_blocks." ):
2661+ num_double_layers = max (num_double_layers , int (key .split ("." )[1 ]) + 1 )
2662+
2663+ lokr_suffixes = ("lokr_w1" , "lokr_w1_a" , "lokr_w1_b" , "lokr_w2" , "lokr_w2_a" , "lokr_w2_b" , "lokr_t2" )
2664+
2665+ def _remap_lokr_module (bfl_path , diff_path ):
2666+ """Pop all lokr keys for a BFL module, bake alpha scaling, and store under diffusers path."""
2667+ alpha_key = f"{ bfl_path } .alpha"
2668+ alpha = original_state_dict .pop (alpha_key ).item () if alpha_key in original_state_dict else None
2669+
2670+ for suffix in lokr_suffixes :
2671+ src_key = f"{ bfl_path } .{ suffix } "
2672+ if src_key not in original_state_dict :
2673+ continue
2674+
2675+ weight = original_state_dict .pop (src_key )
2676+
2677+ # Bake alpha/rank scaling into the first w1 tensor encountered for this module.
2678+ # After baking, peft's config uses alpha=r so its runtime scaling is 1.0.
2679+ if alpha is not None and suffix in ("lokr_w1" , "lokr_w1_a" ):
2680+ w2a_key = f"{ bfl_path } .lokr_w2_a"
2681+ w1a_key = f"{ bfl_path } .lokr_w1_a"
2682+ if w2a_key in original_state_dict :
2683+ r_eff = original_state_dict [w2a_key ].shape [1 ]
2684+ elif w1a_key in original_state_dict :
2685+ r_eff = original_state_dict [w1a_key ].shape [1 ]
2686+ else :
2687+ r_eff = alpha
2688+ scale = alpha / r_eff
2689+ weight = weight * scale
2690+ alpha = None # only bake once per module
2691+
2692+ converted_state_dict [f"{ diff_path } .{ suffix } " ] = weight
2693+
2694+ # --- Single blocks ---
2695+ for sl in range (num_single_layers ):
2696+ _remap_lokr_module (f"single_blocks.{ sl } .linear1" , f"single_transformer_blocks.{ sl } .attn.to_qkv_mlp_proj" )
2697+ _remap_lokr_module (f"single_blocks.{ sl } .linear2" , f"single_transformer_blocks.{ sl } .attn.to_out" )
2698+
2699+ # --- Double blocks ---
2700+ for dl in range (num_double_layers ):
2701+ tb = f"transformer_blocks.{ dl } "
2702+ db = f"double_blocks.{ dl } "
2703+
2704+ # QKV -> fused to_qkv / to_added_qkv (model must be fused before injection)
2705+ _remap_lokr_module (f"{ db } .img_attn.qkv" , f"{ tb } .attn.to_qkv" )
2706+ _remap_lokr_module (f"{ db } .txt_attn.qkv" , f"{ tb } .attn.to_added_qkv" )
2707+
2708+ # Projections
2709+ _remap_lokr_module (f"{ db } .img_attn.proj" , f"{ tb } .attn.to_out.0" )
2710+ _remap_lokr_module (f"{ db } .txt_attn.proj" , f"{ tb } .attn.to_add_out" )
2711+
2712+ # MLPs
2713+ _remap_lokr_module (f"{ db } .img_mlp.0" , f"{ tb } .ff.linear_in" )
2714+ _remap_lokr_module (f"{ db } .img_mlp.2" , f"{ tb } .ff.linear_out" )
2715+ _remap_lokr_module (f"{ db } .txt_mlp.0" , f"{ tb } .ff_context.linear_in" )
2716+ _remap_lokr_module (f"{ db } .txt_mlp.2" , f"{ tb } .ff_context.linear_out" )
2717+
2718+ # --- Extra mappings (embedders, modulation, final layer) ---
2719+ extra_mappings = {
2720+ "img_in" : "x_embedder" ,
2721+ "txt_in" : "context_embedder" ,
2722+ "time_in.in_layer" : "time_guidance_embed.timestep_embedder.linear_1" ,
2723+ "time_in.out_layer" : "time_guidance_embed.timestep_embedder.linear_2" ,
2724+ "final_layer.linear" : "proj_out" ,
2725+ "final_layer.adaLN_modulation.1" : "norm_out.linear" ,
2726+ "single_stream_modulation.lin" : "single_stream_modulation.linear" ,
2727+ "double_stream_modulation_img.lin" : "double_stream_modulation_img.linear" ,
2728+ "double_stream_modulation_txt.lin" : "double_stream_modulation_txt.linear" ,
2729+ }
2730+ for bfl_key , diff_key in extra_mappings .items ():
2731+ _remap_lokr_module (bfl_key , diff_key )
2732+
2733+ if len (original_state_dict ) > 0 :
2734+ raise ValueError (f"`original_state_dict` should be empty at this point but has { original_state_dict .keys ()= } ." )
2735+
2736+ for key in list (converted_state_dict .keys ()):
2737+ converted_state_dict [f"transformer.{ key } " ] = converted_state_dict .pop (key )
2738+
2739+ return converted_state_dict
2740+
2741+
26312742def _convert_non_diffusers_z_image_lora_to_diffusers (state_dict ):
26322743 """
26332744 Convert non-diffusers ZImage LoRA state dict to diffusers format.
@@ -2785,14 +2896,14 @@ def get_alpha_scales(down_weight, alpha_key):
27852896
27862897 base = k [: - len (lora_dot_down_key )]
27872898
2788- # Skip combined "qkv" projection — individual to.q/k/v keys are also present.
2899+ # Skip combined "qkv" projection - individual to.q/k/v keys are also present.
27892900 if base .endswith (".qkv" ):
27902901 state_dict .pop (k )
27912902 state_dict .pop (k .replace (lora_dot_down_key , lora_dot_up_key ), None )
27922903 state_dict .pop (base + ".alpha" , None )
27932904 continue
27942905
2795- # Skip bare "out.lora.*" — "to_out.0.lora.*" covers the same projection.
2906+ # Skip bare "out.lora.*" - "to_out.0.lora.*" covers the same projection.
27962907 if re .search (r"\.out$" , base ) and ".to_out" not in base :
27972908 state_dict .pop (k )
27982909 state_dict .pop (k .replace (lora_dot_down_key , lora_dot_up_key ), None )
0 commit comments