Skip to content

Commit 4d00980

Browse files
authored
[lora] fix non-diffusers lora key handling for flux2 (#13119)
fix non-diffusers lora key handling for flux2
1 parent 5bf248d commit 4d00980

File tree

1 file changed

+44
-10
lines changed

1 file changed

+44
-10
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2321,6 +2321,14 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
23212321
prefix = "diffusion_model."
23222322
original_state_dict = {k[len(prefix) :]: v for k, v in state_dict.items()}
23232323

2324+
has_lora_down_up = any("lora_down" in k or "lora_up" in k for k in original_state_dict.keys())
2325+
if has_lora_down_up:
2326+
temp_state_dict = {}
2327+
for k, v in original_state_dict.items():
2328+
new_key = k.replace("lora_down", "lora_A").replace("lora_up", "lora_B")
2329+
temp_state_dict[new_key] = v
2330+
original_state_dict = temp_state_dict
2331+
23242332
num_double_layers = 0
23252333
num_single_layers = 0
23262334
for key in original_state_dict.keys():
@@ -2337,13 +2345,15 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
23372345
attn_prefix = f"single_transformer_blocks.{sl}.attn"
23382346

23392347
for lora_key in lora_keys:
2340-
converted_state_dict[f"{attn_prefix}.to_qkv_mlp_proj.{lora_key}.weight"] = original_state_dict.pop(
2341-
f"{single_block_prefix}.linear1.{lora_key}.weight"
2342-
)
2348+
linear1_key = f"{single_block_prefix}.linear1.{lora_key}.weight"
2349+
if linear1_key in original_state_dict:
2350+
converted_state_dict[f"{attn_prefix}.to_qkv_mlp_proj.{lora_key}.weight"] = original_state_dict.pop(
2351+
linear1_key
2352+
)
23432353

2344-
converted_state_dict[f"{attn_prefix}.to_out.{lora_key}.weight"] = original_state_dict.pop(
2345-
f"{single_block_prefix}.linear2.{lora_key}.weight"
2346-
)
2354+
linear2_key = f"{single_block_prefix}.linear2.{lora_key}.weight"
2355+
if linear2_key in original_state_dict:
2356+
converted_state_dict[f"{attn_prefix}.to_out.{lora_key}.weight"] = original_state_dict.pop(linear2_key)
23472357

23482358
for dl in range(num_double_layers):
23492359
transformer_block_prefix = f"transformer_blocks.{dl}"
@@ -2352,6 +2362,10 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
23522362
for attn_type in attn_types:
23532363
attn_prefix = f"{transformer_block_prefix}.attn"
23542364
qkv_key = f"double_blocks.{dl}.{attn_type}.qkv.{lora_key}.weight"
2365+
2366+
if qkv_key not in original_state_dict:
2367+
continue
2368+
23552369
fused_qkv_weight = original_state_dict.pop(qkv_key)
23562370

23572371
if lora_key == "lora_A":
@@ -2383,8 +2397,9 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
23832397
for org_proj, diff_proj in proj_mappings:
23842398
for lora_key in lora_keys:
23852399
original_key = f"double_blocks.{dl}.{org_proj}.{lora_key}.weight"
2386-
diffusers_key = f"{transformer_block_prefix}.{diff_proj}.{lora_key}.weight"
2387-
converted_state_dict[diffusers_key] = original_state_dict.pop(original_key)
2400+
if original_key in original_state_dict:
2401+
diffusers_key = f"{transformer_block_prefix}.{diff_proj}.{lora_key}.weight"
2402+
converted_state_dict[diffusers_key] = original_state_dict.pop(original_key)
23882403

23892404
mlp_mappings = [
23902405
("img_mlp.0", "ff.linear_in"),
@@ -2395,8 +2410,27 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
23952410
for org_mlp, diff_mlp in mlp_mappings:
23962411
for lora_key in lora_keys:
23972412
original_key = f"double_blocks.{dl}.{org_mlp}.{lora_key}.weight"
2398-
diffusers_key = f"{transformer_block_prefix}.{diff_mlp}.{lora_key}.weight"
2399-
converted_state_dict[diffusers_key] = original_state_dict.pop(original_key)
2413+
if original_key in original_state_dict:
2414+
diffusers_key = f"{transformer_block_prefix}.{diff_mlp}.{lora_key}.weight"
2415+
converted_state_dict[diffusers_key] = original_state_dict.pop(original_key)
2416+
2417+
extra_mappings = {
2418+
"img_in": "x_embedder",
2419+
"txt_in": "context_embedder",
2420+
"time_in.in_layer": "time_guidance_embed.timestep_embedder.linear_1",
2421+
"time_in.out_layer": "time_guidance_embed.timestep_embedder.linear_2",
2422+
"final_layer.linear": "proj_out",
2423+
"final_layer.adaLN_modulation.1": "norm_out.linear",
2424+
"single_stream_modulation.lin": "single_stream_modulation.linear",
2425+
"double_stream_modulation_img.lin": "double_stream_modulation_img.linear",
2426+
"double_stream_modulation_txt.lin": "double_stream_modulation_txt.linear",
2427+
}
2428+
2429+
for org_key, diff_key in extra_mappings.items():
2430+
for lora_key in lora_keys:
2431+
original_key = f"{org_key}.{lora_key}.weight"
2432+
if original_key in original_state_dict:
2433+
converted_state_dict[f"{diff_key}.{lora_key}.weight"] = original_state_dict.pop(original_key)
24002434

24012435
if len(original_state_dict) > 0:
24022436
raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")

0 commit comments

Comments
 (0)