@@ -2443,6 +2443,191 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
24432443 return converted_state_dict
24442444
24452445
2446+ def _convert_kohya_flux2_lora_to_diffusers (state_dict ):
2447+ def _convert_to_ai_toolkit (sds_sd , ait_sd , sds_key , ait_key ):
2448+ if sds_key + ".lora_down.weight" not in sds_sd :
2449+ return
2450+ down_weight = sds_sd .pop (sds_key + ".lora_down.weight" )
2451+
2452+ # scale weight by alpha and dim
2453+ rank = down_weight .shape [0 ]
2454+ default_alpha = torch .tensor (rank , dtype = down_weight .dtype , device = down_weight .device , requires_grad = False )
2455+ alpha = sds_sd .pop (sds_key + ".alpha" , default_alpha ).item ()
2456+ scale = alpha / rank
2457+
2458+ scale_down = scale
2459+ scale_up = 1.0
2460+ while scale_down * 2 < scale_up :
2461+ scale_down *= 2
2462+ scale_up /= 2
2463+
2464+ ait_sd [ait_key + ".lora_A.weight" ] = down_weight * scale_down
2465+ ait_sd [ait_key + ".lora_B.weight" ] = sds_sd .pop (sds_key + ".lora_up.weight" ) * scale_up
2466+
2467+ def _convert_to_ai_toolkit_cat (sds_sd , ait_sd , sds_key , ait_keys , dims = None ):
2468+ if sds_key + ".lora_down.weight" not in sds_sd :
2469+ return
2470+ down_weight = sds_sd .pop (sds_key + ".lora_down.weight" )
2471+ up_weight = sds_sd .pop (sds_key + ".lora_up.weight" )
2472+ sd_lora_rank = down_weight .shape [0 ]
2473+
2474+ default_alpha = torch .tensor (
2475+ sd_lora_rank , dtype = down_weight .dtype , device = down_weight .device , requires_grad = False
2476+ )
2477+ alpha = sds_sd .pop (sds_key + ".alpha" , default_alpha )
2478+ scale = alpha / sd_lora_rank
2479+
2480+ scale_down = scale
2481+ scale_up = 1.0
2482+ while scale_down * 2 < scale_up :
2483+ scale_down *= 2
2484+ scale_up /= 2
2485+
2486+ down_weight = down_weight * scale_down
2487+ up_weight = up_weight * scale_up
2488+
2489+ num_splits = len (ait_keys )
2490+ if dims is None :
2491+ dims = [up_weight .shape [0 ] // num_splits ] * num_splits
2492+ else :
2493+ assert sum (dims ) == up_weight .shape [0 ]
2494+
2495+ # check if upweight is sparse
2496+ is_sparse = False
2497+ if sd_lora_rank % num_splits == 0 :
2498+ ait_rank = sd_lora_rank // num_splits
2499+ is_sparse = True
2500+ i = 0
2501+ for j in range (len (dims )):
2502+ for k in range (len (dims )):
2503+ if j == k :
2504+ continue
2505+ is_sparse = is_sparse and torch .all (
2506+ up_weight [i : i + dims [j ], k * ait_rank : (k + 1 ) * ait_rank ] == 0
2507+ )
2508+ i += dims [j ]
2509+ if is_sparse :
2510+ logger .info (f"weight is sparse: { sds_key } " )
2511+
2512+ ait_down_keys = [k + ".lora_A.weight" for k in ait_keys ]
2513+ ait_up_keys = [k + ".lora_B.weight" for k in ait_keys ]
2514+ if not is_sparse :
2515+ ait_sd .update (dict .fromkeys (ait_down_keys , down_weight ))
2516+ ait_sd .update ({k : v for k , v in zip (ait_up_keys , torch .split (up_weight , dims , dim = 0 ))}) # noqa: C416
2517+ else :
2518+ ait_sd .update ({k : v for k , v in zip (ait_down_keys , torch .chunk (down_weight , num_splits , dim = 0 ))}) # noqa: C416
2519+ i = 0
2520+ for j in range (len (dims )):
2521+ ait_sd [ait_up_keys [j ]] = up_weight [i : i + dims [j ], j * ait_rank : (j + 1 ) * ait_rank ].contiguous ()
2522+ i += dims [j ]
2523+
2524+ # Detect number of blocks from keys
2525+ num_double_layers = 0
2526+ num_single_layers = 0
2527+ for key in state_dict .keys ():
2528+ if key .startswith ("lora_unet_double_blocks_" ):
2529+ block_idx = int (key .split ("_" )[4 ])
2530+ num_double_layers = max (num_double_layers , block_idx + 1 )
2531+ elif key .startswith ("lora_unet_single_blocks_" ):
2532+ block_idx = int (key .split ("_" )[4 ])
2533+ num_single_layers = max (num_single_layers , block_idx + 1 )
2534+
2535+ ait_sd = {}
2536+
2537+ for i in range (num_double_layers ):
2538+ # Attention projections
2539+ _convert_to_ai_toolkit (
2540+ state_dict ,
2541+ ait_sd ,
2542+ f"lora_unet_double_blocks_{ i } _img_attn_proj" ,
2543+ f"transformer.transformer_blocks.{ i } .attn.to_out.0" ,
2544+ )
2545+ _convert_to_ai_toolkit_cat (
2546+ state_dict ,
2547+ ait_sd ,
2548+ f"lora_unet_double_blocks_{ i } _img_attn_qkv" ,
2549+ [
2550+ f"transformer.transformer_blocks.{ i } .attn.to_q" ,
2551+ f"transformer.transformer_blocks.{ i } .attn.to_k" ,
2552+ f"transformer.transformer_blocks.{ i } .attn.to_v" ,
2553+ ],
2554+ )
2555+ _convert_to_ai_toolkit (
2556+ state_dict ,
2557+ ait_sd ,
2558+ f"lora_unet_double_blocks_{ i } _txt_attn_proj" ,
2559+ f"transformer.transformer_blocks.{ i } .attn.to_add_out" ,
2560+ )
2561+ _convert_to_ai_toolkit_cat (
2562+ state_dict ,
2563+ ait_sd ,
2564+ f"lora_unet_double_blocks_{ i } _txt_attn_qkv" ,
2565+ [
2566+ f"transformer.transformer_blocks.{ i } .attn.add_q_proj" ,
2567+ f"transformer.transformer_blocks.{ i } .attn.add_k_proj" ,
2568+ f"transformer.transformer_blocks.{ i } .attn.add_v_proj" ,
2569+ ],
2570+ )
2571+ # MLP layers (Flux2 uses ff.linear_in/linear_out)
2572+ _convert_to_ai_toolkit (
2573+ state_dict ,
2574+ ait_sd ,
2575+ f"lora_unet_double_blocks_{ i } _img_mlp_0" ,
2576+ f"transformer.transformer_blocks.{ i } .ff.linear_in" ,
2577+ )
2578+ _convert_to_ai_toolkit (
2579+ state_dict ,
2580+ ait_sd ,
2581+ f"lora_unet_double_blocks_{ i } _img_mlp_2" ,
2582+ f"transformer.transformer_blocks.{ i } .ff.linear_out" ,
2583+ )
2584+ _convert_to_ai_toolkit (
2585+ state_dict ,
2586+ ait_sd ,
2587+ f"lora_unet_double_blocks_{ i } _txt_mlp_0" ,
2588+ f"transformer.transformer_blocks.{ i } .ff_context.linear_in" ,
2589+ )
2590+ _convert_to_ai_toolkit (
2591+ state_dict ,
2592+ ait_sd ,
2593+ f"lora_unet_double_blocks_{ i } _txt_mlp_2" ,
2594+ f"transformer.transformer_blocks.{ i } .ff_context.linear_out" ,
2595+ )
2596+
2597+ for i in range (num_single_layers ):
2598+ # Single blocks: linear1 -> attn.to_qkv_mlp_proj (fused, no split needed)
2599+ _convert_to_ai_toolkit (
2600+ state_dict ,
2601+ ait_sd ,
2602+ f"lora_unet_single_blocks_{ i } _linear1" ,
2603+ f"transformer.single_transformer_blocks.{ i } .attn.to_qkv_mlp_proj" ,
2604+ )
2605+ # Single blocks: linear2 -> attn.to_out
2606+ _convert_to_ai_toolkit (
2607+ state_dict ,
2608+ ait_sd ,
2609+ f"lora_unet_single_blocks_{ i } _linear2" ,
2610+ f"transformer.single_transformer_blocks.{ i } .attn.to_out" ,
2611+ )
2612+
2613+ # Handle optional extra keys
2614+ extra_mappings = {
2615+ "lora_unet_img_in" : "transformer.x_embedder" ,
2616+ "lora_unet_txt_in" : "transformer.context_embedder" ,
2617+ "lora_unet_time_in_in_layer" : "transformer.time_guidance_embed.timestep_embedder.linear_1" ,
2618+ "lora_unet_time_in_out_layer" : "transformer.time_guidance_embed.timestep_embedder.linear_2" ,
2619+ "lora_unet_final_layer_linear" : "transformer.proj_out" ,
2620+ }
2621+ for sds_key , ait_key in extra_mappings .items ():
2622+ _convert_to_ai_toolkit (state_dict , ait_sd , sds_key , ait_key )
2623+
2624+ remaining_keys = list (state_dict .keys ())
2625+ if remaining_keys :
2626+ logger .warning (f"Unsupported keys for Kohya Flux2 LoRA conversion: { remaining_keys } " )
2627+
2628+ return ait_sd
2629+
2630+
24462631def _convert_non_diffusers_z_image_lora_to_diffusers (state_dict ):
24472632 """
24482633 Convert non-diffusers ZImage LoRA state dict to diffusers format.
0 commit comments