Skip to content

Commit 8f2052f

Browse files
committed
up
1 parent 49f0b1c commit 8f2052f

2 files changed

Lines changed: 4 additions & 5 deletions

File tree

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1978,13 +1978,11 @@ def get_alpha_scales(down_weight, alpha_key):
19781978
)
19791979

19801980
if any("head.head" in k for k in original_state_dict):
1981-
if any(
1982-
f"head.head.{lora_down_key}.weight" in k and f"head.head.{lora_up_key}.weight" in k
1983-
for k in original_state_dict
1984-
):
1981+
if any(f"head.head.{lora_down_key}.weight" in k for k in state_dict):
19851982
converted_state_dict["proj_out.lora_A.weight"] = original_state_dict.pop(
19861983
f"head.head.{lora_down_key}.weight"
19871984
)
1985+
if any(f"head.head.{lora_up_key}.weight" in k for k in state_dict):
19881986
converted_state_dict["proj_out.lora_B.weight"] = original_state_dict.pop(
19891987
f"head.head.{lora_up_key}.weight"
19901988
)
@@ -1995,7 +1993,7 @@ def get_alpha_scales(down_weight, alpha_key):
19951993
# This is my (sayakpaul) assumption that this particular key belongs to the down matrix.
19961994
# Since for this particular LoRA, we don't have the corresponding up matrix, I will use
19971995
# an identity.
1998-
if any("head.head.diff" in k for k in state_dict):
1996+
if any("head.head" in k and k.endswith(".diff") for k in state_dict):
19991997
if f"head.head.{lora_down_key}.weight" in state_dict:
20001998
logger.info(
20011999
f"The state dict seems to be have both `head.head.diff` and `head.head.{lora_down_key}.weight` keys, which is unexpected."

src/diffusers/loaders/lora_pipeline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4074,6 +4074,7 @@ def load_lora_weights(
40744074
raise ValueError("Invalid LoRA checkpoint.")
40754075

40764076
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
4077+
print(f"{load_into_transformer_2=}")
40774078
if load_into_transformer_2:
40784079
if not hasattr(self, "transformer_2"):
40794080
raise AttributeError(

0 commit comments

Comments
 (0)