Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions monai/networks/nets/autoencoderkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,13 +715,25 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None:
new_state_dict[f"{block}.attn.to_k.bias"] = old_state_dict.pop(f"{block}.to_k.bias")
new_state_dict[f"{block}.attn.to_v.bias"] = old_state_dict.pop(f"{block}.to_v.bias")

# old version did not have a projection so set these to the identity
new_state_dict[f"{block}.attn.out_proj.weight"] = torch.eye(
new_state_dict[f"{block}.attn.out_proj.weight"].shape[0]
)
new_state_dict[f"{block}.attn.out_proj.bias"] = torch.zeros(
new_state_dict[f"{block}.attn.out_proj.bias"].shape
)
out_w = f"{block}.attn.out_proj.weight"
out_b = f"{block}.attn.out_proj.bias"
proj_w = f"{block}.proj_attn.weight"
proj_b = f"{block}.proj_attn.bias"

if out_w in new_state_dict:
if proj_w in old_state_dict:
new_state_dict[out_w] = old_state_dict.pop(proj_w)
if proj_b in old_state_dict:
new_state_dict[out_b] = old_state_dict.pop(proj_b)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
else:
# weights pre-date proj_attn: initialise to identity / zero
new_state_dict[out_w] = torch.eye(new_state_dict[out_w].shape[0])
new_state_dict[out_b] = torch.zeros(new_state_dict[out_b].shape)
Comment thread
ytl0623 marked this conversation as resolved.
Outdated
elif proj_w in old_state_dict:
# new model has no out_proj at all – discard the legacy keys so they
# don't surface as "unexpected keys" during load_state_dict
old_state_dict.pop(proj_w)
old_state_dict.pop(proj_b)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

# fix the upsample conv blocks which were renamed postconv
for k in new_state_dict:
Expand Down
Loading