Skip to content

Commit 5db2a70

Browse files
committed
add unit test
Signed-off-by: ytl0623 <david89062388@gmail.com>
1 parent f85f109 commit 5db2a70

File tree

2 files changed

+99
-6
lines changed

2 files changed

+99
-6
lines changed

monai/networks/nets/autoencoderkl.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,7 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None:
680680
681681
Args:
682682
old_state_dict: state dict from the old AutoencoderKL model.
683+
verbose: if True, print diagnostic information about key mismatches.
683684
"""
684685

685686
new_state_dict = self.state_dict()
@@ -725,15 +726,17 @@ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None:
725726
new_state_dict[out_w] = old_state_dict.pop(proj_w)
726727
if proj_b in old_state_dict:
727728
new_state_dict[out_b] = old_state_dict.pop(proj_b)
728-
else:
729-
# weights pre-date proj_attn: initialise to identity / zero
730-
new_state_dict[out_w] = torch.eye(new_state_dict[out_w].shape[0])
731-
new_state_dict[out_b] = torch.zeros(new_state_dict[out_b].shape)
729+
else:
730+
new_state_dict[out_b] = torch.zeros(
731+
new_state_dict[out_b].shape,
732+
dtype=new_state_dict[out_b].dtype,
733+
device=new_state_dict[out_b].device,
734+
)
732735
elif proj_w in old_state_dict:
733-
# new model has no out_proj at all discard the legacy keys so they
736+
# new model has no out_proj at all - discard the legacy keys so they
734737
# don't surface as "unexpected keys" during load_state_dict
735738
old_state_dict.pop(proj_w)
736-
old_state_dict.pop(proj_b)
739+
old_state_dict.pop(proj_b, None)
737740

738741
# fix the upsample conv blocks which were renamed postconv
739742
for k in new_state_dict:

tests/networks/nets/test_autoencoderkl.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,96 @@ def test_compatibility_with_monai_generative(self):
327327

328328
net.load_old_state_dict(torch.load(weight_path, weights_only=True), verbose=False)
329329

330+
@staticmethod
331+
def _new_to_old_sd(new_sd: dict, include_proj_attn: bool = True) -> dict:
332+
old_sd: dict = {}
333+
for k, v in new_sd.items():
334+
if ".attn.to_q." in k:
335+
old_sd[k.replace(".attn.to_q.", ".to_q.")] = v.clone()
336+
elif ".attn.to_k." in k:
337+
old_sd[k.replace(".attn.to_k.", ".to_k.")] = v.clone()
338+
elif ".attn.to_v." in k:
339+
old_sd[k.replace(".attn.to_v.", ".to_v.")] = v.clone()
340+
elif ".attn.out_proj." in k:
341+
if include_proj_attn:
342+
old_sd[k.replace(".attn.out_proj.", ".proj_attn.")] = v.clone()
343+
elif "postconv" in k:
344+
old_sd[k.replace("postconv", "conv")] = v.clone()
345+
else:
346+
old_sd[k] = v.clone()
347+
return old_sd
348+
349+
@skipUnless(has_einops, "Requires einops")
350+
def test_load_old_state_dict_proj_attn_copied_to_out_proj(self):
351+
params = {**self._MIGRATION_PARAMS, "include_fc": True}
352+
src = AutoencoderKL(**params).to(device)
353+
old_sd = self._new_to_old_sd(src.state_dict(), include_proj_attn=True)
354+
355+
# record the tensor values that were stored under proj_attn
356+
expected = {k.replace(".proj_attn.", ".attn.out_proj."): v for k, v in old_sd.items() if ".proj_attn." in k}
357+
self.assertGreater(len(expected), 0, "No proj_attn keys in old state dict – check model config")
358+
359+
dst = AutoencoderKL(**params).to(device)
360+
dst.load_old_state_dict(old_sd)
361+
362+
for new_key, expected_val in expected.items():
363+
torch.testing.assert_close(
364+
dst.state_dict()[new_key],
365+
expected_val.to(device),
366+
msg=f"Weight mismatch for {new_key}",
367+
)
368+
369+
@skipUnless(has_einops, "Requires einops")
370+
def test_load_old_state_dict_missing_proj_attn_initialises_identity(self):
371+
params = {**self._MIGRATION_PARAMS, "include_fc": True}
372+
src = AutoencoderKL(**params).to(device)
373+
old_sd = self._new_to_old_sd(src.state_dict(), include_proj_attn=False)
374+
375+
dst = AutoencoderKL(**params).to(device)
376+
dst.load_old_state_dict(old_sd)
377+
loaded = dst.state_dict()
378+
379+
out_proj_weights = [k for k in loaded if "attn.out_proj.weight" in k]
380+
out_proj_biases = [k for k in loaded if "attn.out_proj.bias" in k]
381+
self.assertGreater(len(out_proj_weights), 0, "No out_proj keys found – check model config")
382+
383+
for k in out_proj_weights:
384+
n = loaded[k].shape[0]
385+
torch.testing.assert_close(
386+
loaded[k],
387+
torch.eye(n, device=device),
388+
msg=f"{k} should be an identity matrix",
389+
)
390+
for k in out_proj_biases:
391+
torch.testing.assert_close(
392+
loaded[k],
393+
torch.zeros_like(loaded[k]),
394+
msg=f"{k} should be all-zeros",
395+
)
396+
397+
@skipUnless(has_einops, "Requires einops")
398+
def test_load_old_state_dict_proj_attn_discarded_when_no_out_proj(self):
399+
params = {**self._MIGRATION_PARAMS, "include_fc": False}
400+
src = AutoencoderKL(**params).to(device)
401+
old_sd = self._new_to_old_sd(src.state_dict(), include_proj_attn=False)
402+
403+
# inject synthetic proj_attn keys (mimic an old checkpoint)
404+
attn_blocks = [k.replace(".to_q.weight", "") for k in old_sd if k.endswith(".to_q.weight")]
405+
self.assertGreater(len(attn_blocks), 0, "No attention blocks found – check model config")
406+
for block in attn_blocks:
407+
ch = old_sd[f"{block}.to_q.weight"].shape[0]
408+
old_sd[f"{block}.proj_attn.weight"] = torch.randn(ch, ch)
409+
old_sd[f"{block}.proj_attn.bias"] = torch.randn(ch)
410+
411+
dst = AutoencoderKL(**params).to(device)
412+
dst.load_old_state_dict(old_sd)
413+
414+
loaded = dst.state_dict()
415+
self.assertFalse(
416+
any("out_proj" in k for k in loaded),
417+
"out_proj should not exist in a model built with include_fc=False",
418+
)
419+
330420

331421
if __name__ == "__main__":
332422
unittest.main()

0 commit comments

Comments
 (0)