Skip to content

Commit f85f109

Browse files
committed
fix(autoencoderkl): handle proj_attn→out_proj key mapping in load_old_state_dict
Signed-off-by: ytl0623 <david89062388@gmail.com>
1 parent daaedaa commit f85f109

File tree

1 file changed

+19
-7
lines changed

1 file changed

+19
-7
lines changed

monai/networks/nets/autoencoderkl.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -715,13 +715,25 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None:
715715
new_state_dict[f"{block}.attn.to_k.bias"] = old_state_dict.pop(f"{block}.to_k.bias")
716716
new_state_dict[f"{block}.attn.to_v.bias"] = old_state_dict.pop(f"{block}.to_v.bias")
717717

718-
# old version did not have a projection so set these to the identity
719-
new_state_dict[f"{block}.attn.out_proj.weight"] = torch.eye(
720-
new_state_dict[f"{block}.attn.out_proj.weight"].shape[0]
721-
)
722-
new_state_dict[f"{block}.attn.out_proj.bias"] = torch.zeros(
723-
new_state_dict[f"{block}.attn.out_proj.bias"].shape
724-
)
718+
out_w = f"{block}.attn.out_proj.weight"
719+
out_b = f"{block}.attn.out_proj.bias"
720+
proj_w = f"{block}.proj_attn.weight"
721+
proj_b = f"{block}.proj_attn.bias"
722+
723+
if out_w in new_state_dict:
724+
if proj_w in old_state_dict:
725+
new_state_dict[out_w] = old_state_dict.pop(proj_w)
726+
if proj_b in old_state_dict:
727+
new_state_dict[out_b] = old_state_dict.pop(proj_b)
728+
else:
729+
# weights pre-date proj_attn: initialise to identity / zero
730+
new_state_dict[out_w] = torch.eye(new_state_dict[out_w].shape[0])
731+
new_state_dict[out_b] = torch.zeros(new_state_dict[out_b].shape)
732+
elif proj_w in old_state_dict:
733+
# new model has no out_proj at all – discard the legacy keys so they
734+
# don't surface as "unexpected keys" during load_state_dict
735+
old_state_dict.pop(proj_w)
736+
old_state_dict.pop(proj_b)
725737

726738
# fix the upsample conv blocks which were renamed postconv
727739
for k in new_state_dict:

0 commit comments

Comments
 (0)