@@ -715,13 +715,25 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None:
715715 new_state_dict [f"{ block } .attn.to_k.bias" ] = old_state_dict .pop (f"{ block } .to_k.bias" )
716716 new_state_dict [f"{ block } .attn.to_v.bias" ] = old_state_dict .pop (f"{ block } .to_v.bias" )
717717
718- # old version did not have a projection so set these to the identity
719- new_state_dict [f"{ block } .attn.out_proj.weight" ] = torch .eye (
720- new_state_dict [f"{ block } .attn.out_proj.weight" ].shape [0 ]
721- )
722- new_state_dict [f"{ block } .attn.out_proj.bias" ] = torch .zeros (
723- new_state_dict [f"{ block } .attn.out_proj.bias" ].shape
724- )
718+ out_w = f"{ block } .attn.out_proj.weight"
719+ out_b = f"{ block } .attn.out_proj.bias"
720+ proj_w = f"{ block } .proj_attn.weight"
721+ proj_b = f"{ block } .proj_attn.bias"
722+
723+ if out_w in new_state_dict :
724+ if proj_w in old_state_dict :
725+ new_state_dict [out_w ] = old_state_dict .pop (proj_w )
726+ if proj_b in old_state_dict :
727+ new_state_dict [out_b ] = old_state_dict .pop (proj_b )
728+ else :
729+ # weights pre-date proj_attn: initialise to identity / zero
730+ new_state_dict [out_w ] = torch .eye (new_state_dict [out_w ].shape [0 ])
731+ new_state_dict [out_b ] = torch .zeros (new_state_dict [out_b ].shape )
732+ elif proj_w in old_state_dict :
733+ # new model has no out_proj at all – discard the legacy keys so they
734+ # don't surface as "unexpected keys" during load_state_dict
735+ old_state_dict .pop (proj_w )
736+ old_state_dict .pop (proj_b )
725737
726738 # fix the upsample conv blocks which were renamed postconv
727739 for k in new_state_dict :
0 commit comments