@@ -2440,6 +2440,191 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
24402440 return converted_state_dict
24412441
24422442
2443+ def _convert_kohya_flux2_lora_to_diffusers (state_dict ):
2444+ def _convert_to_ai_toolkit (sds_sd , ait_sd , sds_key , ait_key ):
2445+ if sds_key + ".lora_down.weight" not in sds_sd :
2446+ return
2447+ down_weight = sds_sd .pop (sds_key + ".lora_down.weight" )
2448+
2449+ # scale weight by alpha and dim
2450+ rank = down_weight .shape [0 ]
2451+ default_alpha = torch .tensor (rank , dtype = down_weight .dtype , device = down_weight .device , requires_grad = False )
2452+ alpha = sds_sd .pop (sds_key + ".alpha" , default_alpha ).item ()
2453+ scale = alpha / rank
2454+
2455+ scale_down = scale
2456+ scale_up = 1.0
2457+ while scale_down * 2 < scale_up :
2458+ scale_down *= 2
2459+ scale_up /= 2
2460+
2461+ ait_sd [ait_key + ".lora_A.weight" ] = down_weight * scale_down
2462+ ait_sd [ait_key + ".lora_B.weight" ] = sds_sd .pop (sds_key + ".lora_up.weight" ) * scale_up
2463+
2464+ def _convert_to_ai_toolkit_cat (sds_sd , ait_sd , sds_key , ait_keys , dims = None ):
2465+ if sds_key + ".lora_down.weight" not in sds_sd :
2466+ return
2467+ down_weight = sds_sd .pop (sds_key + ".lora_down.weight" )
2468+ up_weight = sds_sd .pop (sds_key + ".lora_up.weight" )
2469+ sd_lora_rank = down_weight .shape [0 ]
2470+
2471+ default_alpha = torch .tensor (
2472+ sd_lora_rank , dtype = down_weight .dtype , device = down_weight .device , requires_grad = False
2473+ )
2474+ alpha = sds_sd .pop (sds_key + ".alpha" , default_alpha )
2475+ scale = alpha / sd_lora_rank
2476+
2477+ scale_down = scale
2478+ scale_up = 1.0
2479+ while scale_down * 2 < scale_up :
2480+ scale_down *= 2
2481+ scale_up /= 2
2482+
2483+ down_weight = down_weight * scale_down
2484+ up_weight = up_weight * scale_up
2485+
2486+ num_splits = len (ait_keys )
2487+ if dims is None :
2488+ dims = [up_weight .shape [0 ] // num_splits ] * num_splits
2489+ else :
2490+ assert sum (dims ) == up_weight .shape [0 ]
2491+
2492+ # check if upweight is sparse
2493+ is_sparse = False
2494+ if sd_lora_rank % num_splits == 0 :
2495+ ait_rank = sd_lora_rank // num_splits
2496+ is_sparse = True
2497+ i = 0
2498+ for j in range (len (dims )):
2499+ for k in range (len (dims )):
2500+ if j == k :
2501+ continue
2502+ is_sparse = is_sparse and torch .all (
2503+ up_weight [i : i + dims [j ], k * ait_rank : (k + 1 ) * ait_rank ] == 0
2504+ )
2505+ i += dims [j ]
2506+ if is_sparse :
2507+ logger .info (f"weight is sparse: { sds_key } " )
2508+
2509+ ait_down_keys = [k + ".lora_A.weight" for k in ait_keys ]
2510+ ait_up_keys = [k + ".lora_B.weight" for k in ait_keys ]
2511+ if not is_sparse :
2512+ ait_sd .update (dict .fromkeys (ait_down_keys , down_weight ))
2513+ ait_sd .update ({k : v for k , v in zip (ait_up_keys , torch .split (up_weight , dims , dim = 0 ))}) # noqa: C416
2514+ else :
2515+ ait_sd .update ({k : v for k , v in zip (ait_down_keys , torch .chunk (down_weight , num_splits , dim = 0 ))}) # noqa: C416
2516+ i = 0
2517+ for j in range (len (dims )):
2518+ ait_sd [ait_up_keys [j ]] = up_weight [i : i + dims [j ], j * ait_rank : (j + 1 ) * ait_rank ].contiguous ()
2519+ i += dims [j ]
2520+
2521+ # Detect number of blocks from keys
2522+ num_double_layers = 0
2523+ num_single_layers = 0
2524+ for key in state_dict .keys ():
2525+ if key .startswith ("lora_unet_double_blocks_" ):
2526+ block_idx = int (key .split ("_" )[4 ])
2527+ num_double_layers = max (num_double_layers , block_idx + 1 )
2528+ elif key .startswith ("lora_unet_single_blocks_" ):
2529+ block_idx = int (key .split ("_" )[4 ])
2530+ num_single_layers = max (num_single_layers , block_idx + 1 )
2531+
2532+ ait_sd = {}
2533+
2534+ for i in range (num_double_layers ):
2535+ # Attention projections
2536+ _convert_to_ai_toolkit (
2537+ state_dict ,
2538+ ait_sd ,
2539+ f"lora_unet_double_blocks_{ i } _img_attn_proj" ,
2540+ f"transformer.transformer_blocks.{ i } .attn.to_out.0" ,
2541+ )
2542+ _convert_to_ai_toolkit_cat (
2543+ state_dict ,
2544+ ait_sd ,
2545+ f"lora_unet_double_blocks_{ i } _img_attn_qkv" ,
2546+ [
2547+ f"transformer.transformer_blocks.{ i } .attn.to_q" ,
2548+ f"transformer.transformer_blocks.{ i } .attn.to_k" ,
2549+ f"transformer.transformer_blocks.{ i } .attn.to_v" ,
2550+ ],
2551+ )
2552+ _convert_to_ai_toolkit (
2553+ state_dict ,
2554+ ait_sd ,
2555+ f"lora_unet_double_blocks_{ i } _txt_attn_proj" ,
2556+ f"transformer.transformer_blocks.{ i } .attn.to_add_out" ,
2557+ )
2558+ _convert_to_ai_toolkit_cat (
2559+ state_dict ,
2560+ ait_sd ,
2561+ f"lora_unet_double_blocks_{ i } _txt_attn_qkv" ,
2562+ [
2563+ f"transformer.transformer_blocks.{ i } .attn.add_q_proj" ,
2564+ f"transformer.transformer_blocks.{ i } .attn.add_k_proj" ,
2565+ f"transformer.transformer_blocks.{ i } .attn.add_v_proj" ,
2566+ ],
2567+ )
2568+ # MLP layers (Flux2 uses ff.linear_in/linear_out)
2569+ _convert_to_ai_toolkit (
2570+ state_dict ,
2571+ ait_sd ,
2572+ f"lora_unet_double_blocks_{ i } _img_mlp_0" ,
2573+ f"transformer.transformer_blocks.{ i } .ff.linear_in" ,
2574+ )
2575+ _convert_to_ai_toolkit (
2576+ state_dict ,
2577+ ait_sd ,
2578+ f"lora_unet_double_blocks_{ i } _img_mlp_2" ,
2579+ f"transformer.transformer_blocks.{ i } .ff.linear_out" ,
2580+ )
2581+ _convert_to_ai_toolkit (
2582+ state_dict ,
2583+ ait_sd ,
2584+ f"lora_unet_double_blocks_{ i } _txt_mlp_0" ,
2585+ f"transformer.transformer_blocks.{ i } .ff_context.linear_in" ,
2586+ )
2587+ _convert_to_ai_toolkit (
2588+ state_dict ,
2589+ ait_sd ,
2590+ f"lora_unet_double_blocks_{ i } _txt_mlp_2" ,
2591+ f"transformer.transformer_blocks.{ i } .ff_context.linear_out" ,
2592+ )
2593+
2594+ for i in range (num_single_layers ):
2595+ # Single blocks: linear1 -> attn.to_qkv_mlp_proj (fused, no split needed)
2596+ _convert_to_ai_toolkit (
2597+ state_dict ,
2598+ ait_sd ,
2599+ f"lora_unet_single_blocks_{ i } _linear1" ,
2600+ f"transformer.single_transformer_blocks.{ i } .attn.to_qkv_mlp_proj" ,
2601+ )
2602+ # Single blocks: linear2 -> attn.to_out
2603+ _convert_to_ai_toolkit (
2604+ state_dict ,
2605+ ait_sd ,
2606+ f"lora_unet_single_blocks_{ i } _linear2" ,
2607+ f"transformer.single_transformer_blocks.{ i } .attn.to_out" ,
2608+ )
2609+
2610+ # Handle optional extra keys
2611+ extra_mappings = {
2612+ "lora_unet_img_in" : "transformer.x_embedder" ,
2613+ "lora_unet_txt_in" : "transformer.context_embedder" ,
2614+ "lora_unet_time_in_in_layer" : "transformer.time_guidance_embed.timestep_embedder.linear_1" ,
2615+ "lora_unet_time_in_out_layer" : "transformer.time_guidance_embed.timestep_embedder.linear_2" ,
2616+ "lora_unet_final_layer_linear" : "transformer.proj_out" ,
2617+ }
2618+ for sds_key , ait_key in extra_mappings .items ():
2619+ _convert_to_ai_toolkit (state_dict , ait_sd , sds_key , ait_key )
2620+
2621+ remaining_keys = list (state_dict .keys ())
2622+ if remaining_keys :
2623+ logger .warning (f"Unsupported keys for Kohya Flux2 LoRA conversion: { remaining_keys } " )
2624+
2625+ return ait_sd
2626+
2627+
24432628def _convert_non_diffusers_z_image_lora_to_diffusers (state_dict ):
24442629 """
24452630 Convert non-diffusers ZImage LoRA state dict to diffusers format.
0 commit comments