Skip to content

Commit 70ab180

Browse files
sayakpaulDN6
authored andcommitted
fix klein lora loading. (#13313)
1 parent 5cd2fa7 commit 70ab180

File tree

2 files changed

+193
-0
lines changed

2 files changed

+193
-0
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2440,6 +2440,191 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
24402440
return converted_state_dict
24412441

24422442

2443+
def _convert_kohya_flux2_lora_to_diffusers(state_dict):
2444+
def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
2445+
if sds_key + ".lora_down.weight" not in sds_sd:
2446+
return
2447+
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
2448+
2449+
# scale weight by alpha and dim
2450+
rank = down_weight.shape[0]
2451+
default_alpha = torch.tensor(rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False)
2452+
alpha = sds_sd.pop(sds_key + ".alpha", default_alpha).item()
2453+
scale = alpha / rank
2454+
2455+
scale_down = scale
2456+
scale_up = 1.0
2457+
while scale_down * 2 < scale_up:
2458+
scale_down *= 2
2459+
scale_up /= 2
2460+
2461+
ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down
2462+
ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up
2463+
2464+
def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
2465+
if sds_key + ".lora_down.weight" not in sds_sd:
2466+
return
2467+
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
2468+
up_weight = sds_sd.pop(sds_key + ".lora_up.weight")
2469+
sd_lora_rank = down_weight.shape[0]
2470+
2471+
default_alpha = torch.tensor(
2472+
sd_lora_rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False
2473+
)
2474+
alpha = sds_sd.pop(sds_key + ".alpha", default_alpha)
2475+
scale = alpha / sd_lora_rank
2476+
2477+
scale_down = scale
2478+
scale_up = 1.0
2479+
while scale_down * 2 < scale_up:
2480+
scale_down *= 2
2481+
scale_up /= 2
2482+
2483+
down_weight = down_weight * scale_down
2484+
up_weight = up_weight * scale_up
2485+
2486+
num_splits = len(ait_keys)
2487+
if dims is None:
2488+
dims = [up_weight.shape[0] // num_splits] * num_splits
2489+
else:
2490+
assert sum(dims) == up_weight.shape[0]
2491+
2492+
# check if upweight is sparse
2493+
is_sparse = False
2494+
if sd_lora_rank % num_splits == 0:
2495+
ait_rank = sd_lora_rank // num_splits
2496+
is_sparse = True
2497+
i = 0
2498+
for j in range(len(dims)):
2499+
for k in range(len(dims)):
2500+
if j == k:
2501+
continue
2502+
is_sparse = is_sparse and torch.all(
2503+
up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0
2504+
)
2505+
i += dims[j]
2506+
if is_sparse:
2507+
logger.info(f"weight is sparse: {sds_key}")
2508+
2509+
ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
2510+
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
2511+
if not is_sparse:
2512+
ait_sd.update(dict.fromkeys(ait_down_keys, down_weight))
2513+
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
2514+
else:
2515+
ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) # noqa: C416
2516+
i = 0
2517+
for j in range(len(dims)):
2518+
ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous()
2519+
i += dims[j]
2520+
2521+
# Detect number of blocks from keys
2522+
num_double_layers = 0
2523+
num_single_layers = 0
2524+
for key in state_dict.keys():
2525+
if key.startswith("lora_unet_double_blocks_"):
2526+
block_idx = int(key.split("_")[4])
2527+
num_double_layers = max(num_double_layers, block_idx + 1)
2528+
elif key.startswith("lora_unet_single_blocks_"):
2529+
block_idx = int(key.split("_")[4])
2530+
num_single_layers = max(num_single_layers, block_idx + 1)
2531+
2532+
ait_sd = {}
2533+
2534+
for i in range(num_double_layers):
2535+
# Attention projections
2536+
_convert_to_ai_toolkit(
2537+
state_dict,
2538+
ait_sd,
2539+
f"lora_unet_double_blocks_{i}_img_attn_proj",
2540+
f"transformer.transformer_blocks.{i}.attn.to_out.0",
2541+
)
2542+
_convert_to_ai_toolkit_cat(
2543+
state_dict,
2544+
ait_sd,
2545+
f"lora_unet_double_blocks_{i}_img_attn_qkv",
2546+
[
2547+
f"transformer.transformer_blocks.{i}.attn.to_q",
2548+
f"transformer.transformer_blocks.{i}.attn.to_k",
2549+
f"transformer.transformer_blocks.{i}.attn.to_v",
2550+
],
2551+
)
2552+
_convert_to_ai_toolkit(
2553+
state_dict,
2554+
ait_sd,
2555+
f"lora_unet_double_blocks_{i}_txt_attn_proj",
2556+
f"transformer.transformer_blocks.{i}.attn.to_add_out",
2557+
)
2558+
_convert_to_ai_toolkit_cat(
2559+
state_dict,
2560+
ait_sd,
2561+
f"lora_unet_double_blocks_{i}_txt_attn_qkv",
2562+
[
2563+
f"transformer.transformer_blocks.{i}.attn.add_q_proj",
2564+
f"transformer.transformer_blocks.{i}.attn.add_k_proj",
2565+
f"transformer.transformer_blocks.{i}.attn.add_v_proj",
2566+
],
2567+
)
2568+
# MLP layers (Flux2 uses ff.linear_in/linear_out)
2569+
_convert_to_ai_toolkit(
2570+
state_dict,
2571+
ait_sd,
2572+
f"lora_unet_double_blocks_{i}_img_mlp_0",
2573+
f"transformer.transformer_blocks.{i}.ff.linear_in",
2574+
)
2575+
_convert_to_ai_toolkit(
2576+
state_dict,
2577+
ait_sd,
2578+
f"lora_unet_double_blocks_{i}_img_mlp_2",
2579+
f"transformer.transformer_blocks.{i}.ff.linear_out",
2580+
)
2581+
_convert_to_ai_toolkit(
2582+
state_dict,
2583+
ait_sd,
2584+
f"lora_unet_double_blocks_{i}_txt_mlp_0",
2585+
f"transformer.transformer_blocks.{i}.ff_context.linear_in",
2586+
)
2587+
_convert_to_ai_toolkit(
2588+
state_dict,
2589+
ait_sd,
2590+
f"lora_unet_double_blocks_{i}_txt_mlp_2",
2591+
f"transformer.transformer_blocks.{i}.ff_context.linear_out",
2592+
)
2593+
2594+
for i in range(num_single_layers):
2595+
# Single blocks: linear1 -> attn.to_qkv_mlp_proj (fused, no split needed)
2596+
_convert_to_ai_toolkit(
2597+
state_dict,
2598+
ait_sd,
2599+
f"lora_unet_single_blocks_{i}_linear1",
2600+
f"transformer.single_transformer_blocks.{i}.attn.to_qkv_mlp_proj",
2601+
)
2602+
# Single blocks: linear2 -> attn.to_out
2603+
_convert_to_ai_toolkit(
2604+
state_dict,
2605+
ait_sd,
2606+
f"lora_unet_single_blocks_{i}_linear2",
2607+
f"transformer.single_transformer_blocks.{i}.attn.to_out",
2608+
)
2609+
2610+
# Handle optional extra keys
2611+
extra_mappings = {
2612+
"lora_unet_img_in": "transformer.x_embedder",
2613+
"lora_unet_txt_in": "transformer.context_embedder",
2614+
"lora_unet_time_in_in_layer": "transformer.time_guidance_embed.timestep_embedder.linear_1",
2615+
"lora_unet_time_in_out_layer": "transformer.time_guidance_embed.timestep_embedder.linear_2",
2616+
"lora_unet_final_layer_linear": "transformer.proj_out",
2617+
}
2618+
for sds_key, ait_key in extra_mappings.items():
2619+
_convert_to_ai_toolkit(state_dict, ait_sd, sds_key, ait_key)
2620+
2621+
remaining_keys = list(state_dict.keys())
2622+
if remaining_keys:
2623+
logger.warning(f"Unsupported keys for Kohya Flux2 LoRA conversion: {remaining_keys}")
2624+
2625+
return ait_sd
2626+
2627+
24432628
def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
24442629
"""
24452630
Convert non-diffusers ZImage LoRA state dict to diffusers format.

src/diffusers/loaders/lora_pipeline.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
_convert_bfl_flux_control_lora_to_diffusers,
4444
_convert_fal_kontext_lora_to_diffusers,
4545
_convert_hunyuan_video_lora_to_diffusers,
46+
_convert_kohya_flux2_lora_to_diffusers,
4647
_convert_kohya_flux_lora_to_diffusers,
4748
_convert_musubi_wan_lora_to_diffusers,
4849
_convert_non_diffusers_flux2_lora_to_diffusers,
@@ -5673,6 +5674,13 @@ def lora_state_dict(
56735674
logger.warning(warn_msg)
56745675
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
56755676

5677+
is_kohya = any(".lora_down.weight" in k for k in state_dict)
5678+
if is_kohya:
5679+
state_dict = _convert_kohya_flux2_lora_to_diffusers(state_dict)
5680+
# Kohya already takes care of scaling the LoRA parameters with alpha.
5681+
out = (state_dict, metadata) if return_lora_metadata else state_dict
5682+
return out
5683+
56765684
is_peft_format = any(k.startswith("base_model.model.") for k in state_dict)
56775685
if is_peft_format:
56785686
state_dict = {k.replace("base_model.model.", "diffusion_model."): v for k, v in state_dict.items()}

0 commit comments

Comments
 (0)