Skip to content

Commit 153fcbc

Browse files
authored
fix klein lora loading. (#13313)
1 parent da6718f commit 153fcbc

File tree

2 files changed

+193
-0
lines changed

2 files changed

+193
-0
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2443,6 +2443,191 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
24432443
return converted_state_dict
24442444

24452445

2446+
def _convert_kohya_flux2_lora_to_diffusers(state_dict):
2447+
def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
2448+
if sds_key + ".lora_down.weight" not in sds_sd:
2449+
return
2450+
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
2451+
2452+
# scale weight by alpha and dim
2453+
rank = down_weight.shape[0]
2454+
default_alpha = torch.tensor(rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False)
2455+
alpha = sds_sd.pop(sds_key + ".alpha", default_alpha).item()
2456+
scale = alpha / rank
2457+
2458+
scale_down = scale
2459+
scale_up = 1.0
2460+
while scale_down * 2 < scale_up:
2461+
scale_down *= 2
2462+
scale_up /= 2
2463+
2464+
ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down
2465+
ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up
2466+
2467+
def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
2468+
if sds_key + ".lora_down.weight" not in sds_sd:
2469+
return
2470+
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
2471+
up_weight = sds_sd.pop(sds_key + ".lora_up.weight")
2472+
sd_lora_rank = down_weight.shape[0]
2473+
2474+
default_alpha = torch.tensor(
2475+
sd_lora_rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False
2476+
)
2477+
alpha = sds_sd.pop(sds_key + ".alpha", default_alpha)
2478+
scale = alpha / sd_lora_rank
2479+
2480+
scale_down = scale
2481+
scale_up = 1.0
2482+
while scale_down * 2 < scale_up:
2483+
scale_down *= 2
2484+
scale_up /= 2
2485+
2486+
down_weight = down_weight * scale_down
2487+
up_weight = up_weight * scale_up
2488+
2489+
num_splits = len(ait_keys)
2490+
if dims is None:
2491+
dims = [up_weight.shape[0] // num_splits] * num_splits
2492+
else:
2493+
assert sum(dims) == up_weight.shape[0]
2494+
2495+
# check if upweight is sparse
2496+
is_sparse = False
2497+
if sd_lora_rank % num_splits == 0:
2498+
ait_rank = sd_lora_rank // num_splits
2499+
is_sparse = True
2500+
i = 0
2501+
for j in range(len(dims)):
2502+
for k in range(len(dims)):
2503+
if j == k:
2504+
continue
2505+
is_sparse = is_sparse and torch.all(
2506+
up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0
2507+
)
2508+
i += dims[j]
2509+
if is_sparse:
2510+
logger.info(f"weight is sparse: {sds_key}")
2511+
2512+
ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
2513+
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
2514+
if not is_sparse:
2515+
ait_sd.update(dict.fromkeys(ait_down_keys, down_weight))
2516+
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
2517+
else:
2518+
ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) # noqa: C416
2519+
i = 0
2520+
for j in range(len(dims)):
2521+
ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous()
2522+
i += dims[j]
2523+
2524+
# Detect number of blocks from keys
2525+
num_double_layers = 0
2526+
num_single_layers = 0
2527+
for key in state_dict.keys():
2528+
if key.startswith("lora_unet_double_blocks_"):
2529+
block_idx = int(key.split("_")[4])
2530+
num_double_layers = max(num_double_layers, block_idx + 1)
2531+
elif key.startswith("lora_unet_single_blocks_"):
2532+
block_idx = int(key.split("_")[4])
2533+
num_single_layers = max(num_single_layers, block_idx + 1)
2534+
2535+
ait_sd = {}
2536+
2537+
for i in range(num_double_layers):
2538+
# Attention projections
2539+
_convert_to_ai_toolkit(
2540+
state_dict,
2541+
ait_sd,
2542+
f"lora_unet_double_blocks_{i}_img_attn_proj",
2543+
f"transformer.transformer_blocks.{i}.attn.to_out.0",
2544+
)
2545+
_convert_to_ai_toolkit_cat(
2546+
state_dict,
2547+
ait_sd,
2548+
f"lora_unet_double_blocks_{i}_img_attn_qkv",
2549+
[
2550+
f"transformer.transformer_blocks.{i}.attn.to_q",
2551+
f"transformer.transformer_blocks.{i}.attn.to_k",
2552+
f"transformer.transformer_blocks.{i}.attn.to_v",
2553+
],
2554+
)
2555+
_convert_to_ai_toolkit(
2556+
state_dict,
2557+
ait_sd,
2558+
f"lora_unet_double_blocks_{i}_txt_attn_proj",
2559+
f"transformer.transformer_blocks.{i}.attn.to_add_out",
2560+
)
2561+
_convert_to_ai_toolkit_cat(
2562+
state_dict,
2563+
ait_sd,
2564+
f"lora_unet_double_blocks_{i}_txt_attn_qkv",
2565+
[
2566+
f"transformer.transformer_blocks.{i}.attn.add_q_proj",
2567+
f"transformer.transformer_blocks.{i}.attn.add_k_proj",
2568+
f"transformer.transformer_blocks.{i}.attn.add_v_proj",
2569+
],
2570+
)
2571+
# MLP layers (Flux2 uses ff.linear_in/linear_out)
2572+
_convert_to_ai_toolkit(
2573+
state_dict,
2574+
ait_sd,
2575+
f"lora_unet_double_blocks_{i}_img_mlp_0",
2576+
f"transformer.transformer_blocks.{i}.ff.linear_in",
2577+
)
2578+
_convert_to_ai_toolkit(
2579+
state_dict,
2580+
ait_sd,
2581+
f"lora_unet_double_blocks_{i}_img_mlp_2",
2582+
f"transformer.transformer_blocks.{i}.ff.linear_out",
2583+
)
2584+
_convert_to_ai_toolkit(
2585+
state_dict,
2586+
ait_sd,
2587+
f"lora_unet_double_blocks_{i}_txt_mlp_0",
2588+
f"transformer.transformer_blocks.{i}.ff_context.linear_in",
2589+
)
2590+
_convert_to_ai_toolkit(
2591+
state_dict,
2592+
ait_sd,
2593+
f"lora_unet_double_blocks_{i}_txt_mlp_2",
2594+
f"transformer.transformer_blocks.{i}.ff_context.linear_out",
2595+
)
2596+
2597+
for i in range(num_single_layers):
2598+
# Single blocks: linear1 -> attn.to_qkv_mlp_proj (fused, no split needed)
2599+
_convert_to_ai_toolkit(
2600+
state_dict,
2601+
ait_sd,
2602+
f"lora_unet_single_blocks_{i}_linear1",
2603+
f"transformer.single_transformer_blocks.{i}.attn.to_qkv_mlp_proj",
2604+
)
2605+
# Single blocks: linear2 -> attn.to_out
2606+
_convert_to_ai_toolkit(
2607+
state_dict,
2608+
ait_sd,
2609+
f"lora_unet_single_blocks_{i}_linear2",
2610+
f"transformer.single_transformer_blocks.{i}.attn.to_out",
2611+
)
2612+
2613+
# Handle optional extra keys
2614+
extra_mappings = {
2615+
"lora_unet_img_in": "transformer.x_embedder",
2616+
"lora_unet_txt_in": "transformer.context_embedder",
2617+
"lora_unet_time_in_in_layer": "transformer.time_guidance_embed.timestep_embedder.linear_1",
2618+
"lora_unet_time_in_out_layer": "transformer.time_guidance_embed.timestep_embedder.linear_2",
2619+
"lora_unet_final_layer_linear": "transformer.proj_out",
2620+
}
2621+
for sds_key, ait_key in extra_mappings.items():
2622+
_convert_to_ai_toolkit(state_dict, ait_sd, sds_key, ait_key)
2623+
2624+
remaining_keys = list(state_dict.keys())
2625+
if remaining_keys:
2626+
logger.warning(f"Unsupported keys for Kohya Flux2 LoRA conversion: {remaining_keys}")
2627+
2628+
return ait_sd
2629+
2630+
24462631
def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
24472632
"""
24482633
Convert non-diffusers ZImage LoRA state dict to diffusers format.

src/diffusers/loaders/lora_pipeline.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
_convert_bfl_flux_control_lora_to_diffusers,
4444
_convert_fal_kontext_lora_to_diffusers,
4545
_convert_hunyuan_video_lora_to_diffusers,
46+
_convert_kohya_flux2_lora_to_diffusers,
4647
_convert_kohya_flux_lora_to_diffusers,
4748
_convert_musubi_wan_lora_to_diffusers,
4849
_convert_non_diffusers_flux2_lora_to_diffusers,
@@ -5673,6 +5674,13 @@ def lora_state_dict(
56735674
logger.warning(warn_msg)
56745675
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
56755676

5677+
is_kohya = any(".lora_down.weight" in k for k in state_dict)
5678+
if is_kohya:
5679+
state_dict = _convert_kohya_flux2_lora_to_diffusers(state_dict)
5680+
# Kohya already takes care of scaling the LoRA parameters with alpha.
5681+
out = (state_dict, metadata) if return_lora_metadata else state_dict
5682+
return out
5683+
56765684
is_peft_format = any(k.startswith("base_model.model.") for k in state_dict)
56775685
if is_peft_format:
56785686
state_dict = {k.replace("base_model.model.", "diffusion_model."): v for k, v in state_dict.items()}

0 commit comments

Comments
 (0)